diff --git a/network_latency/websocket_latency_test.py b/network_latency/websocket_latency_test.py index f59fcb9..b3e77e7 100644 --- a/network_latency/websocket_latency_test.py +++ b/network_latency/websocket_latency_test.py @@ -81,7 +81,7 @@ def resolve_dns(host: str, port: int) -> dict: # ============================== # WEBSOCKET HANDSHAKE # ============================== -def measure_ws_handshake(host: str, port: int, ip_address: str, ws_path: str, api_key: str) -> dict: +def measure_ws_handshake(host: str, port: int, ip_address: str, ws_path: str, api_key: str, use_tls: bool = True) -> dict: """Measure WebSocket handshake with detailed phase breakdown. Uses pre-resolved IP.""" result = {"ip_address": ip_address} total_start = time.perf_counter() @@ -99,22 +99,26 @@ def measure_ws_handshake(host: str, port: int, ip_address: str, ws_path: str, ap "error_message": str(e), } - # TLS handshake - try: - tls_start = time.perf_counter() - context = ssl.create_default_context() - ssock = context.wrap_socket(sock, server_hostname=host) - result["tls_ms"] = (time.perf_counter() - tls_start) * 1000 - except Exception as e: - sock.close() - return { - "status": "failed", - "phase": "tls", - "tcp_ms": result.get("tcp_ms"), - "error_type": type(e).__name__, - "error_message": str(e), - } - + # TLS handshake (skip for plain ws://) + if use_tls: + try: + tls_start = time.perf_counter() + context = ssl.create_default_context() + ssock = context.wrap_socket(sock, server_hostname=host) + result["tls_ms"] = (time.perf_counter() - tls_start) * 1000 + except Exception as e: + sock.close() + return { + "status": "failed", + "phase": "tls", + "tcp_ms": result.get("tcp_ms"), + "error_type": type(e).__name__, + "error_message": str(e), + } + else: + ssock = sock + result["tls_ms"] = None + # WebSocket upgrade try: ws_key = base64.b64encode(os.urandom(16)).decode("utf-8") @@ -261,7 +265,7 @@ def run_test(args): output_file = output_dir / f"ws_latency_{host}.jsonl" print(f"WebSocket Latency Test") - print(f" Target: wss://{host}:{port}{ws_path}") + print(f" Target: {args.scheme}://{host}:{port}{ws_path}") print(f" Output: {output_file}") if args.duration: print(f" Duration: {args.duration}s") @@ -303,7 +307,8 @@ def run_test(args): else: # Measure WebSocket handshake using resolved IP ws_result = measure_ws_handshake( - host, port, dns_result["ip_address"], ws_path, api_key + host, port, dns_result["ip_address"], ws_path, api_key, + use_tls=(args.scheme == "wss"), ) # Run traceroute if enabled @@ -400,6 +405,7 @@ def main(): parser.error(f"URL must start with wss:// or ws://") parsed = urlparse(url_input) + args.scheme = parsed.scheme args.host = parsed.hostname args.port = parsed.port or (443 if parsed.scheme == "wss" else 80) args.ws_path = parsed.path