use post versions to find updates
This commit is contained in:
		
							
								
								
									
										33
									
								
								e621.py
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								e621.py
									
									
									
									
									
								
							@@ -132,6 +132,35 @@ class E621Post:
 | 
				
			|||||||
    duration: Optional[float]
 | 
					    duration: Optional[float]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass_json
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class E621PostVersion:
 | 
				
			||||||
 | 
					    id: int
 | 
				
			||||||
 | 
					    post_id: int
 | 
				
			||||||
 | 
					    tags: str
 | 
				
			||||||
 | 
					    updater_id: int
 | 
				
			||||||
 | 
					    updated_at: str
 | 
				
			||||||
 | 
					    rating: Rating
 | 
				
			||||||
 | 
					    parent_id: Optional[int]
 | 
				
			||||||
 | 
					    source: Optional[str]
 | 
				
			||||||
 | 
					    description: str
 | 
				
			||||||
 | 
					    reason: Optional[str]
 | 
				
			||||||
 | 
					    locked_tags: Optional[str]
 | 
				
			||||||
 | 
					    added_tags: List[str]
 | 
				
			||||||
 | 
					    removed_tags: List[str]
 | 
				
			||||||
 | 
					    added_locked_tags: List[str]
 | 
				
			||||||
 | 
					    removed_locked_tags: List[str]
 | 
				
			||||||
 | 
					    rating_changed: bool
 | 
				
			||||||
 | 
					    parent_changed: bool
 | 
				
			||||||
 | 
					    source_changed: bool
 | 
				
			||||||
 | 
					    description_changed: bool
 | 
				
			||||||
 | 
					    version: int
 | 
				
			||||||
 | 
					    obsolete_added_tags: str
 | 
				
			||||||
 | 
					    obsolete_removed_tags: str
 | 
				
			||||||
 | 
					    unchanged_tags: str
 | 
				
			||||||
 | 
					    updater_name: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class E621:
 | 
					class E621:
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net')
 | 
					        self.client = httpx.AsyncClient(headers={'user-agent': 'bot/1.0 (bakatrouble)'}, base_url='https://e621.net')
 | 
				
			||||||
@@ -142,3 +171,7 @@ class E621:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    async def get_post(self, post_id: str) -> E621Post:
 | 
					    async def get_post(self, post_id: str) -> E621Post:
 | 
				
			||||||
        return (await self.get_posts(f'id:{post_id}'))[0]
 | 
					        return (await self.get_posts(f'id:{post_id}'))[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_post_versions(self, start_id=0, page=1, limit=320) -> List[E621PostVersion]:
 | 
				
			||||||
 | 
					        r = (await self.client.get('/post_versions.json', params={'search[start_id]': start_id, 'limit': limit})).json()
 | 
				
			||||||
 | 
					        return [E621PostVersion.from_dict(p) for p in r]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										98
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										98
									
								
								main.py
									
									
									
									
									
								
							@@ -8,6 +8,7 @@ import re
 | 
				
			|||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
from asyncio import sleep
 | 
					from asyncio import sleep
 | 
				
			||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
 | 
					from itertools import count
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from tempfile import TemporaryDirectory
 | 
					from tempfile import TemporaryDirectory
 | 
				
			||||||
from time import time
 | 
					from time import time
 | 
				
			||||||
@@ -25,7 +26,7 @@ from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, B
 | 
				
			|||||||
    CallbackQuery
 | 
					    CallbackQuery
 | 
				
			||||||
import dotenv
 | 
					import dotenv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from e621 import E621, E621Post, E621PostFile
 | 
					from e621 import E621, E621Post, E621PostFile, E621PostVersion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dotenv.load_dotenv('.env')
 | 
					dotenv.load_dotenv('.env')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -49,11 +50,15 @@ def recover_url(file: E621PostFile):
 | 
				
			|||||||
    return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
 | 
					    return f'https://static1.e621.net/data/{file.md5[:2]}/{file.md5[2:4]}/{file.md5}.{file.ext}'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def send_post(post: E621Post, tag_list: List[str]):
 | 
					async def send_post(post: E621Post, tag_list: Iterable[Iterable[str]]):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        logging.warning(f'Sending post #{post.id}')
 | 
					        logging.warning(f'Sending post #{post.id}')
 | 
				
			||||||
        await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
 | 
					        await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatAction.TYPING)
 | 
				
			||||||
        monitored_tags = set(post.tags.flatten()) & set(tag_list)
 | 
					        flat_post_tags = set(post.tags.flatten())
 | 
				
			||||||
 | 
					        monitored_tags = set()
 | 
				
			||||||
 | 
					        for tag_group in tag_list:
 | 
				
			||||||
 | 
					            if set(tag_group) <= flat_post_tags:
 | 
				
			||||||
 | 
					                monitored_tags.update(tag_group)
 | 
				
			||||||
        artist_tags = post.tags.artist
 | 
					        artist_tags = post.tags.artist
 | 
				
			||||||
        character_tags = post.tags.character
 | 
					        character_tags = post.tags.character
 | 
				
			||||||
        copyright_tags = post.tags.copyright
 | 
					        copyright_tags = post.tags.copyright
 | 
				
			||||||
@@ -133,24 +138,55 @@ async def check_updates():
 | 
				
			|||||||
    logging.warning('Waiting for lock...')
 | 
					    logging.warning('Waiting for lock...')
 | 
				
			||||||
    async with redis.lock('e621:update'):
 | 
					    async with redis.lock('e621:update'):
 | 
				
			||||||
        logging.warning('Lock acquired...')
 | 
					        logging.warning('Lock acquired...')
 | 
				
			||||||
        tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
					        matched_posts = []
 | 
				
			||||||
        random.shuffle(tag_list)
 | 
					        tag_list = set(tuple(t.decode().split()) for t in await redis.smembers('e621:subs'))
 | 
				
			||||||
        for tl_idx in range(0, len(tag_list), PAGE_SIZE):
 | 
					        last_post_version = int((await redis.get('e621:last_version') or b'0').decode())
 | 
				
			||||||
            tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
 | 
					        post_versions: List[E621PostVersion] = []
 | 
				
			||||||
            logging.warning(tags)
 | 
					        for page in count(1):
 | 
				
			||||||
            posts = await e621.get_posts(tags)
 | 
					            post_versions_page = await e621.get_post_versions(last_post_version, page)
 | 
				
			||||||
            if not posts:
 | 
					            post_versions += post_versions_page
 | 
				
			||||||
                return
 | 
					            if not last_post_version or not post_versions_page:
 | 
				
			||||||
            already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
 | 
					                break
 | 
				
			||||||
            # last_index = len(posts)
 | 
					        for post_version in post_versions[::-1]:
 | 
				
			||||||
            # if already_sent.count(True):
 | 
					            if post_version.id > last_post_version:
 | 
				
			||||||
            #     last_index = already_sent.index(True)
 | 
					                last_post_version = post_version.id
 | 
				
			||||||
            #     await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
 | 
					            post_tags = set(post_version.tags.split())
 | 
				
			||||||
            for i in list(range(len(posts)))[::-1]:
 | 
					            for tag_group in tag_list:
 | 
				
			||||||
                if already_sent[i]:
 | 
					                if set(tag_group) <= post_tags:
 | 
				
			||||||
                    continue
 | 
					                    matched_posts.append(post_version.post_id)
 | 
				
			||||||
                await send_post(posts[i], tag_list)
 | 
					                    break
 | 
				
			||||||
                await sleep(1)
 | 
					        matched_posts.sort()
 | 
				
			||||||
 | 
					        if matched_posts:
 | 
				
			||||||
 | 
					            logging.warning(f'Found {len(matched_posts)} posts')
 | 
				
			||||||
 | 
					            already_sent: List = await redis.smismember('e621:sent', matched_posts)
 | 
				
			||||||
 | 
					            posts_to_send = [post_id for post_id, sent in zip(matched_posts, already_sent) if not sent]
 | 
				
			||||||
 | 
					            for post_chunk_idx in range(0, len(posts_to_send), PAGE_SIZE):
 | 
				
			||||||
 | 
					                chunk = posts_to_send[post_chunk_idx: post_chunk_idx + PAGE_SIZE]
 | 
				
			||||||
 | 
					                posts = await e621.get_posts('order:id id:' + ','.join(f'{post_id}' for post_id in chunk))
 | 
				
			||||||
 | 
					                for i, post in enumerate(posts):
 | 
				
			||||||
 | 
					                    logging.warning(f'Sending post {post_chunk_idx + i + 1}/{len(matched_posts)}')
 | 
				
			||||||
 | 
					                    await send_post(post, tag_list)
 | 
				
			||||||
 | 
					                    await redis.sadd('e621:sent', post.id)
 | 
				
			||||||
 | 
					                    await sleep(1)
 | 
				
			||||||
 | 
					        await redis.set('e621:last_version', last_post_version)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # random.shuffle(tag_list)
 | 
				
			||||||
 | 
					        # for tl_idx in range(0, len(tag_list), PAGE_SIZE):
 | 
				
			||||||
 | 
					        #     tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + PAGE_SIZE])
 | 
				
			||||||
 | 
					        #     logging.warning(tags)
 | 
				
			||||||
 | 
					        #     posts = await e621.get_posts(tags)
 | 
				
			||||||
 | 
					        #     if not posts:
 | 
				
			||||||
 | 
					        #         return
 | 
				
			||||||
 | 
					        #     already_sent: List = await redis.smismember('e621:sent', [p.id for p in posts])
 | 
				
			||||||
 | 
					        #     # last_index = len(posts)
 | 
				
			||||||
 | 
					        #     # if already_sent.count(True):
 | 
				
			||||||
 | 
					        #     #     last_index = already_sent.index(True)
 | 
				
			||||||
 | 
					        #     #     await redis.sadd('e621:sent', *[posts[i].id for i in range(last_index, len(posts))])
 | 
				
			||||||
 | 
					        #     for i in list(range(len(posts)))[::-1]:
 | 
				
			||||||
 | 
					        #         if already_sent[i]:
 | 
				
			||||||
 | 
					        #             continue
 | 
				
			||||||
 | 
					        #         await send_post(posts[i], tag_list)
 | 
				
			||||||
 | 
					        #         await sleep(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dp.message(filters.Command('resend_after'), ChatFilter)
 | 
					@dp.message(filters.Command('resend_after'), ChatFilter)
 | 
				
			||||||
@@ -163,7 +199,7 @@ async def resend_after(msg: Message):
 | 
				
			|||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with redis.lock('e621:update'):
 | 
					    async with redis.lock('e621:update'):
 | 
				
			||||||
        tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
 | 
					        tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
 | 
				
			||||||
        for i, tag in enumerate(tag_list):
 | 
					        for i, tag in enumerate(tag_list):
 | 
				
			||||||
            await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
 | 
					            await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
 | 
				
			||||||
            posts = []
 | 
					            posts = []
 | 
				
			||||||
@@ -195,6 +231,21 @@ async def add_tag(msg: Message):
 | 
				
			|||||||
    await msg.reply(f'Tags {args} added')
 | 
					    await msg.reply(f'Tags {args} added')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dp.message(filters.Command('add_tags'), ChatFilter)
 | 
				
			||||||
 | 
					async def add_tags(msg: Message):
 | 
				
			||||||
 | 
					    args = ' '.join(msg.text.split()[1:])
 | 
				
			||||||
 | 
					    if not args:
 | 
				
			||||||
 | 
					        await msg.reply('Please provide tags to subscribe to')
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    tags = args.split()
 | 
				
			||||||
 | 
					    tags.sort()
 | 
				
			||||||
 | 
					    tags = ' '.join(tags)
 | 
				
			||||||
 | 
					    posts = await e621.get_posts(tags)
 | 
				
			||||||
 | 
					    await redis.sadd('e621:sent', *[post.id for post in posts])
 | 
				
			||||||
 | 
					    await redis.sadd('e621:subs', tags)
 | 
				
			||||||
 | 
					    await msg.reply(f'Tag group <code>{tags}</code> added', parse_mode=ParseMode.HTML)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
 | 
					@dp.message(filters.Command('mark_old_as_sent'), ChatFilter)
 | 
				
			||||||
async def mark_old_as_sent(msg: Message):
 | 
					async def mark_old_as_sent(msg: Message):
 | 
				
			||||||
    logging.warning('Waiting for lock...')
 | 
					    logging.warning('Waiting for lock...')
 | 
				
			||||||
@@ -270,7 +321,8 @@ async def test(msg: Message):
 | 
				
			|||||||
    if not post:
 | 
					    if not post:
 | 
				
			||||||
        await msg.reply('Post not found')
 | 
					        await msg.reply('Post not found')
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    await send_post(post[0], [])
 | 
					    tag_list = [tuple(t.decode().split()) for t in await redis.smembers('e621:subs')]
 | 
				
			||||||
 | 
					    await send_post(post[0], tag_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dp.callback_query(F.data.startswith('send '))
 | 
					@dp.callback_query(F.data.startswith('send '))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user