import asyncio import logging import os import random from asyncio import sleep from io import BytesIO from pathlib import Path from tempfile import TemporaryFile, TemporaryDirectory from time import time from typing import List import boto3 import ffmpeg from PIL import Image import httpx import redis.asyncio as aioredis from aiogram import Bot, Dispatcher from aiogram.dispatcher import filters from aiogram.types import Message, ParseMode, ChatActions, InputFile from aiogram.utils import executor, exceptions import dotenv from e621 import E621, E621Post dotenv.load_dotenv('.env') redis = aioredis.from_url('redis://localhost') e621 = E621() logging.basicConfig(level=logging.INFO) bot = Bot(token=os.environ['BOT_TOKEN']) dp = Dispatcher(bot) user_ids = list(map(int, os.environ['USERS'].split(','))) async def send_post(post: E621Post, tag_list: List[str]): await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING) monitored_tags = set(post.tags.flatten()) & set(tag_list) artist_tags = post.tags.artist character_tags = post.tags.character copyright_tags = post.tags.copyright caption = '\n'.join(l for l in [ f'Monitored tags: {" ".join(monitored_tags) or "None"}', artist_tags and f'Artist: {" ".join(artist_tags)}', character_tags and f'Character: {", ".join(character_tags)}', copyright_tags and f'Copyright: {", ".join(copyright_tags)}', f'\nhttps://e621.net/posts/{post.id}' ] if l) if post.file.url: try: logging.warning(post.file.url) async with httpx.AsyncClient() as client: file = BytesIO() file.write((await client.get(post.file.url)).content) file.name = f'file.{post.file.ext}' file.seek(0) if post.file.ext in ('webm', 'gif'): with TemporaryDirectory() as td: src_path = Path(td) / f'video.{post.file.ext}' mp4_path = Path(td) / 'video.mp4' with open(src_path, 'wb') as webm: webm.write(file.read()) ffmpeg\ .input(str(src_path)) \ .filter('pad', **{ 'width': 'ceil(iw/2)*2', 'height': 'ceil(ih/2)*2', 'x': '0', 'y': '0', 'color': 'Black' })\ .output(str(mp4_path), vcodec='libx264', crf='26')\ .run() s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY']) bucket = os.environ['AWS_S3_BUCKET'] upload_filename = f'e621-{post.id}-{int(time())}.mp4' s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'}) await bot.send_message(int(os.environ['SEND_CHAT']), f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption, parse_mode=ParseMode.HTML) src_path.unlink() mp4_path.unlink() elif post.file.ext in ('png', 'jpg'): if post.file.size > 10000000: logging.warning('compressing') dl_im = Image.open(file).convert('RGBA') size = dl_im.size if size[0] > 2000 or size[1] > 2000: larger_dimension = max(size) ratio = 2000 / larger_dimension dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)), Image.LANCZOS) print(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}') im = Image.new('RGBA', dl_im.size, (255, 255, 255)) composite = Image.alpha_composite(im, dl_im).convert('RGB') file = BytesIO() composite.save(file, format='JPEG') file.seek(0) file.name = 'file.jpg' await bot.send_photo(int(os.environ['SEND_CHAT']), file, caption=caption, parse_mode=ParseMode.HTML) await redis.sadd('e621:sent', post.id) except exceptions.TelegramAPIError as e: logging.exception(e) async def check_updates(): logging.warning('Waiting for lock...') async with redis.lock('e621:update'): logging.warning('Lock acquired...') tag_list = [t.decode() for t in await redis.smembers('e621:subs')] random.shuffle(tag_list) for tl_idx in range(0, len(tag_list), 40): tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40]) logging.warning(tags) posts = await e621.get_posts(tags) if not posts: return already_sent = await redis.smismember('e621:sent', [p.id for p in posts]) for i in list(range(len(posts)))[::-1]: if already_sent[i]: continue await send_post(posts[i], tag_list) await sleep(1) @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add']) async def add_tag(msg: Message): args = msg.get_args() if not args: await msg.reply('Please provide tag to subscribe to') return for tag in args.split(): await redis.sadd('e621:subs', tag) await msg.reply(f'Tags {args} added') @dp.message_handler(filters.IDFilter(chat_id=user_ids), regexp=r'^\/del_\S+$') async def del_tag(msg: Message): args = msg.text[5:] if not args: await msg.reply('Please provide tag to subscribe to') return if ' ' in args: await msg.reply('Tag should not contain spaces') return if not await redis.sismember('e621:subs', args): await msg.reply('Tag not found') return await redis.srem('e621:subs', args) await msg.reply(f'Tag {args} removed') @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['del']) async def del_command(msg: Message): args = msg.get_args() if not args: await msg.reply('Please provide tag to subscribe to') return for tag in args.split(): await redis.srem('e621:subs', tag) await msg.reply(f'Tags {args} removed') @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['list']) async def list_tags(msg: Message): tags = [t.decode() for t in await redis.smembers('e621:subs')] tags.sort() lines = [] for tag in tags: entry = f'- {tag} [/del_{tag}]' if len('\n'.join(lines + [entry])) > 2000: lines = "\n".join(lines) await msg.reply(f'Monitored tags:\n\n{lines}') lines = [entry] else: lines.append(f'- {tag} [/del_{tag}]') lines = "\n".join(lines) await msg.reply(f'Monitored tags:\n\n{lines}') @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['update']) async def update(msg: Message): await check_updates() @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test']) async def test(msg: Message): args = msg.get_args() if not args: await msg.reply('Please provide post id') return post = await e621.get_posts(f'id:{args}') print(f'id:{args}') if not post: await msg.reply('Post not found') return await send_post(post[0], []) async def background_on_start(): await redis.delete('e621:update') while True: logging.warning('Checking updates...') try: await check_updates() except Exception as e: logging.exception(e) logging.warning('Sleeping...') await asyncio.sleep(600) async def on_bot_startup(dp: Dispatcher): asyncio.create_task(background_on_start()) if __name__ == '__main__': executor.start_polling(dp, on_startup=on_bot_startup)