Skip to content

Commit 76cfdef

Browse files
authored
Merge pull request #116 from sbintuitions/add_correlation
Add `Correlation` metric class
2 parents 59976fe + 7399579 commit 76cfdef

File tree

5 files changed

+224
-7
lines changed

5 files changed

+224
-7
lines changed

flexeval/core/metric/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .code_eval import CodeEval
55
from .common_prefix_length import CommonPrefixLength
66
from .common_string_length import CommonStringLength
7+
from .correlation import Correlation
78
from .exact_match import ExactMatch
89
from .llm_label import ChatLLMLabel, LLMLabel
910
from .llm_score import ChatLLMScore, LLMScore

flexeval/core/metric/correlation.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import warnings
5+
from typing import Literal
6+
7+
from scipy.stats import kendalltau, pearsonr, spearmanr
8+
9+
from .base import Metric, MetricResult
10+
from .string_processor import StringProcessor
11+
12+
13+
class Correlation(Metric):
14+
"""
15+
Correlation metric to compute Pearson, Spearman, or Kendall correlation coefficients.
16+
The lm_outputs and references should be numeric values, optionally preprocessed by StringProcessor.
17+
18+
Args:
19+
method: The correlation method to use ('pearson', 'spearman', 'kendall').
20+
lm_output_processor: StringProcessor or a list of StringProcessor to be applied to the model outputs before
21+
computing the correlation. If a list is provided, the processors will be applied in order.
22+
reference_processor: StringProcessor or a list of StringProcessor to be applied to the references before
23+
computing the correlation. If a list is provided, the processors will be applied in order.
24+
25+
Examples:
26+
>>> from flexeval import Correlation
27+
>>> correlation = Correlation(method='pearson')
28+
>>> lm_outputs = ["1", "2", "3", "4", "5"]
29+
>>> references = [["5"], ["4"], ["3"], ["2"], ["1"]]
30+
>>> result = correlation.evaluate(lm_outputs, references)
31+
>>> print(result)
32+
MetricResult(
33+
summary={"pearson_correlation": -1.0, "pearson_pvalue": 0.0},
34+
instance_details=[],
35+
)
36+
"""
37+
38+
def __init__(
39+
self,
40+
method: Literal["pearson", "spearman", "kendall"] = "pearson",
41+
lm_output_processor: StringProcessor | list[StringProcessor] | None = None,
42+
reference_processor: StringProcessor | list[StringProcessor] | None = None,
43+
) -> None:
44+
if method not in {"pearson", "spearman", "kendall"}:
45+
msg = f"Invalid method '{method}'. Choose from 'pearson', 'spearman', 'kendall'."
46+
raise ValueError(msg)
47+
self.method = method
48+
49+
if isinstance(lm_output_processor, StringProcessor):
50+
lm_output_processor = [lm_output_processor]
51+
if isinstance(reference_processor, StringProcessor):
52+
reference_processor = [reference_processor]
53+
self.lm_output_processors = lm_output_processor
54+
self.reference_processors = reference_processor
55+
56+
def evaluate(
57+
self,
58+
lm_outputs: list[str],
59+
references_list: list[list[str]],
60+
task_inputs_list: list[dict[str, str]] | None = None,
61+
) -> MetricResult:
62+
if len(lm_outputs) != len(references_list):
63+
msg = (
64+
f"Number of model outputs ({len(lm_outputs)}) and number of references ({len(references_list)}) "
65+
"should be the same."
66+
)
67+
raise ValueError(msg)
68+
69+
# We only use the first reference here
70+
references = [refs[0] for refs in references_list]
71+
72+
if self.lm_output_processors:
73+
lm_outputs = [
74+
functools.reduce(lambda x, norm: norm(x), self.lm_output_processors, output) for output in lm_outputs
75+
]
76+
77+
if self.reference_processors:
78+
references = [
79+
functools.reduce(lambda x, norm: norm(x), self.reference_processors, ref) for ref in references
80+
]
81+
82+
# The model output should be converted to float, if fails it will be treated as 0
83+
lm_outputs_as_float: list[float] = []
84+
for output in lm_outputs:
85+
try:
86+
lm_outputs_as_float.append(float(output))
87+
except ValueError: # noqa:PERF203
88+
warnings.warn(f"Failed to convert model output '{output}' to float. Treating it as 0.", stacklevel=2)
89+
lm_outputs_as_float.append(0.0)
90+
91+
# The reference should be converted to float
92+
references_as_float = [float(ref) for ref in references]
93+
94+
# Compute correlation
95+
if self.method == "pearson":
96+
correlation, pvalue = pearsonr(lm_outputs_as_float, references_as_float)
97+
elif self.method == "spearman":
98+
correlation, pvalue = spearmanr(lm_outputs_as_float, references_as_float)
99+
elif self.method == "kendall":
100+
correlation, pvalue = kendalltau(lm_outputs_as_float, references_as_float)
101+
else:
102+
msg = f"Unsupported method: {self.method}"
103+
raise ValueError(msg)
104+
105+
return MetricResult(
106+
{f"{self.method}_correlation": correlation, f"{self.method}_pvalue": pvalue},
107+
instance_details=[],
108+
)

poetry.lock

Lines changed: 48 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ vllm = {version = "^0.6.4.post1", optional = true }
3333
loguru = "^0.7.2"
3434
wandb = {version = "^0.17.2", optional = true}
3535
pyarrow = "16.1.0" # set the version because we get "Unable to find installation candidates" with 17.0.0
36+
scipy = "1.13.0"
3637

3738
[tool.poetry.extras]
3839
vllm = ["vllm"]

tests/core/metric/test_correlation.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from flexeval import Correlation, MetricResult
6+
7+
8+
@pytest.mark.parametrize(
9+
("method", "lm_outputs", "references", "expected_correlation"),
10+
[
11+
("pearson", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 1.0),
12+
("pearson", [1, 2, 3, 4, 5], [5, 4, 3, 2, 1], -1.0),
13+
("spearman", [1, 2, 3, 4, 5], [1, 20, 30, 400, 500], 1.0),
14+
("spearman", [1, 2, 3, 4, 5], [500, 400, 30, 20, 1], -1.0),
15+
("kendall", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], 1.0),
16+
("kendall", [1, 2, 3, 4, 5], [5, 4, 3, 2, 1], -1.0),
17+
],
18+
)
19+
def test_correlation(
20+
method: str, lm_outputs: list[float], references: list[float], expected_correlation: float
21+
) -> None:
22+
correlation = Correlation(method=method)
23+
references_list = [[ref] for ref in references] # Wrap references in a list for each instance
24+
25+
result = correlation.evaluate(lm_outputs, references_list)
26+
27+
assert isinstance(result, MetricResult)
28+
assert f"{method}_correlation" in result.summary
29+
assert result.summary[f"{method}_correlation"] == pytest.approx(expected_correlation, rel=1e-3)
30+
31+
32+
def test_instantiation_fails_with_invalid_method() -> None:
33+
with pytest.raises(ValueError, match="Invalid method"): # Expecting an error for invalid method
34+
Correlation(method="invalid")
35+
36+
37+
def test_evaluation_fails_with_mismatched_lengths() -> None:
38+
correlation = Correlation(method="pearson")
39+
40+
lm_outputs = [1, 2, 3]
41+
references_list = [[1], [2]] # Mismatched lengths
42+
43+
with pytest.raises(ValueError):
44+
correlation.evaluate(lm_outputs, references_list)
45+
46+
47+
def test_evaluation_does_not_fail_with_non_numeric_lm_outputs() -> None:
48+
correlation = Correlation(method="pearson")
49+
50+
lm_outputs = ["1", "a", "3"]
51+
references_list = [["1.0"], ["2.0"], ["3.0"]]
52+
53+
with pytest.warns(UserWarning, match="Failed to convert model output 'a' to float"):
54+
result = correlation.evaluate(lm_outputs, references_list)
55+
56+
assert result.summary["pearson_correlation"] is not None
57+
58+
59+
def test_evaluation_fails_with_non_numeric_references() -> None:
60+
correlation = Correlation(method="pearson")
61+
62+
lm_outputs = ["1", "2", "3"]
63+
references_list = [["1.0"], ["non-numeric"], ["3.0"]]
64+
65+
with pytest.raises(ValueError):
66+
correlation.evaluate(lm_outputs, references_list)

0 commit comments

Comments
 (0)