import json import os from dataclasses import dataclass from functools import wraps import dotenv import jwt from sanic import Sanic, Unauthorized from sanic_ext import validate from sanic_ext.extensions.openapi import openapi from sanic_ext.extensions.openapi.definitions import RequestBody from sanic_redis import SanicRedis from passlib.hash import pbkdf2_sha256 from const import REDIS_SUBS_KEY dotenv.load_dotenv('.env') api_secret = os.environ['API_SECRET'] api_auth = json.loads(os.environ['API_AUTH']) app = Sanic('e621_bot_api') app.config.update({ 'REDIS': 'redis://localhost', }) redis = SanicRedis() async def get_subs(r): subs = await r.smembers(REDIS_SUBS_KEY) return {s.decode() for s in subs} @dataclass class LoginRequest: username: str password: str def protected(wrapped): def decorator(f): @wraps(f) async def decorated(request, *args, **kwargs): token = request.headers.get('Authorization') if not token: raise Unauthorized('Authorization header is missing') try: jwt.decode(token, api_secret, algorithms=['HS256']) except jwt.ExpiredSignatureError: raise Unauthorized('Token has expired') except jwt.InvalidTokenError: raise Unauthorized('Invalid token') return await f(request, *args, **kwargs) return decorated return decorator(wrapped) @app.post('/api/login') @openapi.definition( body=RequestBody({ 'application/json': LoginRequest, }) ) @validate(json=LoginRequest) async def login(request): if pbkdf2_sha256(request.json['password']) != api_auth.get(request.json['username']): return {'status': 'error', 'message': 'Invalid username or password'} return { 'token': jwt.encode({}, api_secret, algorithm='HS256'), } @app.get('/api/subscriptions') @protected async def get_subscriptions(request): async with redis.conn as r: return { 'subscriptions': await r.smembers(REDIS_SUBS_KEY), } @dataclass class UpdateSubscriptionRequest: subs: list[str] @app.delete('/api/subscriptions') @openapi.definition( body=RequestBody({ 'application/json': UpdateSubscriptionRequest, }) ) @validate(json=UpdateSubscriptionRequest) @protected async def delete_subscriptions(request): data = request.json requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']} async with redis.conn as r: subs = await get_subs(r) skipped = requested_subs - subs if skipped: return {'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)} await r.srem(REDIS_SUBS_KEY, *requested_subs) return {'status': 'ok', 'removed': sorted(requested_subs)} @app.post('/api/subscriptions') @openapi.definition( body=RequestBody({ 'application/json': UpdateSubscriptionRequest, }) ) @validate(json=UpdateSubscriptionRequest) @protected async def add_subscriptions(request): data = request.json requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']} async with redis.conn as r: subs = await get_subs(r) conflicts = requested_subs & subs if conflicts: return {'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)} await r.sadd(REDIS_SUBS_KEY, *data['subs']) return {'status': 'ok', 'added': sorted(requested_subs)} if __name__ == '__main__': is_debug = os.path.exists('.debug') app.run( host=os.environ.get('API_HOST', '0.0.0.0'), port=int(os.environ.get('API_PORT', 8000)), debug=is_debug, access_log=is_debug, auto_reload=is_debug, )