use post versions to find updates

master
bakatrouble 3 months ago
parent 061115be9e
commit 2d1fb6f9b0

@ -132,6 +132,35 @@ class E621Post:
duration: Optional[float] duration: Optional[float]
@dataclass_json
@dataclass
class E621PostVersion:
id: int
post_id: int
tags: str
updater_id: int
updated_at: str
rating: Rating
parent_id: Optional[int]
source: Optional[str]
description: str
reason: Optional[str]
locked_tags: Optional[str]
added_tags: List[str]
removed_tags: List[str]
added_locked_tags: List[str]
removed_locked_tags: List[str]
rating_changed: bool
parent_changed: bool
source_changed: bool
description_changed: bool
version: int
obsolete_added_tags: str
obsolete_removed_tags: str
unchanged_tags: str
updater_name: str
class E621: class E621:
def __init__(self): def __init__(self):
self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net') self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net')
@ -142,3 +171,7 @@ class E621:
async def get_post(self, post_id: str) -> E621Post: async def get_post(self, post_id: str) -> E621Post:
return (await self.get_posts(f'id:{post_id}'))[0] return (await self.get_posts(f'id:{post_id}'))[0]
async def get_post_versions(self, start_id=0, page=1, limit=320) -> List[E621PostVersion]:
r = (await self.client.get('/post_versions.json', params={'search[start_id]': start_id, 'limit': limit})).json()
return [E621PostVersion.from_dict(p) for p in r]

@ -8,6 +8,7 @@ import re
import traceback import traceback
from asyncio import sleep from asyncio import sleep
from io import BytesIO from io import BytesIO
from itertools import count
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from time import time from time import time
@ -25,7 +26,7 @@ from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, B
CallbackQuery CallbackQuery
import dotenv import dotenv
from e621 import E621, E621Post, E621PostFile from e621 import E621, E621Post, E621PostFile, E621PostVersion
dotenv.load_dotenv('.env') dotenv.load_dotenv('.env')
@ -49,11 +50,15 @@ def recover_url(file: E621PostFile):
return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}' return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
async def send_post(post: E621Post, tag_list: List[str]): async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
try: try:
logging.warning(f'Sending post #{post.id}') logging.warning(f'Sending post #{post.id}')
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING) await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
monitored_tags = set(post.tags.flatten()) & set(tag_list) flat_post_tags = set(post.tags.flatten())
monitored_tags = set()
for tag_group in tag_list:
if set(tag_group) <= flat_post_tags:
monitored_tags.update(tag_group)
artist_tags = post.tags.artist artist_tags = post.tags.artist
character_tags = post.tags.character character_tags = post.tags.character
copyright_tags = post.tags.copyright copyright_tags = post.tags.copyright
@ -133,24 +138,55 @@ async def check_updates():
logging.warning('Waiting for lock...') logging.warning('Waiting for lock...')
async with redis.lock('e621:update'): async with redis.lock('e621:update'):
logging.warning('Lock acquired...') logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')] matched_posts = []
random.shuffle(tag_list) tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
for tl_idx in range(0, len(tag_list), PAGE_SIZE): last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE]) post_versions: List[E621PostVersion] = []
logging.warning(tags) for page in count(1):
posts = await e621.get_posts(tags) post_versions_page = await e621.get_post_versions(last_post_version, page)
if not posts: post_versions += post_versions_page
return if not last_post_version or not post_versions_page:
already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts]) break
# last_index = len(posts) for post_version in post_versions[::-1]:
# if already_sent.count(True): if post_version.id > last_post_version:
# last_index = already_sent.index(True) last_post_version = post_version.id
# await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))]) post_tags = set(post_version.tags.split())
for i in list(range(len(posts)))[::-1]: for tag_group in tag_list:
if already_sent[i]: if set(tag_group) <= post_tags:
continue matched_posts.append(post_version.post_id)
await send_post(posts[i], tag_list) break
await sleep(1) matched_posts.sort()
if matched_posts:
logging.warning(f'Found {len(matched_posts)} posts')
already_sent: List = await redis.smismember('e621:sent', matched_posts)
posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
for i, post in enumerate(posts):
logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(matched_posts)}')
await send_post(post, tag_list)
await redis.sadd('e621:sent', post.id)
await sleep(1)
await redis.set('e621:last_version', last_post_version)
# random.shuffle(tag_list)
# for tl_idx in range(0, len(tag_list), PAGE_SIZE):
# tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
# logging.warning(tags)
# posts = await e621.get_posts(tags)
# if not posts:
# return
# already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
# # last_index = len(posts)
# # if already_sent.count(True):
# # last_index = already_sent.index(True)
# # await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
# for i in list(range(len(posts)))[::-1]:
# if already_sent[i]:
# continue
# await send_post(posts[i], tag_list)
# await sleep(1)
@dp.message(filters.Command('resend_after'), ChatFilter) @dp.message(filters.Command('resend_after'), ChatFilter)
@ -163,7 +199,7 @@ async def resend_after(msg: Message):
return return
async with redis.lock('e621:update'): async with redis.lock('e621:update'):
tag_list = [t.decode() for t in await redis.smembers('e621:subs')] tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
for i, tag in enumerate(tag_list): for i, tag in enumerate(tag_list):
await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML) await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
posts = [] posts = []
@ -195,6 +231,21 @@ async def add_tag(msg: Message):
await msg.reply(f'Tags {args} added') await msg.reply(f'Tags {args} added')
@dp.message(filters.Command('add_tags'), ChatFilter)
async def add_tags(msg: Message):
args = ' '.join(msg.text.split()[1:])
if not args:
await msg.reply('Please provide tags to subscribe to')
return
tags = args.split()
tags.sort()
tags = ' '.join(tags)
posts = await e621.get_posts(tags)
await redis.sadd('e621:sent', *[post.id for post in posts])
await redis.sadd('e621:subs', tags)
await msg.reply(f'Tag group <code>{tags}</code> added', parse_mode=ParseMode.HTML)
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter) @dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
async def mark_old_as_sent(msg: Message): async def mark_old_as_sent(msg: Message):
logging.warning('Waiting for lock...') logging.warning('Waiting for lock...')
@ -270,7 +321,8 @@ async def test(msg: Message):
if not post: if not post:
await msg.reply('Post not found') await msg.reply('Post not found')
return return
await send_post(post[0], []) tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
await send_post(post[0], tag_list)
@dp.callback_query(F.data.startswith('send ')) @dp.callback_query(F.data.startswith('send '))

Loading…
Cancel
Save