Fix: Enforce strict subpath validation and URL safety

This commit is contained in:
Whispering Wind
2025-02-28 21:17:33 +03:30
committed by GitHub
parent 40bbac95ab
commit eb2b3f590b

View File

@ -1,4 +1,3 @@
# import logging
import os import os
import ssl import ssl
import json import json
@ -19,12 +18,9 @@ import qrcode
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
load_dotenv() load_dotenv()
# logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@dataclass @dataclass
class AppConfig: class AppConfig:
"""Application configuration settings"""
domain: str domain: str
cert_file: str cert_file: str
key_file: str key_file: str
@ -38,47 +34,33 @@ class AppConfig:
template_dir: str template_dir: str
subpath: str subpath: str
class RateLimiter: class RateLimiter:
"""Handles rate limiting for requests"""
def __init__(self, limit: int, window: int): def __init__(self, limit: int, window: int):
self.limit = limit self.limit = limit
self.window = window self.window = window
self.store: Dict[str, Tuple[int, float]] = {} self.store: Dict[str, Tuple[int, float]] = {}
def check_limit(self, client_ip: str) -> bool: def check_limit(self, client_ip: str) -> bool:
"""Checks if a client has exceeded their rate limit
Returns:
bool: True if rate limit not exceeded, False otherwise
"""
current_time = time.monotonic() current_time = time.monotonic()
requests, last_request_time = self.store.get(client_ip, (0, 0)) requests, last_request_time = self.store.get(client_ip, (0, 0))
if current_time - last_request_time < self.window: if current_time - last_request_time < self.window:
if requests >= self.limit: if requests >= self.limit:
return False return False
else: else:
requests = 0 requests = 0
self.store[client_ip] = (requests + 1, current_time) self.store[client_ip] = (requests + 1, current_time)
return True return True
@dataclass @dataclass
class UriComponents: class UriComponents:
"""Components extracted from a Hysteria2 URI"""
username: Optional[str] username: Optional[str]
password: Optional[str] password: Optional[str]
ip: Optional[str] ip: Optional[str]
port: Optional[int] port: Optional[int]
obfs_password: str obfs_password: str
@dataclass @dataclass
class UserInfo: class UserInfo:
"""User information and statistics"""
username: str username: str
upload_bytes: int upload_bytes: int
download_bytes: int download_bytes: int
@ -88,12 +70,10 @@ class UserInfo:
@property @property
def total_usage(self) -> int: def total_usage(self) -> int:
"""Total bandwidth usage"""
return self.upload_bytes + self.download_bytes return self.upload_bytes + self.download_bytes
@property @property
def expiration_timestamp(self) -> int: def expiration_timestamp(self) -> int:
"""Unix timestamp when account expires"""
if not self.account_creation_date or self.expiration_days <= 0: if not self.account_creation_date or self.expiration_days <= 0:
return 0 return 0
creation_timestamp = int(time.mktime(time.strptime(self.account_creation_date, "%Y-%m-%d"))) creation_timestamp = int(time.mktime(time.strptime(self.account_creation_date, "%Y-%m-%d")))
@ -101,7 +81,6 @@ class UserInfo:
@property @property
def expiration_date(self) -> str: def expiration_date(self) -> str:
"""Formatted expiration date string"""
if not self.account_creation_date or self.expiration_days <= 0: if not self.account_creation_date or self.expiration_days <= 0:
return "N/A" return "N/A"
creation_timestamp = int(time.mktime(time.strptime(self.account_creation_date, "%Y-%m-%d"))) creation_timestamp = int(time.mktime(time.strptime(self.account_creation_date, "%Y-%m-%d")))
@ -110,23 +89,19 @@ class UserInfo:
@property @property
def usage_human_readable(self) -> str: def usage_human_readable(self) -> str:
"""Human readable string of usage"""
total = Utils.human_readable_bytes(self.max_download_bytes) total = Utils.human_readable_bytes(self.max_download_bytes)
used = Utils.human_readable_bytes(self.total_usage) used = Utils.human_readable_bytes(self.total_usage)
return f"{used} / {total}" return f"{used} / {total}"
@property @property
def usage_detailed(self) -> str: def usage_detailed(self) -> str:
"""Detailed usage breakdown"""
total = Utils.human_readable_bytes(self.max_download_bytes) total = Utils.human_readable_bytes(self.max_download_bytes)
upload = Utils.human_readable_bytes(self.upload_bytes) upload = Utils.human_readable_bytes(self.upload_bytes)
download = Utils.human_readable_bytes(self.download_bytes) download = Utils.human_readable_bytes(self.download_bytes)
return f"Upload: {upload}, Download: {download}, Total: {total}" return f"Upload: {upload}, Download: {download}, Total: {total}"
@dataclass @dataclass
class TemplateContext: class TemplateContext:
"""Context for HTML template rendering"""
username: str username: str
usage: str usage: str
usage_raw: str usage_raw: str
@ -138,29 +113,18 @@ class TemplateContext:
ipv4_uri: Optional[str] ipv4_uri: Optional[str]
ipv6_uri: Optional[str] ipv6_uri: Optional[str]
class Utils: class Utils:
"""Utility functions"""
@staticmethod @staticmethod
def sanitize_input(value: str, pattern: str) -> str: def sanitize_input(value: str, pattern: str) -> str:
"""Sanitizes input using a regex pattern and quotes it for shell commands"""
if not re.match(pattern, value): if not re.match(pattern, value):
raise ValueError(f"Invalid value: {value}") raise ValueError(f"Invalid value: {value}")
return shlex.quote(value) return shlex.quote(value)
@staticmethod @staticmethod
def generate_qrcode_base64(data: str) -> str: def generate_qrcode_base64(data: str) -> str:
"""Generates a base64-encoded PNG QR code image"""
if not data: if not data:
return None return None
qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=10, border=4)
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(data) qr.add_data(data)
qr.make(fit=True) qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white") img = qr.make_image(fill_color="black", back_color="white")
@ -170,7 +134,6 @@ class Utils:
@staticmethod @staticmethod
def human_readable_bytes(bytes_value: int) -> str: def human_readable_bytes(bytes_value: int) -> str:
"""Converts bytes to a human-readable string (KB, MB, GB, etc.)"""
units = ["Bytes", "KB", "MB", "GB", "TB"] units = ["Bytes", "KB", "MB", "GB", "TB"]
size = float(bytes_value) size = float(bytes_value)
for unit in units: for unit in units:
@ -181,27 +144,30 @@ class Utils:
@staticmethod @staticmethod
def build_url(base: str, path: str) -> str: def build_url(base: str, path: str) -> str:
"""Constructs a URL, handling potential double slashes correctly."""
return urljoin(base, path) return urljoin(base, path)
@staticmethod
def is_valid_url(url: str) -> bool:
"""Checks if the given string is a valid URL."""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
class HysteriaCLI: class HysteriaCLI:
"""Interface for Hysteria CLI commands"""
def __init__(self, cli_path: str): def __init__(self, cli_path: str):
self.cli_path = cli_path self.cli_path = cli_path
def _run_command(self, args: List[str]) -> str: def _run_command(self, args: List[str]) -> str:
"""Runs the hysteria CLI with the given arguments and returns the output"""
try: try:
command = ['python3', self.cli_path] + args command = ['python3', self.cli_path] + args
return subprocess.check_output(command, stderr=subprocess.DEVNULL, text=True).strip() return subprocess.check_output(command, stderr=subprocess.DEVNULL, text=True).strip()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Hysteria CLI error: {e}") # Log the error print(f"Hysteria CLI error: {e}")
raise raise
def get_user_info(self, username: str) -> UserInfo: def get_user_info(self, username: str) -> UserInfo:
"""Retrieves user information"""
raw_info = json.loads(self._run_command(['get-user', '-u', username])) raw_info = json.loads(self._run_command(['get-user', '-u', username]))
return UserInfo( return UserInfo(
username=username, username=username,
@ -213,28 +179,20 @@ class HysteriaCLI:
) )
def get_user_uri(self, username: str, ip_version: Optional[str] = None) -> str: def get_user_uri(self, username: str, ip_version: Optional[str] = None) -> str:
"""Gets the URI for a user, optionally specifying IP version"""
if ip_version: if ip_version:
return self._run_command(['show-user-uri', '-u', username, '-ip', ip_version]) return self._run_command(['show-user-uri', '-u', username, '-ip', ip_version])
else: else:
output = self._run_command(['show-user-uri', '-u', username, '-a']) return self._run_command(['show-user-uri', '-u', username, '-a'])
return output
def get_uris(self, username: str) -> Tuple[Optional[str], Optional[str]]: def get_uris(self, username: str) -> Tuple[Optional[str], Optional[str]]:
"""Retrieves IPv4 and IPv6 URIs for a user"""
output = self._run_command(['show-user-uri', '-u', username, '-a']) output = self._run_command(['show-user-uri', '-u', username, '-a'])
ipv4_uri = re.search(r'IPv4:\s*(.*)', output) ipv4_uri = re.search(r'IPv4:\s*(.*)', output)
ipv6_uri = re.search(r'IPv6:\s*(.*)', output) ipv6_uri = re.search(r'IPv6:\s*(.*)', output)
return (ipv4_uri.group(1).strip() if ipv4_uri else None, return (ipv4_uri.group(1).strip() if ipv4_uri else None, ipv6_uri.group(1).strip() if ipv6_uri else None)
ipv6_uri.group(1).strip() if ipv6_uri else None)
class UriParser: class UriParser:
"""Parser for Hysteria2 URIs"""
@staticmethod @staticmethod
def extract_uri_components(uri: Optional[str], prefix: str) -> Optional[UriComponents]: def extract_uri_components(uri: Optional[str], prefix: str) -> Optional[UriComponents]:
"""Extracts components from a Hysteria2 URI"""
if not uri or not uri.startswith(prefix): if not uri or not uri.startswith(prefix):
return None return None
uri = uri[len(prefix):].strip() uri = uri[len(prefix):].strip()
@ -245,14 +203,7 @@ class UriParser:
hostname = parsed_url.hostname hostname = parsed_url.hostname
if hostname and hostname.startswith('[') and hostname.endswith(']'): if hostname and hostname.startswith('[') and hostname.endswith(']'):
hostname = hostname[1:-1] hostname = hostname[1:-1]
port = parsed_url.port if parsed_url.port is not None else None
port = None
if parsed_url.port is not None:
try:
port = int(parsed_url.port)
except ValueError:
print(f"Warning: Invalid port in URI: {parsed_url.port}")
return UriComponents( return UriComponents(
username=parsed_url.username, username=parsed_url.username,
password=parsed_url.password, password=parsed_url.password,
@ -264,10 +215,7 @@ class UriParser:
print(f"Error during URI parsing: {e}, URI: {uri}") print(f"Error during URI parsing: {e}, URI: {uri}")
return None return None
class SingboxConfigGenerator: class SingboxConfigGenerator:
"""Generator for Sing-box configurations"""
def __init__(self, hysteria_cli: HysteriaCLI, default_sni: str): def __init__(self, hysteria_cli: HysteriaCLI, default_sni: str):
self.hysteria_cli = hysteria_cli self.hysteria_cli = hysteria_cli
self.default_sni = default_sni self.default_sni = default_sni
@ -275,12 +223,10 @@ class SingboxConfigGenerator:
self.template_path = None self.template_path = None
def set_template_path(self, path: str): def set_template_path(self, path: str):
"""Sets the path to the template file"""
self.template_path = path self.template_path = path
self._template_cache = None self._template_cache = None
def get_template(self) -> Dict[str, Any]: def get_template(self) -> Dict[str, Any]:
"""Loads and caches the singbox template"""
if self._template_cache is None: if self._template_cache is None:
try: try:
with open(self.template_path, 'r') as f: with open(self.template_path, 'r') as f:
@ -290,19 +236,17 @@ class SingboxConfigGenerator:
return self._template_cache.copy() return self._template_cache.copy()
def generate_config(self, username: str, ip_version: str, fragment: str) -> Optional[Dict[str, Any]]: def generate_config(self, username: str, ip_version: str, fragment: str) -> Optional[Dict[str, Any]]:
"""Generates a Sing-box outbound configuration for a given user and IP version"""
try: try:
uri = self.hysteria_cli.get_user_uri(username, ip_version) uri = self.hysteria_cli.get_user_uri(username, ip_version)
except Exception: except Exception:
print(f"Warning: Failed to get URI for {username} with IP version {ip_version}. Skipping.") print(f"Failed to get URI for {username} with IP version {ip_version}. Skipping.")
return None return None
if not uri: if not uri:
print(f"Warning: No URI found for {username} with IP version {ip_version}. Skipping.") print(f"No URI found for {username} with IP version {ip_version}. Skipping.")
return None return None
components = UriParser.extract_uri_components(uri, f'IPv{ip_version}:') components = UriParser.extract_uri_components(uri, f'IPv{ip_version}:')
if components is None or components.port is None: if components is None or components.port is None:
print(f"Warning: Invalid URI components for {username} with IP version {ip_version}. Skipping.") print(f"Invalid URI components for {username} with IP version {ip_version}. Skipping.")
return None return None
return { return {
@ -324,27 +268,21 @@ class SingboxConfigGenerator:
}] }]
} }
def combine_configs(self, username: str, config_v4: Optional[Dict[str, Any]], def combine_configs(self, username: str, config_v4: Optional[Dict[str, Any]], config_v6: Optional[Dict[str, Any]]) -> Dict[str, Any]:
config_v6: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Combines IPv4 and IPv6 configurations into a single config"""
combined_config = self.get_template() combined_config = self.get_template()
combined_config['outbounds'] = [outbound for outbound in combined_config['outbounds'] if outbound.get('type') != 'hysteria2']
combined_config['outbounds'] = [
outbound for outbound in combined_config['outbounds']
if outbound.get('type') != 'hysteria2'
]
modified_v4_outbounds = [] modified_v4_outbounds = []
if config_v4: if config_v4:
v4_outbound = config_v4['outbounds'][0] v4_outbound = config_v4['outbounds'][0]
v4_outbound['tag'] = f"{username}-IPv4" v4_outbound['tag'] = f"{username}-IPv4"
modified_v4_outbounds = [v4_outbound] modified_v4_outbounds.append(v4_outbound)
modified_v6_outbounds = [] modified_v6_outbounds = []
if config_v6: if config_v6:
v6_outbound = config_v6['outbounds'][0] v6_outbound = config_v6['outbounds'][0]
v6_outbound['tag'] = f"{username}-IPv6" v6_outbound['tag'] = f"{username}-IPv6"
modified_v6_outbounds = [v6_outbound] modified_v6_outbounds.append(v6_outbound)
select_outbounds = ["auto"] select_outbounds = ["auto"]
if config_v4: if config_v4:
@ -363,23 +301,17 @@ class SingboxConfigGenerator:
outbound['outbounds'] = select_outbounds outbound['outbounds'] = select_outbounds
elif outbound.get('tag') == 'auto': elif outbound.get('tag') == 'auto':
outbound['outbounds'] = auto_outbounds outbound['outbounds'] = auto_outbounds
combined_config['outbounds'].extend(modified_v4_outbounds + modified_v6_outbounds) combined_config['outbounds'].extend(modified_v4_outbounds + modified_v6_outbounds)
return combined_config return combined_config
class SubscriptionManager: class SubscriptionManager:
"""Handles user subscription generation"""
def __init__(self, hysteria_cli: HysteriaCLI, config: AppConfig): def __init__(self, hysteria_cli: HysteriaCLI, config: AppConfig):
self.hysteria_cli = hysteria_cli self.hysteria_cli = hysteria_cli
self.config = config self.config = config
def get_normal_subscription(self, username: str, user_agent: str) -> str: def get_normal_subscription(self, username: str, user_agent: str) -> str:
"""Generates the user URI for normal subscriptions"""
user_info = self.hysteria_cli.get_user_info(username) user_info = self.hysteria_cli.get_user_info(username)
ipv4_uri, ipv6_uri = self.hysteria_cli.get_uris(username) ipv4_uri, ipv6_uri = self.hysteria_cli.get_uris(username)
output_lines = [uri for uri in [ipv4_uri, ipv6_uri] if uri] output_lines = [uri for uri in [ipv4_uri, ipv6_uri] if uri]
if not output_lines: if not output_lines:
return "No URI available" return "No URI available"
@ -401,48 +333,48 @@ class SubscriptionManager:
f"expire={user_info.expiration_timestamp}\n" f"expire={user_info.expiration_timestamp}\n"
) )
profile_lines = f"//profile-title: {username}-Hysteria2 🚀\n//profile-update-interval: 1\n" profile_lines = f"//profile-title: {username}-Hysteria2 🚀\n//profile-update-interval: 1\n"
return profile_lines + subscription_info + "\n".join(processed_uris) return profile_lines + subscription_info + "\n".join(processed_uris)
class TemplateRenderer: class TemplateRenderer:
"""Handles HTML template rendering"""
def __init__(self, template_dir: str, config: AppConfig): def __init__(self, template_dir: str, config: AppConfig):
self.env = Environment(loader=FileSystemLoader(template_dir), autoescape=True) self.env = Environment(loader=FileSystemLoader(template_dir), autoescape=True)
self.html_template = self.env.get_template('template.html') self.html_template = self.env.get_template('template.html')
self.config = config self.config = config
def render(self, context: TemplateContext) -> str: def render(self, context: TemplateContext) -> str:
"""Renders the HTML template with the given context"""
return self.html_template.render(vars(context)) return self.html_template.render(vars(context))
class HysteriaServer: class HysteriaServer:
"""Main application server class"""
def __init__(self): def __init__(self):
self.config = self._load_config() self.config = self._load_config()
self.rate_limiter = RateLimiter(self.config.rate_limit, self.config.rate_limit_window) self.rate_limiter = RateLimiter(self.config.rate_limit, self.config.rate_limit_window)
self.hysteria_cli = HysteriaCLI(self.config.hysteria_cli_path) self.hysteria_cli = HysteriaCLI(self.config.hysteria_cli_path)
self.singbox_generator = SingboxConfigGenerator(self.hysteria_cli, self.config.sni) self.singbox_generator = SingboxConfigGenerator(self.hysteria_cli, self.config.sni)
self.singbox_generator.set_template_path(self.config.singbox_template_path) self.singbox_generator.set_template_path(self.config.singbox_template_path)
self.subscription_manager = SubscriptionManager(self.hysteria_cli, self.config) self.subscription_manager = SubscriptionManager(self.hysteria_cli, self.config)
self.template_renderer = TemplateRenderer(self.config.template_dir, self.config) self.template_renderer = TemplateRenderer(self.config.template_dir, self.config)
self.app = web.Application(middlewares=[self._rate_limit_middleware]) self.app = web.Application(middlewares=[self._rate_limit_middleware])
self.app.add_routes([web.get(Utils.build_url('/{subpath}/sub/normal/', '{username}'), self.handle)])
self.app.router.add_route('*', '/{subpath:[^{}]+}/{tail:.*}', self.handle_404) safe_subpath = self.validate_and_escape_subpath(self.config.subpath)
self.app.router.add_route('*', '/{tail:.*}', self.handle_404) self.app.add_routes([
web.get(f'/{safe_subpath}/sub/normal/{{username}}', self.handle)
])
self.app.router.add_route('*', f'/{safe_subpath}/{{tail:.*}}', self.handle_404)
self.app.router.add_route('*', '/{tail:.*}', self.handle_generic_404)
def _load_config(self) -> AppConfig: def _load_config(self) -> AppConfig:
"""Loads application configuration from environment variables"""
domain = os.getenv('HYSTERIA_DOMAIN', 'localhost') domain = os.getenv('HYSTERIA_DOMAIN', 'localhost')
cert_file = os.getenv('HYSTERIA_CERTFILE') cert_file = os.getenv('HYSTERIA_CERTFILE')
key_file = os.getenv('HYSTERIA_KEYFILE') key_file = os.getenv('HYSTERIA_KEYFILE')
port = int(os.getenv('HYSTERIA_PORT', '3326')) port = int(os.getenv('HYSTERIA_PORT', '3326'))
subpath = os.getenv('SUBPATH', '').strip().strip("/") subpath = os.getenv('SUBPATH', '').strip().strip("/")
if not self.is_valid_subpath(subpath):
raise ValueError(f"Invalid SUBPATH: '{subpath}'. Subpath must contain only alphanumeric characters, hyphens, and underscores.")
sni_file = '/etc/hysteria/.configs.env' sni_file = '/etc/hysteria/.configs.env'
singbox_template_path = '/etc/hysteria/core/scripts/normalsub/singbox.json' singbox_template_path = '/etc/hysteria/core/scripts/normalsub/singbox.json'
hysteria_cli_path = '/etc/hysteria/core/cli.py' hysteria_cli_path = '/etc/hysteria/core/cli.py'
@ -451,24 +383,12 @@ class HysteriaServer:
template_dir = os.path.dirname(__file__) template_dir = os.path.dirname(__file__)
sni = self._load_sni_from_env(sni_file) sni = self._load_sni_from_env(sni_file)
return AppConfig(domain=domain, cert_file=cert_file, key_file=key_file, port=port, sni_file=sni_file,
return AppConfig( singbox_template_path=singbox_template_path, hysteria_cli_path=hysteria_cli_path,
domain=domain, rate_limit=rate_limit, rate_limit_window=rate_limit_window, sni=sni, template_dir=template_dir,
cert_file=cert_file, subpath=subpath)
key_file=key_file,
port=port,
sni_file=sni_file,
singbox_template_path=singbox_template_path,
hysteria_cli_path=hysteria_cli_path,
rate_limit=rate_limit,
rate_limit_window=rate_limit_window,
sni=sni,
template_dir=template_dir,
subpath=subpath
)
def _load_sni_from_env(self, sni_file: str) -> str: def _load_sni_from_env(self, sni_file: str) -> str:
"""Loads SNI configuration from the environment file"""
try: try:
with open(sni_file, 'r') as f: with open(sni_file, 'r') as f:
for line in f: for line in f:
@ -477,34 +397,36 @@ class HysteriaServer:
except FileNotFoundError: except FileNotFoundError:
print("Warning: SNI file not found. Using default SNI.") print("Warning: SNI file not found. Using default SNI.")
return "bts.com" return "bts.com"
def is_valid_subpath(self, subpath: str) -> bool:
"""Validates the subpath using a regex."""
return bool(re.match(r"^[a-zA-Z0-9_-]+$", subpath))
def validate_and_escape_subpath(self, subpath: str) -> str:
"""Validates the subpath and returns the escaped version."""
if not self.is_valid_subpath(subpath):
raise ValueError(f"Invalid subpath: {subpath}")
return re.escape(subpath)
@middleware @middleware
async def _rate_limit_middleware(self, request: web.Request, handler): async def _rate_limit_middleware(self, request: web.Request, handler):
"""Middleware for rate limiting requests""" client_ip = request.headers.get('X-Forwarded-For', request.headers.get('X-Real-IP', request.remote))
client_ip = request.headers.get('X-Forwarded-For',
request.headers.get('X-Real-IP', request.remote))
if not self.rate_limiter.check_limit(client_ip): if not self.rate_limiter.check_limit(client_ip):
return web.Response(status=429, text="Rate limit exceeded.") return web.Response(status=429, text="Rate limit exceeded.")
return await handler(request) return await handler(request)
async def handle(self, request: web.Request) -> web.Response: async def handle(self, request: web.Request) -> web.Response:
"""Main request handler"""
try: try:
# No need to extract subpath here; aiohttp handles it in the route
username = Utils.sanitize_input(request.match_info.get('username', ''), r'^[a-zA-Z0-9_-]+$') username = Utils.sanitize_input(request.match_info.get('username', ''), r'^[a-zA-Z0-9_-]+$')
if not username: if not username:
return web.Response(status=400, text="Error: Missing 'username' parameter.") return web.Response(status=400, text="Error: Missing 'username' parameter.")
user_agent = request.headers.get('User-Agent', '').lower() user_agent = request.headers.get('User-Agent', '').lower()
if any(browser in user_agent for browser in ['chrome', 'firefox', 'safari', 'edge', 'opera']): if any(browser in user_agent for browser in ['chrome', 'firefox', 'safari', 'edge', 'opera']):
return await self._handle_html(request, username) return await self._handle_html(request, username)
else: fragment = request.query.get('fragment', '')
fragment = request.query.get('fragment', '') if not user_agent.startswith('hiddifynext') and ('singbox' in user_agent or 'sing' in user_agent):
if not user_agent.startswith('hiddifynext') and ('singbox' in user_agent or 'sing' in user_agent): return await self._handle_singbox(username, fragment)
return await self._handle_singbox(username, fragment) return await self._handle_normalsub(request, username)
return await self._handle_normalsub(request, username)
except ValueError as e: except ValueError as e:
return web.Response(status=400, text=f"Error: {e}") return web.Response(status=400, text=f"Error: {e}")
except Exception as e: except Exception as e:
@ -512,36 +434,31 @@ class HysteriaServer:
return web.Response(status=500, text="Error: Internal server error") return web.Response(status=500, text="Error: Internal server error")
async def _handle_html(self, request: web.Request, username: str) -> web.Response: async def _handle_html(self, request: web.Request, username: str) -> web.Response:
"""Handles requests for HTML output"""
context = await self._get_template_context(username) context = await self._get_template_context(username)
rendered_html = self.template_renderer.render(context) return web.Response(text=self.template_renderer.render(context), content_type='text/html')
return web.Response(text=rendered_html, content_type='text/html')
async def _handle_singbox(self, username: str, fragment: str) -> web.Response: async def _handle_singbox(self, username: str, fragment: str) -> web.Response:
"""Handles requests for Sing-box configuration"""
config_v4 = self.singbox_generator.generate_config(username, '4', fragment) config_v4 = self.singbox_generator.generate_config(username, '4', fragment)
config_v6 = self.singbox_generator.generate_config(username, '6', fragment) config_v6 = self.singbox_generator.generate_config(username, '6', fragment)
if config_v4 is None and config_v6 is None: if config_v4 is None and config_v6 is None:
return web.Response(status=404, text=f"Error: No valid URIs found for user {username}.") return web.Response(status=404, text=f"Error: No valid URIs found for user {username}.")
combined_config = self.singbox_generator.combine_configs(username, config_v4, config_v6) combined_config = self.singbox_generator.combine_configs(username, config_v4, config_v6)
return web.Response(text=json.dumps(combined_config, indent=4, sort_keys=True), return web.Response(text=json.dumps(combined_config, indent=4, sort_keys=True), content_type='application/json')
content_type='application/json')
async def _handle_normalsub(self, request: web.Request, username: str) -> web.Response: async def _handle_normalsub(self, request: web.Request, username: str) -> web.Response:
"""Handles requests for normal subscription links"""
user_agent = request.headers.get('User-Agent', '').lower() user_agent = request.headers.get('User-Agent', '').lower()
subscription = self.subscription_manager.get_normal_subscription(username, user_agent) subscription = self.subscription_manager.get_normal_subscription(username, user_agent)
return web.Response(text=subscription, content_type='text/plain') return web.Response(text=subscription, content_type='text/plain')
async def _get_template_context(self, username: str) -> TemplateContext: async def _get_template_context(self, username: str) -> TemplateContext:
"""Generates the context for HTML template rendering, incorporating subpath"""
user_info = self.hysteria_cli.get_user_info(username) user_info = self.hysteria_cli.get_user_info(username)
ipv4_uri, ipv6_uri = self.hysteria_cli.get_uris(username) ipv4_uri, ipv6_uri = self.hysteria_cli.get_uris(username)
base_url = f"https://{self.config.domain}:{self.config.port}" base_url = f"https://{self.config.domain}:{self.config.port}"
sub_link = Utils.build_url(base_url, f"/{self.config.subpath}/sub/normal/{username}") if not Utils.is_valid_url(base_url):
raise ValueError(f"Invalid base URL constructed: {base_url}")
sub_link = f"{base_url}/{self.config.subpath}/sub/normal/{username}"
ipv4_qrcode = Utils.generate_qrcode_base64(ipv4_uri) ipv4_qrcode = Utils.generate_qrcode_base64(ipv4_uri)
ipv6_qrcode = Utils.generate_qrcode_base64(ipv6_uri) ipv6_qrcode = Utils.generate_qrcode_base64(ipv6_uri)
sublink_qrcode = Utils.generate_qrcode_base64(sub_link) sublink_qrcode = Utils.generate_qrcode_base64(sub_link)
@ -560,19 +477,21 @@ class HysteriaServer:
) )
async def handle_404(self, request: web.Request) -> web.Response: async def handle_404(self, request: web.Request) -> web.Response:
"""Handles 404 Not Found errors""" """Handles 404 Not Found errors *within* the subpath."""
print(f"404 Not Found: {request.path}") print(f"404 Not Found (within subpath): {request.path}")
return web.Response(status=404, text="Not Found within Subpath")
async def handle_generic_404(self, request: web.Request) -> web.Response:
"""Handles 404 Not Found errors *outside* the subpath."""
print(f"404 Not Found (generic): {request.path}")
return web.Response(status=404, text="Not Found") return web.Response(status=404, text="Not Found")
def run(self): def run(self):
"""Runs the web server"""
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(certfile=self.config.cert_file, keyfile=self.config.key_file) ssl_context.load_cert_chain(certfile=self.config.cert_file, keyfile=self.config.key_file)
ssl_context.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384') ssl_context.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384')
web.run_app(self.app, port=self.config.port, ssl_context=ssl_context) web.run_app(self.app, port=self.config.port, ssl_context=ssl_context)
if __name__ == '__main__': if __name__ == '__main__':
server = HysteriaServer() server = HysteriaServer()
server.run() server.run()