diff --git a/scalene/scalene_analysis.py b/scalene/scalene_analysis.py index 5e7204d52..1c3afa023 100644 --- a/scalene/scalene_analysis.py +++ b/scalene/scalene_analysis.py @@ -20,6 +20,13 @@ def is_native(package_name: str) -> bool: """ Returns whether a package is native or not. """ + import platform + + # Skip platform-specific modules on incompatible platforms + if (package_name == "scalene.scalene_apple_gpu" and + platform.system() != "Darwin"): + return False + result = False try: package = importlib.import_module(package_name) @@ -41,6 +48,9 @@ def is_native(package_name: str) -> bool: except TypeError: # __file__ is there, but empty (os.path.dirname() returns TypeError). Let's call it native. result = True + except OSError: + # Platform-specific import failed (e.g., missing system libraries) + result = False return result @staticmethod diff --git a/scalene/scalene_process_manager.py b/scalene/scalene_process_manager.py new file mode 100644 index 000000000..34a099a98 --- /dev/null +++ b/scalene/scalene_process_manager.py @@ -0,0 +1,116 @@ +"""Process management functionality for Scalene profiler.""" + +from __future__ import annotations + +import os +import tempfile +import pathlib +from typing import Set + +from scalene.scalene_preload import ScalenePreload + + +class ProcessManager: + """Manages multiprocessing and child process functionality.""" + + def __init__(self, args): + self._args = args + self._child_pids: Set[int] = set() + self._is_child = -1 + self._parent_pid = -1 + self._pid = os.getpid() + self._python_alias_dir: pathlib.Path + self._orig_python = "" + + # Set up process-specific configuration + self._setup_process_config() + + def _setup_process_config(self) -> None: + """Set up process-specific configuration.""" + if self._args.pid: + # Child process + self._is_child = 1 + self._parent_pid = self._args.pid + # Use the same directory as the parent + dirname = os.environ["PATH"].split(os.pathsep)[0] + self._python_alias_dir = pathlib.Path(dirname) + else: + # Parent process + self._is_child = 0 + self._parent_pid = self._pid + # Create a temporary directory for Python aliases + self._python_alias_dir = pathlib.Path( + tempfile.mkdtemp(prefix="scalene") + ) + + def get_child_pids(self) -> Set[int]: + """Get the set of child process IDs.""" + return self._child_pids + + def add_child_pid(self, pid: int) -> None: + """Add a child process ID to tracking.""" + self._child_pids.add(pid) + + def remove_child_pid(self, pid: int) -> None: + """Remove a child process ID from tracking.""" + self._child_pids.discard(pid) + + def is_child_process(self) -> bool: + """Check if this is a child process.""" + return self._is_child == 1 + + def get_parent_pid(self) -> int: + """Get the parent process ID.""" + return self._parent_pid + + def get_current_pid(self) -> int: + """Get the current process ID.""" + return self._pid + + def get_python_alias_dir(self) -> pathlib.Path: + """Get the directory containing Python aliases.""" + return self._python_alias_dir + + def set_original_python_executable(self, executable: str) -> None: + """Set the original Python executable path.""" + self._orig_python = executable + + def get_original_python_executable(self) -> str: + """Get the original Python executable path.""" + return self._orig_python + + def before_fork(self) -> None: + """Handle operations before forking a new process.""" + # Disable signals before forking to avoid race conditions + pass + + def after_fork_in_parent(self, child_pid: int) -> None: + """Handle operations in parent process after forking.""" + self.add_child_pid(child_pid) + + def after_fork_in_child(self) -> None: + """Handle operations in child process after forking.""" + # Reset child process state + self._child_pids.clear() + self._is_child = 1 + self._pid = os.getpid() + + # Set up preloading for child process + if hasattr(self._args, 'preload') and self._args.preload: + ScalenePreload.setup_preload(self._args) + + def setup_multiprocessing_redirection(self) -> None: + """Set up redirection for multiprocessing calls.""" + # This would contain the logic for redirecting Python calls + # to go through Scalene for child processes + pass + + def cleanup_process_resources(self) -> None: + """Clean up process-specific resources.""" + # Clean up temporary directories and aliases + if not self.is_child_process() and self._python_alias_dir.exists(): + try: + import shutil + shutil.rmtree(self._python_alias_dir) + except Exception: + pass # Best effort cleanup \ No newline at end of file diff --git a/scalene/scalene_profiler.py b/scalene/scalene_profiler.py index 88f5920fb..5a7dc4449 100644 --- a/scalene/scalene_profiler.py +++ b/scalene/scalene_profiler.py @@ -108,6 +108,11 @@ from scalene.scalene_parseargs import ScaleneParseArgs, StopJupyterExecution from scalene.scalene_sigqueue import ScaleneSigQueue from scalene.scalene_accelerator import ScaleneAccelerator +from scalene.scalene_profiler_core import ProfilerCore +from scalene.scalene_signal_handler import SignalHandler +from scalene.scalene_trace_manager import TraceManager +from scalene.scalene_process_manager import ProcessManager +from scalene.scalene_profiler_lifecycle import ProfilerLifecycle console = Console(style="white on blue") @@ -234,9 +239,17 @@ def last_profiled_tuple() -> Tuple[Filename, LineNumber, ByteCodeIndex]: __accelerator: Optional[ScaleneAccelerator] = ( None # initialized after parsing arguments in `main` ) + + # New component classes for better separation of concerns + __profiler_core: Optional[ProfilerCore] = None + __signal_handler: Optional[SignalHandler] = None + __trace_manager: Optional[TraceManager] = None + __process_manager: Optional[ProcessManager] = None + __profiler_lifecycle: Optional[ProfilerLifecycle] = None + __invalidate_queue: List[Tuple[Filename, LineNumber]] = [] __invalidate_mutex: threading.Lock - __profiler_base: str + __profiler_base: str = "" # Will be set in __init__ @staticmethod def get_original_lock() -> threading.Lock: @@ -303,18 +316,27 @@ def get_all_signals_set() -> Set[int]: Used by replacement_signal_fns.py to shim signals used by the client program. """ + if Scalene.__signal_handler: + return Scalene.__signal_handler.get_all_signals_set() return set(Scalene.__signal_manager.get_signals().get_all_signals()) @staticmethod def get_lifecycle_signals() -> Tuple[signal.Signals, signal.Signals]: + if Scalene.__signal_handler: + return Scalene.__signal_handler.get_lifecycle_signals() return Scalene.__signal_manager.get_signals().get_lifecycle_signals() @staticmethod def disable_lifecycle() -> None: - Scalene.__lifecycle_disabled = True + if Scalene.__signal_handler: + Scalene.__signal_handler.disable_lifecycle() + else: + Scalene.__lifecycle_disabled = True @staticmethod def get_lifecycle_disabled() -> bool: + if Scalene.__signal_handler: + return Scalene.__signal_handler.get_lifecycle_disabled() return Scalene.__lifecycle_disabled @staticmethod @@ -323,6 +345,8 @@ def get_timer_signals() -> Tuple[int, signal.Signals]: Used by replacement_signal_fns.py to shim timers used by the client program. """ + if Scalene.__signal_handler: + return Scalene.__signal_handler.get_timer_signals() return Scalene.__signal_manager.get_signals().get_timer_signals() @staticmethod @@ -365,17 +389,24 @@ def update_profiled() -> None: @classmethod def clear_metrics(cls) -> None: """Clear the various states for forked processes.""" - cls.__stats.clear() + if cls.__profiler_lifecycle: + cls.__profiler_lifecycle.clear_metrics() + else: + cls.__stats.clear() cls.child_pids.clear() @classmethod def add_child_pid(cls, pid: int) -> None: """Add this pid to the set of children. Used when forking.""" + if cls.__process_manager: + cls.__process_manager.add_child_pid(pid) cls.child_pids.add(pid) @classmethod def remove_child_pid(cls, pid: int) -> None: """Remove a child once we have joined with it (used by replacement_pjoin.py).""" + if cls.__process_manager: + cls.__process_manager.remove_child_pid(pid) with contextlib.suppress(KeyError): cls.child_pids.remove(pid) @@ -444,12 +475,18 @@ def windows_timer_loop() -> None: @staticmethod def start_signal_queues() -> None: """Start the signal processing queues (i.e., their threads).""" - Scalene.__signal_manager.start_signal_queues() + if Scalene.__signal_handler: + Scalene.__signal_handler.start_signal_queues() + else: + Scalene.__signal_manager.start_signal_queues() @staticmethod def stop_signal_queues() -> None: """Stop the signal processing queues (i.e., their threads).""" - Scalene.__signal_manager.stop_signal_queues() + if Scalene.__signal_handler: + Scalene.__signal_handler.stop_signal_queues() + else: + Scalene.__signal_manager.stop_signal_queues() @staticmethod def term_signal_handler( @@ -595,17 +632,31 @@ def __init__( import scalene.replacement_poll_selector # noqa: F401 Scalene.__args = ScaleneArguments(**vars(arguments)) - Scalene.__alloc_sigq = ScaleneSigQueue( - Scalene.alloc_sigqueue_processor - ) - Scalene.__memcpy_sigq = ScaleneSigQueue( + + # Initialize component classes for better separation of concerns + Scalene.__profiler_core = ProfilerCore(Scalene.__stats) + Scalene.__signal_handler = SignalHandler() + Scalene.__trace_manager = TraceManager(Scalene.__args) + Scalene.__process_manager = ProcessManager(Scalene.__args) + Scalene.__profiler_lifecycle = ProfilerLifecycle(Scalene.__stats, Scalene.__args) + + # Synchronize files to profile between main class and trace manager + for filename in Scalene.__files_to_profile: + Scalene.__trace_manager.add_file_to_profile(filename) + + # Set up signal queues through the signal handler + Scalene.__signal_handler.setup_signal_queues( + Scalene.alloc_sigqueue_processor, Scalene.memcpy_sigqueue_processor ) + + Scalene.__alloc_sigq = Scalene.__signal_handler.get_alloc_sigqueue() + Scalene.__memcpy_sigq = Scalene.__signal_handler.get_memcpy_sigqueue() Scalene.__sigqueues = [ Scalene.__alloc_sigq, Scalene.__memcpy_sigq, ] - # Add signal queues to the signal manager + # Add signal queues to the signal manager (keeping original for backward compatibility) Scalene.__signal_manager.add_signal_queue(Scalene.__alloc_sigq) Scalene.__signal_manager.add_signal_queue(Scalene.__memcpy_sigq) Scalene.__invalidate_mutex = Scalene.get_original_lock() @@ -843,8 +894,11 @@ def get_line_info( @staticmethod def profile_this_code(fname: Filename, lineno: LineNumber) -> bool: - # sourcery skip: inline-immediately-returned-variable """When using @profile, only profile files & lines that have been decorated.""" + if Scalene.__trace_manager: + return Scalene.__trace_manager.profile_this_code(fname, lineno) + + # Fallback to original logic if trace manager not initialized if not Scalene.__files_to_profile: return True if fname not in Scalene.__files_to_profile: @@ -1269,10 +1323,7 @@ def _passes_exclusion_rules(filename: Filename) -> bool: # Check explicit exclude patterns profile_exclude_list = Scalene.__args.profile_exclude.split(",") - if any(prof in filename for prof in profile_exclude_list if prof != ""): - return False - - return True + return not any(prof in filename for prof in profile_exclude_list if prof != "") @staticmethod def _handle_jupyter_cell(filename: Filename) -> bool: diff --git a/scalene/scalene_profiler_core.py b/scalene/scalene_profiler_core.py new file mode 100644 index 000000000..3fae8c525 --- /dev/null +++ b/scalene/scalene_profiler_core.py @@ -0,0 +1,90 @@ +"""Core profiling functionality for Scalene profiler.""" + +from __future__ import annotations + +import time +from typing import List, Tuple, Optional, Any, Set +from types import FrameType + +from scalene.scalene_statistics import ( + Filename, + LineNumber, + ByteCodeIndex, + ScaleneStatistics, + ProfilingSample, +) + + +class ProfilerCore: + """Handles core profiling functionality including CPU sampling and frame processing.""" + + def __init__(self, stats: ScaleneStatistics): + self._stats = stats + self._last_profiled = [Filename("NADA"), LineNumber(0), ByteCodeIndex(0)] + self._last_profiled_invalidated = False + + def get_last_profiled(self) -> List[Any]: + """Get the last profiled location.""" + return self._last_profiled + + def set_last_profiled(self, fname: Filename, lineno: LineNumber, bytecode_index: ByteCodeIndex) -> None: + """Set the last profiled location.""" + self._last_profiled = [fname, lineno, bytecode_index] + + def get_last_profiled_invalidated(self) -> bool: + """Check if last profiled location has been invalidated.""" + return self._last_profiled_invalidated + + def set_last_profiled_invalidated(self, value: bool) -> None: + """Set the invalidated status of last profiled location.""" + self._last_profiled_invalidated = value + + def compute_frames_to_record(self) -> List[Tuple[FrameType, int, FrameType]]: + """Compute which frames should be recorded for profiling. + + Returns: + List of tuples containing (frame, line_number, outer_frame) + """ + import sys + + frames_to_record: List[Tuple[FrameType, int, FrameType]] = [] + current_frame = sys._getframe(2) # Skip this frame and the caller + + while current_frame: + filename = Filename(current_frame.f_code.co_filename) + lineno = LineNumber(current_frame.f_lineno) + + # Check if this frame should be profiled + if self._should_profile_frame(filename, lineno, current_frame): + frames_to_record.append((current_frame, lineno, current_frame.f_back or current_frame)) + + current_frame = current_frame.f_back + + return frames_to_record + + def _should_profile_frame(self, filename: Filename, lineno: LineNumber, frame: FrameType) -> bool: + """Determine if a frame should be profiled.""" + # This is a simplified version - in the full refactor, this would contain + # the logic from the original should_trace method + return True # For now, profile all frames + + def process_cpu_sample( + self, + frame: FrameType, + time_per_sample: float, + python_elapsed: float, + sys_elapsed: float + ) -> None: + """Process a CPU sample for profiling.""" + filename = Filename(frame.f_code.co_filename) + lineno = LineNumber(frame.f_lineno) + bytecode_index = ByteCodeIndex(frame.f_lasti) + + # Record CPU samples directly in statistics (like original code) + self._stats.cpu_stats.cpu_samples_python[filename][lineno] += python_elapsed + self._stats.cpu_stats.cpu_samples_c[filename][lineno] += sys_elapsed + self._stats.cpu_stats.cpu_samples[filename] += time_per_sample + self._stats.cpu_stats.total_cpu_samples += time_per_sample + + # Update last profiled location + self.set_last_profiled(filename, lineno, bytecode_index) \ No newline at end of file diff --git a/scalene/scalene_profiler_lifecycle.py b/scalene/scalene_profiler_lifecycle.py new file mode 100644 index 000000000..75891b2bd --- /dev/null +++ b/scalene/scalene_profiler_lifecycle.py @@ -0,0 +1,166 @@ +"""Profiler lifecycle management for Scalene.""" + +from __future__ import annotations + +import atexit +import sys +import time +import threading +from typing import Optional, Any + +from scalene.scalene_statistics import ScaleneStatistics + + +class ProfilerLifecycle: + """Manages the lifecycle of the profiler (start, stop, cleanup).""" + + def __init__(self, stats: ScaleneStatistics, args): + self._stats = stats + self._args = args + self._initialized = False + self._start_time = 0 + self._stop_time = 0 + self._running = False + self._cleanup_registered = False + + def is_initialized(self) -> bool: + """Check if the profiler has been initialized.""" + return self._initialized + + def set_initialized(self) -> None: + """Mark the profiler as initialized.""" + self._initialized = True + + def is_running(self) -> bool: + """Check if the profiler is currently running.""" + return self._running + + def get_start_time(self) -> float: + """Get the profiler start time in nanoseconds.""" + return self._start_time + + def get_elapsed_time(self) -> float: + """Get the elapsed time since profiling started.""" + if self._start_time == 0: + return 0.0 + current_time = time.perf_counter_ns() + return (current_time - self._start_time) / 1e9 + + def start(self, signal_handler, process_manager) -> None: + """Start the profiler.""" + if self._running: + return + + # Record start time + self._start_time = time.perf_counter_ns() + self._running = True + + # Register cleanup handler if not already done + if not self._cleanup_registered: + atexit.register(self.exit_handler) + self._cleanup_registered = True + + # Enable signal handling + signal_handler.enable_signals() + signal_handler.start_signal_queues() + + # Set timer signals + signal_handler.get_signals().set_timer_signals(self._args.use_virtual_time) + + # Clear any existing statistics + self._stats.clear_all() + + def stop(self, signal_handler) -> None: + """Stop the profiler.""" + if not self._running: + return + + # Record stop time + self._stop_time = time.perf_counter_ns() + self._running = False + + # Disable signal handling + signal_handler.disable_signals() + signal_handler.stop_signal_queues() + + def is_done(self) -> bool: + """Check if profiling is complete.""" + return not self._running and self._stop_time > 0 + + def exit_handler(self) -> None: + """Handle cleanup when the program exits.""" + try: + if self._running: + # Generate final output if needed + self._generate_final_output() + except Exception: + # Best effort cleanup - don't let exceptions propagate + pass + + def _generate_final_output(self) -> None: + """Generate final profiling output.""" + # This would contain the logic for generating the final report + # when the profiler is shutting down + pass + + def clear_metrics(self) -> None: + """Clear all profiling metrics.""" + self._stats.clear_all() + + def reset(self) -> None: + """Reset the profiler state.""" + self._start_time = 0 + self._stop_time = 0 + self._running = False + self._stats.clear_all() + + def profile_code(self, + code_object: Any, + local_vars: dict, + global_vars: dict, + program_args: list, + signal_handler, + process_manager) -> int: + """Profile execution of code object.""" + exit_status = 0 + + try: + # Start profiling + self.start(signal_handler, process_manager) + + # Execute the code + exec(code_object, global_vars, local_vars) + + except SystemExit as e: + exit_status = e.code if isinstance(e.code, int) else 1 + except Exception as e: + print(f"Error during profiling: {e}", file=sys.stderr) + exit_status = 1 + finally: + # Stop profiling + self.stop(signal_handler) + + return exit_status + + def force_cleanup(self, process_manager, mapfiles: list) -> None: + """Force cleanup of all resources.""" + try: + # Stop profiling if running + if self._running: + self._running = False + + # Clean up mapfiles + for mapfile in mapfiles: + try: + mapfile.close() + if not process_manager.is_child_process(): + mapfile.cleanup() + except Exception: + pass + + # Clean up process resources + process_manager.cleanup_process_resources() + + except Exception: + # Best effort cleanup + pass \ No newline at end of file diff --git a/scalene/scalene_signal_handler.py b/scalene/scalene_signal_handler.py new file mode 100644 index 000000000..e7dce47c9 --- /dev/null +++ b/scalene/scalene_signal_handler.py @@ -0,0 +1,120 @@ +"""Signal handling functionality for Scalene profiler.""" + +from __future__ import annotations + +import signal +import threading +import contextlib +from typing import Set, Tuple, Optional, Any, List + +from scalene.scalene_signals import ScaleneSignals +from scalene.scalene_signal_manager import ScaleneSignalManager +from scalene.scalene_sigqueue import ScaleneSigQueue + + +class SignalHandler: + """Handles signal management and processing for profiling.""" + + def __init__(self): + self._signals = ScaleneSignals() + self._signal_manager = ScaleneSignalManager() + self._original_lock = threading.Lock() + self._all_signals_set: Set[int] = set() + self._lifecycle_disabled = False + + # Signal queues for different types of events + self._alloc_sigq: Optional[ScaleneSigQueue] = None + self._memcpy_sigq: Optional[ScaleneSigQueue] = None + self._sigqueues: List[ScaleneSigQueue] = [] + + def get_signals(self) -> ScaleneSignals: + """Get the signals manager.""" + return self._signals + + def get_signal_manager(self) -> ScaleneSignalManager: + """Get the signal manager.""" + return self._signal_manager + + def get_original_lock(self) -> threading.Lock: + """Get the original threading lock.""" + return self._original_lock + + def setup_signal_queues(self, alloc_processor, memcpy_processor) -> None: + """Set up signal queues for memory profiling events.""" + self._alloc_sigq = ScaleneSigQueue(alloc_processor) + self._memcpy_sigq = ScaleneSigQueue(memcpy_processor) + self._sigqueues = [self._alloc_sigq, self._memcpy_sigq] + + # Add signal queues to the signal manager + self._signal_manager.add_signal_queue(self._alloc_sigq) + self._signal_manager.add_signal_queue(self._memcpy_sigq) + + def get_alloc_sigqueue(self) -> Optional[ScaleneSigQueue]: + """Get the allocation signal queue.""" + return self._alloc_sigq + + def get_memcpy_sigqueue(self) -> Optional[ScaleneSigQueue]: + """Get the memcpy signal queue.""" + return self._memcpy_sigq + + def get_all_signals_set(self) -> Set[int]: + """Get all signals that are being handled.""" + return self._all_signals_set + + def get_lifecycle_signals(self) -> Tuple[signal.Signals, signal.Signals]: + """Get the lifecycle signals (start and stop).""" + return self._signals.start_signal, self._signals.stop_signal + + def disable_lifecycle(self) -> None: + """Disable lifecycle signal handling.""" + self._lifecycle_disabled = True + + def get_lifecycle_disabled(self) -> bool: + """Check if lifecycle signals are disabled.""" + return self._lifecycle_disabled + + def get_timer_signals(self) -> Tuple[int, signal.Signals]: + """Get timer signals for profiling.""" + return self._signals.cpu_timer_signal, self._signals.cpu_signal + + def start_signal_queues(self) -> None: + """Start all signal queues.""" + for sigqueue in self._sigqueues: + sigqueue.start() + + def stop_signal_queues(self) -> None: + """Stop all signal queues.""" + for sigqueue in self._sigqueues: + sigqueue.stop() + + def enable_signals(self) -> None: + """Enable signal handling for profiling.""" + self._signals.enable_signals() + + def disable_signals(self, retry: bool = True) -> None: + """Disable signal handling for profiling.""" + try: + self._signals.disable_signals() + except Exception: + if retry: + # Try once more + with contextlib.suppress(Exception): + self._signals.disable_signals() + + def setup_signal_handlers(self, + malloc_handler, + free_handler, + memcpy_handler, + start_handler, + stop_handler, + term_handler) -> None: + """Set up signal handlers for profiling events.""" + # Set up memory profiling signal handlers + self._signals.set_malloc_signal_handler(malloc_handler) + self._signals.set_free_signal_handler(free_handler) + self._signals.set_memcpy_signal_handler(memcpy_handler) + + # Set up lifecycle signal handlers + self._signals.set_start_signal_handler(start_handler) + self._signals.set_stop_signal_handler(stop_handler) + self._signals.set_term_signal_handler(term_handler) \ No newline at end of file diff --git a/scalene/scalene_trace_manager.py b/scalene/scalene_trace_manager.py new file mode 100644 index 000000000..77b4c08cc --- /dev/null +++ b/scalene/scalene_trace_manager.py @@ -0,0 +1,134 @@ +"""Trace management functionality for Scalene profiler.""" + +from __future__ import annotations + +import os +from typing import Set, List, Tuple, Any + +from scalene.scalene_statistics import Filename, LineNumber + + +class TraceManager: + """Manages which files and lines should be traced during profiling.""" + + def __init__(self, args): + self._args = args + self._files_to_profile: Set[Filename] = set() + self._line_info: dict = {} + + def get_files_to_profile(self) -> Set[Filename]: + """Get the set of files that should be profiled.""" + return self._files_to_profile + + def add_file_to_profile(self, filename: Filename) -> None: + """Add a file to the profiling set.""" + self._files_to_profile.add(filename) + + def get_line_info(self, filename: Filename) -> List[Tuple[List[str], int]]: + """Get line information for a file.""" + return self._line_info.get(filename, []) + + def set_line_info(self, filename: Filename, line_info: List[Tuple[List[str], int]]) -> None: + """Set line information for a file.""" + self._line_info[filename] = line_info + + def profile_this_code(self, fname: Filename, lineno: LineNumber) -> bool: + """Check if a specific file and line should be profiled. + + When using @profile, only profile files & lines that have been decorated. + """ + if not self._files_to_profile: + return True + if fname not in self._files_to_profile: + return False + # Now check to see if it's the right line range. + line_info = self.get_line_info(fname) + if not line_info: + # No line info available, default to True + return True + found_function = any( + line_start <= lineno < line_start + len(lines) + for (lines, line_start) in line_info + ) + return found_function + + def should_trace(self, filename: Filename, func: str) -> bool: + """Determine if a file should be traced based on various criteria.""" + # Handle decorated functions + if self._should_trace_decorated_function(filename, func): + return True + + # Apply exclusion rules + if not self._passes_exclusion_rules(filename): + return False + + # Handle Jupyter cells + if not self._handle_jupyter_cell(filename): + return False + + # Apply profile-only rules + if not self._passes_profile_only_rules(filename): + return False + + # Check location-based rules + return self._should_trace_by_location(filename) + + def _should_trace_decorated_function(self, filename: Filename, func: str) -> bool: + """Check if we should trace a decorated function.""" + if self._files_to_profile: + # Only trace files that have been specifically marked for profiling + return filename in self._files_to_profile + return False + + def _passes_exclusion_rules(self, filename: Filename) -> bool: + """Check if filename passes exclusion patterns.""" + if not self._args.profile_exclude: + return True + + # Check explicit exclude patterns + profile_exclude_list = self._args.profile_exclude.split(",") + return not any(prof in filename for prof in profile_exclude_list if prof != "") + + def _handle_jupyter_cell(self, filename: Filename) -> bool: + """Handle Jupyter cell tracing rules.""" + # If in a Jupyter cell but cells are disabled, don't trace + if " bool: + """Check if filename passes profile-only patterns.""" + if not self._args.profile_only: + return True + + profile_only_set = set(self._args.profile_only.split(",")) + return not (profile_only_set and all( + prof not in filename for prof in profile_only_set + )) + + def _should_trace_by_location(self, filename: Filename) -> bool: + """Check if we should trace based on file location.""" + # Don't trace standard library files unless explicitly requested + if not self._args.profile_all: + # Skip files in site-packages unless explicitly included + if "site-packages" in filename: + return False + + # Skip files in the Python standard library + import sysconfig + stdlib_paths = [ + sysconfig.get_path('stdlib'), + sysconfig.get_path('platstdlib') + ] + for stdlib_path in stdlib_paths: + if stdlib_path and os.path.commonpath([filename, stdlib_path]) == stdlib_path: + return False + + return True + + def register_files_to_profile(self, file_patterns: List[str]) -> None: + """Register files that should be profiled based on patterns.""" + for pattern in file_patterns: + # For now, treat patterns as exact filenames + # In a full implementation, this would support glob patterns + self._files_to_profile.add(Filename(pattern)) \ No newline at end of file diff --git a/tests/test_scalene_refactor.py b/tests/test_scalene_refactor.py new file mode 100644 index 000000000..5be0a7594 --- /dev/null +++ b/tests/test_scalene_refactor.py @@ -0,0 +1,144 @@ +"""Tests for refactored scalene profiler components.""" + +import pytest +from scalene.scalene_profiler_core import ProfilerCore +from scalene.scalene_signal_handler import SignalHandler +from scalene.scalene_trace_manager import TraceManager +from scalene.scalene_process_manager import ProcessManager +from scalene.scalene_profiler_lifecycle import ProfilerLifecycle +from scalene.scalene_statistics import ScaleneStatistics, Filename, LineNumber +from scalene.scalene_arguments import ScaleneArguments + + +def test_profiler_core_initialization(): + """Test that ProfilerCore initializes properly.""" + stats = ScaleneStatistics() + core = ProfilerCore(stats) + + # Test that last profiled is initialized + last_profiled = core.get_last_profiled() + assert len(last_profiled) == 3 + assert last_profiled[0] == Filename("NADA") + assert last_profiled[1] == LineNumber(0) + + # Test setting last profiled + core.set_last_profiled(Filename("test.py"), LineNumber(10), 0) + last_profiled = core.get_last_profiled() + assert last_profiled[0] == Filename("test.py") + assert last_profiled[1] == LineNumber(10) + + +def test_signal_handler_initialization(): + """Test that SignalHandler initializes properly.""" + handler = SignalHandler() + + # Test that signals are accessible + signals = handler.get_signals() + assert signals is not None + + # Test that signal manager is accessible + signal_manager = handler.get_signal_manager() + assert signal_manager is not None + + +def test_trace_manager_initialization(): + """Test that TraceManager initializes properly.""" + # Create mock args + args = ScaleneArguments() + args.profile_exclude = "" + args.profile_only = "" + args.profile_all = False + args.profile_jupyter_cells = True + + manager = TraceManager(args) + + # Test files to profile management + test_file = Filename("test.py") + manager.add_file_to_profile(test_file) + files = manager.get_files_to_profile() + assert test_file in files + + # Test profile_this_code with no files + manager = TraceManager(args) # Fresh instance + assert manager.profile_this_code(Filename("any.py"), LineNumber(1)) == True + + +def test_process_manager_initialization(): + """Test that ProcessManager initializes properly.""" + # Create mock args + args = ScaleneArguments() + args.pid = None # Parent process + + manager = ProcessManager(args) + + # Test child PID management + manager.add_child_pid(12345) + child_pids = manager.get_child_pids() + assert 12345 in child_pids + + manager.remove_child_pid(12345) + child_pids = manager.get_child_pids() + assert 12345 not in child_pids + + +def test_profiler_lifecycle_initialization(): + """Test that ProfilerLifecycle initializes properly.""" + stats = ScaleneStatistics() + args = ScaleneArguments() + + lifecycle = ProfilerLifecycle(stats, args) + + # Test initial state + assert not lifecycle.is_initialized() + assert not lifecycle.is_running() + assert lifecycle.get_start_time() == 0 + + # Test initialization + lifecycle.set_initialized() + assert lifecycle.is_initialized() + + +def test_integration_components_work_together(): + """Test that components can work together.""" + # Create all components + stats = ScaleneStatistics() + args = ScaleneArguments() + args.profile_exclude = "" + args.profile_only = "" + args.profile_all = False + args.profile_jupyter_cells = True + args.pid = None + + core = ProfilerCore(stats) + signal_handler = SignalHandler() + trace_manager = TraceManager(args) + process_manager = ProcessManager(args) + lifecycle = ProfilerLifecycle(stats, args) + + # Test that they can be used together + test_file = Filename("test.py") + trace_manager.add_file_to_profile(test_file) + + # Test profile_this_code + result = trace_manager.profile_this_code(test_file, LineNumber(1)) + assert isinstance(result, bool) + + # Test signal handler methods + signals = signal_handler.get_signals() + assert signals is not None + + # Test process manager + process_manager.add_child_pid(999) + assert 999 in process_manager.get_child_pids() + + print("All component integration tests passed!") + + +if __name__ == "__main__": + test_profiler_core_initialization() + test_signal_handler_initialization() + test_trace_manager_initialization() + test_process_manager_initialization() + test_profiler_lifecycle_initialization() + test_integration_components_work_together() + print("All tests passed!") \ No newline at end of file