import json import os from dataclasses import dataclass from functools import wraps import dotenv import jwt from sanic import Sanic, Unauthorized, json as jsonr from sanic_ext import validate, Extend from sanic_ext.extensions.openapi import openapi from sanic_ext.extensions.openapi.definitions import RequestBody from sanic_redis import SanicRedis from passlib.hash import pbkdf2_sha256 from const import REDIS_SUBS_KEY dotenv.load_dotenv('.env') api_secret = os.environ['API_SECRET'] api_auth = json.loads(os.environ['API_AUTH']) app = Sanic('e621_bot_api') app.config.CORS_ORIGINS = '*' app.config.CORS_HEADERS = 'Authorization, *' app.config.update({ 'REDIS': 'redis://localhost', }) Extend(app) redis = SanicRedis() async def get_subs(r): subs = await r.smembers(REDIS_SUBS_KEY) return {s.decode() for s in subs} @dataclass class LoginRequest: username: str password: str def protected(wrapped): def decorator(f): @wraps(f) async def decorated(request, *args, **kwargs): token = request.headers.get('Authorization') if not token: raise Unauthorized('Authorization header is missing') try: jwt.decode(token, api_secret, algorithms=['HS256']) except jwt.ExpiredSignatureError: raise Unauthorized('Token has expired') except jwt.InvalidTokenError: raise Unauthorized('Invalid token') return await f(request, *args, **kwargs) return decorated return decorator(wrapped) @app.post('/api/login') @openapi.definition( body=RequestBody({ 'application/json': LoginRequest, }) ) @validate(json=LoginRequest) async def login(_, body: LoginRequest): hash = api_auth.get(body.username) if not hash or not pbkdf2_sha256(10000, salt=b'salt').verify(body.password, hash): return jsonr({'status': 'error', 'message': 'Invalid username or password'}, 401) return jsonr({ 'token': jwt.encode({}, api_secret, algorithm='HS256'), }) @app.get('/api/subscriptions') @protected async def get_subscriptions(_): async with redis.conn as r: return jsonr({ 'subscriptions': await r.smembers(REDIS_SUBS_KEY), }) @dataclass class UpdateSubscriptionRequest: subs: list[str] @app.delete('/api/subscriptions') @openapi.definition( body=RequestBody({ 'application/json': UpdateSubscriptionRequest, }) ) @validate(json=UpdateSubscriptionRequest) @protected async def delete_subscriptions(_, body: UpdateSubscriptionRequest): requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs} async with redis.conn as r: subs = await get_subs(r) skipped = requested_subs - subs if skipped: return jsonr({'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}, 404) await r.srem(REDIS_SUBS_KEY, *requested_subs) return jsonr({'status': 'ok', 'removed': sorted(requested_subs)}) @app.post('/api/subscriptions') @openapi.definition( body=RequestBody({ 'application/json': UpdateSubscriptionRequest, }) ) @validate(json=UpdateSubscriptionRequest) @protected async def add_subscriptions(_, body: UpdateSubscriptionRequest): requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs} async with redis.conn as r: subs = await get_subs(r) conflicts = requested_subs & subs if conflicts: return jsonr({'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}, 409) await r.sadd(REDIS_SUBS_KEY, *body.subs) return jsonr({'status': 'ok', 'added': sorted(requested_subs)}) if __name__ == '__main__': is_debug = os.path.exists('.debug') app.run( host=os.environ.get('API_HOST', '0.0.0.0'), port=int(os.environ.get('API_PORT', 8000)), debug=is_debug, access_log=is_debug, auto_reload=is_debug, )