Fix wrong api key error handling
This commit is contained in:
@ -1 +1 @@
|
|||||||
from .handler import setup_exception_handler
|
from .handler import setup_exception_handler, exception_handler
|
||||||
|
|||||||
@ -13,8 +13,12 @@ def setup_exception_handler(app: FastAPI):
|
|||||||
Setup exception handler for FastAPI.
|
Setup exception handler for FastAPI.
|
||||||
'''
|
'''
|
||||||
@app.exception_handler(HTTPException)
|
@app.exception_handler(HTTPException)
|
||||||
async def http_exception_handler(request: Request, exc: HTTPException): # type: ignore
|
async def http_exception_handler_wrapper(request: Request, exc: HTTPException): # type: ignore
|
||||||
return JSONResponse(
|
return exception_handler(exc)
|
||||||
status_code=exc.status_code,
|
|
||||||
content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(),
|
|
||||||
)
|
def exception_handler(exc: HTTPException): # type: ignore
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=JSONErrorResponse(status=exc.status_code, detail=exc.detail).model_dump(),
|
||||||
|
)
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from fastapi import Request, Response, HTTPException
|
from fastapi import Request, Response, HTTPException
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from exception_handler import exception_handler
|
||||||
from session import SessionManager
|
from session import SessionManager
|
||||||
from config import CONFIGS
|
from config import CONFIGS
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
self.__session_manager = session_manager
|
self.__session_manager = session_manager
|
||||||
self.__api_token = api_token
|
self.__api_token = api_token
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]):
|
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): # type: ignore
|
||||||
'''Handles session authentication.'''
|
'''Handles session authentication.'''
|
||||||
public_routes = [
|
public_routes = [
|
||||||
f'/{CONFIGS.ROOT_PATH}/login',
|
f'/{CONFIGS.ROOT_PATH}/login',
|
||||||
@ -36,14 +37,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if api_key == self.__api_token:
|
if api_key == self.__api_token:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=401, detail="Invalid API token.")
|
return self.__handle_api_failure(status=401, detail="Invalid API token.")
|
||||||
|
|
||||||
# Extract session_id from cookies
|
# Extract session_id from cookies
|
||||||
session_id = request.cookies.get("session_id")
|
session_id = request.cookies.get("session_id")
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized.")
|
return self.__handle_api_failure(status=401, detail="Unauthorized.")
|
||||||
|
|
||||||
return self.__redirect_to_login(request)
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
@ -51,18 +52,23 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
if not session_data:
|
if not session_data:
|
||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="The session is invalid.")
|
return self.__handle_api_failure(status=401, detail="The session is invalid.")
|
||||||
|
|
||||||
return self.__redirect_to_login(request)
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
if session_data.expires_at < datetime.now(timezone.utc):
|
if session_data.expires_at < datetime.now(timezone.utc):
|
||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="The session has expired.")
|
return self.__handle_api_failure(status=401, detail="The session has expired.")
|
||||||
|
|
||||||
return self.__redirect_to_login(request)
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
def __handle_api_failure(self, status: int, detail: str):
|
||||||
|
exc = HTTPException(status_code=status, detail=detail)
|
||||||
|
|
||||||
|
return exception_handler(exc)
|
||||||
|
|
||||||
def __redirect_to_login(self, request: Request):
|
def __redirect_to_login(self, request: Request):
|
||||||
next_url = quote(str(request.url))
|
next_url = quote(str(request.url))
|
||||||
redirect_url = str(request.url_for('login')) + f'?next_url={next_url}'
|
redirect_url = str(request.url_for('login')) + f'?next_url={next_url}'
|
||||||
|
|||||||
Reference in New Issue
Block a user