add test command

This commit is contained in:
bakatrouble 2023-06-01 00:04:20 +03:00
parent 728373cdd5
commit d10f9a52be
2 changed files with 87 additions and 65 deletions

View File

@ -137,3 +137,6 @@ class E621:
async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]: async def get_posts(self, tags='', page=1, limit=50) -> List[E621Post]:
r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json() r = (await self.client.get('/posts.json', params={'tags': tags, 'page': page, 'limit': limit})).json()
return [E621Post.from_dict(p) for p in r['posts']] return [E621Post.from_dict(p) for p in r['posts']]
async def get_post(self, post_id: str) -> E621Post:
return (await self.get_posts(f'id:{post_id}'))[0]

59
main.py
View File

@ -2,17 +2,19 @@ import asyncio
import logging import logging
import os import os
from io import BytesIO from io import BytesIO
from typing import List
from PIL import Image from PIL import Image
import httpx import httpx
import redis.asyncio as aioredis import redis.asyncio as aioredis
from aiogram import Bot, Dispatcher from aiogram import Bot, Dispatcher
from aiogram.dispatcher import filters from aiogram.dispatcher import filters
from aiogram.types import Message, ParseMode from aiogram.types import Message, ParseMode, ChatActions
from aiogram.utils import executor, exceptions from aiogram.utils import executor, exceptions
import dotenv import dotenv
from e621 import E621 from e621 import E621, E621Post
dotenv.load_dotenv('.env') dotenv.load_dotenv('.env')
@ -26,29 +28,14 @@ dp = Dispatcher(bot)
user_ids = list(map(int, os.environ['USERS'].split(','))) user_ids = list(map(int, os.environ['USERS'].split(',')))
async def check_updates(): async def send_post(post: E621Post, tag_list: List[str]):
logging.warning('Waiting for lock...') await bot.send_chat_action(int(os.environ['SEND_CHAT']), action=ChatActions.TYPING)
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
tag_list.sort()
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
for i in range(len(posts)):
if already_sent[i]:
continue
post = posts[i]
monitored_tags = set(post.tags.flatten()) & set(tag_list) monitored_tags = set(post.tags.flatten()) & set(tag_list)
artist_tags = post.tags.artist artist_tags = post.tags.artist
character_tags = post.tags.character character_tags = post.tags.character
copyright_tags = post.tags.copyright copyright_tags = post.tags.copyright
caption = '\n'.join(l for l in [ caption = '\n'.join(l for l in [
f'Monitored tags: <b>{" ".join(monitored_tags)}</b>', f'Monitored tags: <b>{" ".join(monitored_tags) or "None"}</b>',
artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>', artist_tags and f'Artist: <b>{" ".join(artist_tags)}</b>',
character_tags and f'Character: <b>{", ".join(character_tags)}</b>', character_tags and f'Character: <b>{", ".join(character_tags)}</b>',
copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>', copyright_tags and f'Copyright: <b>{", ".join(copyright_tags)}</b>',
@ -107,6 +94,25 @@ async def check_updates():
logging.exception(e) logging.exception(e)
async def check_updates():
logging.warning('Waiting for lock...')
async with redis.lock('e621:update'):
logging.warning('Lock acquired...')
tag_list = [t.decode() for t in await redis.smembers('e621:subs')]
tag_list.sort()
for tl_idx in range(0, len(tag_list), 40):
tags = ' '.join(f'~{tag}' for tag in tag_list[tl_idx: tl_idx + 40])
logging.warning(tags)
posts = await e621.get_posts(tags)
if not posts:
return
already_sent = await redis.smismember('e621:sent', [p.id for p in posts])
for i in range(len(posts)):
if already_sent[i]:
continue
await send_post(posts[i], tag_list)
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add']) @dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['add'])
async def add_tag(msg: Message): async def add_tag(msg: Message):
args = msg.get_args() args = msg.get_args()
@ -168,6 +174,19 @@ async def update(msg: Message):
await check_updates() await check_updates()
@dp.message_handler(filters.IDFilter(chat_id=user_ids), commands=['test'])
async def test(msg: Message):
args = msg.get_args()
if not args:
await msg.reply('Please provide post id')
return
post = await e621.get_posts(f'id:{args[0]}')
if not post:
await msg.reply('Post not found')
return
await send_post(post[0], [])
async def background_on_start(): async def background_on_start():
await redis.delete('e621:update') await redis.delete('e621:update')
while True: while True: