diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index ee46ac5..e33ad8a 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -62,13 +62,15 @@ def test_error_when_input_length_mismatch() -> None: @pytest.mark.parametrize( - ["predictions", "references", "bleu", "codebleu"], + ["predictions", "references", "bleu", "syntax_match", "dataflow_match", "codebleu"], [ # https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png ( ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7846, + 14/21, + 2/3, 0.7238, # TODO: lol, not working at <3.12 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 @@ -76,6 +78,8 @@ def test_error_when_input_length_mismatch() -> None: ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7543, + 14/21, + 2/3, 0.7091, # Should be 0.6973 if AST=13/21, however at the moment tee-sitter AST is 14/21 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 @@ -83,17 +87,26 @@ def test_error_when_input_length_mismatch() -> None: ["public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7571, # Error in the Figure 4, text "Example 2" states 0.7571, not 0.6814, + 1.0, + 1.0, 0.8804, # Error in the Figure 4, text "Example 2" states 0.8804, not 0.8397, ), ], ) def test_code_x_glue_readme_examples( - predictions: List[Any], references: List[Any], bleu: float, codebleu: float + predictions: List[Any], + references: List[Any], + bleu: float, + syntax_match: float, + dataflow_match: float, + codebleu: float, ) -> None: result = calc_codebleu(references, predictions, "java") logging.debug(result) assert result["ngram_match_score"] == pytest.approx(bleu, 0.01) + assert result["syntax_match_score"] == pytest.approx(syntax_match, 0.01) + assert result["dataflow_match_score"] == pytest.approx(dataflow_match, 0.01) assert result["codebleu"] == pytest.approx(codebleu, 0.01)