Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ __pycache__/
poetry.lock
/site
uv*/
uv.lock

# Build and Distribution Files
/dist
Expand Down
35 changes: 29 additions & 6 deletions openjudge/graders/agent/tool/tool_call_sequence_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,31 @@ def __init__(
self,
strict_mode: bool = True,
use_jaccard_similarity: bool = True,
metric_type: str = "recall",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved type safety and code clarity, consider using typing.Literal for the metric_type parameter instead of str. This makes the allowed string values ('recall', 'precision') explicit for static analysis tools and developers reading the code.

You would need to add from typing import Literal to the imports at the top of the file and change the signature to:

metric_type: Literal["recall", "precision"] = "recall",

The runtime validation on line 67 is still valuable and should be kept.

**kwargs,
):
"""
Initialize the ToolCallSequenceMatchGrader.

Args:
strict_mode: If True, matches both tool_call name and arguments; if False, only matches tool_call name
use_jaccard_similarity: If True, use Jaccard similarity for loose mode (ignores step order)
metric_type: Metric type for step matching when use_jaccard_similarity=False and strict_mode=False.
- "recall": matched_count / reference_count (default)
- "precision": matched_count / predicted_count
**kwargs: Additional arguments passed to BaseGrader
"""
super().__init__(
name="tool_call_sequence",
mode=GraderMode.POINTWISE,
description="Evaluate tool call sequence matching against reference",
**kwargs,
)
if metric_type not in ("recall", "precision"):
raise ValueError(f"metric_type must be 'recall' or 'precision', got '{metric_type}'")
self.strict_mode = strict_mode
self.use_jaccard_similarity = use_jaccard_similarity
self.metric_type = metric_type

def extract_predicted_tool_sequence(
self,
Expand Down Expand Up @@ -267,11 +282,19 @@ def calculate_step_matching_score(
step_score = sum(tool_scores) / len(tool_scores) if tool_scores else 0.0
else:
# In loose mode, calculate step score based on the ratio of matched tools
if len(gt_tool_names) > 0:
matched_count = len(gt_tool_names) - len(missing)
step_score = matched_count / len(gt_tool_names)
else:
step_score = 1.0
matched_count = len(gt_tool_names) - len(missing)
if self.metric_type == "recall":
# Recall: matched / reference
if len(gt_tool_names) > 0:
step_score = matched_count / len(gt_tool_names)
else:
step_score = 1.0
else: # precision
# Precision: matched / predicted
if len(pred_tool_names) > 0:
step_score = matched_count / len(pred_tool_names)
else:
step_score = 0.0 if len(gt_tool_names) > 0 else 1.0
Comment on lines +285 to +297
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating the step_score is correct, but the nested if/else statements can be simplified to improve readability. By defining the denominator based on the metric type and then handling the zero-denominator edge cases separately, the code becomes flatter and easier to follow.

Suggested change
matched_count = len(gt_tool_names) - len(missing)
if self.metric_type == "recall":
# Recall: matched / reference
if len(gt_tool_names) > 0:
step_score = matched_count / len(gt_tool_names)
else:
step_score = 1.0
else: # precision
# Precision: matched / predicted
if len(pred_tool_names) > 0:
step_score = matched_count / len(pred_tool_names)
else:
step_score = 0.0 if len(gt_tool_names) > 0 else 1.0
matched_count = len(gt_tool_names) - len(missing)
if self.metric_type == "recall":
denominator = len(gt_tool_names)
if denominator == 0:
step_score = 1.0 # Perfect recall if no reference tools are expected
else:
step_score = matched_count / denominator
else: # precision
denominator = len(pred_tool_names)
if denominator == 0:
# If no tools predicted, score is 1.0 only if no tools were expected
step_score = 1.0 if not gt_tool_names else 0.0
else:
step_score = matched_count / denominator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we have metadata, should we also consider expose these info like matched_count/denominator?

else:
step_score = 0.0 # No matching step in model
total_score += step_score
Expand Down Expand Up @@ -420,7 +443,7 @@ async def aevaluate(
score_type = "step_matching"
# Generate detailed reason
mode_str = "strict" if self.strict_mode else "loose"
method_str = "jaccard" if self.use_jaccard_similarity else "step-by-step"
method_str = "jaccard" if self.use_jaccard_similarity else f"step-by-step/{self.metric_type}"
reason = f"Tool call sequence evaluation ({mode_str} mode, {method_str}): {score_type}={final_score:.3f}"
# Count tools for metadata
predicted_tool_count = sum(len(tools) for tools in predicted_tool_steps.values())
Expand Down
137 changes: 137 additions & 0 deletions tests/graders/agent/tool/test_tool_call_sequence_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,140 @@ def test_tool_call_sequence_match_grader_extract_predicted_tool_sequence():
assert 1 in sequence
assert sequence[0][0]["name"] == "search"
assert sequence[1][0]["name"] == "analyze"


def test_tool_call_sequence_match_grader_metric_type_default():
"""Test that default metric_type is recall"""
grader = ToolCallSequenceMatchGrader(strict_mode=False, use_jaccard_similarity=False)
assert grader.metric_type == "recall"


def test_tool_call_sequence_match_grader_metric_type_precision():
"""Test creating grader with precision metric_type"""
grader = ToolCallSequenceMatchGrader(
strict_mode=False,
use_jaccard_similarity=False,
metric_type="precision",
)
assert grader.metric_type == "precision"


def test_tool_call_sequence_match_grader_invalid_metric_type():
"""Test that invalid metric_type raises ValueError"""
with pytest.raises(ValueError, match="metric_type must be 'recall' or 'precision'"):
ToolCallSequenceMatchGrader(metric_type="invalid")


@pytest.mark.asyncio
async def test_tool_call_sequence_match_grader_recall_metric():
"""Test loose mode with recall metric (matched / reference)"""
grader = ToolCallSequenceMatchGrader(
strict_mode=False,
use_jaccard_similarity=False,
metric_type="recall",
)

# Predicted has 1 tool, reference has 2 tools, 1 match
messages = [
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "search", "arguments": "{}"}},
],
},
]

reference_tool_calls = [
[
{"name": "search", "arguments": {}},
{"name": "calculate", "arguments": {}},
],
]

result = await grader.aevaluate(
messages=messages,
reference_tool_calls=reference_tool_calls,
)

# Recall = 1 matched / 2 reference = 0.5
assert result.score == 0.5


@pytest.mark.asyncio
async def test_tool_call_sequence_match_grader_precision_metric():
"""Test loose mode with precision metric (matched / predicted)"""
grader = ToolCallSequenceMatchGrader(
strict_mode=False,
use_jaccard_similarity=False,
metric_type="precision",
)

# Predicted has 2 tools, reference has 1 tool, 1 match
messages = [
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "search", "arguments": "{}"}},
{"function": {"name": "calculate", "arguments": "{}"}},
],
},
]

reference_tool_calls = [
[
{"name": "search", "arguments": {}},
],
]

result = await grader.aevaluate(
messages=messages,
reference_tool_calls=reference_tool_calls,
)

# Precision = 1 matched / 2 predicted = 0.5
assert result.score == 0.5


@pytest.mark.asyncio
async def test_tool_call_sequence_match_grader_recall_vs_precision():
"""Test that recall and precision give different scores for same input"""
messages = [
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "search", "arguments": "{}"}},
{"function": {"name": "extra_tool", "arguments": "{}"}},
],
},
]

reference_tool_calls = [
[
{"name": "search", "arguments": {}},
],
]

# Recall grader: 1 matched / 1 reference = 1.0
recall_grader = ToolCallSequenceMatchGrader(
strict_mode=False,
use_jaccard_similarity=False,
metric_type="recall",
)
recall_result = await recall_grader.aevaluate(
messages=messages,
reference_tool_calls=reference_tool_calls,
)

# Precision grader: 1 matched / 2 predicted = 0.5
precision_grader = ToolCallSequenceMatchGrader(
strict_mode=False,
use_jaccard_similarity=False,
metric_type="precision",
)
precision_result = await precision_grader.aevaluate(
messages=messages,
reference_tool_calls=reference_tool_calls,
)

assert recall_result.score == 1.0
assert precision_result.score == 0.5