188 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			188 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import logging
 | |
| import os
 | |
| from io import BytesIO
 | |
| from PIL import Image
 | |
| 
 | |
| import httpx
 | |
| import redis.asyncio as aioredis
 | |
| from aiogram import Bot, Dispatcher
 | |
| from aiogram.dispatcher import filters
 | |
| from aiogram.types import Message, ParseMode
 | |
| from aiogram.utils import executor, exceptions
 | |
| import dotenv
 | |
| 
 | |
| from e621 import E621
 | |
| 
 | |
| dotenv.load_dotenv('.env')
 | |
| 
 | |
| redis = aioredis.from_url('redis://localhost')
 | |
| e621 = E621()
 | |
| 
 | |
| logging.basicConfig(level=logging.INFO)
 | |
| bot = Bot(token=os.environ['BOT_TOKEN'])
 | |
| dp = Dispatcher(bot)
 | |
| 
 | |
| user_ids = list(map(int, os.environ['USERS'].split(',')))
 | |
| 
 | |
| 
 | |
| async def check_updates():
 | |
|     logging.warning('Waiting for lock...')
 | |
|     async with redis.lock('e621:update'):
 | |
|         logging.warning('Lock acquired...')
 | |
|         tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | |
|         tag_list.sort()
 | |
|         for tl_idx in range(0, len(tag_list), 40):
 | |
|             tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
 | |
|             logging.warning(tags)
 | |
|             posts = await e621.get_posts(tags)
 | |
|             if not posts:
 | |
|                 return
 | |
|             already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
 | |
|             for i in range(len(posts)):
 | |
|                 if already_sent[i]:
 | |
|                     continue
 | |
|                 post = posts[i]
 | |
|                 monitored_tags = set(post.tags.flatten()) & set(tag_list)
 | |
|                 artist_tags = post.tags.artist
 | |
|                 character_tags = post.tags.character
 | |
|                 copyright_tags = post.tags.copyright
 | |
|                 caption = '\n'.join(l for l in [
 | |
|                     f'Monitored tags: <b>{" ".join(monitored_tags)}</b>',
 | |
|                     artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
 | |
|                     character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
 | |
|                     copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
 | |
|                     f'\nhttps://e621.net/posts/{post.id}'
 | |
|                 ] if l)
 | |
|                 if post.file.url:
 | |
|                     try:
 | |
|                         logging.warning(post.file.url)
 | |
|                         async with httpx.AsyncClient() as client:
 | |
|                             file = BytesIO()
 | |
|                             file.write((await client.get(post.file.url)).content)
 | |
|                             file.name = f'file.{post.file.ext}'
 | |
|                             file.seek(0)
 | |
|                             if post.file.ext == 'webm':
 | |
|                                 if post.file.size < 50_000_000:
 | |
|                                     await bot.send_video(int(os.environ['SEND_CHAT']),
 | |
|                                                          file,
 | |
|                                                          width=post.file.width,
 | |
|                                                          height=post.file.height,
 | |
|                                                          thumb=post.preview.url,
 | |
|                                                          caption=caption,
 | |
|                                                          parse_mode=ParseMode.HTML)
 | |
|                                 else:
 | |
|                                     await bot.send_message(int(os.environ['SEND_CHAT']),
 | |
|                                                            f'File is too large: {post.file.url}\n\n' + caption)
 | |
|                             elif post.file.ext == 'gif':
 | |
|                                 await bot.send_animation(int(os.environ['SEND_CHAT']),
 | |
|                                                          file,
 | |
|                                                          width=post.file.width,
 | |
|                                                          height=post.file.height,
 | |
|                                                          thumb=post.preview.url,
 | |
|                                                          caption=caption,
 | |
|                                                      parse_mode=ParseMode.HTML)
 | |
|                             elif post.file.ext in ('png', 'jpg'):
 | |
|                                 if post.file.size > 10000000:
 | |
|                                     logging.warning('compressing')
 | |
|                                     dl_im = Image.open(file).convert('RGBA')
 | |
|                                     if dl_im.size[0] > 2000 or dl_im.size[1] > 2000:
 | |
|                                         larger_dimension = max(dl_im.size)
 | |
|                                         ratio = 2000 / larger_dimension
 | |
|                                         dl_im = dl_im.resize((int(dl_im.size[0] * ratio), int(dl_im.size[1] * ratio)),
 | |
|                                                              Image.LANCZOS)
 | |
|                                     im = Image.new('RGBA', dl_im.size, (255, 255, 255))
 | |
|                                     composite = Image.alpha_composite(im, dl_im).convert('RGB')
 | |
|                                     file = BytesIO()
 | |
|                                     composite.save(file, format='JPEG')
 | |
|                                     file.seek(0)
 | |
|                                     file.name = 'file.jpg'
 | |
|                                 await bot.send_photo(int(os.environ['SEND_CHAT']),
 | |
|                                                      file,
 | |
|                                                      caption=caption,
 | |
|                                                      parse_mode=ParseMode.HTML)
 | |
|                         await redis.sadd('e621:sent', post.id)
 | |
|                     except exceptions.TelegramAPIError as e:
 | |
|                         logging.exception(e)
 | |
| 
 | |
| 
 | |
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add'])
 | |
| async def add_tag(msg: Message):
 | |
|     args = msg.get_args()
 | |
|     if not args:
 | |
|         await msg.reply('Please provide tag to subscribe to')
 | |
|         return
 | |
|     for tag in args.split():
 | |
|         await redis.sadd('e621:subs', tag)
 | |
|     await msg.reply(f'Tags {args} added')
 | |
| 
 | |
| 
 | |
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), regexp=r'^\/del_\S+$')
 | |
| async def del_tag(msg: Message):
 | |
|     args = msg.text[5:]
 | |
|     if not args:
 | |
|         await msg.reply('Please provide tag to subscribe to')
 | |
|         return
 | |
|     if ' ' in args:
 | |
|         await msg.reply('Tag should not contain spaces')
 | |
|         return
 | |
|     if not await redis.sismember('e621:subs', args):
 | |
|         await msg.reply('Tag not found')
 | |
|         return
 | |
|     await redis.srem('e621:subs', args)
 | |
|     await msg.reply(f'Tag {args} removed')
 | |
| 
 | |
| 
 | |
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['del'])
 | |
| async def del_command(msg: Message):
 | |
|     args = msg.get_args()
 | |
|     if not args:
 | |
|         await msg.reply('Please provide tag to subscribe to')
 | |
|         return
 | |
|     for tag in args.split():
 | |
|         await redis.srem('e621:subs', tag)
 | |
|     await msg.reply(f'Tags {args} removed')
 | |
| 
 | |
| 
 | |
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['list'])
 | |
| async def list_tags(msg: Message):
 | |
|     tags = [t.decode() for t in await redis.smembers('e621:subs')]
 | |
|     tags.sort()
 | |
|     lines = []
 | |
|     for tag in tags:
 | |
|         entry = f'- {tag} [/del_{tag}]'
 | |
|         if len('\n'.join(lines + [entry])) > 2000:
 | |
|             lines = "\n".join(lines)
 | |
|             await msg.reply(f'Monitored tags:\n\n{lines}')
 | |
|             lines = [entry]
 | |
|         else:
 | |
|             lines.append(f'- {tag} [/del_{tag}]')
 | |
| 
 | |
|     lines = "\n".join(lines)
 | |
|     await msg.reply(f'Monitored tags:\n\n{lines}')
 | |
| 
 | |
| 
 | |
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['update'])
 | |
| async def update(msg: Message):
 | |
|     await check_updates()
 | |
| 
 | |
| 
 | |
| async def background_on_start():
 | |
|     await redis.delete('e621:update')
 | |
|     while True:
 | |
|         logging.warning('Checking updates...')
 | |
|         try:
 | |
|             await check_updates()
 | |
|         except Exception as e:
 | |
|             logging.exception(e)
 | |
|         logging.warning('Sleeping...')
 | |
|         await asyncio.sleep(600)
 | |
| 
 | |
| 
 | |
| async def on_bot_startup(dp: Dispatcher):
 | |
|     asyncio.create_task(background_on_start())
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     executor.start_polling(dp, on_startup=on_bot_startup)
 |