From 9dcace9792f80321192b67456e22e971226374b3 Mon Sep 17 00:00:00 2001 From: Iam54r1n4 Date: Fri, 7 Feb 2025 17:38:05 +0000 Subject: [PATCH] Fix wrong api key error handling --- .../webpanel/exception_handler/__init__.py | 2 +- .../webpanel/exception_handler/handler.py | 14 +++++++++----- core/scripts/webpanel/middleware/auth.py | 18 ++++++++++++------ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/core/scripts/webpanel/exception_handler/__init__.py b/core/scripts/webpanel/exception_handler/__init__.py index 3dc88f1..8a11b3f 100644 --- a/core/scripts/webpanel/exception_handler/__init__.py +++ b/core/scripts/webpanel/exception_handler/__init__.py @@ -1 +1 @@ -from .handler import setup_exception_handler +from .handler import setup_exception_handler, exception_handler diff --git a/core/scripts/webpanel/exception_handler/handler.py b/core/scripts/webpanel/exception_handler/handler.py index d8520d0..8ddfe2d 100644 --- a/core/scripts/webpanel/exception_handler/handler.py +++ b/core/scripts/webpanel/exception_handler/handler.py @@ -13,8 +13,12 @@ def setup_exception_handler(app: FastAPI): Setup exception handler for FastAPI. ''' @app.exception_handler(HTTPException) - async def http_exception_handler(request: Request, exc: HTTPException): # type: ignore - return JSONResponse( - status_code=exc.status_code, - content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(), - ) + async def http_exception_handler_wrapper(request: Request, exc: HTTPException): # type: ignore + return exception_handler(exc) + + +def exception_handler(exc: HTTPException): # type: ignore + return JSONResponse( + status_code=exc.status_code, + content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(), + ) diff --git a/core/scripts/webpanel/middleware/auth.py b/core/scripts/webpanel/middleware/auth.py index dc80179..cfc893d 100644 --- a/core/scripts/webpanel/middleware/auth.py +++ b/core/scripts/webpanel/middleware/auth.py @@ -1,11 +1,12 @@ -from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request, Response, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware from fastapi.responses import RedirectResponse from datetime import datetime, timezone from typing import Awaitable, Callable from starlette.types import ASGIApp from urllib.parse import quote +from exception_handler import exception_handler from session import SessionManager from config import CONFIGS @@ -18,7 +19,7 @@ class AuthMiddleware(BaseHTTPMiddleware): self.__session_manager = session_manager self.__api_token = api_token - async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): # type: ignore '''Handles session authentication.''' public_routes = [ f'/{CONFIGS.ROOT_PATH}/login', @@ -36,14 +37,14 @@ class AuthMiddleware(BaseHTTPMiddleware): if api_key == self.__api_token: return await call_next(request) else: - raise HTTPException(status_code=401, detail="Invalid API token.") + return self.__handle_api_failure(status=401, detail="Invalid API token.") # Extract session_id from cookies session_id = request.cookies.get("session_id") if not session_id: if is_api_request: - raise HTTPException(status_code=401, detail="Unauthorized.") + return self.__handle_api_failure(status=401, detail="Unauthorized.") return self.__redirect_to_login(request) @@ -51,18 +52,23 @@ class AuthMiddleware(BaseHTTPMiddleware): if not session_data: if is_api_request: - raise HTTPException(status_code=401, detail="The session is invalid.") + return self.__handle_api_failure(status=401, detail="The session is invalid.") return self.__redirect_to_login(request) if session_data.expires_at < datetime.now(timezone.utc): if is_api_request: - raise HTTPException(status_code=401, detail="The session has expired.") + return self.__handle_api_failure(status=401, detail="The session has expired.") return self.__redirect_to_login(request) return await call_next(request) + def __handle_api_failure(self, status: int, detail: str): + exc = HTTPException(status_code=status, detail=detail) + + return exception_handler(exc) + def __redirect_to_login(self, request: Request): next_url = quote(str(request.url)) redirect_url = str(request.url_for('login')) + f'?next_url={next_url}'