diff --git a/env.d/development/summary.dist b/env.d/development/summary.dist index e5777226b..5980144d8 100644 --- a/env.d/development/summary.dist +++ b/env.d/development/summary.dist @@ -21,3 +21,8 @@ WEBHOOK_URL="https://configure-your-url.com" POSTHOG_API_KEY="your-posthog-key" POSTHOG_ENABLED="False" + +LANGFUSE_SECRET_KEY="your-secret-key" +LANGFUSE_PUBLIC_KEY="your-public-key" +LANGFUSE_HOST="https://cloud.langfuse.com" +LANFUSE_ENABLED="False" diff --git a/src/summary/pyproject.toml b/src/summary/pyproject.toml index 4d1b934b1..cc9eaa2a3 100644 --- a/src/summary/pyproject.toml +++ b/src/summary/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "posthog==6.0.3", "requests==2.32.4", "sentry-sdk[fastapi, celery]==2.30.0", + "langfuse==3.4.0", ] [project.optional-dependencies] diff --git a/src/summary/summary/core/celery_worker.py b/src/summary/summary/core/celery_worker.py index 506d4c40c..763b00585 100644 --- a/src/summary/summary/core/celery_worker.py +++ b/src/summary/summary/core/celery_worker.py @@ -9,18 +9,19 @@ from pathlib import Path from typing import Optional -import openai import sentry_sdk from celery import Celery, signals from celery.utils.log import get_task_logger from minio import Minio from mutagen import File +from openai import OpenAI from requests import Session, exceptions from requests.adapters import HTTPAdapter from urllib3.util import Retry from summary.core.analytics import MetadataManager, get_analytics from summary.core.config import get_settings +from summary.core.observability import Observability from summary.core.prompt import ( PROMPT_SYSTEM_CLEANING, PROMPT_SYSTEM_NEXT_STEP, @@ -46,6 +47,14 @@ celery.config_from_object("summary.core.celery_config") +obs = Observability( + is_enabled=settings.langfuse_is_enabled, + langfuse_host=settings.langfuse_host, + langfuse_public_key=settings.langfuse_public_key, + langfuse_secret_key=settings.langfuse_secret_key, +) +logger.info("Observability enabled: %s", obs.is_enabled) + if settings.sentry_dsn and settings.sentry_is_enabled: @signals.celeryd_init.connect @@ -111,9 +120,10 @@ class LLMService: def __init__(self): """Init the LLMService once.""" - self._client = openai.OpenAI( + self._client = OpenAI( base_url=settings.llm_base_url, api_key=settings.llm_api_key ) + self.gen_ctx = obs.generation def call(self, system_prompt: str, user_prompt: str): """Call the LLM service. @@ -134,6 +144,14 @@ def call(self, system_prompt: str, user_prompt: str): logger.error("LLM call failed: %s", e) raise LLMException("LLM call failed.") from e + def call_llm_gen(self, name, system, user): + """Call the LLM service within a generation context.""" + with self.gen_ctx( + name=name, + model=settings.llm_model, + ): + return self.call(system, user) + def format_segments(transcription_data): """Format transcription segments from WhisperX into a readable conversation format. @@ -201,7 +219,8 @@ def task_failure_handler(task_id, exception=None, **kwargs): autoretry_for=[exceptions.HTTPError], max_retries=settings.celery_max_retries, ) -def process_audio_transcribe_summarize_v2( +@obs.observe(name="process-audio", capture_input=True, capture_output=False) +def process_audio_transcribe_summarize_v2( # noqa: PLR0915 self, filename: str, email: str, @@ -222,6 +241,21 @@ def process_audio_transcribe_summarize_v2( logger.info("Notification received") logger.debug("filename: %s", filename) + try: + obs.update_current_trace( + user_id=sub or email, + tags=["celery", "transcription", "whisperx"], + metadata={ + "filename": filename, + "room": room, + "recording_date": recording_date, + "recording_time": recording_time, + }, + ) + except Exception as e: + logger.warning("Langfuse update trace failed: %s", e) + pass + task_id = self.request.id minio_client = Minio( @@ -236,7 +270,6 @@ def process_audio_transcribe_summarize_v2( audio_file_stream = minio_client.get_object( settings.aws_storage_bucket_name, object_name=filename ) - temp_file_path = save_audio_stream(audio_file_stream) logger.info("Recording successfully downloaded") @@ -257,7 +290,7 @@ def process_audio_transcribe_summarize_v2( raise AudioValidationError(error_msg) logger.info("Initiating WhisperX client") - whisperx_client = openai.OpenAI( + whisperx_client = OpenAI( api_key=settings.whisperx_api_key, base_url=settings.whisperx_base_url, max_retries=settings.whisperx_max_retries, @@ -266,20 +299,25 @@ def process_audio_transcribe_summarize_v2( try: logger.info("Querying transcription …") transcription_start_time = time.time() - with open(temp_file_path, "rb") as audio_file: - transcription = whisperx_client.audio.transcriptions.create( - model=settings.whisperx_asr_model, file=audio_file - ) - metadata_manager.track( - task_id, - { - "transcription_time": round( - time.time() - transcription_start_time, 2 - ) - }, - ) - logger.info("Transcription received.") - logger.debug("Transcription: \n %s", transcription) + with obs.span( + name="whisperx.transcribe", + input={ + "model": settings.whisperx_asr_model, + "audio_seconds": round(audio_file.info.length, 2), + "endpoint": settings.whisperx_base_url, + }, + ): + with open(temp_file_path, "rb") as audio_file_rb: + transcription = whisperx_client.audio.transcriptions.create( + model=settings.whisperx_asr_model, + file=audio_file_rb, + ) + metadata_manager.track( + task_id, + {"transcription_time": round(time.time() - transcription_start_time, 2)}, + ) + logger.info("Transcription received.") + logger.debug("Transcription: \n %s", transcription) finally: if os.path.exists(temp_file_path): os.remove(temp_file_path) @@ -337,6 +375,7 @@ def process_audio_transcribe_summarize_v2( max_retries=settings.celery_max_retries, queue=settings.summarize_queue, ) +@obs.observe(name="summarize-transcription", capture_input=False, capture_output=False) def summarize_transcription(self, transcript: str, email: str, sub: str, title: str): """Generate a summary from the provided transcription text. @@ -351,11 +390,21 @@ def summarize_transcription(self, transcript: str, email: str, sub: str, title: llm_service = LLMService() - tldr = llm_service.call(PROMPT_SYSTEM_TLDR, transcript) + try: + obs.update_current_trace( + user_id=sub or email, + tags=["celery", "summarization"], + metadata={"title": title}, + ) + except Exception as e: + logger.warning("Langfuse update trace failed: %s", e) + pass + + tldr = llm_service.call_llm_gen("tldr", PROMPT_SYSTEM_TLDR, transcript) logger.info("TLDR generated") - parts = llm_service.call(PROMPT_SYSTEM_PLAN, transcript) + parts = llm_service.call_llm_gen("plan", PROMPT_SYSTEM_PLAN, transcript) logger.info("Plan generated") parts = parts.split("\n") @@ -366,16 +415,22 @@ def summarize_transcription(self, transcript: str, email: str, sub: str, title: for part in parts: prompt_user_part = PROMPT_USER_PART.format(part=part, transcript=transcript) logger.info("Summarizing part: %s", part) - parts_summarized.append(llm_service.call(PROMPT_SYSTEM_PART, prompt_user_part)) + parts_summarized.append( + llm_service.call_llm_gen("part", PROMPT_SYSTEM_PART, prompt_user_part) + ) logger.info("Parts summarized") raw_summary = "\n\n".join(parts_summarized) - next_steps = llm_service.call(PROMPT_SYSTEM_NEXT_STEP, transcript) + next_steps = llm_service.call_llm_gen( + "next_steps", PROMPT_SYSTEM_NEXT_STEP, transcript + ) logger.info("Next steps generated") - cleaned_summary = llm_service.call(PROMPT_SYSTEM_CLEANING, raw_summary) + cleaned_summary = llm_service.call_llm_gen( + "cleaning", PROMPT_SYSTEM_CLEANING, raw_summary + ) logger.info("Summary cleaned") summary = tldr + "\n\n" + cleaned_summary + "\n\n" + next_steps @@ -395,3 +450,8 @@ def summarize_transcription(self, transcript: str, email: str, sub: str, title: logger.info("Webhook submitted successfully. Status: %s", response.status_code) logger.debug("Response body: %s", response.text) + try: + obs.flush() + except Exception as e: + logger.warning("Langfuse flush failed: %s", e) + pass diff --git a/src/summary/summary/core/config.py b/src/summary/summary/core/config.py index 77dc7aa3f..2f97ca83c 100644 --- a/src/summary/summary/core/config.py +++ b/src/summary/summary/core/config.py @@ -4,6 +4,7 @@ from typing import Annotated, List, Optional from fastapi import Depends +from pydantic import SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict @@ -75,6 +76,12 @@ class Settings(BaseSettings): task_tracker_redis_url: str = "redis://redis/0" task_tracker_prefix: str = "task_metadata:" + # Langfuse + langfuse_is_enabled: bool = True + langfuse_host: Optional[str] = None + langfuse_public_key: Optional[str] = None + langfuse_secret_key: Optional[SecretStr] = None + @lru_cache def get_settings(): diff --git a/src/summary/summary/core/observability.py b/src/summary/summary/core/observability.py new file mode 100644 index 000000000..dad8d6633 --- /dev/null +++ b/src/summary/summary/core/observability.py @@ -0,0 +1,101 @@ +"""Wrapper around Langfuse observability.""" + +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import Any, Callable, ContextManager + +logger = logging.getLogger(__name__) + +try: + from langfuse import Langfuse as _Langfuse + from langfuse import observe as _lf_observe +except Exception as e: + logger.debug("Langfuse import failed: %s", e) + _Langfuse = None + _lf_observe = None + + +class Observability: + """Wrapper around Langfuse observability.""" + + def __init__( + self, is_enabled, langfuse_host, langfuse_public_key, langfuse_secret_key + ) -> None: + """Initialize the Observability instance.""" + self._client = None + if hasattr(langfuse_secret_key, "get_secret_value"): + langfuse_secret_key = langfuse_secret_key.get_secret_value() + + self._enabled = bool( + is_enabled and langfuse_host and langfuse_public_key and langfuse_secret_key + ) + + if not self._enabled or _Langfuse is None: + self._enabled = False + return + + try: + self._client = _Langfuse( + public_key=langfuse_public_key, + secret_key=langfuse_secret_key, + host=langfuse_host, + ) + except Exception as e: + logger.warning("Langfuse init failed: %s", e) + self._enabled = False + self._client = None + + def observe( + self, **decorator_kwargs + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to observe a function with Langfuse. If disabled, returns a no-op decorator.""" # noqa: E501 + if self._enabled and self._client and _lf_observe is not None: + return _lf_observe(**decorator_kwargs) + + def _noop(fn): + return fn + + return _noop + + def span(self, name: str, **kwargs) -> ContextManager[Any]: + """Context manager to create a span with Langfuse.""" + if self._enabled and self._client: + start_span = getattr(self._client, "start_as_current_span", None) + if callable(start_span): + return start_span(name=name, **kwargs) + return nullcontext() + + def generation(self, **kwargs) -> ContextManager[Any]: + """Context manager to create a generation with Langfuse.""" + if self._enabled and self._client: + start_gen = getattr(self._client, "start_as_current_generation", None) + if callable(start_gen): + return start_gen(**kwargs) + return nullcontext() + + def update_current_trace(self, **kwargs) -> None: + """Update the current trace with additional metadata.""" + if not (self._enabled and self._client): + return + try: + self._client.update_current_trace(**kwargs) + except Exception as e: + logger.warning("Langfuse update_current_trace failed: %s", e) + pass + + def flush(self) -> None: + """Flush any buffered data to Langfuse.""" + if not (self._enabled and self._client): + return + try: + self._client.flush() + except Exception as e: + logger.warning("Langfuse flush failed: %s", e) + pass + + @property + def is_enabled(self) -> bool: + """Check if observability is enabled.""" + return bool(self._enabled and self._client)