Skip to content

Commit 033dc0f

Browse files
committed
Cleaned logging directory
1 parent 71d00f0 commit 033dc0f

File tree

4 files changed

+98
-45
lines changed

4 files changed

+98
-45
lines changed

nemoguardrails/context.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,31 @@
1414
# limitations under the License.
1515

1616
import contextvars
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Optional
18+
19+
if TYPE_CHECKING:
20+
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
21+
from nemoguardrails.logging.stats import LLMStats
1822

1923
streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
2024

2125
# The object that holds additional explanation information.
22-
explain_info_var = contextvars.ContextVar("explain_info", default=None)
26+
explain_info_var: contextvars.ContextVar[
27+
Optional["ExplainInfo"]
28+
] = contextvars.ContextVar("explain_info", default=None)
2329

2430
# The current LLM call.
25-
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
31+
llm_call_info_var: contextvars.ContextVar[
32+
Optional["LLMCallInfo"]
33+
] = contextvars.ContextVar("llm_call_info", default=None)
2634

2735
# All the generation options applicable to the current context.
2836
generation_options_var = contextvars.ContextVar("generation_options", default=None)
2937

3038
# The stats about the LLM calls.
31-
llm_stats_var = contextvars.ContextVar("llm_stats", default=None)
39+
llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar(
40+
"llm_stats", default=None
41+
)
3242

3343
# The raw LLM request that comes from the user.
3444
# This is used in passthrough mode.

nemoguardrails/logging/callbacks.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
import logging
1616
import uuid
1717
from time import time
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Dict, List, Optional, Union, cast
1919
from uuid import UUID
2020

2121
from langchain.callbacks import StdOutCallbackHandler
22-
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
22+
from langchain.callbacks.base import (
23+
AsyncCallbackHandler,
24+
BaseCallbackHandler,
25+
BaseCallbackManager,
26+
)
2327
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
2428
from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult
2529
from langchain_core.outputs import ChatGeneration
@@ -33,7 +37,7 @@
3337
log = logging.getLogger(__name__)
3438

3539

36-
class LoggingCallbackHandler(AsyncCallbackHandler, StdOutCallbackHandler):
40+
class LoggingCallbackHandler(AsyncCallbackHandler):
3741
"""Async callback handler that can be used to handle callbacks from langchain."""
3842

3943
async def on_llm_start(
@@ -184,10 +188,17 @@ async def on_llm_end(
184188
)
185189

186190
log.info("Output Stats :: %s", response.llm_output)
187-
took = llm_call_info.finished_at - llm_call_info.started_at
188-
log.info("--- :: LLM call took %.2f seconds", took)
189-
llm_stats.inc("total_time", took)
190-
llm_call_info.duration = took
191+
if (
192+
llm_call_info.finished_at is not None
193+
and llm_call_info.started_at is not None
194+
):
195+
took = llm_call_info.finished_at - llm_call_info.started_at
196+
log.info("--- :: LLM call took %.2f seconds", took)
197+
llm_stats.inc("total_time", took)
198+
llm_call_info.duration = took
199+
else:
200+
log.warning("LLM call timing information incomplete")
201+
llm_call_info.duration = 0.0
191202

192203
# Update the token usage stats as well
193204
token_stats_found = False
@@ -259,7 +270,7 @@ async def on_llm_end(
259270

260271
async def on_llm_error(
261272
self,
262-
error: Union[Exception, KeyboardInterrupt],
273+
error: BaseException,
263274
*,
264275
run_id: UUID,
265276
parent_run_id: Optional[UUID] = None,
@@ -290,7 +301,7 @@ async def on_chain_end(
290301

291302
async def on_chain_error(
292303
self,
293-
error: Union[Exception, KeyboardInterrupt],
304+
error: BaseException,
294305
*,
295306
run_id: UUID,
296307
parent_run_id: Optional[UUID] = None,
@@ -321,7 +332,7 @@ async def on_tool_end(
321332

322333
async def on_tool_error(
323334
self,
324-
error: Union[Exception, KeyboardInterrupt],
335+
error: BaseException,
325336
*,
326337
run_id: UUID,
327338
parent_run_id: Optional[UUID] = None,
@@ -362,14 +373,15 @@ async def on_agent_finish(
362373

363374
handlers = [LoggingCallbackHandler()]
364375
logging_callbacks = BaseCallbackManager(
365-
handlers=handlers, inheritable_handlers=handlers
376+
handlers=cast(List[BaseCallbackHandler], handlers),
377+
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
366378
)
367379

368380
logging_callback_manager_for_chain = AsyncCallbackManagerForChainRun(
369381
run_id=uuid.uuid4(),
370382
parent_run_id=None,
371-
handlers=handlers,
372-
inheritable_handlers=handlers,
383+
handlers=cast(List[BaseCallbackHandler], handlers),
384+
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
373385
tags=[],
374386
inheritable_tags=[],
375387
)

nemoguardrails/logging/processing_log.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,36 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
153153
action_params=event_data["action_params"],
154154
started_at=event["timestamp"],
155155
)
156-
activated_rail.executed_actions.append(executed_action)
156+
if activated_rail is not None:
157+
activated_rail.executed_actions.append(executed_action)
157158

158159
elif event_type == "InternalSystemActionFinished":
159160
action_name = event_data["action_name"]
160161
if action_name in ignored_actions:
161162
continue
162163

163-
executed_action.finished_at = event["timestamp"]
164-
executed_action.duration = (
165-
executed_action.finished_at - executed_action.started_at
166-
)
167-
executed_action.return_value = event_data["return_value"]
164+
if executed_action is not None:
165+
executed_action.finished_at = event["timestamp"]
166+
if (
167+
executed_action.finished_at is not None
168+
and executed_action.started_at is not None
169+
):
170+
executed_action.duration = (
171+
executed_action.finished_at - executed_action.started_at
172+
)
173+
executed_action.return_value = event_data["return_value"]
168174
executed_action = None
169175

170176
elif event_type in ["InputRailFinished", "OutputRailFinished"]:
171-
activated_rail.finished_at = event["timestamp"]
172-
activated_rail.duration = (
173-
activated_rail.finished_at - activated_rail.started_at
174-
)
177+
if activated_rail is not None:
178+
activated_rail.finished_at = event["timestamp"]
179+
if (
180+
activated_rail.finished_at is not None
181+
and activated_rail.started_at is not None
182+
):
183+
activated_rail.duration = (
184+
activated_rail.finished_at - activated_rail.started_at
185+
)
175186
activated_rail = None
176187

177188
elif event_type == "InputRailsFinished":
@@ -181,14 +192,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
181192
output_rails_finished_at = event["timestamp"]
182193

183194
elif event["type"] == "llm_call_info":
184-
executed_action.llm_calls.append(event["data"])
195+
if executed_action is not None:
196+
executed_action.llm_calls.append(event["data"])
185197

186198
# If at the end of the processing we still have an active rail, it is because
187199
# we have hit a stop. In this case, we take the last timestamp as the timestamp for
188200
# finishing the rail.
189201
if activated_rail is not None:
190202
activated_rail.finished_at = last_timestamp
191-
activated_rail.duration = activated_rail.finished_at - activated_rail.started_at
203+
if (
204+
activated_rail.finished_at is not None
205+
and activated_rail.started_at is not None
206+
):
207+
activated_rail.duration = (
208+
activated_rail.finished_at - activated_rail.started_at
209+
)
192210

193211
if activated_rail.type in ["input", "output"]:
194212
activated_rail.stop = True
@@ -213,9 +231,13 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
213231
if activated_rail.type in ["dialog", "generation"]:
214232
next_rail = generation_log.activated_rails[i + 1]
215233
activated_rail.finished_at = next_rail.started_at
216-
activated_rail.duration = (
217-
activated_rail.finished_at - activated_rail.started_at
218-
)
234+
if (
235+
activated_rail.finished_at is not None
236+
and activated_rail.started_at is not None
237+
):
238+
activated_rail.duration = (
239+
activated_rail.finished_at - activated_rail.started_at
240+
)
219241

220242
# If we have output rails, we also record the general stats
221243
if output_rails_started_at:
@@ -257,17 +279,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
257279

258280
for executed_action in activated_rail.executed_actions:
259281
for llm_call in executed_action.llm_calls:
260-
generation_log.stats.llm_calls_count += 1
261-
generation_log.stats.llm_calls_duration += llm_call.duration
262-
generation_log.stats.llm_calls_total_prompt_tokens += (
263-
llm_call.prompt_tokens or 0
264-
)
265-
generation_log.stats.llm_calls_total_completion_tokens += (
266-
llm_call.completion_tokens or 0
267-
)
268-
generation_log.stats.llm_calls_total_tokens += (
269-
llm_call.total_tokens or 0
270-
)
282+
generation_log.stats.llm_calls_count = (
283+
generation_log.stats.llm_calls_count or 0
284+
) + 1
285+
generation_log.stats.llm_calls_duration = (
286+
generation_log.stats.llm_calls_duration or 0
287+
) + (llm_call.duration or 0)
288+
generation_log.stats.llm_calls_total_prompt_tokens = (
289+
generation_log.stats.llm_calls_total_prompt_tokens or 0
290+
) + (llm_call.prompt_tokens or 0)
291+
generation_log.stats.llm_calls_total_completion_tokens = (
292+
generation_log.stats.llm_calls_total_completion_tokens or 0
293+
) + (llm_call.completion_tokens or 0)
294+
generation_log.stats.llm_calls_total_tokens = (
295+
generation_log.stats.llm_calls_total_tokens or 0
296+
) + (llm_call.total_tokens or 0)
271297

272298
generation_log.stats.total_duration = (
273299
processing_log[-1]["timestamp"] - processing_log[0]["timestamp"]

nemoguardrails/logging/verbose.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def emit(self, record) -> None:
5454
skip_print = True
5555
if verbose_llm_calls:
5656
console.print("")
57-
console.print(f"[cyan]LLM {title} ({record.id[:5]}..)[/]")
57+
record_id = getattr(record, "id", "unknown")
58+
console.print(
59+
f"[cyan]LLM {title} ({record_id[:5] if record_id != 'unknown' else record_id}..)[/]"
60+
)
5861
for line in body.split("\n"):
5962
text = Text(line, style="black on #006600", end="\n")
6063
text.pad_right(console.width)
@@ -66,8 +69,10 @@ def emit(self, record) -> None:
6669
if verbose_llm_calls:
6770
skip_print = True
6871
console.print("")
72+
record_id = getattr(record, "id", "unknown")
73+
record_task = getattr(record, "task", "unknown")
6974
console.print(
70-
f"[cyan]LLM Prompt ({record.id[:5]}..) - {record.task}[/]"
75+
f"[cyan]LLM Prompt ({record_id[:5] if record_id != 'unknown' else record_id}..) - {record_task}[/]"
7176
)
7277

7378
for line in body.split("\n"):

0 commit comments

Comments
 (0)