You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

283 lines
11 KiB

import asyncio
import base64
import logging
import os
import random
import re
import traceback
from asyncio import sleep
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryDirectory
from time import time
from typing import List
import boto3
import ffmpeg
from PIL import Image
import httpx
import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher, filters, exceptions, F
from aiogram.enums import ChatAction, ParseMode
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \
CallbackQuery
import dotenv
from e621 import E621, E621Post
dotenv.load_dotenv('.env')
redis = aioredis.from_url('redis://localhost')
e621 = E621()
logging.basicConfig(level=logging.INFO)
bot = Bot(token=os.environ['BOT_TOKEN'])
dp = Dispatcher()
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
async def send_post(post: E621Post, tag_list: List[str]):
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
monitored_tags = set(post.tags.flatten()) & set(tag_list)
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>',
artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
f'\nhttps://e621.net/posts/{post.id}'
] if l)
if post.file.url:
try:
logging.warning(post.file.url)
async with httpx.AsyncClient() as client:
file = BytesIO()
file.write((await client.get(post.file.url)).content)
file.name = f'file.{post.file.ext}'
file.seek(0)
if post.file.ext in ('webm', 'gif'):
with TemporaryDirectory() as td:
src_path = Path(td) / f'video.{post.file.ext}'
mp4_path = Path(td) / 'video.mp4'
with open(src_path, 'wb') as webm:
webm.write(file.read())
ffmpeg\
.input(str(src_path)) \
.filter('pad', **{
'width': 'ceil(iw/2)*2',
'height': 'ceil(ih/2)*2',
'x': '0',
'y': '0',
'color': 'Black'
})\
.output(str(mp4_path), vcodec='libx264', crf='26')\
.run()
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
bucket = os.environ['AWS_S3_BUCKET']
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
await bot.send_message(int(os.environ['SEND_CHAT']),
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
parse_mode=ParseMode.HTML)
src_path.unlink()
mp4_path.unlink()
elif post.file.ext in ('png', 'jpg'):
markup = InlineKeyboardMarkup(inline_keyboard=[[
InlineKeyboardButton(text='Safe', callback_data='send pics'),
InlineKeyboardButton(text='NSFW', callback_data='send nsfw'),
]])
# if post.file.size > 10000000:
logging.warning('compressing')
dl_im = Image.open(file).convert('RGBA')
size = dl_im.size
if size[0] > 2000 or size[1] > 2000:
larger_dimension = max(size)
ratio = 2000 / larger_dimension
dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
Image.LANCZOS)
print(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
composite = Image.alpha_composite(im, dl_im).convert('RGB')
file = BytesIO()
composite.save(file, format='JPEG')
file.seek(0)
await bot.send_photo(int(os.environ['SEND_CHAT']),
BufferedInputFile(file.read(), 'file.jpg'),
caption=caption,
parse_mode=ParseMode.HTML,
reply_markup=markup)
await redis.sadd('e621:sent', post.id)
except exceptions.TelegramAPIError as e:
logging.exception(e)
async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
random.shuffle(tag_list)
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
# last_index = len(posts)
# if already_sent.count(True):
# last_index = already_sent.index(True)
# await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
for i in list(range(len(posts)))[::-1]:
if already_sent[i]:
continue
await send_post(posts[i], tag_list)
await sleep(1)
@dp.message(filters.Command('add'), ChatFilter)
async def add_tag(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
await redis.sadd('e621:subs', tag)
await msg.reply(f'Tags {args} added')
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
async def mark_old_as_sent(msg: Message):
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent')
for i, tag in enumerate(tag_list, 1):
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent')
await sleep(1)
await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags')
@dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter)
async def del_tag(msg: Message):
args = msg.text[5:]
if not args:
await msg.reply('Please provide tag to unsubscribe from')
return
if ' ' in args:
await msg.reply('Tag should not contain spaces')
return
if not await redis.sismember('e621:subs', args):
await msg.reply('Tag not found')
return
await redis.srem('e621:subs', args)
await msg.reply(f'Tag {args} removed')
@dp.message(filters.Command('del'), ChatFilter)
async def del_command(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.srem('e621:subs', tag)
await msg.reply(f'Tags {args} removed')
@dp.message(filters.Command('list'), ChatFilter)
async def list_tags(msg: Message):
tags = [t.decode() for t in await redis.smembers('e621:subs')]
tags.sort()
lines = []
for tag in tags:
entry = f'- {tag} [/del_{tag}]'
if len('\n'.join(lines + [entry])) > 2000:
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
lines = [entry]
else:
lines.append(f'- {tag} [/del_{tag}]')
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
@dp.message(filters.Command('update'), ChatFilter)
async def update(msg: Message):
await check_updates()
@dp.message(filters.Command('test'), ChatFilter)
async def test(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide post id')
return
post = await e621.get_posts(f'id:{args}')
print(f'id:{args}')
if not post:
await msg.reply('Post not found')
return
await send_post(post[0], [])
@dp.callback_query(F.data.startswith('send '))
async def send_callback(cq: CallbackQuery):
_, destination = cq.data.split()
img_bytes = BytesIO()
await bot.download(cq.message.photo[-1], img_bytes)
img_bytes.seek(0)
data = base64.b64encode(img_bytes.read()).decode()
async with httpx.AsyncClient() as client:
try:
r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
"method": "post_photo",
"params": [data, True],
"jsonrpc": "2.0",
"id": 0,
})
resp = r.json()
if 'result' in resp and resp['result'] == True:
await cq.answer('Sent')
elif 'result' in resp and resp['result'] == 'duplicate':
await cq.answer('Duplicate')
else:
raise Exception(resp)
except:
traceback.print_exc()
await cq.answer('An error has occurred, check logs')
async def background_on_start():
await redis.delete('e621:update')
while True:
logging.warning('Checking updates...')
try:
await check_updates()
except Exception as e:
logging.exception(e)
logging.warning('Sleeping...')
await asyncio.sleep(600)
async def on_bot_startup():
asyncio.create_task(background_on_start())
async def main():
dp.startup.register(on_bot_startup)
await dp.start_polling(bot)
if __name__ == '__main__':
asyncio.run(main())