Fix wrong api key error handling

This commit is contained in:
Iam54r1n4
2025-02-07 17:38:05 +00:00
parent da421b0d15
commit 9dcace9792
3 changed files with 22 additions and 12 deletions

View File

@ -1 +1 @@
from .handler import setup_exception_handler from .handler import setup_exception_handler, exception_handler

View File

@ -13,8 +13,12 @@ def setup_exception_handler(app: FastAPI):
Setup exception handler for FastAPI. Setup exception handler for FastAPI.
''' '''
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException): # type: ignore async def http_exception_handler_wrapper(request: Request, exc: HTTPException): # type: ignore
return JSONResponse( return exception_handler(exc)
status_code=exc.status_code,
content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(),
) def exception_handler(exc: HTTPException): # type: ignore
return JSONResponse(
status_code=exc.status_code,
content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(),
)

View File

@ -1,11 +1,12 @@
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request, Response, HTTPException from fastapi import Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Awaitable, Callable from typing import Awaitable, Callable
from starlette.types import ASGIApp from starlette.types import ASGIApp
from urllib.parse import quote from urllib.parse import quote
from exception_handler import exception_handler
from session import SessionManager from session import SessionManager
from config import CONFIGS from config import CONFIGS
@ -18,7 +19,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
self.__session_manager = session_manager self.__session_manager = session_manager
self.__api_token = api_token self.__api_token = api_token
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): # type: ignore
'''Handles session authentication.''' '''Handles session authentication.'''
public_routes = [ public_routes = [
f'/{CONFIGS.ROOT_PATH}/login', f'/{CONFIGS.ROOT_PATH}/login',
@ -36,14 +37,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
if api_key == self.__api_token: if api_key == self.__api_token:
return await call_next(request) return await call_next(request)
else: else:
raise HTTPException(status_code=401, detail="Invalid API token.") return self.__handle_api_failure(status=401, detail="Invalid API token.")
# Extract session_id from cookies # Extract session_id from cookies
session_id = request.cookies.get("session_id") session_id = request.cookies.get("session_id")
if not session_id: if not session_id:
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="Unauthorized.") return self.__handle_api_failure(status=401, detail="Unauthorized.")
return self.__redirect_to_login(request) return self.__redirect_to_login(request)
@ -51,18 +52,23 @@ class AuthMiddleware(BaseHTTPMiddleware):
if not session_data: if not session_data:
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="The session is invalid.") return self.__handle_api_failure(status=401, detail="The session is invalid.")
return self.__redirect_to_login(request) return self.__redirect_to_login(request)
if session_data.expires_at < datetime.now(timezone.utc): if session_data.expires_at < datetime.now(timezone.utc):
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="The session has expired.") return self.__handle_api_failure(status=401, detail="The session has expired.")
return self.__redirect_to_login(request) return self.__redirect_to_login(request)
return await call_next(request) return await call_next(request)
def __handle_api_failure(self, status: int, detail: str):
exc = HTTPException(status_code=status, detail=detail)
return exception_handler(exc)
def __redirect_to_login(self, request: Request): def __redirect_to_login(self, request: Request):
next_url = quote(str(request.url)) next_url = quote(str(request.url))
redirect_url = str(request.url_for('login')) + f'?next_url={next_url}' redirect_url = str(request.url_for('login')) + f'?next_url={next_url}'