e621_bot/server.py
2025-07-14 17:12:42 +03:00

144 lines
3.8 KiB
Python

import json
import os
from dataclasses import dataclass
from functools import wraps
import dotenv
import jwt
from sanic import Sanic, Unauthorized
from sanic_ext import validate
from sanic_ext.extensions.openapi import openapi
from sanic_ext.extensions.openapi.definitions import RequestBody
from sanic_redis import SanicRedis
from passlib.hash import pbkdf2_sha256
from const import REDIS_SUBS_KEY
dotenv.load_dotenv('.env')
api_secret = os.environ['API_SECRET']
api_auth = json.loads(os.environ['API_AUTH'])
app = Sanic('e621_bot_api')
app.config.update({
'REDIS': 'redis://localhost',
})
redis = SanicRedis()
async def get_subs(r):
subs = await r.smembers(REDIS_SUBS_KEY)
return {s.decode() for s in subs}
@dataclass
class LoginRequest:
username: str
password: str
def protected(wrapped):
def decorator(f):
@wraps(f)
async def decorated(request, *args, **kwargs):
token = request.headers.get('Authorization')
if not token:
raise Unauthorized('Authorization header is missing')
try:
jwt.decode(token, api_secret, algorithms=['HS256'])
except jwt.ExpiredSignatureError:
raise Unauthorized('Token has expired')
except jwt.InvalidTokenError:
raise Unauthorized('Invalid token')
return await f(request, *args, **kwargs)
return decorated
return decorator(wrapped)
@app.post('/api/login')
@openapi.definition(
body=RequestBody({
'application/json': LoginRequest,
})
)
@validate(json=LoginRequest)
async def login(request):
if pbkdf2_sha256(request.json['password']) != api_auth.get(request.json['username']):
return {'status': 'error', 'message': 'Invalid username or password'}
return {
'token': jwt.encode({}, api_secret, algorithm='HS256'),
}
@app.get('/api/subscriptions')
@protected
async def get_subscriptions(request):
async with redis.conn as r:
return {
'subscriptions': await r.smembers(REDIS_SUBS_KEY),
}
@dataclass
class UpdateSubscriptionRequest:
subs: list[str]
@app.delete('/api/subscriptions')
@openapi.definition(
body=RequestBody({
'application/json': UpdateSubscriptionRequest,
})
)
@validate(json=UpdateSubscriptionRequest)
@protected
async def delete_subscriptions(request):
data = request.json
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
async with redis.conn as r:
subs = await get_subs(r)
skipped = requested_subs - subs
if skipped:
return {'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}
await r.srem(REDIS_SUBS_KEY, *requested_subs)
return {'status': 'ok', 'removed': sorted(requested_subs)}
@app.post('/api/subscriptions')
@openapi.definition(
body=RequestBody({
'application/json': UpdateSubscriptionRequest,
})
)
@validate(json=UpdateSubscriptionRequest)
@protected
async def add_subscriptions(request):
data = request.json
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
async with redis.conn as r:
subs = await get_subs(r)
conflicts = requested_subs & subs
if conflicts:
return {'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}
await r.sadd(REDIS_SUBS_KEY, *data['subs'])
return {'status': 'ok', 'added': sorted(requested_subs)}
if __name__ == '__main__':
is_debug = os.path.exists('.debug')
app.run(
host=os.environ.get('API_HOST', '0.0.0.0'),
port=int(os.environ.get('API_PORT', 8000)),
debug=is_debug,
access_log=is_debug,
auto_reload=is_debug,
)