e621_bot/main.py

383 lines
15 KiB
Python
Raw Permalink Normal View History

2023-03-22 15:07:42 +00:00
import asyncio
2023-09-03 22:57:05 +00:00
import base64
2023-11-03 12:17:24 +00:00
import datetime
2023-03-22 16:10:44 +00:00
import logging
import os
2023-07-23 21:50:51 +00:00
import random
2023-09-03 22:57:05 +00:00
import re
import traceback
2023-08-13 19:23:10 +00:00
from asyncio import sleep
from io import BytesIO
2024-02-25 08:49:15 +00:00
from itertools import count
2023-05-31 21:32:11 +00:00
from pathlib import Path
2023-09-03 22:57:05 +00:00
from tempfile import TemporaryDirectory
2023-05-31 21:32:11 +00:00
from time import time
2023-11-03 11:32:23 +00:00
from typing import List, Iterable
2023-05-31 21:04:20 +00:00
2023-05-31 21:32:11 +00:00
import boto3
import ffmpeg
2023-03-30 17:50:09 +00:00
from PIL import Image
2023-03-22 15:07:42 +00:00
import httpx
2023-03-22 15:07:42 +00:00
import redis.asyncio as aioredis
2023-09-03 22:57:05 +00:00
from aiogram import Bot, Dispatcher, filters, exceptions, F
from aiogram.enums import ChatAction, ParseMode
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \
CallbackQuery
2023-03-22 16:10:44 +00:00
import dotenv
2023-03-22 15:07:42 +00:00
2024-02-25 08:49:15 +00:00
from e621 import E621, E621Post, E621PostFile, E621PostVersion
2023-03-22 15:07:42 +00:00
2023-03-22 16:10:44 +00:00
dotenv.load_dotenv('.env')
2023-03-22 15:07:42 +00:00
redis = aioredis.from_url('redis://localhost')
e621 = E621()
logging.basicConfig(level=logging.INFO)
2023-03-22 16:10:44 +00:00
bot = Bot(token=os.environ['BOT_TOKEN'])
2023-09-03 22:57:05 +00:00
dp = Dispatcher()
2023-03-22 15:07:42 +00:00
2023-09-03 22:57:05 +00:00
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
2023-04-14 08:24:20 +00:00
2023-11-21 14:33:37 +00:00
PAGE_SIZE = 20
2023-03-22 15:07:42 +00:00
2023-11-03 11:32:23 +00:00
def format_tags(tags: Iterable[str]):
2023-11-12 22:09:51 +00:00
return ' '.join(f'#{re.sub(r"[^0-9a-zA-Z_]", "_", tag)}' for tag in tags) or 'None'
2023-11-03 11:32:23 +00:00
2023-11-03 12:57:29 +00:00
def recover_url(file: E621PostFile):
2023-11-03 12:57:57 +00:00
return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
2023-11-03 12:57:29 +00:00
2024-02-25 08:49:15 +00:00
async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
2023-11-03 12:37:23 +00:00
try:
logging.warning(f'Sending post #{post.id}')
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
2024-02-25 08:49:15 +00:00
flat_post_tags = set(post.tags.flatten())
monitored_tags = set()
for tag_group in tag_list:
if set(tag_group) <= flat_post_tags:
monitored_tags.update(tag_group)
2023-11-03 12:37:23 +00:00
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
2023-11-03 12:39:23 +00:00
f'Monitored tags: <b>{format_tags(monitored_tags)}</b>',
2023-11-03 12:37:23 +00:00
artist_tags and f'Artist: <b>{format_tags(artist_tags)}</b>',
character_tags and f'Character: <b>{format_tags(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{format_tags(copyright_tags)}</b>',
f'\nhttps://e621.net/posts/{post.id}'
] if l)
2023-11-03 12:57:29 +00:00
if not post.file.url:
post.file.url = recover_url(post.file)
try:
logging.warning(post.file.url)
async with httpx.AsyncClient() as client:
file = BytesIO()
file.write((await client.get(post.file.url)).content)
file.name = f'file.{post.file.ext}'
file.seek(0)
if post.file.ext in ('webm', 'gif'):
with TemporaryDirectory() as td:
src_path = Path(td) / f'video.{post.file.ext}'
mp4_path = Path(td) / 'video.mp4'
with open(src_path, 'wb') as webm:
webm.write(file.read())
video_input = ffmpeg\
.input(str(src_path))
cmd = video_input \
.output(str(mp4_path),
vf='pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:x=0:y=0:color=Black',
vcodec='libx264',
crf='26')
logging.info('ffmpeg ' + ' '.join(cmd.get_args()))
cmd.run()
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
bucket = os.environ['AWS_S3_BUCKET']
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
await bot.send_message(int(os.environ['SEND_CHAT']),
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
parse_mode=ParseMode.HTML)
src_path.unlink()
mp4_path.unlink()
elif post.file.ext in ('png', 'jpg'):
markup = InlineKeyboardMarkup(inline_keyboard=[[
InlineKeyboardButton(text='NSFW', callback_data='send nsfw'),
InlineKeyboardButton(text='Safe', callback_data='send pics'),
]])
# if post.file.size > 10000000:
logging.warning('compressing')
dl_im = Image.open(file).convert('RGBA')
size = dl_im.size
if size[0] > 2000 or size[1] > 2000:
larger_dimension = max(size)
ratio = 2000 / larger_dimension
dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
Image.LANCZOS)
2024-02-25 08:51:19 +00:00
logging.warning(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
2023-11-03 12:57:29 +00:00
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
composite = Image.alpha_composite(im, dl_im).convert('RGB')
2023-08-13 19:30:21 +00:00
file = BytesIO()
2023-11-03 12:57:29 +00:00
composite.save(file, format='JPEG')
2023-08-13 19:30:21 +00:00
file.seek(0)
2023-11-03 12:57:29 +00:00
await bot.send_photo(int(os.environ['SEND_CHAT']),
BufferedInputFile(file.read(), 'file.jpg'),
caption=caption,
parse_mode=ParseMode.HTML,
reply_markup=markup)
await redis.sadd('e621:sent', post.id)
except Exception as e:
logging.exception(e)
2023-11-03 12:37:23 +00:00
except Exception as e:
logging.exception(e)
2023-05-31 21:04:20 +00:00
2023-03-22 15:07:42 +00:00
async def check_updates():
2023-03-30 20:00:56 +00:00
logging.warning('Waiting for lock...')
2023-03-22 15:07:42 +00:00
async with redis.lock('e621:update'):
2023-03-30 20:00:56 +00:00
logging.warning('Lock acquired...')
2024-02-25 08:49:15 +00:00
matched_posts = []
tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
tag_list_flat = set(sum(tag_list, ()))
2024-02-25 08:49:15 +00:00
last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
post_versions: List[E621PostVersion] = []
2024-02-25 09:32:54 +00:00
logging.warning(f'Getting post versions from id {last_post_version}')
2024-02-25 08:49:15 +00:00
for page in count(1):
2024-02-25 09:33:30 +00:00
logging.warning(f'Loading page {page}')
2024-02-25 08:49:15 +00:00
post_versions_page = await e621.get_post_versions(last_post_version, page)
post_versions += post_versions_page
if not last_post_version or not post_versions_page:
break
for post_version in post_versions[::-1]:
if post_version.id > last_post_version:
last_post_version = post_version.id
if not bool(tag_list_flat & set(post_version.added_tags)):
continue
2024-02-25 08:49:15 +00:00
post_tags = set(post_version.tags.split())
for tag_group in tag_list:
if set(tag_group) <= post_tags:
matched_posts.append(post_version.post_id)
break
matched_posts.sort()
if matched_posts:
already_sent: List = await redis.smismember('e621:sent', matched_posts)
posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
2024-02-25 08:51:19 +00:00
logging.warning(f'Found {len(posts_to_send)} posts')
2024-02-25 08:49:15 +00:00
for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
for i, post in enumerate(posts):
2024-02-25 08:51:19 +00:00
logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(posts_to_send)}')
2024-02-25 08:49:15 +00:00
await send_post(post, tag_list)
await redis.sadd('e621:sent', post.id)
await sleep(1)
await redis.set('e621:last_version', last_post_version)
# random.shuffle(tag_list)
# for tl_idx in range(0, len(tag_list), PAGE_SIZE):
# tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
# logging.warning(tags)
# posts = await e621.get_posts(tags)
# if not posts:
# return
# already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
# # last_index = len(posts)
# # if already_sent.count(True):
# # last_index = already_sent.index(True)
# # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
# for i in list(range(len(posts)))[::-1]:
# if already_sent[i]:
# continue
# await send_post(posts[i], tag_list)
# await sleep(1)
2023-03-22 15:07:42 +00:00
2023-11-03 12:17:24 +00:00
@dp.message(filters.Command('resend_after'), ChatFilter)
async def resend_after(msg: Message):
try:
timestamp = int(msg.text.split()[1])
except:
traceback.print_exc()
await msg.reply('Invalid timestamp or not provided')
return
async with redis.lock('e621:update'):
2024-02-25 08:49:15 +00:00
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
2023-11-03 12:17:24 +00:00
for i, tag in enumerate(tag_list):
2023-11-03 12:33:23 +00:00
await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
2023-11-03 12:17:24 +00:00
posts = []
page = 1
while True:
page_posts = await e621.get_posts(tag, page)
if not page_posts:
break
for post in page_posts:
if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp:
break
posts.append(post)
page += 1
for post in posts[::-1]:
await send_post(post, tag_list)
await msg.reply('Finished')
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('add'), ChatFilter)
2023-03-22 15:07:42 +00:00
async def add_tag(msg: Message):
2023-09-03 22:57:05 +00:00
args = ' '.join(msg.text.split()[1:])
2023-03-22 15:07:42 +00:00
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
2023-08-13 19:55:30 +00:00
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
2023-03-22 15:07:42 +00:00
await redis.sadd('e621:subs', tag)
await msg.reply(f'Tags {args} added')
2024-02-25 08:49:15 +00:00
@dp.message(filters.Command('add_tags'), ChatFilter)
async def add_tags(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tags to subscribe to')
return
tags = args.split()
tags.sort()
tags = ' '.join(tags)
posts = await e621.get_posts(tags)
await redis.sadd('e621:sent', *[post.id for post in posts])
await redis.sadd('e621:subs', tags)
await msg.reply(f'Tag group <code>{tags}</code> added', parse_mode=ParseMode.HTML)
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
2023-08-13 20:03:37 +00:00
async def mark_old_as_sent(msg: Message):
2023-09-03 22:57:05 +00:00
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent')
for i, tag in enumerate(tag_list, 1):
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent')
await sleep(1)
await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags')
2023-08-13 20:03:37 +00:00
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter)
2023-03-22 15:07:42 +00:00
async def del_tag(msg: Message):
args = msg.text[5:]
if not args:
2023-09-03 22:57:05 +00:00
await msg.reply('Please provide tag to unsubscribe from')
2023-03-22 15:07:42 +00:00
return
if ' ' in args:
await msg.reply('Tag should not contain spaces')
return
if not await redis.sismember('e621:subs', args):
await msg.reply('Tag not found')
return
await redis.srem('e621:subs', args)
2023-03-31 10:18:11 +00:00
await msg.reply(f'Tag {args} removed')
2023-03-22 15:07:42 +00:00
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('del'), ChatFilter)
2023-04-10 00:58:54 +00:00
async def del_command(msg: Message):
2023-09-03 22:57:05 +00:00
args = ' '.join(msg.text.split()[1:])
2023-04-10 00:58:54 +00:00
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.srem('e621:subs', tag)
await msg.reply(f'Tags {args} removed')
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('list'), ChatFilter)
2023-03-22 15:07:42 +00:00
async def list_tags(msg: Message):
tags = [t.decode() for t in await redis.smembers('e621:subs')]
2023-03-30 20:12:00 +00:00
tags.sort()
2023-03-22 15:07:42 +00:00
lines = []
for tag in tags:
2023-04-10 00:55:29 +00:00
entry = f'- {tag} [/del_{tag}]'
2023-04-10 12:31:11 +00:00
if len('\n'.join(lines + [entry])) > 2000:
2023-04-10 00:55:29 +00:00
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
lines = [entry]
else:
lines.append(f'- {tag} [/del_{tag}]')
2023-03-22 15:07:42 +00:00
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('update'), ChatFilter)
2023-03-22 15:07:42 +00:00
async def update(msg: Message):
await check_updates()
2023-09-03 22:57:05 +00:00
@dp.message(filters.Command('test'), ChatFilter)
2023-05-31 21:04:20 +00:00
async def test(msg: Message):
2023-09-03 22:57:05 +00:00
args = ' '.join(msg.text.split()[1:])
2023-05-31 21:04:20 +00:00
if not args:
await msg.reply('Please provide post id')
return
2023-05-31 21:32:11 +00:00
post = await e621.get_posts(f'id:{args}')
print(f'id:{args}')
2023-05-31 21:04:20 +00:00
if not post:
await msg.reply('Post not found')
return
2024-02-25 08:49:15 +00:00
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
await send_post(post[0], tag_list)
2023-05-31 21:04:20 +00:00
2023-09-03 22:57:05 +00:00
@dp.callback_query(F.data.startswith('send '))
async def send_callback(cq: CallbackQuery):
_, destination = cq.data.split()
img_bytes = BytesIO()
await bot.download(cq.message.photo[-1], img_bytes)
img_bytes.seek(0)
data = base64.b64encode(img_bytes.read()).decode()
async with httpx.AsyncClient() as client:
try:
r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
"method": "post_photo",
"params": [data, True],
"jsonrpc": "2.0",
"id": 0,
})
resp = r.json()
if 'result' in resp and resp['result'] == True:
await cq.answer('Sent')
elif 'result' in resp and resp['result'] == 'duplicate':
await cq.answer('Duplicate')
else:
raise Exception(resp)
except:
traceback.print_exc()
await cq.answer('An error has occurred, check logs')
2023-03-22 15:07:42 +00:00
async def background_on_start():
2023-04-11 23:50:50 +00:00
await redis.delete('e621:update')
2023-03-22 15:07:42 +00:00
while True:
2023-03-30 19:38:14 +00:00
logging.warning('Checking updates...')
2023-03-30 20:08:02 +00:00
try:
await check_updates()
except Exception as e:
logging.exception(e)
2023-03-30 19:59:26 +00:00
logging.warning('Sleeping...')
2023-04-15 15:06:38 +00:00
await asyncio.sleep(600)
2023-03-22 15:07:42 +00:00
2023-09-03 22:57:05 +00:00
async def on_bot_startup():
2023-03-22 15:07:42 +00:00
asyncio.create_task(background_on_start())
2023-09-03 22:57:05 +00:00
async def main():
dp.startup.register(on_bot_startup)
await dp.start_polling(bot)
2023-03-22 15:07:42 +00:00
if __name__ == '__main__':
2023-09-03 22:57:05 +00:00
asyncio.run(main())