380 lines
15 KiB
Python
380 lines
15 KiB
Python
import asyncio
|
|
import base64
|
|
import datetime
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import traceback
|
|
from asyncio import sleep
|
|
from io import BytesIO
|
|
from itertools import count
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from time import time
|
|
from typing import List, Iterable
|
|
|
|
import boto3
|
|
import ffmpeg
|
|
from PIL import Image
|
|
|
|
import httpx
|
|
import redis.asyncio as aioredis
|
|
from aiogram import Bot, Dispatcher, filters, exceptions, F
|
|
from aiogram.enums import ChatAction, ParseMode
|
|
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \
|
|
CallbackQuery
|
|
import dotenv
|
|
|
|
from e621 import E621, E621Post, E621PostFile, E621PostVersion
|
|
|
|
dotenv.load_dotenv('.env')
|
|
|
|
redis = aioredis.from_url('redis://localhost')
|
|
e621 = E621()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
bot = Bot(token=os.environ['BOT_TOKEN'])
|
|
dp = Dispatcher()
|
|
|
|
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
|
|
|
|
PAGE_SIZE = 20
|
|
|
|
|
|
def format_tags(tags: Iterable[str]):
|
|
return ' '.join(f'#{re.sub(r"[^0-9a-zA-Z_]", "_", tag)}' for tag in tags) or 'None'
|
|
|
|
|
|
def recover_url(file: E621PostFile):
|
|
return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
|
|
|
|
|
|
async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
|
|
try:
|
|
logging.warning(f'Sending post #{post.id}')
|
|
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
|
|
flat_post_tags = set(post.tags.flatten())
|
|
monitored_tags = set()
|
|
for tag_group in tag_list:
|
|
if set(tag_group) <= flat_post_tags:
|
|
monitored_tags.update(tag_group)
|
|
artist_tags = post.tags.artist
|
|
character_tags = post.tags.character
|
|
copyright_tags = post.tags.copyright
|
|
caption = '\n'.join(l for l in [
|
|
f'Monitored tags: <b>{format_tags(monitored_tags)}</b>',
|
|
artist_tags and f'Artist: <b>{format_tags(artist_tags)}</b>',
|
|
character_tags and f'Character: <b>{format_tags(character_tags)}</b>',
|
|
copyright_tags and f'Copyright: <b>{format_tags(copyright_tags)}</b>',
|
|
f'\nhttps://e621.net/posts/{post.id}'
|
|
] if l)
|
|
if not post.file.url:
|
|
post.file.url = recover_url(post.file)
|
|
try:
|
|
logging.warning(post.file.url)
|
|
async with httpx.AsyncClient() as client:
|
|
file = BytesIO()
|
|
file.write((await client.get(post.file.url)).content)
|
|
file.name = f'file.{post.file.ext}'
|
|
file.seek(0)
|
|
if post.file.ext in ('webm', 'gif'):
|
|
with TemporaryDirectory() as td:
|
|
src_path = Path(td) / f'video.{post.file.ext}'
|
|
mp4_path = Path(td) / 'video.mp4'
|
|
with open(src_path, 'wb') as webm:
|
|
webm.write(file.read())
|
|
video_input = ffmpeg\
|
|
.input(str(src_path))
|
|
cmd = video_input \
|
|
.output(str(mp4_path),
|
|
vf='pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:x=0:y=0:color=Black',
|
|
vcodec='libx264',
|
|
crf='26')
|
|
logging.info('ffmpeg ' + ' '.join(cmd.get_args()))
|
|
cmd.run()
|
|
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
|
|
bucket = os.environ['AWS_S3_BUCKET']
|
|
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
|
|
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
|
|
await bot.send_message(int(os.environ['SEND_CHAT']),
|
|
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
|
|
parse_mode=ParseMode.HTML)
|
|
src_path.unlink()
|
|
mp4_path.unlink()
|
|
elif post.file.ext in ('png', 'jpg'):
|
|
markup = InlineKeyboardMarkup(inline_keyboard=[[
|
|
InlineKeyboardButton(text='NSFW', callback_data='send nsfw'),
|
|
InlineKeyboardButton(text='Safe', callback_data='send pics'),
|
|
]])
|
|
# if post.file.size > 10000000:
|
|
logging.warning('compressing')
|
|
dl_im = Image.open(file).convert('RGBA')
|
|
size = dl_im.size
|
|
if size[0] > 2000 or size[1] > 2000:
|
|
larger_dimension = max(size)
|
|
ratio = 2000 / larger_dimension
|
|
dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
|
|
Image.LANCZOS)
|
|
logging.warning(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
|
|
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
|
|
composite = Image.alpha_composite(im, dl_im).convert('RGB')
|
|
file = BytesIO()
|
|
composite.save(file, format='JPEG')
|
|
file.seek(0)
|
|
await bot.send_photo(int(os.environ['SEND_CHAT']),
|
|
BufferedInputFile(file.read(), 'file.jpg'),
|
|
caption=caption,
|
|
parse_mode=ParseMode.HTML,
|
|
reply_markup=markup)
|
|
await redis.sadd('e621:sent', post.id)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
|
|
|
|
async def check_updates():
|
|
logging.warning('Waiting for lock...')
|
|
async with redis.lock('e621:update'):
|
|
logging.warning('Lock acquired...')
|
|
matched_posts = []
|
|
tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
|
|
last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
|
|
post_versions: List[E621PostVersion] = []
|
|
logging.warning(f'Getting post versions from id {last_post_version}')
|
|
for page in count(1):
|
|
logging.warning(f'Loading page {page}')
|
|
post_versions_page = await e621.get_post_versions(last_post_version, page)
|
|
post_versions += post_versions_page
|
|
if not last_post_version or not post_versions_page:
|
|
break
|
|
for post_version in post_versions[::-1]:
|
|
if post_version.id > last_post_version:
|
|
last_post_version = post_version.id
|
|
post_tags = set(post_version.tags.split())
|
|
for tag_group in tag_list:
|
|
if set(tag_group) <= post_tags:
|
|
matched_posts.append(post_version.post_id)
|
|
break
|
|
matched_posts.sort()
|
|
if matched_posts:
|
|
already_sent: List = await redis.smismember('e621:sent', matched_posts)
|
|
posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
|
|
logging.warning(f'Found {len(posts_to_send)} posts')
|
|
for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
|
|
chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
|
|
posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
|
|
for i, post in enumerate(posts):
|
|
logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(posts_to_send)}')
|
|
await send_post(post, tag_list)
|
|
await redis.sadd('e621:sent', post.id)
|
|
await sleep(1)
|
|
await redis.set('e621:last_version', last_post_version)
|
|
|
|
# random.shuffle(tag_list)
|
|
# for tl_idx in range(0, len(tag_list), PAGE_SIZE):
|
|
# tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
|
|
# logging.warning(tags)
|
|
# posts = await e621.get_posts(tags)
|
|
# if not posts:
|
|
# return
|
|
# already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
|
|
# # last_index = len(posts)
|
|
# # if already_sent.count(True):
|
|
# # last_index = already_sent.index(True)
|
|
# # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
|
|
# for i in list(range(len(posts)))[::-1]:
|
|
# if already_sent[i]:
|
|
# continue
|
|
# await send_post(posts[i], tag_list)
|
|
# await sleep(1)
|
|
|
|
|
|
@dp.message(filters.Command('resend_after'), ChatFilter)
|
|
async def resend_after(msg: Message):
|
|
try:
|
|
timestamp = int(msg.text.split()[1])
|
|
except:
|
|
traceback.print_exc()
|
|
await msg.reply('Invalid timestamp or not provided')
|
|
return
|
|
|
|
async with redis.lock('e621:update'):
|
|
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
|
|
for i, tag in enumerate(tag_list):
|
|
await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
|
|
posts = []
|
|
page = 1
|
|
while True:
|
|
page_posts = await e621.get_posts(tag, page)
|
|
if not page_posts:
|
|
break
|
|
for post in page_posts:
|
|
if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp:
|
|
break
|
|
posts.append(post)
|
|
page += 1
|
|
for post in posts[::-1]:
|
|
await send_post(post, tag_list)
|
|
await msg.reply('Finished')
|
|
|
|
|
|
@dp.message(filters.Command('add'), ChatFilter)
|
|
async def add_tag(msg: Message):
|
|
args = ' '.join(msg.text.split()[1:])
|
|
if not args:
|
|
await msg.reply('Please provide tag to subscribe to')
|
|
return
|
|
for tag in args.split():
|
|
posts = await e621.get_posts(tag)
|
|
await redis.sadd('e621:sent', *[post.id for post in posts])
|
|
await redis.sadd('e621:subs', tag)
|
|
await msg.reply(f'Tags {args} added')
|
|
|
|
|
|
@dp.message(filters.Command('add_tags'), ChatFilter)
|
|
async def add_tags(msg: Message):
|
|
args = ' '.join(msg.text.split()[1:])
|
|
if not args:
|
|
await msg.reply('Please provide tags to subscribe to')
|
|
return
|
|
tags = args.split()
|
|
tags.sort()
|
|
tags = ' '.join(tags)
|
|
posts = await e621.get_posts(tags)
|
|
await redis.sadd('e621:sent', *[post.id for post in posts])
|
|
await redis.sadd('e621:subs', tags)
|
|
await msg.reply(f'Tag group <code>{tags}</code> added', parse_mode=ParseMode.HTML)
|
|
|
|
|
|
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
|
|
async def mark_old_as_sent(msg: Message):
|
|
logging.warning('Waiting for lock...')
|
|
async with redis.lock('e621:update'):
|
|
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
|
|
m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent')
|
|
for i, tag in enumerate(tag_list, 1):
|
|
posts = await e621.get_posts(tag)
|
|
await redis.sadd('e621:sent', *[post.id for post in posts])
|
|
await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent')
|
|
await sleep(1)
|
|
await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags')
|
|
|
|
|
|
@dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter)
|
|
async def del_tag(msg: Message):
|
|
args = msg.text[5:]
|
|
if not args:
|
|
await msg.reply('Please provide tag to unsubscribe from')
|
|
return
|
|
if ' ' in args:
|
|
await msg.reply('Tag should not contain spaces')
|
|
return
|
|
if not await redis.sismember('e621:subs', args):
|
|
await msg.reply('Tag not found')
|
|
return
|
|
await redis.srem('e621:subs', args)
|
|
await msg.reply(f'Tag {args} removed')
|
|
|
|
|
|
@dp.message(filters.Command('del'), ChatFilter)
|
|
async def del_command(msg: Message):
|
|
args = ' '.join(msg.text.split()[1:])
|
|
if not args:
|
|
await msg.reply('Please provide tag to subscribe to')
|
|
return
|
|
for tag in args.split():
|
|
await redis.srem('e621:subs', tag)
|
|
await msg.reply(f'Tags {args} removed')
|
|
|
|
|
|
@dp.message(filters.Command('list'), ChatFilter)
|
|
async def list_tags(msg: Message):
|
|
tags = [t.decode() for t in await redis.smembers('e621:subs')]
|
|
tags.sort()
|
|
lines = []
|
|
for tag in tags:
|
|
entry = f'- {tag} [/del_{tag}]'
|
|
if len('\n'.join(lines + [entry])) > 2000:
|
|
lines = "\n".join(lines)
|
|
await msg.reply(f'Monitored tags:\n\n{lines}')
|
|
lines = [entry]
|
|
else:
|
|
lines.append(f'- {tag} [/del_{tag}]')
|
|
|
|
lines = "\n".join(lines)
|
|
await msg.reply(f'Monitored tags:\n\n{lines}')
|
|
|
|
|
|
@dp.message(filters.Command('update'), ChatFilter)
|
|
async def update(msg: Message):
|
|
await check_updates()
|
|
|
|
|
|
@dp.message(filters.Command('test'), ChatFilter)
|
|
async def test(msg: Message):
|
|
args = ' '.join(msg.text.split()[1:])
|
|
if not args:
|
|
await msg.reply('Please provide post id')
|
|
return
|
|
post = await e621.get_posts(f'id:{args}')
|
|
print(f'id:{args}')
|
|
if not post:
|
|
await msg.reply('Post not found')
|
|
return
|
|
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
|
|
await send_post(post[0], tag_list)
|
|
|
|
|
|
@dp.callback_query(F.data.startswith('send '))
|
|
async def send_callback(cq: CallbackQuery):
|
|
_, destination = cq.data.split()
|
|
img_bytes = BytesIO()
|
|
await bot.download(cq.message.photo[-1], img_bytes)
|
|
img_bytes.seek(0)
|
|
data = base64.b64encode(img_bytes.read()).decode()
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
|
|
"method": "post_photo",
|
|
"params": [data, True],
|
|
"jsonrpc": "2.0",
|
|
"id": 0,
|
|
})
|
|
resp = r.json()
|
|
if 'result' in resp and resp['result'] == True:
|
|
await cq.answer('Sent')
|
|
elif 'result' in resp and resp['result'] == 'duplicate':
|
|
await cq.answer('Duplicate')
|
|
else:
|
|
raise Exception(resp)
|
|
except:
|
|
traceback.print_exc()
|
|
await cq.answer('An error has occurred, check logs')
|
|
|
|
|
|
async def background_on_start():
|
|
await redis.delete('e621:update')
|
|
while True:
|
|
logging.warning('Checking updates...')
|
|
try:
|
|
await check_updates()
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
logging.warning('Sleeping...')
|
|
await asyncio.sleep(600)
|
|
|
|
|
|
async def on_bot_startup():
|
|
asyncio.create_task(background_on_start())
|
|
|
|
|
|
async def main():
|
|
dp.startup.register(on_bot_startup)
|
|
await dp.start_polling(bot)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|