326 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			326 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
import base64
 | 
						|
import datetime
 | 
						|
import logging
 | 
						|
import os
 | 
						|
import random
 | 
						|
import re
 | 
						|
import traceback
 | 
						|
from asyncio import sleep
 | 
						|
from io import BytesIO
 | 
						|
from pathlib import Path
 | 
						|
from tempfile import TemporaryDirectory
 | 
						|
from time import time
 | 
						|
from typing import List, Iterable
 | 
						|
 | 
						|
import boto3
 | 
						|
import ffmpeg
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
import httpx
 | 
						|
import redis.asyncio as aioredis
 | 
						|
from aiogram import Bot, Dispatcher, filters, exceptions, F
 | 
						|
from aiogram.enums import ChatAction, ParseMode
 | 
						|
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \
 | 
						|
    CallbackQuery
 | 
						|
import dotenv
 | 
						|
 | 
						|
from e621 import E621, E621Post, E621PostFile
 | 
						|
 | 
						|
dotenv.load_dotenv('.env')
 | 
						|
 | 
						|
redis = aioredis.from_url('redis://localhost')
 | 
						|
e621 = E621()
 | 
						|
 | 
						|
logging.basicConfig(level=logging.INFO)
 | 
						|
bot = Bot(token=os.environ['BOT_TOKEN'])
 | 
						|
dp = Dispatcher()
 | 
						|
 | 
						|
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
 | 
						|
 | 
						|
PAGE_SIZE = 20
 | 
						|
 | 
						|
 | 
						|
def format_tags(tags: Iterable[str]):
 | 
						|
    return ' '.join(f'#{re.sub(r"[^0-9a-zA-Z_]", "_", tag)}' for tag in tags) or 'None'
 | 
						|
 | 
						|
 | 
						|
def recover_url(file: E621PostFile):
 | 
						|
    return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
 | 
						|
 | 
						|
 | 
						|
async def send_post(post: E621Post, tag_list: List[str]):
 | 
						|
    try:
 | 
						|
        logging.warning(f'Sending post #{post.id}')
 | 
						|
        await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
 | 
						|
        monitored_tags = set(post.tags.flatten()) & set(tag_list)
 | 
						|
        artist_tags = post.tags.artist
 | 
						|
        character_tags = post.tags.character
 | 
						|
        copyright_tags = post.tags.copyright
 | 
						|
        caption = '\n'.join(l for l in [
 | 
						|
            f'Monitored tags: <b>{format_tags(monitored_tags)}</b>',
 | 
						|
            artist_tags and f'Artist: <b>{format_tags(artist_tags)}</b>',
 | 
						|
            character_tags and f'Character: <b>{format_tags(character_tags)}</b>',
 | 
						|
            copyright_tags and f'Copyright: <b>{format_tags(copyright_tags)}</b>',
 | 
						|
            f'\nhttps://e621.net/posts/{post.id}'
 | 
						|
        ] if l)
 | 
						|
        if not post.file.url:
 | 
						|
            post.file.url = recover_url(post.file)
 | 
						|
        try:
 | 
						|
            logging.warning(post.file.url)
 | 
						|
            async with httpx.AsyncClient() as client:
 | 
						|
                file = BytesIO()
 | 
						|
                file.write((await client.get(post.file.url)).content)
 | 
						|
                file.name = f'file.{post.file.ext}'
 | 
						|
                file.seek(0)
 | 
						|
                if post.file.ext in ('webm', 'gif'):
 | 
						|
                    with TemporaryDirectory() as td:
 | 
						|
                        src_path = Path(td) / f'video.{post.file.ext}'
 | 
						|
                        mp4_path = Path(td) / 'video.mp4'
 | 
						|
                        with open(src_path, 'wb') as webm:
 | 
						|
                            webm.write(file.read())
 | 
						|
                        video_input = ffmpeg\
 | 
						|
                            .input(str(src_path))
 | 
						|
                        cmd = video_input \
 | 
						|
                            .output(str(mp4_path),
 | 
						|
                                    vf='pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:x=0:y=0:color=Black',
 | 
						|
                                    vcodec='libx264',
 | 
						|
                                    crf='26')
 | 
						|
                        logging.info('ffmpeg ' + ' '.join(cmd.get_args()))
 | 
						|
                        cmd.run()
 | 
						|
                        s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
 | 
						|
                        bucket = os.environ['AWS_S3_BUCKET']
 | 
						|
                        upload_filename = f'e621-{post.id}-{int(time())}.mp4'
 | 
						|
                        s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
 | 
						|
                        await bot.send_message(int(os.environ['SEND_CHAT']),
 | 
						|
                                               f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
 | 
						|
                                               parse_mode=ParseMode.HTML)
 | 
						|
                        src_path.unlink()
 | 
						|
                        mp4_path.unlink()
 | 
						|
                elif post.file.ext in ('png', 'jpg'):
 | 
						|
                    markup = InlineKeyboardMarkup(inline_keyboard=[[
 | 
						|
                        InlineKeyboardButton(text='NSFW', callback_data='send nsfw'),
 | 
						|
                        InlineKeyboardButton(text='Safe', callback_data='send pics'),
 | 
						|
                    ]])
 | 
						|
                    # if post.file.size > 10000000:
 | 
						|
                    logging.warning('compressing')
 | 
						|
                    dl_im = Image.open(file).convert('RGBA')
 | 
						|
                    size = dl_im.size
 | 
						|
                    if size[0] > 2000 or size[1] > 2000:
 | 
						|
                        larger_dimension = max(size)
 | 
						|
                        ratio = 2000 / larger_dimension
 | 
						|
                        dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
 | 
						|
                                             Image.LANCZOS)
 | 
						|
                        print(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
 | 
						|
                    im = Image.new('RGBA', dl_im.size, (255, 255, 255))
 | 
						|
                    composite = Image.alpha_composite(im, dl_im).convert('RGB')
 | 
						|
                    file = BytesIO()
 | 
						|
                    composite.save(file, format='JPEG')
 | 
						|
                    file.seek(0)
 | 
						|
                    await bot.send_photo(int(os.environ['SEND_CHAT']),
 | 
						|
                                         BufferedInputFile(file.read(), 'file.jpg'),
 | 
						|
                                         caption=caption,
 | 
						|
                                         parse_mode=ParseMode.HTML,
 | 
						|
                                         reply_markup=markup)
 | 
						|
            await redis.sadd('e621:sent', post.id)
 | 
						|
        except Exception as e:
 | 
						|
            logging.exception(e)
 | 
						|
    except Exception as e:
 | 
						|
        logging.exception(e)
 | 
						|
 | 
						|
 | 
						|
async def check_updates():
 | 
						|
    logging.warning('Waiting for lock...')
 | 
						|
    async with redis.lock('e621:update'):
 | 
						|
        logging.warning('Lock acquired...')
 | 
						|
        tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
						|
        random.shuffle(tag_list)
 | 
						|
        for tl_idx in range(0, len(tag_list), PAGE_SIZE):
 | 
						|
            tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
 | 
						|
            logging.warning(tags)
 | 
						|
            posts = await e621.get_posts(tags)
 | 
						|
            if not posts:
 | 
						|
                return
 | 
						|
            already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
 | 
						|
            # last_index = len(posts)
 | 
						|
            # if already_sent.count(True):
 | 
						|
            #     last_index = already_sent.index(True)
 | 
						|
            #     await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
 | 
						|
            for i in list(range(len(posts)))[::-1]:
 | 
						|
                if already_sent[i]:
 | 
						|
                    continue
 | 
						|
                await send_post(posts[i], tag_list)
 | 
						|
                await sleep(1)
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('resend_after'), ChatFilter)
 | 
						|
async def resend_after(msg: Message):
 | 
						|
    try:
 | 
						|
        timestamp = int(msg.text.split()[1])
 | 
						|
    except:
 | 
						|
        traceback.print_exc()
 | 
						|
        await msg.reply('Invalid timestamp or not provided')
 | 
						|
        return
 | 
						|
 | 
						|
    async with redis.lock('e621:update'):
 | 
						|
        tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
						|
        for i, tag in enumerate(tag_list):
 | 
						|
            await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
 | 
						|
            posts = []
 | 
						|
            page = 1
 | 
						|
            while True:
 | 
						|
                page_posts = await e621.get_posts(tag, page)
 | 
						|
                if not page_posts:
 | 
						|
                    break
 | 
						|
                for post in page_posts:
 | 
						|
                    if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp:
 | 
						|
                        break
 | 
						|
                    posts.append(post)
 | 
						|
                page += 1
 | 
						|
            for post in posts[::-1]:
 | 
						|
                await send_post(post, tag_list)
 | 
						|
    await msg.reply('Finished')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('add'), ChatFilter)
 | 
						|
async def add_tag(msg: Message):
 | 
						|
    args = ' '.join(msg.text.split()[1:])
 | 
						|
    if not args:
 | 
						|
        await msg.reply('Please provide tag to subscribe to')
 | 
						|
        return
 | 
						|
    for tag in args.split():
 | 
						|
        posts = await e621.get_posts(tag)
 | 
						|
        await redis.sadd('e621:sent', *[post.id for post in posts])
 | 
						|
        await redis.sadd('e621:subs', tag)
 | 
						|
    await msg.reply(f'Tags {args} added')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
 | 
						|
async def mark_old_as_sent(msg: Message):
 | 
						|
    logging.warning('Waiting for lock...')
 | 
						|
    async with redis.lock('e621:update'):
 | 
						|
        tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
						|
        m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent')
 | 
						|
        for i, tag in enumerate(tag_list, 1):
 | 
						|
            posts = await e621.get_posts(tag)
 | 
						|
            await redis.sadd('e621:sent', *[post.id for post in posts])
 | 
						|
            await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent')
 | 
						|
            await sleep(1)
 | 
						|
        await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter)
 | 
						|
async def del_tag(msg: Message):
 | 
						|
    args = msg.text[5:]
 | 
						|
    if not args:
 | 
						|
        await msg.reply('Please provide tag to unsubscribe from')
 | 
						|
        return
 | 
						|
    if ' ' in args:
 | 
						|
        await msg.reply('Tag should not contain spaces')
 | 
						|
        return
 | 
						|
    if not await redis.sismember('e621:subs', args):
 | 
						|
        await msg.reply('Tag not found')
 | 
						|
        return
 | 
						|
    await redis.srem('e621:subs', args)
 | 
						|
    await msg.reply(f'Tag {args} removed')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('del'), ChatFilter)
 | 
						|
async def del_command(msg: Message):
 | 
						|
    args = ' '.join(msg.text.split()[1:])
 | 
						|
    if not args:
 | 
						|
        await msg.reply('Please provide tag to subscribe to')
 | 
						|
        return
 | 
						|
    for tag in args.split():
 | 
						|
        await redis.srem('e621:subs', tag)
 | 
						|
    await msg.reply(f'Tags {args} removed')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('list'), ChatFilter)
 | 
						|
async def list_tags(msg: Message):
 | 
						|
    tags = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
						|
    tags.sort()
 | 
						|
    lines = []
 | 
						|
    for tag in tags:
 | 
						|
        entry = f'- {tag} [/del_{tag}]'
 | 
						|
        if len('\n'.join(lines + [entry])) > 2000:
 | 
						|
            lines = "\n".join(lines)
 | 
						|
            await msg.reply(f'Monitored tags:\n\n{lines}')
 | 
						|
            lines = [entry]
 | 
						|
        else:
 | 
						|
            lines.append(f'- {tag} [/del_{tag}]')
 | 
						|
 | 
						|
    lines = "\n".join(lines)
 | 
						|
    await msg.reply(f'Monitored tags:\n\n{lines}')
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('update'), ChatFilter)
 | 
						|
async def update(msg: Message):
 | 
						|
    await check_updates()
 | 
						|
 | 
						|
 | 
						|
@dp.message(filters.Command('test'), ChatFilter)
 | 
						|
async def test(msg: Message):
 | 
						|
    args = ' '.join(msg.text.split()[1:])
 | 
						|
    if not args:
 | 
						|
        await msg.reply('Please provide post id')
 | 
						|
        return
 | 
						|
    post = await e621.get_posts(f'id:{args}')
 | 
						|
    print(f'id:{args}')
 | 
						|
    if not post:
 | 
						|
        await msg.reply('Post not found')
 | 
						|
        return
 | 
						|
    await send_post(post[0], [])
 | 
						|
 | 
						|
 | 
						|
@dp.callback_query(F.data.startswith('send '))
 | 
						|
async def send_callback(cq: CallbackQuery):
 | 
						|
    _, destination = cq.data.split()
 | 
						|
    img_bytes = BytesIO()
 | 
						|
    await bot.download(cq.message.photo[-1], img_bytes)
 | 
						|
    img_bytes.seek(0)
 | 
						|
    data = base64.b64encode(img_bytes.read()).decode()
 | 
						|
    async with httpx.AsyncClient() as client:
 | 
						|
        try:
 | 
						|
            r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
 | 
						|
                "method": "post_photo",
 | 
						|
                "params": [data, True],
 | 
						|
                "jsonrpc": "2.0",
 | 
						|
                "id": 0,
 | 
						|
            })
 | 
						|
            resp = r.json()
 | 
						|
            if 'result' in resp and resp['result'] == True:
 | 
						|
                await cq.answer('Sent')
 | 
						|
            elif 'result' in resp and resp['result'] == 'duplicate':
 | 
						|
                await cq.answer('Duplicate')
 | 
						|
            else:
 | 
						|
                raise Exception(resp)
 | 
						|
        except:
 | 
						|
            traceback.print_exc()
 | 
						|
            await cq.answer('An error has occurred, check logs')
 | 
						|
 | 
						|
 | 
						|
async def background_on_start():
 | 
						|
    await redis.delete('e621:update')
 | 
						|
    while True:
 | 
						|
        logging.warning('Checking updates...')
 | 
						|
        try:
 | 
						|
            await check_updates()
 | 
						|
        except Exception as e:
 | 
						|
            logging.exception(e)
 | 
						|
        logging.warning('Sleeping...')
 | 
						|
        await asyncio.sleep(600)
 | 
						|
 | 
						|
 | 
						|
async def on_bot_startup():
 | 
						|
    asyncio.create_task(background_on_start())
 | 
						|
 | 
						|
 | 
						|
async def main():
 | 
						|
    dp.startup.register(on_bot_startup)
 | 
						|
    await dp.start_polling(bot)
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    asyncio.run(main())
 |