Compare commits
33 Commits
5c35f8cda5
...
master
Author | SHA1 | Date | |
---|---|---|---|
b8b13616a8 | |||
8f5772b5dc | |||
5d11a1a3de | |||
abb4570958 | |||
1b88ad2ac8 | |||
f1684edef8 | |||
3e275f54c4 | |||
a9518259fb | |||
2c2088d993 | |||
69d0950037 | |||
baea50eb73 | |||
a9c022b992 | |||
40644bcb4b | |||
3723280e3d | |||
cf9aad9aee | |||
94184f8635 | |||
b52afa59cd | |||
15d4853402 | |||
ca194a8a2c | |||
687dabc354 | |||
af1479bc75 | |||
7bdebae54a | |||
c88e5e7ab1 | |||
973f432b53 | |||
1b3af70082 | |||
9b9197be0e | |||
fed68c0dfb | |||
60c6448f72 | |||
01ba70c0ce | |||
4c3b80fc5a | |||
2b6e9bd2df | |||
fab1798f7e | |||
92599a335f |
4
e621.py
4
e621.py
@@ -177,11 +177,11 @@ class E621:
|
||||
})).json()
|
||||
if 'success' in r:
|
||||
return []
|
||||
return [E621PostVersion.from_dict(p) for p in r]
|
||||
return [E621PostVersion.from_dict(p) for p in r if p.get('tags') is not None]
|
||||
|
||||
async def get_tag_aliases(self, name: str) -> List[str]:
|
||||
data = (await self.client.get('/tag_aliases.json', params={'search[antecedent_name]': name})).json()
|
||||
logging.warning(f'{name}: {data}')
|
||||
if 'tag_aliases' in data:
|
||||
return []
|
||||
return [alias['consequent_name'] for alias in data]
|
||||
return [alias['consequent_name'] for alias in data if alias['status'] == 'active']
|
||||
|
@@ -4,3 +4,4 @@ USERS=9893249151
|
||||
AWS_ACCESS_KEY=AKIAUIXZQT
|
||||
AWS_SECRET_KEY=QyBnXOhmlc
|
||||
AWS_S3_BUCKET=bucket
|
||||
UPLOAD_KEY=123
|
||||
|
53
main.py
53
main.py
@@ -36,6 +36,7 @@ e621 = E621()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
bot = Bot(token=os.environ['BOT_TOKEN'])
|
||||
dp = Dispatcher()
|
||||
upload_key = os.environ['UPLOAD_KEY']
|
||||
|
||||
ChatFilter = F.chat.id.in_(set(map(int, os.environ['USERS'].split(','))))
|
||||
|
||||
@@ -183,27 +184,38 @@ async def check_updates():
|
||||
|
||||
@dp.message(filters.Command('resend_after'), ChatFilter)
|
||||
async def resend_after(msg: Message):
|
||||
args = msg.text.split()[1:]
|
||||
try:
|
||||
timestamp = int(msg.text.split()[1])
|
||||
timestamp = int(args[0])
|
||||
skip_to_sub = args[1] if len(args) > 1 else None
|
||||
except:
|
||||
traceback.print_exc()
|
||||
await msg.reply('Invalid timestamp or not provided')
|
||||
return
|
||||
|
||||
async with redis.lock(REDIS_LOCK_KEY):
|
||||
tag_list = [tuple(t.decode().split()) for t in await redis.smembers(REDIS_SUBS_KEY)]
|
||||
tags = sorted(await redis.smembers(REDIS_SUBS_KEY))
|
||||
if skip_to_sub is not None and skip_to_sub in tags:
|
||||
tags = tags[tags.index(skip_to_sub):]
|
||||
tag_list = [tuple(t.decode().split()) for t in tags]
|
||||
for i, tag in enumerate(tag_list):
|
||||
await msg.reply(f'Checking tag <b>{tag}</b> ({i+1}/{len(tag_list)})', parse_mode=ParseMode.HTML)
|
||||
posts = []
|
||||
page = 1
|
||||
while True:
|
||||
page_posts = await e621.get_posts(tag, page)
|
||||
break_loop = False
|
||||
if page > 10:
|
||||
break
|
||||
page_posts = await e621.get_posts(' '.join(tag), page)
|
||||
if not page_posts:
|
||||
break
|
||||
for post in page_posts:
|
||||
if datetime.datetime.fromisoformat(post.created_at).timestamp() < timestamp:
|
||||
post_created_at = datetime.datetime.fromisoformat(post.created_at)
|
||||
if post_created_at.timestamp() < timestamp:
|
||||
break_loop = True
|
||||
break
|
||||
posts.append(post)
|
||||
if break_loop:
|
||||
break
|
||||
page += 1
|
||||
for post in posts[::-1]:
|
||||
await send_post(post, tag_list)
|
||||
@@ -218,7 +230,10 @@ async def add_tag(msg: Message):
|
||||
return
|
||||
for tag in args.split():
|
||||
posts = await e621.get_posts(tag)
|
||||
await redis.sadd(REDIS_SENT_KEY, *[post.id for post in posts])
|
||||
if posts:
|
||||
await redis.sadd(REDIS_SENT_KEY, *[post.id for post in posts])
|
||||
else:
|
||||
logging.warning(f'No posts found for tag {tag}')
|
||||
await redis.sadd(REDIS_SUBS_KEY, tag)
|
||||
await msg.reply(f'Tags {args} added')
|
||||
|
||||
@@ -311,11 +326,16 @@ async def check_aliases(msg: Message):
|
||||
await resp.edit_text(f'Checking aliases {progress}/{len(tags)}\n\n{l}', parse_mode=ParseMode.HTML)
|
||||
|
||||
for sub in tags:
|
||||
replaced_tags = False
|
||||
for subtag in sub.split():
|
||||
if replacements := await e621.get_tag_aliases(subtag):
|
||||
lines.append(f'- {subtag} -> {replacements[0]}, (<code>{sub}</code>)')
|
||||
replaced_tags = False
|
||||
progress += 1
|
||||
await send_progress()
|
||||
if replaced_tags:
|
||||
await send_progress()
|
||||
|
||||
await send_progress()
|
||||
|
||||
|
||||
@dp.message(filters.Command('update'), ChatFilter)
|
||||
@@ -344,20 +364,19 @@ async def send_callback(cq: CallbackQuery):
|
||||
img_bytes = BytesIO()
|
||||
await bot.download(cq.message.photo[-1], img_bytes)
|
||||
img_bytes.seek(0)
|
||||
data = base64.b64encode(img_bytes.read()).decode()
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
r = await client.post(f'https://bots.bakatrouble.me/bots_rpc/{destination}/', json={
|
||||
"method": "post_photo",
|
||||
"params": [data, True],
|
||||
"jsonrpc": "2.0",
|
||||
"id": 0,
|
||||
})
|
||||
subdomain = 'ch' + ('sfw' if destination == 'pics' else 'nsfw')
|
||||
r = await client.post(f'https://{subdomain}.bakatrouble.me/{upload_key}/photo', files={'upload': img_bytes})
|
||||
logging.info(r.text)
|
||||
resp = r.json()
|
||||
if 'result' in resp and resp['result'] == True:
|
||||
await cq.answer('Sent')
|
||||
elif 'result' in resp and resp['result'] == 'duplicate':
|
||||
status = resp.get('status')
|
||||
if not status:
|
||||
raise Exception(f'No result in response: {resp}')
|
||||
elif status == 'duplicate':
|
||||
await cq.answer('Duplicate')
|
||||
elif status == 'ok':
|
||||
await cq.answer('Sent')
|
||||
else:
|
||||
raise Exception(resp)
|
||||
except:
|
||||
|
48
server.py
48
server.py
@@ -5,8 +5,8 @@ from functools import wraps
|
||||
|
||||
import dotenv
|
||||
import jwt
|
||||
from sanic import Sanic, Unauthorized
|
||||
from sanic_ext import validate
|
||||
from sanic import Sanic, Unauthorized, json as jsonr
|
||||
from sanic_ext import validate, Extend
|
||||
from sanic_ext.extensions.openapi import openapi
|
||||
from sanic_ext.extensions.openapi.definitions import RequestBody
|
||||
from sanic_redis import SanicRedis
|
||||
@@ -19,11 +19,16 @@ api_secret = os.environ['API_SECRET']
|
||||
api_auth = json.loads(os.environ['API_AUTH'])
|
||||
|
||||
app = Sanic('e621_bot_api')
|
||||
app.config.CORS_ORIGINS = '*'
|
||||
app.config.CORS_HEADERS = 'Authorization, *'
|
||||
app.config.update({
|
||||
'REDIS': 'redis://localhost',
|
||||
})
|
||||
Extend(app)
|
||||
|
||||
|
||||
redis = SanicRedis()
|
||||
redis.init_app(app)
|
||||
|
||||
|
||||
async def get_subs(r):
|
||||
@@ -65,21 +70,22 @@ def protected(wrapped):
|
||||
})
|
||||
)
|
||||
@validate(json=LoginRequest)
|
||||
async def login(request):
|
||||
if pbkdf2_sha256(request.json['password']) != api_auth.get(request.json['username']):
|
||||
return {'status': 'error', 'message': 'Invalid username or password'}
|
||||
return {
|
||||
async def login(_, body: LoginRequest):
|
||||
hash = api_auth.get(body.username)
|
||||
if not hash or not pbkdf2_sha256(10000, salt=b'salt').verify(body.password, hash):
|
||||
return jsonr({'status': 'error', 'message': 'Invalid username or password'}, 401)
|
||||
return jsonr({
|
||||
'token': jwt.encode({}, api_secret, algorithm='HS256'),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.get('/api/subscriptions')
|
||||
@protected
|
||||
async def get_subscriptions(request):
|
||||
async def get_subscriptions(_):
|
||||
async with redis.conn as r:
|
||||
return {
|
||||
'subscriptions': await r.smembers(REDIS_SUBS_KEY),
|
||||
}
|
||||
return jsonr({
|
||||
'subscriptions': sorted(list(await get_subs(r))),
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -95,18 +101,17 @@ class UpdateSubscriptionRequest:
|
||||
)
|
||||
@validate(json=UpdateSubscriptionRequest)
|
||||
@protected
|
||||
async def delete_subscriptions(request):
|
||||
data = request.json
|
||||
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
|
||||
async def delete_subscriptions(_, body: UpdateSubscriptionRequest):
|
||||
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs}
|
||||
|
||||
async with redis.conn as r:
|
||||
subs = await get_subs(r)
|
||||
skipped = requested_subs - subs
|
||||
if skipped:
|
||||
return {'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}
|
||||
return jsonr({'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}, 404)
|
||||
await r.srem(REDIS_SUBS_KEY, *requested_subs)
|
||||
|
||||
return {'status': 'ok', 'removed': sorted(requested_subs)}
|
||||
return jsonr({'status': 'ok', 'removed': sorted(requested_subs)})
|
||||
|
||||
|
||||
@app.post('/api/subscriptions')
|
||||
@@ -117,18 +122,17 @@ async def delete_subscriptions(request):
|
||||
)
|
||||
@validate(json=UpdateSubscriptionRequest)
|
||||
@protected
|
||||
async def add_subscriptions(request):
|
||||
data = request.json
|
||||
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
|
||||
async def add_subscriptions(_, body: UpdateSubscriptionRequest):
|
||||
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs}
|
||||
|
||||
async with redis.conn as r:
|
||||
subs = await get_subs(r)
|
||||
conflicts = requested_subs & subs
|
||||
if conflicts:
|
||||
return {'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}
|
||||
await r.sadd(REDIS_SUBS_KEY, *data['subs'])
|
||||
return jsonr({'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}, 409)
|
||||
await r.sadd(REDIS_SUBS_KEY, *body.subs)
|
||||
|
||||
return {'status': 'ok', 'added': sorted(requested_subs)}
|
||||
return jsonr({'status': 'ok', 'added': sorted(requested_subs)})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user