diff --git a/e621.py b/e621.py index 8c6f671..f4e2249 100644 --- a/e621.py +++ b/e621.py @@ -132,6 +132,35 @@ class E621Post: duration: Optional[float] +@dataclass_json +@dataclass +class E621PostVersion: + id: int + post_id: int + tags: str + updater_id: int + updated_at: str + rating: Rating + parent_id: Optional[int] + source: Optional[str] + description: str + reason: Optional[str] + locked_tags: Optional[str] + added_tags: List[str] + removed_tags: List[str] + added_locked_tags: List[str] + removed_locked_tags: List[str] + rating_changed: bool + parent_changed: bool + source_changed: bool + description_changed: bool + version: int + obsolete_added_tags: str + obsolete_removed_tags: str + unchanged_tags: str + updater_name: str + + class E621: def __init__(self): self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net') @@ -142,3 +171,7 @@ class E621: async def get_post(self, post_id: str) -> E621Post: return (await self.get_posts(f'id:{post_id}'))[0] + + async def get_post_versions(self, start_id=0, page=1, limit=320) -> List[E621PostVersion]: + r = (await self.client.get('/post_versions.json', params={'search[start_id]': start_id, 'limit': limit})).json() + return [E621PostVersion.from_dict(p) for p in r] diff --git a/main.py b/main.py index 6784965..8daeade 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ import re import traceback from asyncio import sleep from io import BytesIO +from itertools import count from pathlib import Path from tempfile import TemporaryDirectory from time import time @@ -25,7 +26,7 @@ from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, B CallbackQuery import dotenv -from e621 import E621, E621Post, E621PostFile +from e621 import E621, E621Post, E621PostFile, E621PostVersion dotenv.load_dotenv('.env') @@ -49,11 +50,15 @@ def recover_url(file: E621PostFile): return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}' -async def send_post(post: E621Post, tag_list: List[str]): +async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]): try: logging.warning(f'Sending post #{post.id}') await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING) - monitored_tags = set(post.tags.flatten()) & set(tag_list) + flat_post_tags = set(post.tags.flatten()) + monitored_tags = set() + for tag_group in tag_list: + if set(tag_group) <= flat_post_tags: + monitored_tags.update(tag_group) artist_tags = post.tags.artist character_tags = post.tags.character copyright_tags = post.tags.copyright @@ -133,24 +138,55 @@ async def check_updates(): logging.warning('Waiting for lock...') async with redis.lock('e621:update'): logging.warning('Lock acquired...') - tag_list = [t.decode() for t in await redis.smembers('e621:subs')] - random.shuffle(tag_list) - for tl_idx in range(0, len(tag_list), PAGE_SIZE): - tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE]) - logging.warning(tags) - posts = await e621.get_posts(tags) - if not posts: - return - already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts]) - # last_index = len(posts) - # if already_sent.count(True): - # last_index = already_sent.index(True) - # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))]) - for i in list(range(len(posts)))[::-1]: - if already_sent[i]: - continue - await send_post(posts[i], tag_list) - await sleep(1) + matched_posts = [] + tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs')) + last_post_version = int((await redis.get('e621:last_version') or b'0').decode()) + post_versions: List[E621PostVersion] = [] + for page in count(1): + post_versions_page = await e621.get_post_versions(last_post_version, page) + post_versions += post_versions_page + if not last_post_version or not post_versions_page: + break + for post_version in post_versions[::-1]: + if post_version.id > last_post_version: + last_post_version = post_version.id + post_tags = set(post_version.tags.split()) + for tag_group in tag_list: + if set(tag_group) <= post_tags: + matched_posts.append(post_version.post_id) + break + matched_posts.sort() + if matched_posts: + logging.warning(f'Found {len(matched_posts)} posts') + already_sent: List = await redis.smismember('e621:sent', matched_posts) + posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent] + for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE): + chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE] + posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk)) + for i, post in enumerate(posts): + logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(matched_posts)}') + await send_post(post, tag_list) + await redis.sadd('e621:sent', post.id) + await sleep(1) + await redis.set('e621:last_version', last_post_version) + + # random.shuffle(tag_list) + # for tl_idx in range(0, len(tag_list), PAGE_SIZE): + # tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE]) + # logging.warning(tags) + # posts = await e621.get_posts(tags) + # if not posts: + # return + # already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts]) + # # last_index = len(posts) + # # if already_sent.count(True): + # # last_index = already_sent.index(True) + # # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))]) + # for i in list(range(len(posts)))[::-1]: + # if already_sent[i]: + # continue + # await send_post(posts[i], tag_list) + # await sleep(1) @dp.message(filters.Command('resend_after'), ChatFilter) @@ -163,7 +199,7 @@ async def resend_after(msg: Message): return async with redis.lock('e621:update'): - tag_list = [t.decode() for t in await redis.smembers('e621:subs')] + tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')] for i, tag in enumerate(tag_list): await msg.reply(f'Checking tag {tag} ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML) posts = [] @@ -195,6 +231,21 @@ async def add_tag(msg: Message): await msg.reply(f'Tags {args} added') +@dp.message(filters.Command('add_tags'), ChatFilter) +async def add_tags(msg: Message): + args = ' '.join(msg.text.split()[1:]) + if not args: + await msg.reply('Please provide tags to subscribe to') + return + tags = args.split() + tags.sort() + tags = ' '.join(tags) + posts = await e621.get_posts(tags) + await redis.sadd('e621:sent', *[post.id for post in posts]) + await redis.sadd('e621:subs', tags) + await msg.reply(f'Tag group {tags} added', parse_mode=ParseMode.HTML) + + @dp.message(filters.Command('mark_old_as_sent'), ChatFilter) async def mark_old_as_sent(msg: Message): logging.warning('Waiting for lock...') @@ -270,7 +321,8 @@ async def test(msg: Message): if not post: await msg.reply('Post not found') return - await send_post(post[0], []) + tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')] + await send_post(post[0], tag_list) @dp.callback_query(F.data.startswith('send '))