add test command
This commit is contained in:
		
							
								
								
									
										3
									
								
								e621.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								e621.py
									
									
									
									
									
								
							| @@ -137,3 +137,6 @@ class E621: | ||||
|     async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]: | ||||
|         r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json() | ||||
|         return [E621Post.from_dict(p) for p in r['posts']] | ||||
|  | ||||
|     async def get_post(self, post_id: str) -> E621Post: | ||||
|         return (await self.get_posts(f'id:{post_id}'))[0] | ||||
|   | ||||
							
								
								
									
										59
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								main.py
									
									
									
									
									
								
							| @@ -2,17 +2,19 @@ import asyncio | ||||
| import logging | ||||
| import os | ||||
| from io import BytesIO | ||||
| from typing import List | ||||
|  | ||||
| from PIL import Image | ||||
|  | ||||
| import httpx | ||||
| import redis.asyncio as aioredis | ||||
| from aiogram import Bot, Dispatcher | ||||
| from aiogram.dispatcher import filters | ||||
| from aiogram.types import Message, ParseMode | ||||
| from aiogram.types import Message, ParseMode, ChatActions | ||||
| from aiogram.utils import executor, exceptions | ||||
| import dotenv | ||||
|  | ||||
| from e621 import E621 | ||||
| from e621 import E621, E621Post | ||||
|  | ||||
| dotenv.load_dotenv('.env') | ||||
|  | ||||
| @@ -26,29 +28,14 @@ dp = Dispatcher(bot) | ||||
| user_ids = list(map(int, os.environ['USERS'].split(','))) | ||||
|  | ||||
|  | ||||
| async def check_updates(): | ||||
|     logging.warning('Waiting for lock...') | ||||
|     async with redis.lock('e621:update'): | ||||
|         logging.warning('Lock acquired...') | ||||
|         tag_list = [t.decode() for t in await redis.smembers('e621:subs')] | ||||
|         tag_list.sort() | ||||
|         for tl_idx in range(0, len(tag_list), 40): | ||||
|             tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40]) | ||||
|             logging.warning(tags) | ||||
|             posts = await e621.get_posts(tags) | ||||
|             if not posts: | ||||
|                 return | ||||
|             already_sent = await redis.smismember('e621:sent', [p.id for p in posts]) | ||||
|             for i in range(len(posts)): | ||||
|                 if already_sent[i]: | ||||
|                     continue | ||||
|                 post = posts[i] | ||||
| async def send_post(post: E621Post, tag_list: List[str]): | ||||
|     await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING) | ||||
|     monitored_tags = set(post.tags.flatten()) & set(tag_list) | ||||
|     artist_tags = post.tags.artist | ||||
|     character_tags = post.tags.character | ||||
|     copyright_tags = post.tags.copyright | ||||
|     caption = '\n'.join(l for l in [ | ||||
|                     f'Monitored tags: <b>{" ".join(monitored_tags)}</b>', | ||||
|         f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>', | ||||
|         artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>', | ||||
|         character_tags and f'Character: <b>{", ".join(character_tags)}</b>', | ||||
|         copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>', | ||||
| @@ -107,6 +94,25 @@ async def check_updates(): | ||||
|             logging.exception(e) | ||||
|  | ||||
|  | ||||
| async def check_updates(): | ||||
|     logging.warning('Waiting for lock...') | ||||
|     async with redis.lock('e621:update'): | ||||
|         logging.warning('Lock acquired...') | ||||
|         tag_list = [t.decode() for t in await redis.smembers('e621:subs')] | ||||
|         tag_list.sort() | ||||
|         for tl_idx in range(0, len(tag_list), 40): | ||||
|             tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40]) | ||||
|             logging.warning(tags) | ||||
|             posts = await e621.get_posts(tags) | ||||
|             if not posts: | ||||
|                 return | ||||
|             already_sent = await redis.smismember('e621:sent', [p.id for p in posts]) | ||||
|             for i in range(len(posts)): | ||||
|                 if already_sent[i]: | ||||
|                     continue | ||||
|                 await send_post(posts[i], tag_list) | ||||
|  | ||||
|  | ||||
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add']) | ||||
| async def add_tag(msg: Message): | ||||
|     args = msg.get_args() | ||||
| @@ -168,6 +174,19 @@ async def update(msg: Message): | ||||
|     await check_updates() | ||||
|  | ||||
|  | ||||
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test']) | ||||
| async def test(msg: Message): | ||||
|     args = msg.get_args() | ||||
|     if not args: | ||||
|         await msg.reply('Please provide post id') | ||||
|         return | ||||
|     post = await e621.get_posts(f'id:{args[0]}') | ||||
|     if not post: | ||||
|         await msg.reply('Post not found') | ||||
|         return | ||||
|     await send_post(post[0], []) | ||||
|  | ||||
|  | ||||
| async def background_on_start(): | ||||
|     await redis.delete('e621:update') | ||||
|     while True: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user