e621_bot/server.py

146 lines
3.9 KiB
Python

import json
import os
from dataclasses import dataclass
from functools import wraps
import dotenv
import jwt
from sanic import Sanic, Unauthorized, json as jsonr
from sanic_ext import validate, Extend
from sanic_ext.extensions.openapi import openapi
from sanic_ext.extensions.openapi.definitions import RequestBody
from sanic_redis import SanicRedis
from passlib.hash import pbkdf2_sha256
from const import REDIS_SUBS_KEY
dotenv.load_dotenv('.env')
api_secret = os.environ['API_SECRET']
api_auth = json.loads(os.environ['API_AUTH'])
app = Sanic('e621_bot_api')
app.config.CORS_ORIGINS = '*'
app.config.update({
'REDIS': 'redis://localhost',
})
Extend(app)
redis = SanicRedis()
async def get_subs(r):
subs = await r.smembers(REDIS_SUBS_KEY)
return {s.decode() for s in subs}
@dataclass
class LoginRequest:
username: str
password: str
def protected(wrapped):
def decorator(f):
@wraps(f)
async def decorated(request, *args, **kwargs):
token = request.headers.get('Authorization')
if not token:
raise Unauthorized('Authorization header is missing')
try:
jwt.decode(token, api_secret, algorithms=['HS256'])
except jwt.ExpiredSignatureError:
raise Unauthorized('Token has expired')
except jwt.InvalidTokenError:
raise Unauthorized('Invalid token')
return await f(request, *args, **kwargs)
return decorated
return decorator(wrapped)
@app.post('/api/login')
@openapi.definition(
body=RequestBody({
'application/json': LoginRequest,
})
)
@validate(json=LoginRequest)
async def login(_, body: LoginRequest):
hash = api_auth.get(body.username)
if not hash or not pbkdf2_sha256(10000, salt=b'salt').verify(body.password, hash):
return jsonr({'status': 'error', 'message': 'Invalid username or password'}, 401)
return jsonr({
'token': jwt.encode({}, api_secret, algorithm='HS256'),
})
@app.get('/api/subscriptions')
@protected
async def get_subscriptions(_):
async with redis.conn as r:
return jsonr({
'subscriptions': await r.smembers(REDIS_SUBS_KEY),
})
@dataclass
class UpdateSubscriptionRequest:
subs: list[str]
@app.delete('/api/subscriptions')
@openapi.definition(
body=RequestBody({
'application/json': UpdateSubscriptionRequest,
})
)
@validate(json=UpdateSubscriptionRequest)
@protected
async def delete_subscriptions(_, body: UpdateSubscriptionRequest):
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs}
async with redis.conn as r:
subs = await get_subs(r)
skipped = requested_subs - subs
if skipped:
return jsonr({'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}, 404)
await r.srem(REDIS_SUBS_KEY, *requested_subs)
return jsonr({'status': 'ok', 'removed': sorted(requested_subs)})
@app.post('/api/subscriptions')
@openapi.definition(
body=RequestBody({
'application/json': UpdateSubscriptionRequest,
})
)
@validate(json=UpdateSubscriptionRequest)
@protected
async def add_subscriptions(_, body: UpdateSubscriptionRequest):
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in body.subs}
async with redis.conn as r:
subs = await get_subs(r)
conflicts = requested_subs & subs
if conflicts:
return jsonr({'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}, 409)
await r.sadd(REDIS_SUBS_KEY, *body.subs)
return jsonr({'status': 'ok', 'added': sorted(requested_subs)})
if __name__ == '__main__':
is_debug = os.path.exists('.debug')
app.run(
host=os.environ.get('API_HOST', '0.0.0.0'),
port=int(os.environ.get('API_PORT', 8000)),
debug=is_debug,
access_log=is_debug,
auto_reload=is_debug,
)