add test command

This commit is contained in:
bakatrouble 2023-06-01 00:04:20 +03:00
parent 728373cdd5
commit d10f9a52be
2 changed files with 87 additions and 65 deletions

View File

@ -137,3 +137,6 @@ class E621:
async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]:
r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json()
return [E621Post.from_dict(p) for p in r['posts']]
async def get_post(self, post_id: str) -> E621Post:
return (await self.get_posts(f'id:{post_id}'))[0]

59
main.py
View File

@ -2,17 +2,19 @@ import asyncio
import logging
import os
from io import BytesIO
from typing import List
from PIL import Image
import httpx
import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher
from aiogram.dispatcher import filters
from aiogram.types import Message, ParseMode
from aiogram.types import Message, ParseMode, ChatActions
from aiogram.utils import executor, exceptions
import dotenv
from e621 import E621
from e621 import E621, E621Post
dotenv.load_dotenv('.env')
@ -26,29 +28,14 @@ dp = Dispatcher(bot)
user_ids = list(map(int, os.environ['USERS'].split(',')))
async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
tag_list.sort()
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
for i in range(len(posts)):
if already_sent[i]:
continue
post = posts[i]
async def send_post(post: E621Post, tag_list: List[str]):
await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING)
monitored_tags = set(post.tags.flatten()) & set(tag_list)
artist_tags = post.tags.artist
character_tags = post.tags.character
copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [
f'Monitored tags: <b>{" ".join(monitored_tags)}</b>',
f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>',
artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
@ -107,6 +94,25 @@ async def check_updates():
logging.exception(e)
async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
tag_list.sort()
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
for i in range(len(posts)):
if already_sent[i]:
continue
await send_post(posts[i], tag_list)
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add'])
async def add_tag(msg: Message):
args = msg.get_args()
@ -168,6 +174,19 @@ async def update(msg: Message):
await check_updates()
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test'])
async def test(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide post id')
return
post = await e621.get_posts(f'id:{args[0]}')
if not post:
await msg.reply('Post not found')
return
await send_post(post[0], [])
async def background_on_start():
await redis.delete('e621:update')
while True: