144 lines
3.8 KiB
Python
144 lines
3.8 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import wraps
|
|
|
|
import dotenv
|
|
import jwt
|
|
from sanic import Sanic, Unauthorized
|
|
from sanic_ext import validate
|
|
from sanic_ext.extensions.openapi import openapi
|
|
from sanic_ext.extensions.openapi.definitions import RequestBody
|
|
from sanic_redis import SanicRedis
|
|
from passlib.hash import pbkdf2_sha256
|
|
|
|
from const import REDIS_SUBS_KEY
|
|
|
|
dotenv.load_dotenv('.env')
|
|
api_secret = os.environ['API_SECRET']
|
|
api_auth = json.loads(os.environ['API_AUTH'])
|
|
|
|
app = Sanic('e621_bot_api')
|
|
app.config.update({
|
|
'REDIS': 'redis://localhost',
|
|
})
|
|
|
|
redis = SanicRedis()
|
|
|
|
|
|
async def get_subs(r):
|
|
subs = await r.smembers(REDIS_SUBS_KEY)
|
|
return {s.decode() for s in subs}
|
|
|
|
|
|
@dataclass
|
|
class LoginRequest:
|
|
username: str
|
|
password: str
|
|
|
|
|
|
def protected(wrapped):
|
|
def decorator(f):
|
|
@wraps(f)
|
|
async def decorated(request, *args, **kwargs):
|
|
token = request.headers.get('Authorization')
|
|
if not token:
|
|
raise Unauthorized('Authorization header is missing')
|
|
|
|
try:
|
|
jwt.decode(token, api_secret, algorithms=['HS256'])
|
|
except jwt.ExpiredSignatureError:
|
|
raise Unauthorized('Token has expired')
|
|
except jwt.InvalidTokenError:
|
|
raise Unauthorized('Invalid token')
|
|
|
|
return await f(request, *args, **kwargs)
|
|
|
|
return decorated
|
|
return decorator(wrapped)
|
|
|
|
|
|
@app.post('/api/login')
|
|
@openapi.definition(
|
|
body=RequestBody({
|
|
'application/json': LoginRequest,
|
|
})
|
|
)
|
|
@validate(json=LoginRequest)
|
|
async def login(request):
|
|
if pbkdf2_sha256(request.json['password']) != api_auth.get(request.json['username']):
|
|
return {'status': 'error', 'message': 'Invalid username or password'}
|
|
return {
|
|
'token': jwt.encode({}, api_secret, algorithm='HS256'),
|
|
}
|
|
|
|
|
|
@app.get('/api/subscriptions')
|
|
@protected
|
|
async def get_subscriptions(request):
|
|
async with redis.conn as r:
|
|
return {
|
|
'subscriptions': await r.smembers(REDIS_SUBS_KEY),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class UpdateSubscriptionRequest:
|
|
subs: list[str]
|
|
|
|
|
|
@app.delete('/api/subscriptions')
|
|
@openapi.definition(
|
|
body=RequestBody({
|
|
'application/json': UpdateSubscriptionRequest,
|
|
})
|
|
)
|
|
@validate(json=UpdateSubscriptionRequest)
|
|
@protected
|
|
async def delete_subscriptions(request):
|
|
data = request.json
|
|
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
|
|
|
|
async with redis.conn as r:
|
|
subs = await get_subs(r)
|
|
skipped = requested_subs - subs
|
|
if skipped:
|
|
return {'status': 'error', 'message': 'Some subscriptions were not found', 'skipped': sorted(skipped)}
|
|
await r.srem(REDIS_SUBS_KEY, *requested_subs)
|
|
|
|
return {'status': 'ok', 'removed': sorted(requested_subs)}
|
|
|
|
|
|
@app.post('/api/subscriptions')
|
|
@openapi.definition(
|
|
body=RequestBody({
|
|
'application/json': UpdateSubscriptionRequest,
|
|
})
|
|
)
|
|
@validate(json=UpdateSubscriptionRequest)
|
|
@protected
|
|
async def add_subscriptions(request):
|
|
data = request.json
|
|
requested_subs = {' '.join(sorted(sub.lower().split())) for sub in data['subs']}
|
|
|
|
async with redis.conn as r:
|
|
subs = await get_subs(r)
|
|
conflicts = requested_subs & subs
|
|
if conflicts:
|
|
return {'status': 'error', 'message': 'Some subscriptions already exist', 'conflicts': sorted(conflicts)}
|
|
await r.sadd(REDIS_SUBS_KEY, *data['subs'])
|
|
|
|
return {'status': 'ok', 'added': sorted(requested_subs)}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
is_debug = os.path.exists('.debug')
|
|
app.run(
|
|
host=os.environ.get('API_HOST', '0.0.0.0'),
|
|
port=int(os.environ.get('API_PORT', 8000)),
|
|
debug=is_debug,
|
|
access_log=is_debug,
|
|
auto_reload=is_debug,
|
|
)
|
|
|