feat(api): Add endpoint for receiving node traffic data
This commit is contained in:
@ -2,13 +2,15 @@ from fastapi import APIRouter, HTTPException
|
|||||||
from ..schema.response import DetailResponse
|
from ..schema.response import DetailResponse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from scripts.db.database import db
|
||||||
|
|
||||||
from ..schema.config.ip import (
|
from ..schema.config.ip import (
|
||||||
EditInputBody,
|
EditInputBody,
|
||||||
StatusResponse,
|
StatusResponse,
|
||||||
AddNodeBody,
|
AddNodeBody,
|
||||||
DeleteNodeBody,
|
DeleteNodeBody,
|
||||||
NodeListResponse
|
NodeListResponse,
|
||||||
|
NodesTrafficPayload
|
||||||
)
|
)
|
||||||
import cli_api
|
import cli_api
|
||||||
|
|
||||||
@ -119,4 +121,41 @@ async def delete_node(body: DeleteNodeBody):
|
|||||||
cli_api.delete_node(body.name)
|
cli_api.delete_node(body.name)
|
||||||
return DetailResponse(detail=f"Node '{body.name}' deleted successfully.")
|
return DetailResponse(detail=f"Node '{body.name}' deleted successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/nodestraffic', response_model=DetailResponse, summary='Receive and Aggregate Traffic from Node')
|
||||||
|
async def receive_node_traffic(body: NodesTrafficPayload):
|
||||||
|
"""
|
||||||
|
Receives traffic delta from a node and adds it to the user's total in the database.
|
||||||
|
Authentication is handled by the AuthMiddleware.
|
||||||
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Database connection is not available.")
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
for user_traffic in body.users:
|
||||||
|
try:
|
||||||
|
db_user = db.get_user(user_traffic.username)
|
||||||
|
if not db_user:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_upload = db_user.get('upload_bytes', 0) + user_traffic.upload_bytes
|
||||||
|
new_download = db_user.get('download_bytes', 0) + user_traffic.download_bytes
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
'upload_bytes': new_upload,
|
||||||
|
'download_bytes': new_download,
|
||||||
|
'status': user_traffic.status,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not db_user.get('account_creation_date') and user_traffic.account_creation_date:
|
||||||
|
update_data['account_creation_date'] = user_traffic.account_creation_date
|
||||||
|
|
||||||
|
db.update_user(user_traffic.username, update_data)
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating traffic for user {user_traffic.username}: {e}")
|
||||||
|
|
||||||
|
return DetailResponse(detail=f"Successfully processed and aggregated traffic for {updated_count} users.")
|
||||||
@ -1,7 +1,8 @@
|
|||||||
from pydantic import BaseModel, field_validator, Field
|
from pydantic import BaseModel, field_validator
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
def validate_ip_or_domain(v: str) -> str | None:
|
def validate_ip_or_domain(v: str) -> str | None:
|
||||||
if v is None or v.strip() in ['', 'None']:
|
if v is None or v.strip() in ['', 'None']:
|
||||||
@ -35,7 +36,7 @@ class EditInputBody(StatusResponse):
|
|||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
ip: str
|
ip: str
|
||||||
port: Optional[int] = Field(default=None, ge=1, le=65535)
|
port: Optional[int] = None
|
||||||
sni: Optional[str] = None
|
sni: Optional[str] = None
|
||||||
pinSHA256: Optional[str] = None
|
pinSHA256: Optional[str] = None
|
||||||
obfs: Optional[str] = None
|
obfs: Optional[str] = None
|
||||||
@ -47,43 +48,40 @@ class Node(BaseModel):
|
|||||||
raise ValueError("IP or Domain field cannot be empty.")
|
raise ValueError("IP or Domain field cannot be empty.")
|
||||||
return validate_ip_or_domain(v)
|
return validate_ip_or_domain(v)
|
||||||
|
|
||||||
|
@field_validator('port')
|
||||||
|
def check_port(cls, v: int | None):
|
||||||
|
if v is not None and not (1 <= v <= 65535):
|
||||||
|
raise ValueError('Port must be between 1 and 65535.')
|
||||||
|
return v
|
||||||
|
|
||||||
@field_validator('sni', mode='before')
|
@field_validator('sni', mode='before')
|
||||||
def validate_sni_format(cls, v: str | None):
|
def check_sni(cls, v: str | None):
|
||||||
if v is None or not v.strip():
|
if v is None or not v.strip():
|
||||||
return None
|
return None
|
||||||
|
v = v.strip()
|
||||||
v_stripped = v.strip()
|
|
||||||
|
|
||||||
if "://" in v_stripped:
|
|
||||||
raise ValueError("SNI must not contain a protocol (e.g., http://).")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ip_address(v_stripped)
|
ip_address(v)
|
||||||
raise ValueError("SNI cannot be an IP address.")
|
raise ValueError("SNI must be a domain name, not an IP address.")
|
||||||
except ValueError as e:
|
except ValueError:
|
||||||
if "SNI cannot be an IP address" in str(e):
|
pass
|
||||||
raise e
|
if "://" in v:
|
||||||
|
raise ValueError("SNI cannot contain '://'")
|
||||||
domain_regex = re.compile(
|
domain_regex = re.compile(
|
||||||
r'^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$',
|
r'^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$',
|
||||||
re.IGNORECASE
|
re.IGNORECASE
|
||||||
)
|
)
|
||||||
if not domain_regex.match(v_stripped):
|
if not domain_regex.match(v):
|
||||||
raise ValueError(f"'{v_stripped}' is not a valid domain name for SNI.")
|
raise ValueError("Invalid domain name format for SNI.")
|
||||||
|
return v
|
||||||
return v_stripped
|
|
||||||
|
|
||||||
@field_validator('pinSHA256', mode='before')
|
@field_validator('pinSHA256', mode='before')
|
||||||
def validate_pin_format(cls, v: str | None):
|
def check_pin(cls, v: str | None):
|
||||||
if v is None or not v.strip():
|
if v is None or not v.strip():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
v_stripped = v.strip().upper()
|
v_stripped = v.strip().upper()
|
||||||
pin_regex = re.compile(r'^([0-9A-F]{2}:){31}[0-9A-F]{2}$')
|
pin_regex = re.compile(r'^([0-9A-F]{2}:){31}[0-9A-F]{2}$')
|
||||||
|
|
||||||
if not pin_regex.match(v_stripped):
|
if not pin_regex.match(v_stripped):
|
||||||
raise ValueError("Invalid SHA256 pin format.")
|
raise ValueError("Invalid SHA256 pin format.")
|
||||||
|
|
||||||
return v_stripped
|
return v_stripped
|
||||||
|
|
||||||
class AddNodeBody(Node):
|
class AddNodeBody(Node):
|
||||||
@ -92,4 +90,24 @@ class AddNodeBody(Node):
|
|||||||
class DeleteNodeBody(BaseModel):
|
class DeleteNodeBody(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
NodeListResponse = list[Node]
|
NodeListResponse = list[Node]
|
||||||
|
|
||||||
|
class NodeUserTraffic(BaseModel):
|
||||||
|
username: str
|
||||||
|
upload_bytes: int
|
||||||
|
download_bytes: int
|
||||||
|
status: str
|
||||||
|
account_creation_date: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator('account_creation_date')
|
||||||
|
def check_date_format(cls, v: str | None):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
datetime.strptime(v, "%Y-%m-%d")
|
||||||
|
return v
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("account_creation_date must be in YYYY-MM-DD format.")
|
||||||
|
|
||||||
|
class NodesTrafficPayload(BaseModel):
|
||||||
|
users: List[NodeUserTraffic]
|
||||||
Reference in New Issue
Block a user