22
33import json
44import logging
5+ import time
56from collections .abc import Callable
67from logging import Logger
78from typing import Any
@@ -52,14 +53,14 @@ def _format_message(self, message: dict[str, str | int]) -> str:
5253 else :
5354 return " " .join ([f"{ k } ={ v } " for k , v in message .items ()])
5455
55- def _get_timestamp_from_context (self , context : MiddlewareContext [Any ]) -> str :
56- """Get a timestamp from the context."""
57- return context .timestamp .isoformat ()
58-
5956 def _create_before_message (
60- self , context : MiddlewareContext [Any ], event : str
57+ self , context : MiddlewareContext [Any ]
6158 ) -> dict [str , str | int ]:
62- message = self ._create_base_message (context , event )
59+ message = {
60+ "event" : context .type + "_start" ,
61+ "method" : context .method or "unknown" ,
62+ "source" : context .source ,
63+ }
6364
6465 if (
6566 self .include_payloads
@@ -85,57 +86,61 @@ def _create_before_message(
8586
8687 return message
8788
88- def _create_after_message (
89- self , context : MiddlewareContext [Any ], event : str
90- ) -> dict [str , str | int ]:
91- return self ._create_base_message (context , event )
92-
93- def _create_base_message (
89+ def _create_error_message (
9490 self ,
9591 context : MiddlewareContext [Any ],
96- event : str ,
97- ) -> dict [str , str | int ]:
98- """Format a message for logging."""
92+ start_time : float ,
93+ error : Exception ,
94+ ) -> dict [str , str | int | float ]:
95+ duration_ms : float = _get_duration_ms (start_time )
96+ message = {
97+ "event" : context .type + "_error" ,
98+ "method" : context .method or "unknown" ,
99+ "source" : context .source ,
100+ "duration_ms" : duration_ms ,
101+ "error" : str (object = error ),
102+ }
103+ return message
99104
100- parts : dict [str , str | int ] = {
101- "event" : event ,
102- "timestamp" : self ._get_timestamp_from_context (context ),
105+ def _create_after_message (
106+ self ,
107+ context : MiddlewareContext [Any ],
108+ start_time : float ,
109+ ) -> dict [str , str | int | float ]:
110+ duration_ms : float = _get_duration_ms (start_time )
111+ message = {
112+ "event" : context .type + "_success" ,
103113 "method" : context .method or "unknown" ,
104- "type" : context .type ,
105114 "source" : context .source ,
115+ "duration_ms" : duration_ms ,
106116 }
117+ return message
107118
108- return parts
119+ def _log_message (
120+ self , message : dict [str , str | int | float ], log_level : int | None = None
121+ ):
122+ self .logger .log (log_level or self .log_level , self ._format_message (message ))
109123
110124 async def on_message (
111125 self , context : MiddlewareContext [Any ], call_next : CallNext [Any , Any ]
112126 ) -> Any :
113- """Log all messages."""
127+ """Log messages for configured methods ."""
114128
115129 if self .methods and context .method not in self .methods :
116130 return await call_next (context )
117131
118- request_start_log_message = self ._create_before_message (
119- context , "request_start"
120- )
121-
122- formatted_message = self ._format_message (request_start_log_message )
123- self .logger .log (self .log_level , f"Processing message: { formatted_message } " )
132+ self ._log_message (self ._create_before_message (context ))
124133
134+ start_time = time .perf_counter ()
125135 try :
126136 result = await call_next (context )
127137
128- request_success_log_message = self ._create_after_message (
129- context , "request_success"
130- )
131-
132- formatted_message = self ._format_message (request_success_log_message )
133- self .logger .log (self .log_level , f"Completed message: { formatted_message } " )
138+ self ._log_message (self ._create_after_message (context , start_time ))
134139
135140 return result
136141 except Exception as e :
137- self .logger . log (
138- logging . ERROR , f"Failed message: { context . method or 'unknown' } - { e } "
142+ self ._log_message (
143+ self . _create_error_message ( context , start_time , e ), logging . ERROR
139144 )
140145 raise
141146
@@ -184,7 +189,7 @@ def __init__(
184189 payload_serializer: Callable that converts objects to a JSON string for the
185190 payload. If not provided, uses FastMCP's default tool serializer.
186191 """
187- self .logger : Logger = logger or logging .getLogger ("fastmcp.requests " )
192+ self .logger : Logger = logger or logging .getLogger ("fastmcp.middleware.logging " )
188193 self .log_level = log_level
189194 self .include_payloads : bool = include_payloads
190195 self .include_payload_length : bool = include_payload_length
@@ -234,7 +239,9 @@ def __init__(
234239 payload_serializer: Callable that converts objects to a JSON string for the
235240 payload. If not provided, uses FastMCP's default tool serializer.
236241 """
237- self .logger : Logger = logger or logging .getLogger ("fastmcp.structured" )
242+ self .logger : Logger = logger or logging .getLogger (
243+ "fastmcp.middleware.structured_logging"
244+ )
238245 self .log_level : int = log_level
239246 self .include_payloads : bool = include_payloads
240247 self .include_payload_length : bool = include_payload_length
@@ -243,3 +250,7 @@ def __init__(
243250 self .payload_serializer : Callable [[Any ], str ] | None = payload_serializer
244251 self .max_payload_length : int | None = None
245252 self .structured_logging : bool = True
253+
254+
255+ def _get_duration_ms (start_time : float , / ) -> float :
256+ return round (number = (time .perf_counter () - start_time ) * 1000 , ndigits = 2 )
0 commit comments