add test command
This commit is contained in:
		
							
								
								
									
										3
									
								
								e621.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								e621.py
									
									
									
									
									
								
							| @@ -137,3 +137,6 @@ class E621: | |||||||
|     async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]: |     async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]: | ||||||
|         r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json() |         r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json() | ||||||
|         return [E621Post.from_dict(p) for p in r['posts']] |         return [E621Post.from_dict(p) for p in r['posts']] | ||||||
|  |  | ||||||
|  |     async def get_post(self, post_id: str) -> E621Post: | ||||||
|  |         return (await self.get_posts(f'id:{post_id}'))[0] | ||||||
|   | |||||||
							
								
								
									
										59
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								main.py
									
									
									
									
									
								
							| @@ -2,17 +2,19 @@ import asyncio | |||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
| from PIL import Image | from PIL import Image | ||||||
|  |  | ||||||
| import httpx | import httpx | ||||||
| import redis.asyncio as aioredis | import redis.asyncio as aioredis | ||||||
| from aiogram import Bot, Dispatcher | from aiogram import Bot, Dispatcher | ||||||
| from aiogram.dispatcher import filters | from aiogram.dispatcher import filters | ||||||
| from aiogram.types import Message, ParseMode | from aiogram.types import Message, ParseMode, ChatActions | ||||||
| from aiogram.utils import executor, exceptions | from aiogram.utils import executor, exceptions | ||||||
| import dotenv | import dotenv | ||||||
|  |  | ||||||
| from e621 import E621 | from e621 import E621, E621Post | ||||||
|  |  | ||||||
| dotenv.load_dotenv('.env') | dotenv.load_dotenv('.env') | ||||||
|  |  | ||||||
| @@ -26,29 +28,14 @@ dp = Dispatcher(bot) | |||||||
| user_ids = list(map(int, os.environ['USERS'].split(','))) | user_ids = list(map(int, os.environ['USERS'].split(','))) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def check_updates(): | async def send_post(post: E621Post, tag_list: List[str]): | ||||||
|     logging.warning('Waiting for lock...') |     await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING) | ||||||
|     async with redis.lock('e621:update'): |  | ||||||
|         logging.warning('Lock acquired...') |  | ||||||
|         tag_list = [t.decode() for t in await redis.smembers('e621:subs')] |  | ||||||
|         tag_list.sort() |  | ||||||
|         for tl_idx in range(0, len(tag_list), 40): |  | ||||||
|             tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40]) |  | ||||||
|             logging.warning(tags) |  | ||||||
|             posts = await e621.get_posts(tags) |  | ||||||
|             if not posts: |  | ||||||
|                 return |  | ||||||
|             already_sent = await redis.smismember('e621:sent', [p.id for p in posts]) |  | ||||||
|             for i in range(len(posts)): |  | ||||||
|                 if already_sent[i]: |  | ||||||
|                     continue |  | ||||||
|                 post = posts[i] |  | ||||||
|     monitored_tags = set(post.tags.flatten()) & set(tag_list) |     monitored_tags = set(post.tags.flatten()) & set(tag_list) | ||||||
|     artist_tags = post.tags.artist |     artist_tags = post.tags.artist | ||||||
|     character_tags = post.tags.character |     character_tags = post.tags.character | ||||||
|     copyright_tags = post.tags.copyright |     copyright_tags = post.tags.copyright | ||||||
|     caption = '\n'.join(l for l in [ |     caption = '\n'.join(l for l in [ | ||||||
|                     f'Monitored tags: <b>{" ".join(monitored_tags)}</b>', |         f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>', | ||||||
|         artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>', |         artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>', | ||||||
|         character_tags and f'Character: <b>{", ".join(character_tags)}</b>', |         character_tags and f'Character: <b>{", ".join(character_tags)}</b>', | ||||||
|         copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>', |         copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>', | ||||||
| @@ -107,6 +94,25 @@ async def check_updates(): | |||||||
|             logging.exception(e) |             logging.exception(e) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def check_updates(): | ||||||
|  |     logging.warning('Waiting for lock...') | ||||||
|  |     async with redis.lock('e621:update'): | ||||||
|  |         logging.warning('Lock acquired...') | ||||||
|  |         tag_list = [t.decode() for t in await redis.smembers('e621:subs')] | ||||||
|  |         tag_list.sort() | ||||||
|  |         for tl_idx in range(0, len(tag_list), 40): | ||||||
|  |             tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40]) | ||||||
|  |             logging.warning(tags) | ||||||
|  |             posts = await e621.get_posts(tags) | ||||||
|  |             if not posts: | ||||||
|  |                 return | ||||||
|  |             already_sent = await redis.smismember('e621:sent', [p.id for p in posts]) | ||||||
|  |             for i in range(len(posts)): | ||||||
|  |                 if already_sent[i]: | ||||||
|  |                     continue | ||||||
|  |                 await send_post(posts[i], tag_list) | ||||||
|  |  | ||||||
|  |  | ||||||
| @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add']) | @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add']) | ||||||
| async def add_tag(msg: Message): | async def add_tag(msg: Message): | ||||||
|     args = msg.get_args() |     args = msg.get_args() | ||||||
| @@ -168,6 +174,19 @@ async def update(msg: Message): | |||||||
|     await check_updates() |     await check_updates() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test']) | ||||||
|  | async def test(msg: Message): | ||||||
|  |     args = msg.get_args() | ||||||
|  |     if not args: | ||||||
|  |         await msg.reply('Please provide post id') | ||||||
|  |         return | ||||||
|  |     post = await e621.get_posts(f'id:{args[0]}') | ||||||
|  |     if not post: | ||||||
|  |         await msg.reply('Post not found') | ||||||
|  |         return | ||||||
|  |     await send_post(post[0], []) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def background_on_start(): | async def background_on_start(): | ||||||
|     await redis.delete('e621:update') |     await redis.delete('e621:update') | ||||||
|     while True: |     while True: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user