-
Notifications
You must be signed in to change notification settings - Fork 14
feat(grader): add metric_type parameter to ToolCallSequenceMatchGrader #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ __pycache__/ | |
| poetry.lock | ||
| /site | ||
| uv*/ | ||
| uv.lock | ||
|
|
||
| # Build and Distribution Files | ||
| /dist | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -44,16 +44,31 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| strict_mode: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_jaccard_similarity: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metric_type: str = "recall", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Initialize the ToolCallSequenceMatchGrader. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| strict_mode: If True, matches both tool_call name and arguments; if False, only matches tool_call name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_jaccard_similarity: If True, use Jaccard similarity for loose mode (ignores step order) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metric_type: Metric type for step matching when use_jaccard_similarity=False and strict_mode=False. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - "recall": matched_count / reference_count (default) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - "precision": matched_count / predicted_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs: Additional arguments passed to BaseGrader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name="tool_call_sequence", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mode=GraderMode.POINTWISE, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description="Evaluate tool call sequence matching against reference", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if metric_type not in ("recall", "precision"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"metric_type must be 'recall' or 'precision', got '{metric_type}'") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.strict_mode = strict_mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_jaccard_similarity = use_jaccard_similarity | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.metric_type = metric_type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def extract_predicted_tool_sequence( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -267,11 +282,19 @@ def calculate_step_matching_score( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = sum(tool_scores) / len(tool_scores) if tool_scores else 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # In loose mode, calculate step score based on the ratio of matched tools | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(gt_tool_names) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| matched_count = len(gt_tool_names) - len(missing) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = matched_count / len(gt_tool_names) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| matched_count = len(gt_tool_names) - len(missing) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.metric_type == "recall": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Recall: matched / reference | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(gt_tool_names) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = matched_count / len(gt_tool_names) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: # precision | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Precision: matched / predicted | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(pred_tool_names) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = matched_count / len(pred_tool_names) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = 0.0 if len(gt_tool_names) > 0 else 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+285
to
+297
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for calculating the
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we have metadata, should we also consider expose these info like matched_count/denominator? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_score = 0.0 # No matching step in model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_score += step_score | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -420,7 +443,7 @@ async def aevaluate( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| score_type = "step_matching" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Generate detailed reason | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mode_str = "strict" if self.strict_mode else "loose" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| method_str = "jaccard" if self.use_jaccard_similarity else "step-by-step" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| method_str = "jaccard" if self.use_jaccard_similarity else f"step-by-step/{self.metric_type}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reason = f"Tool call sequence evaluation ({mode_str} mode, {method_str}): {score_type}={final_score:.3f}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Count tools for metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| predicted_tool_count = sum(len(tools) for tools in predicted_tool_steps.values()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For improved type safety and code clarity, consider using
typing.Literalfor themetric_typeparameter instead ofstr. This makes the allowed string values ('recall', 'precision') explicit for static analysis tools and developers reading the code.You would need to add
from typing import Literalto the imports at the top of the file and change the signature to:The runtime validation on line 67 is still valuable and should be kept.