Implement custom url_for to generating urls with ROOT_PATH prefix

This commit is contained in:
Iam54r1n4
2025-02-04 16:34:52 +00:00
parent 468f4a4abc
commit 360e6ac4ba
4 changed files with 29 additions and 8 deletions

View File

@ -1 +1 @@
from .dependency import get_templates, get_session_manager from .dependency import get_templates, get_session_manager, url_for

View File

@ -1,4 +1,8 @@
from fastapi import Request
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from jinja2 import pass_context
from typing import Any
from starlette.datastructures import URL
from session import SessionStorage, SessionManager from session import SessionStorage, SessionManager
from config import CONFIGS from config import CONFIGS
@ -6,6 +10,21 @@ from config import CONFIGS
__TEMPLATES = Jinja2Templates(directory='templates') __TEMPLATES = Jinja2Templates(directory='templates')
@pass_context
def url_for(context: dict[str, Any], name: str = '', **path_params: dict[str, Any]) -> URL:
'''
Custom url_for function for Jinja2 to add a prefix to the generated URL.
'''
request: Request = context["request"]
url = request.url_for(name, **path_params)
prefixed_path = f"{CONFIGS.ROOT_PATH.rstrip('/')}/{url.path.lstrip('/')}"
return url.replace(path=prefixed_path)
__TEMPLATES.env.globals['url_for'] = url_for # type: ignore
def get_templates() -> Jinja2Templates: def get_templates() -> Jinja2Templates:
return __TEMPLATES return __TEMPLATES

View File

@ -5,6 +5,7 @@ from starlette.types import ASGIApp
from typing import Awaitable, Callable from typing import Awaitable, Callable
from datetime import datetime, timezone from datetime import datetime, timezone
from dependency import url_for
from session import SessionManager from session import SessionManager
@ -15,6 +16,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
super().__init__(app) super().__init__(app)
self.__session_manager = session_manager self.__session_manager = session_manager
self.__api_token = api_token self.__api_token = api_token
self.__url_for = url_for
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]): async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]):
'''Handles session authentication.''' '''Handles session authentication.'''
@ -41,7 +43,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
if not session_id: if not session_id:
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
return RedirectResponse(url='/login', status_code=302) return RedirectResponse(url=self.__url_for(context={'request': request}, name='login'), status_code=302)
session_data = self.__session_manager.get_session(session_id) session_data = self.__session_manager.get_session(session_id)
@ -49,12 +51,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="The session is invalid.") raise HTTPException(status_code=401, detail="The session is invalid.")
return RedirectResponse(url='/login', status_code=302) return RedirectResponse(url=self.__url_for(context={'request': request}, name='login'), status_code=302)
if session_data.expires_at < datetime.now(timezone.utc): if session_data.expires_at < datetime.now(timezone.utc):
if is_api_request: if is_api_request:
raise HTTPException(status_code=401, detail="The session has expired.") raise HTTPException(status_code=401, detail="The session has expired.")
return RedirectResponse(url='/login', status_code=302) return RedirectResponse(url=self.__url_for(context={'request': request}, name='login'), status_code=302)
return await call_next(request) return await call_next(request)

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from dependency import get_templates, get_session_manager from dependency import get_templates, get_session_manager, url_for
from session import SessionManager from session import SessionManager
from config import CONFIGS from config import CONFIGS
@ -18,7 +18,7 @@ async def login(request: Request, templates: Jinja2Templates = Depends(get_templ
async def login_post( async def login_post(
request: Request, request: Request,
templates: Jinja2Templates = Depends(get_templates), session_manager: SessionManager = Depends(get_session_manager), templates: Jinja2Templates = Depends(get_templates), session_manager: SessionManager = Depends(get_session_manager),
username: str = Form(), password: str = Form(), username: str = Form(), password: str = Form()
): ):
ADMIN_USERNAME = CONFIGS.ADMIN_USERNAME ADMIN_USERNAME = CONFIGS.ADMIN_USERNAME
ADMIN_PASSWORD = CONFIGS.ADMIN_PASSWORD ADMIN_PASSWORD = CONFIGS.ADMIN_PASSWORD
@ -28,7 +28,7 @@ async def login_post(
session_id = session_manager.set_session(username) session_id = session_manager.set_session(username)
res = RedirectResponse(url='/', status_code=302) res = RedirectResponse(url=url_for(context={'request': request}, name='index'), status_code=302)
res.set_cookie(key='session_id', value=session_id) res.set_cookie(key='session_id', value=session_id)
return res return res
@ -40,6 +40,6 @@ async def logout(request: Request, session_manager: SessionManager = Depends(get
if session_id: if session_id:
session_manager.revoke_session(session_id) session_manager.revoke_session(session_id)
res = RedirectResponse(url='/', status_code=302) res = RedirectResponse(url=url_for(context={'request': request}, name='index'), status_code=302)
res.delete_cookie('session_id') res.delete_cookie('session_id')
return res return res