e621_bot/main.py

224 lines
8.3 KiB
Python
Raw Normal View History

2023-03-22 15:07:42 +00:00
import asyncio
2023-03-22 16:10:44 +00:00
import logging
import os
2023-07-23 21:50:51 +00:00
import random
2023-08-13 19:23:10 +00:00
from asyncio import sleep
from io import BytesIO
2023-05-31 21:32:11 +00:00
from pathlib import Path
from tempfile import TemporaryFile, TemporaryDirectory
from time import time
2023-05-31 21:04:20 +00:00
from typing import List
2023-05-31 21:32:11 +00:00
import boto3
import ffmpeg
2023-03-30 17:50:09 +00:00
from PIL import Image
2023-03-22 15:07:42 +00:00
import httpx
2023-03-22 15:07:42 +00:00
import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher
2023-04-14 08:24:20 +00:00
from aiogram.dispatcher import filters
2023-05-31 21:32:11 +00:00
from aiogram.types import Message, ParseMode, ChatActions, InputFile
2023-03-22 16:28:33 +00:00
from aiogram.utils import executor, exceptions
2023-03-22 16:10:44 +00:00
import dotenv
2023-03-22 15:07:42 +00:00
2023-05-31 21:04:20 +00:00
from e621 import E621, E621Post
2023-03-22 15:07:42 +00:00
2023-03-22 16:10:44 +00:00
dotenv.load_dotenv('.env')
2023-03-22 15:07:42 +00:00
redis = aioredis.from_url('redis://localhost')
e621 = E621()
logging.basicConfig(level=logging.INFO)
2023-03-22 16:10:44 +00:00
bot = Bot(token=os.environ['BOT_TOKEN'])
2023-03-22 15:07:42 +00:00
dp = Dispatcher(bot)
2023-04-14 08:25:37 +00:00
user_ids = list(map(int, os.environ['USERS'].split(',')))
2023-04-14 08:24:20 +00:00
2023-03-22 15:07:42 +00:00
2023-05-31 21:04:20 +00:00
async def send_post(post: E621Post, tag_list: List[str]):
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING)
monitored_tags = set(post.tags.flatten()) & set(tag_list)
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>',
artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
f'\nhttps://e621.net/posts/{post.id}'
] if l)
if post.file.url:
try:
logging.warning(post.file.url)
async with httpx.AsyncClient() as client:
file = BytesIO()
file.write((await client.get(post.file.url)).content)
file.name = f'file.{post.file.ext}'
file.seek(0)
2023-06-02 18:52:24 +00:00
if post.file.ext in ('webm', 'gif'):
2023-05-31 21:32:11 +00:00
with TemporaryDirectory() as td:
2023-06-02 18:52:24 +00:00
src_path = Path(td) / f'video.{post.file.ext}'
2023-05-31 21:32:11 +00:00
mp4_path = Path(td) / 'video.mp4'
2023-06-02 18:52:24 +00:00
with open(src_path, 'wb') as webm:
2023-05-31 21:32:11 +00:00
webm.write(file.read())
ffmpeg\
2023-07-23 10:46:43 +00:00
.input(str(src_path)) \
2023-07-23 10:48:12 +00:00
.filter('pad', **{
2023-07-23 10:46:43 +00:00
'width': 'ceil(iw/2)*2',
'height': 'ceil(ih/2)*2',
'x': '0',
'y': '0',
'color': 'Black'
})\
2023-05-31 21:32:11 +00:00
.output(str(mp4_path), vcodec='libx264', crf='26')\
.run()
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
bucket = os.environ['AWS_S3_BUCKET']
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
2023-05-31 21:04:20 +00:00
await bot.send_message(int(os.environ['SEND_CHAT']),
2023-05-31 21:32:11 +00:00
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
2023-05-31 21:04:20 +00:00
parse_mode=ParseMode.HTML)
2023-06-02 18:52:24 +00:00
src_path.unlink()
2023-05-31 21:32:11 +00:00
mp4_path.unlink()
2023-05-31 21:04:20 +00:00
elif post.file.ext in ('png', 'jpg'):
2023-08-13 19:30:21 +00:00
# if post.file.size > 10000000:
logging.warning('compressing')
dl_im = Image.open(file).convert('RGBA')
size = dl_im.size
if size[0] > 2000 or size[1] > 2000:
larger_dimension = max(size)
ratio = 2000 / larger_dimension
dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
Image.LANCZOS)
print(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
composite = Image.alpha_composite(im, dl_im).convert('RGB')
file = BytesIO()
composite.save(file, format='JPEG')
file.seek(0)
file.name = 'file.jpg'
2023-05-31 21:04:20 +00:00
await bot.send_photo(int(os.environ['SEND_CHAT']),
file,
caption=caption,
parse_mode=ParseMode.HTML)
await redis.sadd('e621:sent', post.id)
except exceptions.TelegramAPIError as e:
logging.exception(e)
2023-03-22 15:07:42 +00:00
async def check_updates():
2023-03-30 20:00:56 +00:00
logging.warning('Waiting for lock...')
2023-03-22 15:07:42 +00:00
async with redis.lock('e621:update'):
2023-03-30 20:00:56 +00:00
logging.warning('Lock acquired...')
2023-03-22 15:07:42 +00:00
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
2023-07-23 21:50:51 +00:00
random.shuffle(tag_list)
2023-03-24 14:25:28 +00:00
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
2023-03-30 20:00:05 +00:00
logging.warning(tags)
2023-03-24 14:25:28 +00:00
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
2023-06-19 13:36:03 +00:00
for i in list(range(len(posts)))[::-1]:
2023-03-24 14:25:28 +00:00
if already_sent[i]:
continue
2023-05-31 21:04:20 +00:00
await send_post(posts[i], tag_list)
2023-08-13 19:25:41 +00:00
await sleep(1)
2023-03-22 15:07:42 +00:00
2023-04-14 08:32:30 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add'])
2023-03-22 15:07:42 +00:00
async def add_tag(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.sadd('e621:subs', tag)
await msg.reply(f'Tags {args} added')
2023-04-14 08:32:30 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), regexp=r'^\/del_\S+$')
2023-03-22 15:07:42 +00:00
async def del_tag(msg: Message):
args = msg.text[5:]
if not args:
await msg.reply('Please provide tag to subscribe to')
return
if ' ' in args:
await msg.reply('Tag should not contain spaces')
return
if not await redis.sismember('e621:subs', args):
await msg.reply('Tag not found')
return
await redis.srem('e621:subs', args)
2023-03-31 10:18:11 +00:00
await msg.reply(f'Tag {args} removed')
2023-03-22 15:07:42 +00:00
2023-04-14 08:32:30 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['del'])
2023-04-10 00:58:54 +00:00
async def del_command(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.srem('e621:subs', tag)
await msg.reply(f'Tags {args} removed')
2023-04-14 08:32:30 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['list'])
2023-03-22 15:07:42 +00:00
async def list_tags(msg: Message):
tags = [t.decode() for t in await redis.smembers('e621:subs')]
2023-03-30 20:12:00 +00:00
tags.sort()
2023-03-22 15:07:42 +00:00
lines = []
for tag in tags:
2023-04-10 00:55:29 +00:00
entry = f'- {tag} [/del_{tag}]'
2023-04-10 12:31:11 +00:00
if len('\n'.join(lines + [entry])) > 2000:
2023-04-10 00:55:29 +00:00
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
lines = [entry]
else:
lines.append(f'- {tag} [/del_{tag}]')
2023-03-22 15:07:42 +00:00
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
2023-04-14 08:32:30 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['update'])
2023-03-22 15:07:42 +00:00
async def update(msg: Message):
await check_updates()
2023-05-31 21:04:20 +00:00
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test'])
async def test(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide post id')
return
2023-05-31 21:32:11 +00:00
post = await e621.get_posts(f'id:{args}')
print(f'id:{args}')
2023-05-31 21:04:20 +00:00
if not post:
await msg.reply('Post not found')
return
await send_post(post[0], [])
2023-03-22 15:07:42 +00:00
async def background_on_start():
2023-04-11 23:50:50 +00:00
await redis.delete('e621:update')
2023-03-22 15:07:42 +00:00
while True:
2023-03-30 19:38:14 +00:00
logging.warning('Checking updates...')
2023-03-30 20:08:02 +00:00
try:
await check_updates()
except Exception as e:
logging.exception(e)
2023-03-30 19:59:26 +00:00
logging.warning('Sleeping...')
2023-04-15 15:06:38 +00:00
await asyncio.sleep(600)
2023-03-22 15:07:42 +00:00
async def on_bot_startup(dp: Dispatcher):
asyncio.create_task(background_on_start())
if __name__ == '__main__':
executor.start_polling(dp, on_startup=on_bot_startup)