diff --git a/graphrag/language_model/providers/fnllm/utils.py b/graphrag/language_model/providers/fnllm/utils.py index b864cf43f8..7873395cce 100644 --- a/graphrag/language_model/providers/fnllm/utils.py +++ b/graphrag/language_model/providers/fnllm/utils.py @@ -6,6 +6,8 @@ from __future__ import annotations import asyncio +import contextlib +import os import threading from typing import TYPE_CHECKING, Any, TypeVar @@ -103,24 +105,41 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon # FNLLM does not support sync operations, so we workaround running in an available loop/thread. T = TypeVar("T") -_loop = asyncio.new_event_loop() - -_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True) +# Globals initialized per process +_loop: asyncio.AbstractEventLoop | None = None +_thr: threading.Thread | None = None +_pid: int | None = None def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T: - """ - Run a coroutine synchronously. + """Run a coroutine synchronously, handling process forks safely.""" + global _loop, _thr, _pid + + current_pid = os.getpid() + + # Check if we're in a new process (fork detected) or thread is dead + if _pid != current_pid or _thr is None or not _thr.is_alive(): + # Clean up resources from parent process (if any) + if _loop is not None: + with contextlib.suppress(Exception): + _loop.call_soon_threadsafe(_loop.stop) + + # Create new resources for this process + _loop = asyncio.new_event_loop() + _thr = threading.Thread( + target=_loop.run_forever, + name="Async Runner", + daemon=True + ) + _thr.start() + _pid = current_pid - Args: - coroutine: The coroutine to run. + # Ensure loop is available (should never be None at this point) + if _loop is None: + msg = "Event loop is not available" + raise RuntimeError(msg) - Returns - ------- - The result of the coroutine. - """ - if not _thr.is_alive(): - _thr.start() + # Schedule the coroutine and wait for its result future = asyncio.run_coroutine_threadsafe(coroutine, _loop) return future.result()