Skip to content

Commit b81ce4e

Browse files
tgasser-nvPouyanpi
authored andcommitted
chore(types): Type-clean logging (43 errors) (#1395)
* Cleaned logging directory * Add nemoguardrails/logging to pyright pre-commit checks --------- Signed-off-by: Tim Gasser <[email protected]>
1 parent 8c7dfd2 commit b81ce4e

File tree

3 files changed

+78
-39
lines changed

3 files changed

+78
-39
lines changed

nemoguardrails/logging/callbacks.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
import logging
1616
import uuid
1717
from time import time
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Dict, List, Optional, Union, cast
1919
from uuid import UUID
2020

2121
from langchain.callbacks import StdOutCallbackHandler
22-
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
22+
from langchain.callbacks.base import (
23+
AsyncCallbackHandler,
24+
BaseCallbackHandler,
25+
BaseCallbackManager,
26+
)
2327
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
2428
from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult
2529
from langchain_core.outputs import ChatGeneration
@@ -33,7 +37,7 @@
3337
log = logging.getLogger(__name__)
3438

3539

36-
class LoggingCallbackHandler(AsyncCallbackHandler, StdOutCallbackHandler):
40+
class LoggingCallbackHandler(AsyncCallbackHandler):
3741
"""Async callback handler that can be used to handle callbacks from langchain."""
3842

3943
async def on_llm_start(
@@ -203,10 +207,17 @@ async def on_llm_end(
203207
)
204208

205209
log.info("Output Stats :: %s", response.llm_output)
206-
took = llm_call_info.finished_at - llm_call_info.started_at
207-
log.info("--- :: LLM call took %.2f seconds", took)
208-
llm_stats.inc("total_time", took)
209-
llm_call_info.duration = took
210+
if (
211+
llm_call_info.finished_at is not None
212+
and llm_call_info.started_at is not None
213+
):
214+
took = llm_call_info.finished_at - llm_call_info.started_at
215+
log.info("--- :: LLM call took %.2f seconds", took)
216+
llm_stats.inc("total_time", took)
217+
llm_call_info.duration = took
218+
else:
219+
log.warning("LLM call timing information incomplete")
220+
llm_call_info.duration = 0.0
210221

211222
# Update the token usage stats as well
212223
token_stats_found = False
@@ -278,7 +289,7 @@ async def on_llm_end(
278289

279290
async def on_llm_error(
280291
self,
281-
error: Union[Exception, KeyboardInterrupt],
292+
error: BaseException,
282293
*,
283294
run_id: UUID,
284295
parent_run_id: Optional[UUID] = None,
@@ -309,7 +320,7 @@ async def on_chain_end(
309320

310321
async def on_chain_error(
311322
self,
312-
error: Union[Exception, KeyboardInterrupt],
323+
error: BaseException,
313324
*,
314325
run_id: UUID,
315326
parent_run_id: Optional[UUID] = None,
@@ -340,7 +351,7 @@ async def on_tool_end(
340351

341352
async def on_tool_error(
342353
self,
343-
error: Union[Exception, KeyboardInterrupt],
354+
error: BaseException,
344355
*,
345356
run_id: UUID,
346357
parent_run_id: Optional[UUID] = None,
@@ -381,14 +392,15 @@ async def on_agent_finish(
381392

382393
handlers = [LoggingCallbackHandler()]
383394
logging_callbacks = BaseCallbackManager(
384-
handlers=handlers, inheritable_handlers=handlers
395+
handlers=cast(List[BaseCallbackHandler], handlers),
396+
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
385397
)
386398

387399
logging_callback_manager_for_chain = AsyncCallbackManagerForChainRun(
388400
run_id=uuid.uuid4(),
389401
parent_run_id=None,
390-
handlers=handlers,
391-
inheritable_handlers=handlers,
402+
handlers=cast(List[BaseCallbackHandler], handlers),
403+
inheritable_handlers=cast(List[BaseCallbackHandler], handlers),
392404
tags=[],
393405
inheritable_tags=[],
394406
)

nemoguardrails/logging/processing_log.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,36 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
153153
action_params=event_data["action_params"],
154154
started_at=event["timestamp"],
155155
)
156-
activated_rail.executed_actions.append(executed_action)
156+
if activated_rail is not None:
157+
activated_rail.executed_actions.append(executed_action)
157158

158159
elif event_type == "InternalSystemActionFinished":
159160
action_name = event_data["action_name"]
160161
if action_name in ignored_actions:
161162
continue
162163

163-
executed_action.finished_at = event["timestamp"]
164-
executed_action.duration = (
165-
executed_action.finished_at - executed_action.started_at
166-
)
167-
executed_action.return_value = event_data["return_value"]
164+
if executed_action is not None:
165+
executed_action.finished_at = event["timestamp"]
166+
if (
167+
executed_action.finished_at is not None
168+
and executed_action.started_at is not None
169+
):
170+
executed_action.duration = (
171+
executed_action.finished_at - executed_action.started_at
172+
)
173+
executed_action.return_value = event_data["return_value"]
168174
executed_action = None
169175

170176
elif event_type in ["InputRailFinished", "OutputRailFinished"]:
171-
activated_rail.finished_at = event["timestamp"]
172-
activated_rail.duration = (
173-
activated_rail.finished_at - activated_rail.started_at
174-
)
177+
if activated_rail is not None:
178+
activated_rail.finished_at = event["timestamp"]
179+
if (
180+
activated_rail.finished_at is not None
181+
and activated_rail.started_at is not None
182+
):
183+
activated_rail.duration = (
184+
activated_rail.finished_at - activated_rail.started_at
185+
)
175186
activated_rail = None
176187

177188
elif event_type == "InputRailsFinished":
@@ -181,14 +192,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
181192
output_rails_finished_at = event["timestamp"]
182193

183194
elif event["type"] == "llm_call_info":
184-
executed_action.llm_calls.append(event["data"])
195+
if executed_action is not None:
196+
executed_action.llm_calls.append(event["data"])
185197

186198
# If at the end of the processing we still have an active rail, it is because
187199
# we have hit a stop. In this case, we take the last timestamp as the timestamp for
188200
# finishing the rail.
189201
if activated_rail is not None:
190202
activated_rail.finished_at = last_timestamp
191-
activated_rail.duration = activated_rail.finished_at - activated_rail.started_at
203+
if (
204+
activated_rail.finished_at is not None
205+
and activated_rail.started_at is not None
206+
):
207+
activated_rail.duration = (
208+
activated_rail.finished_at - activated_rail.started_at
209+
)
192210

193211
if activated_rail.type in ["input", "output"]:
194212
activated_rail.stop = True
@@ -213,9 +231,13 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
213231
if activated_rail.type in ["dialog", "generation"]:
214232
next_rail = generation_log.activated_rails[i + 1]
215233
activated_rail.finished_at = next_rail.started_at
216-
activated_rail.duration = (
217-
activated_rail.finished_at - activated_rail.started_at
218-
)
234+
if (
235+
activated_rail.finished_at is not None
236+
and activated_rail.started_at is not None
237+
):
238+
activated_rail.duration = (
239+
activated_rail.finished_at - activated_rail.started_at
240+
)
219241

220242
# If we have output rails, we also record the general stats
221243
if output_rails_started_at:
@@ -257,17 +279,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
257279

258280
for executed_action in activated_rail.executed_actions:
259281
for llm_call in executed_action.llm_calls:
260-
generation_log.stats.llm_calls_count += 1
261-
generation_log.stats.llm_calls_duration += llm_call.duration
262-
generation_log.stats.llm_calls_total_prompt_tokens += (
263-
llm_call.prompt_tokens or 0
264-
)
265-
generation_log.stats.llm_calls_total_completion_tokens += (
266-
llm_call.completion_tokens or 0
267-
)
268-
generation_log.stats.llm_calls_total_tokens += (
269-
llm_call.total_tokens or 0
270-
)
282+
generation_log.stats.llm_calls_count = (
283+
generation_log.stats.llm_calls_count or 0
284+
) + 1
285+
generation_log.stats.llm_calls_duration = (
286+
generation_log.stats.llm_calls_duration or 0
287+
) + (llm_call.duration or 0)
288+
generation_log.stats.llm_calls_total_prompt_tokens = (
289+
generation_log.stats.llm_calls_total_prompt_tokens or 0
290+
) + (llm_call.prompt_tokens or 0)
291+
generation_log.stats.llm_calls_total_completion_tokens = (
292+
generation_log.stats.llm_calls_total_completion_tokens or 0
293+
) + (llm_call.completion_tokens or 0)
294+
generation_log.stats.llm_calls_total_tokens = (
295+
generation_log.stats.llm_calls_total_tokens or 0
296+
) + (llm_call.total_tokens or 0)
271297

272298
generation_log.stats.total_duration = (
273299
processing_log[-1]["timestamp"] - processing_log[0]["timestamp"]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ pyright = "^1.1.405"
157157
include = [
158158
"nemoguardrails/rails/**",
159159
"nemoguardrails/actions/**",
160+
"nemoguardrails/logging/**",
160161
"nemoguardrails/tracing/**",
161162
"tests/test_callbacks.py",
162163
]

0 commit comments

Comments
 (0)