You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

220 lines
8.3 KiB

import asyncio
import logging
import os
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryFile, TemporaryDirectory
from time import time
from typing import List
import boto3
import ffmpeg
from PIL import Image
import httpx
import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher
from aiogram.dispatcher import filters
from aiogram.types import Message, ParseMode, ChatActions, InputFile
from aiogram.utils import executor, exceptions
import dotenv
from e621 import E621, E621Post
dotenv.load_dotenv('.env')
redis = aioredis.from_url('redis://localhost')
e621 = E621()
logging.basicConfig(level=logging.INFO)
bot = Bot(token=os.environ['BOT_TOKEN'])
dp = Dispatcher(bot)
user_ids = list(map(int, os.environ['USERS'].split(',')))
async def send_post(post: E621Post, tag_list: List[str]):
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING)
monitored_tags = set(post.tags.flatten()) & set(tag_list)
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>',
artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
f'\nhttps://e621.net/posts/{post.id}'
] if l)
if post.file.url:
try:
logging.warning(post.file.url)
async with httpx.AsyncClient() as client:
file = BytesIO()
file.write((await client.get(post.file.url)).content)
file.name = f'file.{post.file.ext}'
file.seek(0)
if post.file.ext == 'webm':
with TemporaryDirectory() as td:
webm_path = Path(td) / 'video.webm'
mp4_path = Path(td) / 'video.mp4'
with open(webm_path, 'wb') as webm:
webm.write(file.read())
ffmpeg\
.input(str(webm_path))\
.output(str(mp4_path), vcodec='libx264', crf='26')\
.run()
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
bucket = os.environ['AWS_S3_BUCKET']
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
await bot.send_message(int(os.environ['SEND_CHAT']),
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
parse_mode=ParseMode.HTML)
webm_path.unlink()
mp4_path.unlink()
elif post.file.ext == 'gif':
await bot.send_animation(int(os.environ['SEND_CHAT']),
file,
width=post.file.width,
height=post.file.height,
thumb=post.preview.url,
caption=caption,
parse_mode=ParseMode.HTML)
elif post.file.ext in ('png', 'jpg'):
if post.file.size > 10000000:
logging.warning('compressing')
dl_im = Image.open(file).convert('RGBA')
if dl_im.size[0] > 2000 or dl_im.size[1] > 2000:
larger_dimension = max(dl_im.size)
ratio = 2000 / larger_dimension
dl_im = dl_im.resize((int(dl_im.size[0] * ratio), int(dl_im.size[1] * ratio)),
Image.LANCZOS)
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
composite = Image.alpha_composite(im, dl_im).convert('RGB')
file = BytesIO()
composite.save(file, format='JPEG')
file.seek(0)
file.name = 'file.jpg'
await bot.send_photo(int(os.environ['SEND_CHAT']),
file,
caption=caption,
parse_mode=ParseMode.HTML)
await redis.sadd('e621:sent', post.id)
except exceptions.TelegramAPIError as e:
logging.exception(e)
async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
tag_list.sort()
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
for i in range(len(posts)):
if already_sent[i]:
continue
await send_post(posts[i], tag_list)
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add'])
async def add_tag(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.sadd('e621:subs', tag)
await msg.reply(f'Tags {args} added')
@dp.message_handler(filters.IDFilter(chat_id=user_ids), regexp=r'^\/del_\S+$')
async def del_tag(msg: Message):
args = msg.text[5:]
if not args:
await msg.reply('Please provide tag to subscribe to')
return
if ' ' in args:
await msg.reply('Tag should not contain spaces')
return
if not await redis.sismember('e621:subs', args):
await msg.reply('Tag not found')
return
await redis.srem('e621:subs', args)
await msg.reply(f'Tag {args} removed')
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['del'])
async def del_command(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.srem('e621:subs', tag)
await msg.reply(f'Tags {args} removed')
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['list'])
async def list_tags(msg: Message):
tags = [t.decode() for t in await redis.smembers('e621:subs')]
tags.sort()
lines = []
for tag in tags:
entry = f'- {tag} [/del_{tag}]'
if len('\n'.join(lines + [entry])) > 2000:
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
lines = [entry]
else:
lines.append(f'- {tag} [/del_{tag}]')
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['update'])
async def update(msg: Message):
await check_updates()
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test'])
async def test(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide post id')
return
post = await e621.get_posts(f'id:{args}')
print(f'id:{args}')
if not post:
await msg.reply('Post not found')
return
await send_post(post[0], [])
async def background_on_start():
await redis.delete('e621:update')
while True:
logging.warning('Checking updates...')
try:
await check_updates()
except Exception as e:
logging.exception(e)
logging.warning('Sleeping...')
await asyncio.sleep(600)
async def on_bot_startup(dp: Dispatcher):
asyncio.create_task(background_on_start())
if __name__ == '__main__':
executor.start_polling(dp, on_startup=on_bot_startup)