import asyncio import base64 import datetime import logging import os import random import re import traceback from asyncio import sleep from io import BytesIO from itertools import count from pathlib import Path from tempfile import TemporaryDirectory from time import time from typing import List, Iterable import boto3 import ffmpeg from PIL import Image import httpx import redis.asyncio as aioredis from aiogram import Bot, Dispatcher, filters, exceptions, F from aiogram.enums import ChatAction, ParseMode from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \ CallbackQuery import dotenv from e621 import E621, E621Post, E621PostFile, E621PostVersion dotenv.load_dotenv('.env') redis = aioredis.from_url('redis://localhost') e621 = E621() logging.basicConfig(level=logging.INFO) bot = Bot(token=os.environ['BOT_TOKEN']) dp = Dispatcher() ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(',')))) PAGE_SIZE = 20 def format_tags(tags: Iterable[str]): return ' '.join(f'#{re.sub(r"[^0-9a-zA-Z_]", "_", tag)}' for tag in tags) or 'None' def recover_url(file: E621PostFile): return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}' async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]): try: logging.warning(f'Sending post #{post.id}') await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING) flat_post_tags = set(post.tags.flatten()) monitored_tags = set() for tag_group in tag_list: if set(tag_group) <= flat_post_tags: monitored_tags.update(tag_group) artist_tags = post.tags.artist character_tags = post.tags.character copyright_tags = post.tags.copyright caption = '\n'.join(l for l in [ f'Monitored tags: {format_tags(monitored_tags)}', artist_tags and f'Artist: {format_tags(artist_tags)}', character_tags and f'Character: {format_tags(character_tags)}', copyright_tags and f'Copyright: {format_tags(copyright_tags)}', f'\nhttps://e621.net/posts/{post.id}' ] if l) if not post.file.url: post.file.url = recover_url(post.file) try: logging.warning(post.file.url) async with httpx.AsyncClient() as client: file = BytesIO() file.write((await client.get(post.file.url)).content) file.name = f'file.{post.file.ext}' file.seek(0) if post.file.ext in ('webm', 'gif'): with TemporaryDirectory() as td: src_path = Path(td) / f'video.{post.file.ext}' mp4_path = Path(td) / 'video.mp4' with open(src_path, 'wb') as webm: webm.write(file.read()) video_input = ffmpeg\ .input(str(src_path)) cmd = video_input \ .output(str(mp4_path), vf='pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:x=0:y=0:color=Black', vcodec='libx264', crf='26') logging.info('ffmpeg ' + ' '.join(cmd.get_args())) cmd.run() s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY']) bucket = os.environ['AWS_S3_BUCKET'] upload_filename = f'e621-{post.id}-{int(time())}.mp4' s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'}) await bot.send_message(int(os.environ['SEND_CHAT']), f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption, parse_mode=ParseMode.HTML) src_path.unlink() mp4_path.unlink() elif post.file.ext in ('png', 'jpg'): markup = InlineKeyboardMarkup(inline_keyboard=[[ InlineKeyboardButton(text='NSFW', callback_data='send nsfw'), InlineKeyboardButton(text='Safe', callback_data='send pics'), ]]) # if post.file.size > 10000000: logging.warning('compressing') dl_im = Image.open(file).convert('RGBA') size = dl_im.size if size[0] > 2000 or size[1] > 2000: larger_dimension = max(size) ratio = 2000 / larger_dimension dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)), Image.LANCZOS) logging.warning(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}') im = Image.new('RGBA', dl_im.size, (255, 255, 255)) composite = Image.alpha_composite(im, dl_im).convert('RGB') file = BytesIO() composite.save(file, format='JPEG') file.seek(0) await bot.send_photo(int(os.environ['SEND_CHAT']), BufferedInputFile(file.read(), 'file.jpg'), caption=caption, parse_mode=ParseMode.HTML, reply_markup=markup) await redis.sadd('e621:sent', post.id) except Exception as e: logging.exception(e) except Exception as e: logging.exception(e) async def check_updates(): logging.warning('Waiting for lock...') async with redis.lock('e621:update'): logging.warning('Lock acquired...') matched_posts = [] tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs')) tag_list_flat = set(sum(tag_list, ())) last_post_version = int((await redis.get('e621:last_version') or b'0').decode()) post_versions: List[E621PostVersion] = [] logging.warning(f'Getting post versions from id {last_post_version}') for page in count(1): logging.warning(f'Loading page {page}') post_versions_page = await e621.get_post_versions(last_post_version, page) post_versions += post_versions_page if not last_post_version or not post_versions_page: break for post_version in post_versions[::-1]: if post_version.id > last_post_version: last_post_version = post_version.id if not bool(tag_list_flat & set(post_version.added_tags)): continue post_tags = set(post_version.tags.split()) for tag_group in tag_list: if set(tag_group) <= post_tags: matched_posts.append(post_version.post_id) break matched_posts.sort() if matched_posts: already_sent: List = await redis.smismember('e621:sent', matched_posts) posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent] logging.warning(f'Found {len(posts_to_send)} posts') for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE): chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE] posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk)) for i, post in enumerate(posts): logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(posts_to_send)}') await send_post(post, tag_list) await redis.sadd('e621:sent', post.id) await sleep(1) await redis.set('e621:last_version', last_post_version) # random.shuffle(tag_list) # for tl_idx in range(0, len(tag_list), PAGE_SIZE): # tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE]) # logging.warning(tags) # posts = await e621.get_posts(tags) # if not posts: # return # already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts]) # # last_index = len(posts) # # if already_sent.count(True): # # last_index = already_sent.index(True) # # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))]) # for i in list(range(len(posts)))[::-1]: # if already_sent[i]: # continue # await send_post(posts[i], tag_list) # await sleep(1) @dp.message(filters.Command('resend_after'), ChatFilter) async def resend_after(msg: Message): try: timestamp = int(msg.text.split()[1]) except: traceback.print_exc() await msg.reply('Invalid timestamp or not provided') return async with redis.lock('e621:update'): tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')] for i, tag in enumerate(tag_list): await msg.reply(f'Checking tag {tag} ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML) posts = [] page = 1 while True: page_posts = await e621.get_posts(tag, page) if not page_posts: break for post in page_posts: if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp: break posts.append(post) page += 1 for post in posts[::-1]: await send_post(post, tag_list) await msg.reply('Finished') @dp.message(filters.Command('add'), ChatFilter) async def add_tag(msg: Message): args = ' '.join(msg.text.split()[1:]) if not args: await msg.reply('Please provide tag to subscribe to') return for tag in args.split(): posts = await e621.get_posts(tag) await redis.sadd('e621:sent', *[post.id for post in posts]) await redis.sadd('e621:subs', tag) await msg.reply(f'Tags {args} added') @dp.message(filters.Command('add_tags'), ChatFilter) async def add_tags(msg: Message): args = ' '.join(msg.text.split()[1:]) if not args: await msg.reply('Please provide tags to subscribe to') return tags = args.split() tags.sort() tags = ' '.join(tags) posts = await e621.get_posts(tags) await redis.sadd('e621:sent', *[post.id for post in posts]) await redis.sadd('e621:subs', tags) await msg.reply(f'Tag group {tags} added', parse_mode=ParseMode.HTML) @dp.message(filters.Command('mark_old_as_sent'), ChatFilter) async def mark_old_as_sent(msg: Message): logging.warning('Waiting for lock...') async with redis.lock('e621:update'): tag_list = [t.decode() for t in await redis.smembers('e621:subs')] m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent') for i, tag in enumerate(tag_list, 1): posts = await e621.get_posts(tag) await redis.sadd('e621:sent', *[post.id for post in posts]) await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent') await sleep(1) await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags') @dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter) async def del_tag(msg: Message): args = msg.text[5:] if not args: await msg.reply('Please provide tag to unsubscribe from') return if ' ' in args: await msg.reply('Tag should not contain spaces') return if not await redis.sismember('e621:subs', args): await msg.reply('Tag not found') return await redis.srem('e621:subs', args) await msg.reply(f'Tag {args} removed') @dp.message(filters.Command('del'), ChatFilter) async def del_command(msg: Message): args = ' '.join(msg.text.split()[1:]) if not args: await msg.reply('Please provide tag to subscribe to') return for tag in args.split(): await redis.srem('e621:subs', tag) await msg.reply(f'Tags {args} removed') @dp.message(filters.Command('list'), ChatFilter) async def list_tags(msg: Message): tags = [t.decode() for t in await redis.smembers('e621:subs')] tags.sort() lines = [] for tag in tags: entry = f'- {tag} [/del_{tag}]' if len('\n'.join(lines + [entry])) > 2000: lines = "\n".join(lines) await msg.reply(f'Monitored tags:\n\n{lines}') lines = [entry] else: lines.append(f'- {tag} [/del_{tag}]') lines = "\n".join(lines) await msg.reply(f'Monitored tags:\n\n{lines}') @dp.message(filters.Command('update'), ChatFilter) async def update(msg: Message): await check_updates() @dp.message(filters.Command('test'), ChatFilter) async def test(msg: Message): args = ' '.join(msg.text.split()[1:]) if not args: await msg.reply('Please provide post id') return post = await e621.get_posts(f'id:{args}') print(f'id:{args}') if not post: await msg.reply('Post not found') return tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')] await send_post(post[0], tag_list) @dp.callback_query(F.data.startswith('send ')) async def send_callback(cq: CallbackQuery): _, destination = cq.data.split() img_bytes = BytesIO() await bot.download(cq.message.photo[-1], img_bytes) img_bytes.seek(0) data = base64.b64encode(img_bytes.read()).decode() async with httpx.AsyncClient() as client: try: r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={ "method": "post_photo", "params": [data, True], "jsonrpc": "2.0", "id": 0, }) resp = r.json() if 'result' in resp and resp['result'] == True: await cq.answer('Sent') elif 'result' in resp and resp['result'] == 'duplicate': await cq.answer('Duplicate') else: raise Exception(resp) except: traceback.print_exc() await cq.answer('An error has occurred, check logs') async def background_on_start(): await redis.delete('e621:update') while True: logging.warning('Checking updates...') try: await check_updates() except Exception as e: logging.exception(e) logging.warning('Sleeping...') await asyncio.sleep(600) async def on_bot_startup(): asyncio.create_task(background_on_start()) async def main(): dp.startup.register(on_bot_startup) await dp.start_polling(bot) if __name__ == '__main__': asyncio.run(main())