Redirect after successful login
This commit is contained in:
@ -1,9 +1,10 @@
|
|||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from fastapi import Request, Response, HTTPException
|
from fastapi import Request, Response, HTTPException
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
from typing import Awaitable, Callable
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Awaitable, Callable
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from session import SessionManager
|
from session import SessionManager
|
||||||
from config import CONFIGS
|
from config import CONFIGS
|
||||||
@ -31,18 +32,20 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_api_request:
|
if is_api_request:
|
||||||
if self.__api_token:
|
if self.__api_token:
|
||||||
# Attempt to authenticate with API token
|
# Attempt to authenticate with API token
|
||||||
if auth_header := request.headers.get('Authorization'):
|
if api_key := request.headers.get('Authorization'):
|
||||||
scheme, _, token = auth_header.partition(' ')
|
if api_key == self.__api_token:
|
||||||
if scheme.lower() == 'bearer' and token == self.__api_token:
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API token.")
|
||||||
|
|
||||||
# Extract session_id from cookies
|
# Extract session_id from cookies
|
||||||
session_id = request.cookies.get("session_id")
|
session_id = request.cookies.get("session_id")
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized.")
|
||||||
return RedirectResponse(url=request.url_for('login'), status_code=302)
|
|
||||||
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
session_data = self.__session_manager.get_session(session_id)
|
session_data = self.__session_manager.get_session(session_id)
|
||||||
|
|
||||||
@ -50,12 +53,18 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="The session is invalid.")
|
raise HTTPException(status_code=401, detail="The session is invalid.")
|
||||||
|
|
||||||
return RedirectResponse(url=request.url_for('login'), status_code=302)
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
if session_data.expires_at < datetime.now(timezone.utc):
|
if session_data.expires_at < datetime.now(timezone.utc):
|
||||||
if is_api_request:
|
if is_api_request:
|
||||||
raise HTTPException(status_code=401, detail="The session has expired.")
|
raise HTTPException(status_code=401, detail="The session has expired.")
|
||||||
|
|
||||||
return RedirectResponse(url=request.url_for('login'), status_code=302)
|
return self.__redirect_to_login(request)
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
def __redirect_to_login(self, request: Request):
|
||||||
|
next_url = quote(str(request.url))
|
||||||
|
redirect_url = str(request.url_for('login')) + f'?next_url={next_url}'
|
||||||
|
|
||||||
|
return RedirectResponse(url=redirect_url, status_code=302)
|
||||||
|
|||||||
@ -18,8 +18,11 @@ async def login(request: Request, templates: Jinja2Templates = Depends(get_templ
|
|||||||
async def login_post(
|
async def login_post(
|
||||||
request: Request,
|
request: Request,
|
||||||
templates: Jinja2Templates = Depends(get_templates), session_manager: SessionManager = Depends(get_session_manager),
|
templates: Jinja2Templates = Depends(get_templates), session_manager: SessionManager = Depends(get_session_manager),
|
||||||
username: str = Form(), password: str = Form()
|
username: str = Form(), password: str = Form(), next_url: str = Form(default='/')
|
||||||
):
|
):
|
||||||
|
'''
|
||||||
|
Handles login form submission.
|
||||||
|
'''
|
||||||
ADMIN_USERNAME = CONFIGS.ADMIN_USERNAME
|
ADMIN_USERNAME = CONFIGS.ADMIN_USERNAME
|
||||||
ADMIN_PASSWORD = CONFIGS.ADMIN_PASSWORD
|
ADMIN_PASSWORD = CONFIGS.ADMIN_PASSWORD
|
||||||
|
|
||||||
@ -28,7 +31,13 @@ async def login_post(
|
|||||||
|
|
||||||
session_id = session_manager.set_session(username)
|
session_id = session_manager.set_session(username)
|
||||||
|
|
||||||
res = RedirectResponse(url=request.url_for('index'), status_code=302)
|
# Redirect to the index page if there is no next query parameter in the URL
|
||||||
|
if next_url == '/':
|
||||||
|
redirect_url = request.url_for('index')
|
||||||
|
else:
|
||||||
|
redirect_url = next_url
|
||||||
|
|
||||||
|
res = RedirectResponse(url=redirect_url, status_code=302)
|
||||||
res.set_cookie(key='session_id', value=session_id)
|
res.set_cookie(key='session_id', value=session_id)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
|
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8">
|
<meta charset="utf-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
@ -14,6 +15,7 @@
|
|||||||
<!-- Theme style -->
|
<!-- Theme style -->
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/admin-lte/3.2.0/css/adminlte.min.css">
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/admin-lte/3.2.0/css/adminlte.min.css">
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body class="hold-transition login-page">
|
<body class="hold-transition login-page">
|
||||||
<div class="login-box">
|
<div class="login-box">
|
||||||
<div class="login-logo">
|
<div class="login-logo">
|
||||||
@ -29,6 +31,7 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
<form action="{{ url_for('login') }}" method="post">
|
<form action="{{ url_for('login') }}" method="post">
|
||||||
|
<input type="hidden" name="next_url" value="{{ request.query_params.get('next_url', '/') }}">
|
||||||
<div class="input-group mb-3">
|
<div class="input-group mb-3">
|
||||||
<input type="text" name="username" class="form-control" placeholder="Username" required>
|
<input type="text" name="username" class="form-control" placeholder="Username" required>
|
||||||
<div class="input-group-append">
|
<div class="input-group-append">
|
||||||
@ -64,4 +67,5 @@
|
|||||||
<!-- AdminLTE App -->
|
<!-- AdminLTE App -->
|
||||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/admin-lte/3.2.0/js/adminlte.min.js"></script>
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/admin-lte/3.2.0/js/adminlte.min.js"></script>
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
Reference in New Issue
Block a user