diff --git a/e621.py b/e621.py
index 8c6f671..f4e2249 100644
--- a/e621.py
+++ b/e621.py
@@ -132,6 +132,35 @@ class E621Post:
duration: Optional[float]
+@dataclass_json
+@dataclass
+class E621PostVersion:
+ id: int
+ post_id: int
+ tags: str
+ updater_id: int
+ updated_at: str
+ rating: Rating
+ parent_id: Optional[int]
+ source: Optional[str]
+ description: str
+ reason: Optional[str]
+ locked_tags: Optional[str]
+ added_tags: List[str]
+ removed_tags: List[str]
+ added_locked_tags: List[str]
+ removed_locked_tags: List[str]
+ rating_changed: bool
+ parent_changed: bool
+ source_changed: bool
+ description_changed: bool
+ version: int
+ obsolete_added_tags: str
+ obsolete_removed_tags: str
+ unchanged_tags: str
+ updater_name: str
+
+
class E621:
def __init__(self):
self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net')
@@ -142,3 +171,7 @@ class E621:
async def get_post(self, post_id: str) -> E621Post:
return (await self.get_posts(f'id:{post_id}'))[0]
+
+ async def get_post_versions(self, start_id=0, page=1, limit=320) -> List[E621PostVersion]:
+ r = (await self.client.get('/post_versions.json', params={'search[start_id]': start_id, 'limit': limit})).json()
+ return [E621PostVersion.from_dict(p) for p in r]
diff --git a/main.py b/main.py
index 6784965..8daeade 100644
--- a/main.py
+++ b/main.py
@@ -8,6 +8,7 @@ import re
import traceback
from asyncio import sleep
from io import BytesIO
+from itertools import count
from pathlib import Path
from tempfile import TemporaryDirectory
from time import time
@@ -25,7 +26,7 @@ from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, B
CallbackQuery
import dotenv
-from e621 import E621, E621Post, E621PostFile
+from e621 import E621, E621Post, E621PostFile, E621PostVersion
dotenv.load_dotenv('.env')
@@ -49,11 +50,15 @@ def recover_url(file: E621PostFile):
return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
-async def send_post(post: E621Post, tag_list: List[str]):
+async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
try:
logging.warning(f'Sending post #{post.id}')
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
- monitored_tags = set(post.tags.flatten()) & set(tag_list)
+ flat_post_tags = set(post.tags.flatten())
+ monitored_tags = set()
+ for tag_group in tag_list:
+ if set(tag_group) <= flat_post_tags:
+ monitored_tags.update(tag_group)
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
@@ -133,24 +138,55 @@ async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
- tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
- random.shuffle(tag_list)
- for tl_idx in range(0, len(tag_list), PAGE_SIZE):
- tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
- logging.warning(tags)
- posts = await e621.get_posts(tags)
- if not posts:
- return
- already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
- # last_index = len(posts)
- # if already_sent.count(True):
- # last_index = already_sent.index(True)
- # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
- for i in list(range(len(posts)))[::-1]:
- if already_sent[i]:
- continue
- await send_post(posts[i], tag_list)
- await sleep(1)
+ matched_posts = []
+ tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
+ last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
+ post_versions: List[E621PostVersion] = []
+ for page in count(1):
+ post_versions_page = await e621.get_post_versions(last_post_version, page)
+ post_versions += post_versions_page
+ if not last_post_version or not post_versions_page:
+ break
+ for post_version in post_versions[::-1]:
+ if post_version.id > last_post_version:
+ last_post_version = post_version.id
+ post_tags = set(post_version.tags.split())
+ for tag_group in tag_list:
+ if set(tag_group) <= post_tags:
+ matched_posts.append(post_version.post_id)
+ break
+ matched_posts.sort()
+ if matched_posts:
+ logging.warning(f'Found {len(matched_posts)} posts')
+ already_sent: List = await redis.smismember('e621:sent', matched_posts)
+ posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
+ for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
+ chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
+ posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
+ for i, post in enumerate(posts):
+ logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(matched_posts)}')
+ await send_post(post, tag_list)
+ await redis.sadd('e621:sent', post.id)
+ await sleep(1)
+ await redis.set('e621:last_version', last_post_version)
+
+ # random.shuffle(tag_list)
+ # for tl_idx in range(0, len(tag_list), PAGE_SIZE):
+ # tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
+ # logging.warning(tags)
+ # posts = await e621.get_posts(tags)
+ # if not posts:
+ # return
+ # already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
+ # # last_index = len(posts)
+ # # if already_sent.count(True):
+ # # last_index = already_sent.index(True)
+ # # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
+ # for i in list(range(len(posts)))[::-1]:
+ # if already_sent[i]:
+ # continue
+ # await send_post(posts[i], tag_list)
+ # await sleep(1)
@dp.message(filters.Command('resend_after'), ChatFilter)
@@ -163,7 +199,7 @@ async def resend_after(msg: Message):
return
async with redis.lock('e621:update'):
- tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
+ tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
for i, tag in enumerate(tag_list):
await msg.reply(f'Checking tag {tag} ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
posts = []
@@ -195,6 +231,21 @@ async def add_tag(msg: Message):
await msg.reply(f'Tags {args} added')
+@dp.message(filters.Command('add_tags'), ChatFilter)
+async def add_tags(msg: Message):
+ args = ' '.join(msg.text.split()[1:])
+ if not args:
+ await msg.reply('Please provide tags to subscribe to')
+ return
+ tags = args.split()
+ tags.sort()
+ tags = ' '.join(tags)
+ posts = await e621.get_posts(tags)
+ await redis.sadd('e621:sent', *[post.id for post in posts])
+ await redis.sadd('e621:subs', tags)
+ await msg.reply(f'Tag group {tags}
added', parse_mode=ParseMode.HTML)
+
+
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
async def mark_old_as_sent(msg: Message):
logging.warning('Waiting for lock...')
@@ -270,7 +321,8 @@ async def test(msg: Message):
if not post:
await msg.reply('Post not found')
return
- await send_post(post[0], [])
+ tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
+ await send_post(post[0], tag_list)
@dp.callback_query(F.data.startswith('send '))