Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions scalene/scalene_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def is_native(package_name: str) -> bool:
"""
Returns whether a package is native or not.
"""
import platform

# Skip platform-specific modules on incompatible platforms
if (package_name == "scalene.scalene_apple_gpu" and
platform.system() != "Darwin"):
return False

result = False
try:
package = importlib.import_module(package_name)
Expand All @@ -41,6 +48,9 @@ def is_native(package_name: str) -> bool:
except TypeError:
# __file__ is there, but empty (os.path.dirname() returns TypeError). Let's call it native.
result = True
except OSError:
# Platform-specific import failed (e.g., missing system libraries)
result = False
return result

@staticmethod
Expand Down
116 changes: 116 additions & 0 deletions scalene/scalene_process_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Process management functionality for Scalene profiler."""

from __future__ import annotations

import os
import tempfile
import pathlib
from typing import Set

from scalene.scalene_preload import ScalenePreload


class ProcessManager:
"""Manages multiprocessing and child process functionality."""

def __init__(self, args):
self._args = args
self._child_pids: Set[int] = set()
self._is_child = -1
self._parent_pid = -1
self._pid = os.getpid()
self._python_alias_dir: pathlib.Path
self._orig_python = ""

# Set up process-specific configuration
self._setup_process_config()

def _setup_process_config(self) -> None:
"""Set up process-specific configuration."""
if self._args.pid:
# Child process
self._is_child = 1
self._parent_pid = self._args.pid
# Use the same directory as the parent
dirname = os.environ["PATH"].split(os.pathsep)[0]
self._python_alias_dir = pathlib.Path(dirname)
else:
# Parent process
self._is_child = 0
self._parent_pid = self._pid
# Create a temporary directory for Python aliases
self._python_alias_dir = pathlib.Path(
tempfile.mkdtemp(prefix="scalene")
)

def get_child_pids(self) -> Set[int]:
"""Get the set of child process IDs."""
return self._child_pids

def add_child_pid(self, pid: int) -> None:
"""Add a child process ID to tracking."""
self._child_pids.add(pid)

def remove_child_pid(self, pid: int) -> None:
"""Remove a child process ID from tracking."""
self._child_pids.discard(pid)

def is_child_process(self) -> bool:
"""Check if this is a child process."""
return self._is_child == 1

def get_parent_pid(self) -> int:
"""Get the parent process ID."""
return self._parent_pid

def get_current_pid(self) -> int:
"""Get the current process ID."""
return self._pid

def get_python_alias_dir(self) -> pathlib.Path:
"""Get the directory containing Python aliases."""
return self._python_alias_dir

def set_original_python_executable(self, executable: str) -> None:
"""Set the original Python executable path."""
self._orig_python = executable

def get_original_python_executable(self) -> str:
"""Get the original Python executable path."""
return self._orig_python

def before_fork(self) -> None:
"""Handle operations before forking a new process."""
# Disable signals before forking to avoid race conditions
pass

def after_fork_in_parent(self, child_pid: int) -> None:
"""Handle operations in parent process after forking."""
self.add_child_pid(child_pid)

def after_fork_in_child(self) -> None:
"""Handle operations in child process after forking."""
# Reset child process state
self._child_pids.clear()
self._is_child = 1
self._pid = os.getpid()

# Set up preloading for child process
if hasattr(self._args, 'preload') and self._args.preload:
ScalenePreload.setup_preload(self._args)

def setup_multiprocessing_redirection(self) -> None:
"""Set up redirection for multiprocessing calls."""
# This would contain the logic for redirecting Python calls
# to go through Scalene for child processes
pass

def cleanup_process_resources(self) -> None:
"""Clean up process-specific resources."""
# Clean up temporary directories and aliases
if not self.is_child_process() and self._python_alias_dir.exists():
try:
import shutil
shutil.rmtree(self._python_alias_dir)
except Exception:
pass # Best effort cleanup
81 changes: 66 additions & 15 deletions scalene/scalene_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@
from scalene.scalene_parseargs import ScaleneParseArgs, StopJupyterExecution
from scalene.scalene_sigqueue import ScaleneSigQueue
from scalene.scalene_accelerator import ScaleneAccelerator
from scalene.scalene_profiler_core import ProfilerCore
from scalene.scalene_signal_handler import SignalHandler
from scalene.scalene_trace_manager import TraceManager
from scalene.scalene_process_manager import ProcessManager
from scalene.scalene_profiler_lifecycle import ProfilerLifecycle

console = Console(style="white on blue")

Expand Down Expand Up @@ -234,9 +239,17 @@ def last_profiled_tuple() -> Tuple[Filename, LineNumber, ByteCodeIndex]:
__accelerator: Optional[ScaleneAccelerator] = (
None # initialized after parsing arguments in `main`
)

# New component classes for better separation of concerns
__profiler_core: Optional[ProfilerCore] = None
__signal_handler: Optional[SignalHandler] = None
__trace_manager: Optional[TraceManager] = None
__process_manager: Optional[ProcessManager] = None
__profiler_lifecycle: Optional[ProfilerLifecycle] = None

__invalidate_queue: List[Tuple[Filename, LineNumber]] = []
__invalidate_mutex: threading.Lock
__profiler_base: str
__profiler_base: str = "" # Will be set in __init__

@staticmethod
def get_original_lock() -> threading.Lock:
Expand Down Expand Up @@ -303,18 +316,27 @@ def get_all_signals_set() -> Set[int]:

Used by replacement_signal_fns.py to shim signals used by the client program.
"""
if Scalene.__signal_handler:
return Scalene.__signal_handler.get_all_signals_set()
return set(Scalene.__signal_manager.get_signals().get_all_signals())

@staticmethod
def get_lifecycle_signals() -> Tuple[signal.Signals, signal.Signals]:
if Scalene.__signal_handler:
return Scalene.__signal_handler.get_lifecycle_signals()
return Scalene.__signal_manager.get_signals().get_lifecycle_signals()

@staticmethod
def disable_lifecycle() -> None:
Scalene.__lifecycle_disabled = True
if Scalene.__signal_handler:
Scalene.__signal_handler.disable_lifecycle()
else:
Scalene.__lifecycle_disabled = True

@staticmethod
def get_lifecycle_disabled() -> bool:
if Scalene.__signal_handler:
return Scalene.__signal_handler.get_lifecycle_disabled()
return Scalene.__lifecycle_disabled

@staticmethod
Expand All @@ -323,6 +345,8 @@ def get_timer_signals() -> Tuple[int, signal.Signals]:

Used by replacement_signal_fns.py to shim timers used by the client program.
"""
if Scalene.__signal_handler:
return Scalene.__signal_handler.get_timer_signals()
return Scalene.__signal_manager.get_signals().get_timer_signals()

@staticmethod
Expand Down Expand Up @@ -365,17 +389,24 @@ def update_profiled() -> None:
@classmethod
def clear_metrics(cls) -> None:
"""Clear the various states for forked processes."""
cls.__stats.clear()
if cls.__profiler_lifecycle:
cls.__profiler_lifecycle.clear_metrics()
else:
cls.__stats.clear()
cls.child_pids.clear()

@classmethod
def add_child_pid(cls, pid: int) -> None:
"""Add this pid to the set of children. Used when forking."""
if cls.__process_manager:
cls.__process_manager.add_child_pid(pid)
cls.child_pids.add(pid)

@classmethod
def remove_child_pid(cls, pid: int) -> None:
"""Remove a child once we have joined with it (used by replacement_pjoin.py)."""
if cls.__process_manager:
cls.__process_manager.remove_child_pid(pid)
with contextlib.suppress(KeyError):
cls.child_pids.remove(pid)

Expand Down Expand Up @@ -444,12 +475,18 @@ def windows_timer_loop() -> None:
@staticmethod
def start_signal_queues() -> None:
"""Start the signal processing queues (i.e., their threads)."""
Scalene.__signal_manager.start_signal_queues()
if Scalene.__signal_handler:
Scalene.__signal_handler.start_signal_queues()
else:
Scalene.__signal_manager.start_signal_queues()

@staticmethod
def stop_signal_queues() -> None:
"""Stop the signal processing queues (i.e., their threads)."""
Scalene.__signal_manager.stop_signal_queues()
if Scalene.__signal_handler:
Scalene.__signal_handler.stop_signal_queues()
else:
Scalene.__signal_manager.stop_signal_queues()

@staticmethod
def term_signal_handler(
Expand Down Expand Up @@ -595,17 +632,31 @@ def __init__(
import scalene.replacement_poll_selector # noqa: F401

Scalene.__args = ScaleneArguments(**vars(arguments))
Scalene.__alloc_sigq = ScaleneSigQueue(
Scalene.alloc_sigqueue_processor
)
Scalene.__memcpy_sigq = ScaleneSigQueue(

# Initialize component classes for better separation of concerns
Scalene.__profiler_core = ProfilerCore(Scalene.__stats)
Scalene.__signal_handler = SignalHandler()
Scalene.__trace_manager = TraceManager(Scalene.__args)
Scalene.__process_manager = ProcessManager(Scalene.__args)
Scalene.__profiler_lifecycle = ProfilerLifecycle(Scalene.__stats, Scalene.__args)

# Synchronize files to profile between main class and trace manager
for filename in Scalene.__files_to_profile:
Scalene.__trace_manager.add_file_to_profile(filename)

# Set up signal queues through the signal handler
Scalene.__signal_handler.setup_signal_queues(
Scalene.alloc_sigqueue_processor,
Scalene.memcpy_sigqueue_processor
)

Scalene.__alloc_sigq = Scalene.__signal_handler.get_alloc_sigqueue()
Scalene.__memcpy_sigq = Scalene.__signal_handler.get_memcpy_sigqueue()
Scalene.__sigqueues = [
Scalene.__alloc_sigq,
Scalene.__memcpy_sigq,
]
# Add signal queues to the signal manager
# Add signal queues to the signal manager (keeping original for backward compatibility)
Scalene.__signal_manager.add_signal_queue(Scalene.__alloc_sigq)
Scalene.__signal_manager.add_signal_queue(Scalene.__memcpy_sigq)
Scalene.__invalidate_mutex = Scalene.get_original_lock()
Expand Down Expand Up @@ -843,8 +894,11 @@ def get_line_info(

@staticmethod
def profile_this_code(fname: Filename, lineno: LineNumber) -> bool:
# sourcery skip: inline-immediately-returned-variable
"""When using @profile, only profile files & lines that have been decorated."""
if Scalene.__trace_manager:
return Scalene.__trace_manager.profile_this_code(fname, lineno)

# Fallback to original logic if trace manager not initialized
if not Scalene.__files_to_profile:
return True
if fname not in Scalene.__files_to_profile:
Expand Down Expand Up @@ -1269,10 +1323,7 @@ def _passes_exclusion_rules(filename: Filename) -> bool:

# Check explicit exclude patterns
profile_exclude_list = Scalene.__args.profile_exclude.split(",")
if any(prof in filename for prof in profile_exclude_list if prof != ""):
return False

return True
return not any(prof in filename for prof in profile_exclude_list if prof != "")

@staticmethod
def _handle_jupyter_cell(filename: Filename) -> bool:
Expand Down
Loading
Loading