diff --git a/core/traffic.py b/core/traffic.py index 51d1d6d..4e7efb8 100644 --- a/core/traffic.py +++ b/core/traffic.py @@ -5,201 +5,227 @@ import os import sys import fcntl import datetime -from hysteria2_api import Hysteria2Client +import logging +from typing import Dict, Any, Optional, List, Tuple SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(SCRIPT_DIR, 'scripts')) +from hysteria2_api import Hysteria2Client from db.database import db CONFIG_FILE = '/etc/hysteria/config.json' API_BASE_URL = 'http://127.0.0.1:25413' LOCKFILE = "/tmp/hysteria_traffic.lock" -def acquire_lock(): - try: - lock_file = open(LOCKFILE, 'w') - fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB) - return lock_file - except IOError: - sys.exit(1) +STATUS_ONLINE = "Online" +STATUS_OFFLINE = "Offline" +STATUS_ON_HOLD = "On-hold" -def get_secret(): - try: - with open(CONFIG_FILE, 'r') as f: - config = json.load(f) - return config.get('trafficStats', {}).get('secret') - except (json.JSONDecodeError, FileNotFoundError): - return None +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -def format_bytes(bytes_val): +def format_bytes(bytes_val: int) -> str: + if not isinstance(bytes_val, (int, float)): return "0B" if bytes_val < 1024: return f"{bytes_val}B" - elif bytes_val < 1048576: return f"{bytes_val / 1024:.2f}KB" - elif bytes_val < 1073741824: return f"{bytes_val / 1048576:.2f}MB" - elif bytes_val < 1099511627776: return f"{bytes_val / 1073741824:.2f}GB" - else: return f"{bytes_val / 1099511627776:.2f}TB" + if bytes_val < 1024**2: return f"{bytes_val / 1024:.2f}KB" + if bytes_val < 1024**3: return f"{bytes_val / 1024**2:.2f}MB" + if bytes_val < 1024**4: return f"{bytes_val / 1024**3:.2f}GB" + return f"{bytes_val / 1024**4:.2f}TB" -def display_traffic_data(data, green, cyan, NC): +def display_traffic_data(data: Dict[str, Dict[str, Any]]): if not data: print("No traffic data to display.") return + green, cyan, nc = '\033[0;32m', '\033[0;36m', '\033[0m' + headers = ["User", "Upload (TX)", "Download (RX)", "Status"] + header_line = f"{headers[0]:<15} {headers[1]:<15} {headers[2]:<15} {headers[3]:<10}" + separator = "-" * len(header_line) + print("Traffic Data:") - print("-------------------------------------------------") - print(f"{'User':<15} {'Upload (TX)':<15} {'Download (RX)':<15} {'Status':<10}") - print("-------------------------------------------------") + print(separator) + print(header_line) + print(separator) for user, entry in data.items(): - upload_bytes = entry.get("upload_bytes", 0) - download_bytes = entry.get("download_bytes", 0) - status = entry.get("status", "On-hold") - formatted_tx = format_bytes(upload_bytes) - formatted_rx = format_bytes(download_bytes) - print(f"{user:<15} {green}{formatted_tx:<15}{NC} {cyan}{formatted_rx:<15}{NC} {status:<10}") - print("-------------------------------------------------") + formatted_tx = format_bytes(entry.get("upload_bytes", 0)) + formatted_rx = format_bytes(entry.get("download_bytes", 0)) + status = entry.get("status", STATUS_ON_HOLD) + print(f"{user:<15} {green}{formatted_tx:<15}{nc} {cyan}{formatted_rx:<15}{nc} {status:<10}") + print(separator) -def traffic_status(no_gui=False): - green, cyan, NC = '\033[0;32m', '\033[0;36m', '\033[0m' - - if db is None: - if not no_gui: print("Error: Database connection failed.") - return None +class TrafficManager: + def __init__(self, db_conn, api_base_url: str): + self.db = db_conn + if self.db is None: + raise ValueError("Database connection is not available.") + self.secret = self._get_secret() + if not self.secret: + raise ValueError(f"Secret not found or failed to read {CONFIG_FILE}.") + self.client = Hysteria2Client(base_url=api_base_url, secret=self.secret) + self.today_date = datetime.datetime.now().strftime("%Y-%m-%d") - secret = get_secret() - if not secret: - if not no_gui: print(f"Error: Secret not found or failed to read {CONFIG_FILE}.") - return None - - client = Hysteria2Client(base_url=API_BASE_URL, secret=secret) - try: - traffic_stats = client.get_traffic_stats(clear=True) - online_status = client.get_online_clients() - except Exception as e: - if not no_gui: print(f"Error communicating with Hysteria2 API: {e}") - return None - - try: - all_users = db.get_all_users() - initial_users_data = {user['_id']: user for user in all_users} - except Exception as e: - if not no_gui: print(f"Error fetching users from database: {e}") - return None - - today_date = datetime.datetime.now().strftime("%Y-%m-%d") - users_to_update = {} - - for username, user_data in initial_users_data.items(): - updates = {} - is_online_locally = username in online_status and online_status[username].is_online - online_count_db = user_data.get('online_count', 0) - - is_online_globally = is_online_locally or online_count_db > 0 - - if username in traffic_stats: - new_upload = user_data.get('upload_bytes', 0) + traffic_stats[username].upload_bytes - new_download = user_data.get('download_bytes', 0) + traffic_stats[username].download_bytes - if new_upload != user_data.get('upload_bytes'): updates['upload_bytes'] = new_upload - if new_download != user_data.get('download_bytes'): updates['download_bytes'] = new_download + @staticmethod + def _get_secret() -> Optional[str]: + try: + with open(CONFIG_FILE, 'r') as f: + config = json.load(f) + return config.get('trafficStats', {}).get('secret') + except (json.JSONDecodeError, FileNotFoundError): + logging.error(f"Could not read or parse secret from {CONFIG_FILE}") + return None - is_activated = "account_creation_date" in user_data - - if not is_activated: - current_traffic = traffic_stats.get(username) - has_activity = is_online_globally or (current_traffic and (current_traffic.upload_bytes > 0 or current_traffic.download_bytes > 0)) + def _get_online_connection_count(self, user_status_from_api: Any) -> int: + if not hasattr(user_status_from_api, 'is_online') or not user_status_from_api.is_online: + return 0 + if not hasattr(user_status_from_api, 'connections'): + return 1 - if has_activity: - updates["account_creation_date"] = today_date - updates["status"] = "Online" if is_online_globally else "Offline" - else: - if user_data.get("status") != "On-hold": - updates["status"] = "On-hold" - else: - new_status = "Online" if is_online_globally else "Offline" + connections_attr = user_status_from_api.connections + try: + return len(connections_attr) + except TypeError: + return int(connections_attr) if isinstance(connections_attr, int) else 1 + + def process_and_update_traffic(self) -> Dict[str, Any]: + try: + live_traffic = self.client.get_traffic_stats(clear=True) + live_status = self.client.get_online_clients() + db_users = {u['_id']: u for u in self.db.get_all_users()} + except Exception as e: + logging.error(f"Error communicating with Hysteria2 API or DB: {e}") + return {} + + users_to_update: List[Tuple[str, Dict[str, Any]]] = [] + for username, user_data in db_users.items(): + updates = self._calculate_user_updates(username, user_data, live_traffic, live_status) + if updates: + users_to_update.append((username, updates)) + + if users_to_update: + for username, update_data in users_to_update: + try: + self.db.update_user(username, update_data) + db_users[username].update(update_data) + except Exception as e: + logging.error(f"Failed to update user {username} in DB: {e}") + return db_users + + def _calculate_user_updates(self, username: str, user_data: Dict, live_traffic: Dict, live_status: Dict) -> Dict[str, Any]: + updates = {} + online_count = self._get_online_connection_count(live_status.get(username)) + is_online = online_count > 0 + if user_data.get('online_count') != online_count: + updates['online_count'] = online_count + + if username in live_traffic: + updates['upload_bytes'] = user_data.get('upload_bytes', 0) + live_traffic[username].upload_bytes + updates['download_bytes'] = user_data.get('download_bytes', 0) + live_traffic[username].download_bytes + + is_activated = "account_creation_date" in user_data + has_activity = is_online or (username in live_traffic and (live_traffic[username].upload_bytes > 0 or live_traffic[username].download_bytes > 0)) + + if not is_activated and has_activity: + updates["account_creation_date"] = self.today_date + updates["status"] = STATUS_ONLINE if is_online else STATUS_OFFLINE + elif is_activated: + new_status = STATUS_ONLINE if is_online else STATUS_OFFLINE if user_data.get("status") != new_status: updates["status"] = new_status - - if updates: - users_to_update[username] = updates + elif not is_activated and not has_activity and user_data.get("status") != STATUS_ON_HOLD: + updates["status"] = STATUS_ON_HOLD + + return updates - if users_to_update: + def kick_expired_users(self): try: - for username, update_data in users_to_update.items(): - db.update_user(username, update_data) + all_users = self.db.get_all_users() except Exception as e: - if not no_gui: print(f"Error updating database: {e}") - return None + logging.error(f"Failed to fetch users for expiration check: {e}") + return - if not no_gui: - # For display, merge updates into the initial data - for username, updates in users_to_update.items(): - initial_users_data[username].update(updates) - display_traffic_data(initial_users_data, green, cyan, NC) - - return initial_users_data + now = datetime.datetime.now() + users_to_kick, users_to_block = [], [] + + for user in all_users: + username = user.get('_id') + if not username or user.get('blocked') or not user.get('account_creation_date'): continue -def kick_api_call(usernames, secret): + try: + total_bytes = user.get('download_bytes', 0) + user.get('upload_bytes', 0) + expired_by_date = (user.get('expiration_days', 0) > 0 and now >= datetime.datetime.strptime(user['account_creation_date'], "%Y-%m-%d") + datetime.timedelta(days=user['expiration_days'])) + expired_by_traffic = (user.get('max_download_bytes', 0) > 0 and total_bytes >= user['max_download_bytes']) + + if expired_by_date or expired_by_traffic: + users_to_block.append(username) + if user.get("online_count", 0) > 0 or user.get("status") == STATUS_ONLINE: + users_to_kick.append(username) + except (ValueError, TypeError): continue + + if users_to_block: + for username in users_to_block: + self.db.update_user(username, {'blocked': True, 'status': STATUS_OFFLINE, 'online_count': 0}) + + if users_to_kick: + for i in range(0, len(users_to_kick), 50): + self._kick_api_call(users_to_kick[i:i+50]) + + def _kick_api_call(self, usernames: List[str]): + try: + self.client.kick_clients(usernames) + logging.info(f"Successfully kicked users: {', '.join(usernames)}") + except Exception as e: + logging.error(f"Failed to kick users via API: {e}") + + +def traffic_status(no_gui=False) -> Optional[Dict[str, Any]]: + """ + Processes traffic stats, updates the database, and optionally displays output. + This function is the primary entry point for external modules. + """ try: - client = Hysteria2Client(base_url=API_BASE_URL, secret=secret) - client.kick_clients(usernames) - except Exception as e: - print(f"Failed to kick users via API: {e}", file=sys.stderr) + manager = TrafficManager(db_conn=db, api_base_url=API_BASE_URL) + final_data = manager.process_and_update_traffic() + if not no_gui: + display_traffic_data(final_data) + return final_data + except ValueError as e: + logging.critical(str(e)) + return None def kick_expired_users(): - if db is None: - print("Error: Database connection failed.", file=sys.stderr) - return - - secret = get_secret() - if not secret: - print(f"Error: Secret not found or failed to read {CONFIG_FILE}.", file=sys.stderr) - return - - all_users = db.get_all_users() - users_to_kick, users_to_block = [], [] - - for user in all_users: - username = user.get('_id') - if not username or user.get('blocked', False) or not user.get('account_creation_date'): - continue - - total_bytes = user.get('download_bytes', 0) + user.get('upload_bytes', 0) - should_block = False - try: - if user.get('expiration_days', 0) > 0: - creation_date = datetime.datetime.strptime(user['account_creation_date'], "%Y-%m-%d") - if datetime.datetime.now() >= creation_date + datetime.timedelta(days=user['expiration_days']): - should_block = True - - if not should_block and user.get('max_download_bytes', 0) > 0 and total_bytes >= user['max_download_bytes']: - should_block = True - - if should_block: - users_to_kick.append(username) - users_to_block.append(username) - except (ValueError, TypeError): - continue - - if users_to_block: - for username in users_to_block: - db.update_user(username, {'blocked': True}) - - if users_to_kick: - for i in range(0, len(users_to_kick), 50): - kick_api_call(users_to_kick[i:i+50], secret) - -if __name__ == "__main__": - lock_file = acquire_lock() + """ + Finds and kicks users who have expired by date or traffic limit. + This function is the primary entry point for external modules. + """ try: - if len(sys.argv) > 1: - if sys.argv[1] == "kick": - kick_expired_users() - elif sys.argv[1] == "--no-gui": - traffic_status(no_gui=True) - kick_expired_users() - else: - print(f"Usage: python {sys.argv[0]} [kick|--no-gui]") + manager = TrafficManager(db_conn=db, api_base_url=API_BASE_URL) + manager.kick_expired_users() + except ValueError as e: + logging.critical(str(e)) + +def main(): + lock_file = None + try: + lock_file = open(LOCKFILE, 'w') + fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB) + + args = sys.argv[1:] + if "kick" in args: + kick_expired_users() + elif "--no-gui" in args: + traffic_status(no_gui=True) + kick_expired_users() else: traffic_status(no_gui=False) + + except IOError: + logging.warning("Another instance of the script is already running.") + sys.exit(1) finally: - fcntl.flock(lock_file, fcntl.LOCK_UN) - lock_file.close() \ No newline at end of file + if lock_file: + fcntl.flock(lock_file, fcntl.LOCK_UN) + lock_file.close() + +if __name__ == "__main__": + main() \ No newline at end of file