You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
15 KiB

12 months ago
import asyncio
import base64
import datetime
12 months ago
import logging
import os
import random
import re
import traceback
from asyncio import sleep
from io import BytesIO
from itertools import count
from pathlib import Path
from tempfile import TemporaryDirectory
from time import time
from typing import List, Iterable
9 months ago
import boto3
import ffmpeg
from PIL import Image
12 months ago
import httpx
12 months ago
import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher, filters, exceptions, F
from aiogram.enums import ChatAction, ParseMode
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, BufferedInputFile, \
CallbackQuery
12 months ago
import dotenv
12 months ago
from e621 import E621, E621Post, E621PostFile, E621PostVersion
12 months ago
12 months ago
dotenv.load_dotenv('.env')
12 months ago
redis = aioredis.from_url('redis://localhost')
e621 = E621()
logging.basicConfig(level=logging.INFO)
12 months ago
bot = Bot(token=os.environ['BOT_TOKEN'])
dp = Dispatcher()
12 months ago
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
11 months ago
PAGE_SIZE = 20
12 months ago
def format_tags(tags: Iterable[str]):
4 months ago
return ' '.join(f'#{re.sub(r"[^0-9a-zA-Z_]", "_", tag)}' for tag in tags) or 'None'
def recover_url(file: E621PostFile):
4 months ago
return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
4 months ago
try:
logging.warning(f'Sending post #{post.id}')
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
flat_post_tags = set(post.tags.flatten())
monitored_tags = set()
for tag_group in tag_list:
if set(tag_group) <= flat_post_tags:
monitored_tags.update(tag_group)
4 months ago
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
4 months ago
f'Monitored tags: <b>{format_tags(monitored_tags)}</b>',
4 months ago
artist_tags and f'Artist: <b>{format_tags(artist_tags)}</b>',
character_tags and f'Character: <b>{format_tags(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{format_tags(copyright_tags)}</b>',
f'\nhttps://e621.net/posts/{post.id}'
] if l)
if not post.file.url:
post.file.url = recover_url(post.file)
try:
logging.warning(post.file.url)
async with httpx.AsyncClient() as client:
file = BytesIO()
file.write((await client.get(post.file.url)).content)
file.name = f'file.{post.file.ext}'
file.seek(0)
if post.file.ext in ('webm', 'gif'):
with TemporaryDirectory() as td:
src_path = Path(td) / f'video.{post.file.ext}'
mp4_path = Path(td) / 'video.mp4'
with open(src_path, 'wb') as webm:
webm.write(file.read())
video_input = ffmpeg\
.input(str(src_path))
cmd = video_input \
.output(str(mp4_path),
vf='pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:x=0:y=0:color=Black',
vcodec='libx264',
crf='26')
logging.info('ffmpeg ' + ' '.join(cmd.get_args()))
cmd.run()
s3 = boto3.client('s3', aws_access_key_id=os.environ['AWS_ACCESS_KEY'], aws_secret_access_key=os.environ['AWS_SECRET_KEY'])
bucket = os.environ['AWS_S3_BUCKET']
upload_filename = f'e621-{post.id}-{int(time())}.mp4'
s3.upload_file(mp4_path, bucket, upload_filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'video/mp4'})
await bot.send_message(int(os.environ['SEND_CHAT']),
f'https://{bucket}.s3.amazonaws.com/{upload_filename}\n\n' + caption,
parse_mode=ParseMode.HTML)
src_path.unlink()
mp4_path.unlink()
elif post.file.ext in ('png', 'jpg'):
markup = InlineKeyboardMarkup(inline_keyboard=[[
InlineKeyboardButton(text='NSFW', callback_data='send nsfw'),
InlineKeyboardButton(text='Safe', callback_data='send pics'),
]])
# if post.file.size > 10000000:
logging.warning('compressing')
dl_im = Image.open(file).convert('RGBA')
size = dl_im.size
if size[0] > 2000 or size[1] > 2000:
larger_dimension = max(size)
ratio = 2000 / larger_dimension
dl_im = dl_im.resize((int(size[0] * ratio), int(size[1] * ratio)),
Image.LANCZOS)
logging.warning(f'Resizing from {size[0]}x{size[1]} to {dl_im.size[0]}x{dl_im.size[1]}')
im = Image.new('RGBA', dl_im.size, (255, 255, 255))
composite = Image.alpha_composite(im, dl_im).convert('RGB')
file = BytesIO()
composite.save(file, format='JPEG')
file.seek(0)
await bot.send_photo(int(os.environ['SEND_CHAT']),
BufferedInputFile(file.read(), 'file.jpg'),
caption=caption,
parse_mode=ParseMode.HTML,
reply_markup=markup)
await redis.sadd('e621:sent', post.id)
except Exception as e:
logging.exception(e)
4 months ago
except Exception as e:
logging.exception(e)
9 months ago
12 months ago
async def check_updates():
11 months ago
logging.warning('Waiting for lock...')
12 months ago
async with redis.lock('e621:update'):
11 months ago
logging.warning('Lock acquired...')
matched_posts = []
tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
tag_list_flat = set(sum(tag_list, ()))
last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
post_versions: List[E621PostVersion] = []
6 days ago
logging.warning(f'Getting post versions from id {last_post_version}')
for page in count(1):
6 days ago
logging.warning(f'Loading page {page}')
post_versions_page = await e621.get_post_versions(last_post_version, page)
post_versions += post_versions_page
if not last_post_version or not post_versions_page:
break
for post_version in post_versions[::-1]:
if post_version.id > last_post_version:
last_post_version = post_version.id
if not bool(tag_list_flat & set(post_version.added_tags)):
continue
post_tags = set(post_version.tags.split())
for tag_group in tag_list:
if set(tag_group) <= post_tags:
matched_posts.append(post_version.post_id)
break
matched_posts.sort()
if matched_posts:
already_sent: List = await redis.smismember('e621:sent', matched_posts)
posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
logging.warning(f'Found {len(posts_to_send)} posts')
for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
for i, post in enumerate(posts):
logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(posts_to_send)}')
await send_post(post, tag_list)
await redis.sadd('e621:sent', post.id)
await sleep(1)
await redis.set('e621:last_version', last_post_version)
# random.shuffle(tag_list)
# for tl_idx in range(0, len(tag_list), PAGE_SIZE):
# tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
# logging.warning(tags)
# posts = await e621.get_posts(tags)
# if not posts:
# return
# already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
# # last_index = len(posts)
# # if already_sent.count(True):
# # last_index = already_sent.index(True)
# # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
# for i in list(range(len(posts)))[::-1]:
# if already_sent[i]:
# continue
# await send_post(posts[i], tag_list)
# await sleep(1)
12 months ago
@dp.message(filters.Command('resend_after'), ChatFilter)
async def resend_after(msg: Message):
try:
timestamp = int(msg.text.split()[1])
except:
traceback.print_exc()
await msg.reply('Invalid timestamp or not provided')
return
async with redis.lock('e621:update'):
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
for i, tag in enumerate(tag_list):
4 months ago
await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
posts = []
page = 1
while True:
page_posts = await e621.get_posts(tag, page)
if not page_posts:
break
for post in page_posts:
if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp:
break
posts.append(post)
page += 1
for post in posts[::-1]:
await send_post(post, tag_list)
await msg.reply('Finished')
@dp.message(filters.Command('add'), ChatFilter)
12 months ago
async def add_tag(msg: Message):
args = ' '.join(msg.text.split()[1:])
12 months ago
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
12 months ago
await redis.sadd('e621:subs', tag)
await msg.reply(f'Tags {args} added')
@dp.message(filters.Command('add_tags'), ChatFilter)
async def add_tags(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tags to subscribe to')
return
tags = args.split()
tags.sort()
tags = ' '.join(tags)
posts = await e621.get_posts(tags)
await redis.sadd('e621:sent', *[post.id for post in posts])
await redis.sadd('e621:subs', tags)
await msg.reply(f'Tag group <code>{tags}</code> added', parse_mode=ParseMode.HTML)
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
async def mark_old_as_sent(msg: Message):
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
m = await msg.reply(f'0/{len(tag_list)} tags have old posts marked as sent')
for i, tag in enumerate(tag_list, 1):
posts = await e621.get_posts(tag)
await redis.sadd('e621:sent', *[post.id for post in posts])
await m.edit_text(f'{i}/{len(tag_list)} tags have old posts marked as sent')
await sleep(1)
await m.edit_text(f'Done marking old posts as sent for {len(tag_list)} tags')
@dp.message(filters.Command(re.compile(r'del_\S+')), ChatFilter)
12 months ago
async def del_tag(msg: Message):
args = msg.text[5:]
if not args:
await msg.reply('Please provide tag to unsubscribe from')
12 months ago
return
if ' ' in args:
await msg.reply('Tag should not contain spaces')
return
if not await redis.sismember('e621:subs', args):
await msg.reply('Tag not found')
return
await redis.srem('e621:subs', args)
11 months ago
await msg.reply(f'Tag {args} removed')
12 months ago
@dp.message(filters.Command('del'), ChatFilter)
async def del_command(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tag to subscribe to')
return
for tag in args.split():
await redis.srem('e621:subs', tag)
await msg.reply(f'Tags {args} removed')
@dp.message(filters.Command('list'), ChatFilter)
12 months ago
async def list_tags(msg: Message):
tags = [t.decode() for t in await redis.smembers('e621:subs')]
11 months ago
tags.sort()
12 months ago
lines = []
for tag in tags:
11 months ago
entry = f'- {tag} [/del_{tag}]'
if len('\n'.join(lines + [entry])) > 2000:
11 months ago
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
lines = [entry]
else:
lines.append(f'- {tag} [/del_{tag}]')
12 months ago
lines = "\n".join(lines)
await msg.reply(f'Monitored tags:\n\n{lines}')
@dp.message(filters.Command('update'), ChatFilter)
12 months ago
async def update(msg: Message):
await check_updates()
@dp.message(filters.Command('test'), ChatFilter)
9 months ago
async def test(msg: Message):
args = ' '.join(msg.text.split()[1:])
9 months ago
if not args:
await msg.reply('Please provide post id')
return
post = await e621.get_posts(f'id:{args}')
print(f'id:{args}')
9 months ago
if not post:
await msg.reply('Post not found')
return
tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
await send_post(post[0], tag_list)
9 months ago
@dp.callback_query(F.data.startswith('send '))
async def send_callback(cq: CallbackQuery):
_, destination = cq.data.split()
img_bytes = BytesIO()
await bot.download(cq.message.photo[-1], img_bytes)
img_bytes.seek(0)
data = base64.b64encode(img_bytes.read()).decode()
async with httpx.AsyncClient() as client:
try:
r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
"method": "post_photo",
"params": [data, True],
"jsonrpc": "2.0",
"id": 0,
})
resp = r.json()
if 'result' in resp and resp['result'] == True:
await cq.answer('Sent')
elif 'result' in resp and resp['result'] == 'duplicate':
await cq.answer('Duplicate')
else:
raise Exception(resp)
except:
traceback.print_exc()
await cq.answer('An error has occurred, check logs')
12 months ago
async def background_on_start():
await redis.delete('e621:update')
12 months ago
while True:
11 months ago
logging.warning('Checking updates...')
try:
await check_updates()
except Exception as e:
logging.exception(e)
11 months ago
logging.warning('Sleeping...')
11 months ago
await asyncio.sleep(600)
12 months ago
async def on_bot_startup():
12 months ago
asyncio.create_task(background_on_start())
async def main():
dp.startup.register(on_bot_startup)
await dp.start_polling(bot)
12 months ago
if __name__ == '__main__':
asyncio.run(main())