diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9c256e1..aa36d11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,7 +39,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' cache: 'pip' # caching pip dependencies - name: Install dependencies run: | diff --git a/codebleu/bleu.py b/codebleu/bleu.py index be0738c..5bb2c9a 100644 --- a/codebleu/bleu.py +++ b/codebleu/bleu.py @@ -9,25 +9,11 @@ """BLEU score implementation.""" import math -import sys -import warnings from collections import Counter -from fractions import Fraction as _Fraction -from typing import Any from .utils import ngrams -# _normalize=False was removed in 3.12, add custom class for back-compatibility -class Fraction(_Fraction): - # We're immutable, so use __new__ not __init__ - def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction": - if sys.version_info >= (3, 12): - return super(Fraction, cls).__new__(cls, numerator, denominator) - else: - return super(Fraction, cls).__new__(cls, numerator, denominator, _normalize=False) - - def sentence_bleu( references, hypothesis, @@ -163,9 +149,9 @@ def corpus_bleu( # For each order of ngram, calculate the numerator and # denominator for the corpus-level modified precision. for i, _ in enumerate(weights, start=1): - p_i = modified_precision(references, hypothesis, i) - p_numerators[i] += p_i.numerator - p_denominators[i] += p_i.denominator + p_i_numerator, p_i_denominator = modified_precision(references, hypothesis, i) + p_numerators[i] += p_i_numerator + p_denominators[i] += p_i_denominator # Calculate the hypothesis length and the closest reference length. # Adds them to the corpus-level hypothesis and reference counts. @@ -182,8 +168,8 @@ def corpus_bleu( if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): weights = (1 / hyp_lengths,) * hyp_lengths - # Collects the various precision values for the different ngram orders. - p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) for i, _ in enumerate(weights, start=1)] + # Collects the various recall values for the different ngram orders. + p_n = [(p_numerators[i], p_denominators[i]) for i, _ in enumerate(weights, start=1)] # Returns 0 if there's no matching n-grams # We only need to check for p_numerators[1] == 0, since if there's @@ -199,7 +185,7 @@ def corpus_bleu( # it tries to retain the Fraction object as much as the # smoothing method allows. p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths) - s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n)) + s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n)) s = bp * math.exp(math.fsum(s)) return s @@ -295,7 +281,8 @@ def modified_precision(references, hypothesis, n): # Usually this happens when the ngram order is > len(reference). denominator = max(1, sum(counts.values())) - return Fraction(numerator, denominator, _normalize=False) + # return Fraction(numerator, denominator, _normalize=False) + return numerator, denominator def closest_ref_length(references, hyp_len): @@ -444,133 +431,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5): self.alpha = alpha self.k = k - def method0(self, p_n, *args, **kwargs): - """ - No smoothing. - """ - p_n_new = [] - for i, p_i in enumerate(p_n): - if p_i.numerator != 0: - p_n_new.append(p_i) - else: - _msg = str( - "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" - "Therefore the BLEU score evaluates to 0, independently of\n" - "how many N-gram overlaps of lower order it contains.\n" - "Consider using lower n-gram order or use " - "SmoothingFunction()" - ).format(i + 1) - warnings.warn(_msg) - # When numerator==0 where denonminator==0 or !=0, the result - # for the precision score should be equal to 0 or undefined. - # Due to BLEU geometric mean computation in logarithm space, - # we we need to take the return sys.float_info.min such that - # math.log(sys.float_info.min) returns a 0 precision score. - p_n_new.append(sys.float_info.min) - return p_n_new - def method1(self, p_n, *args, **kwargs): """ Smoothing method 1: Add *epsilon* counts to precision with 0 counts. """ - return [(p_i.numerator + self.epsilon) / p_i.denominator if p_i.numerator == 0 else p_i for p_i in p_n] - - def method2(self, p_n, *args, **kwargs): - """ - Smoothing method 2: Add 1 to both numerator and denominator from - Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of - machine translation quality using longest common subsequence and - skip-bigram statistics. In ACL04. - """ - return [Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) for p_i in p_n] - - def method3(self, p_n, *args, **kwargs): - """ - Smoothing method 3: NIST geometric sequence smoothing - The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each - precision score whose matching n-gram count is null. - k is 1 for the first 'n' value for which the n-gram match count is null/ - For example, if the text contains: - - one 2-gram match - - and (consequently) two 1-gram matches - the n-gram count for each individual precision score would be: - - n=1 => prec_count = 2 (two unigrams) - - n=2 => prec_count = 1 (one bigram) - - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) - - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) - """ - incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. - for i, p_i in enumerate(p_n): - if p_i.numerator == 0: - p_n[i] = 1 / (2**incvnt * p_i.denominator) - incvnt += 1 - return p_n - - def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 4: - Shorter translations may have inflated precision values due to having - smaller denominators; therefore, we give them proportionally - smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry - suggests dividing by 1/ln(len(T)), where T is the length of the translation. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - for i, p_i in enumerate(p_n): - if p_i.numerator == 0 and hyp_len != 0: - incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST. - p_n[i] = incvnt / p_i.denominator - return p_n - - def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 5: - The matched counts for similar values of n should be similar. To a - calculate the n-gram matched count, it averages the n−1, n and n+1 gram - matched counts. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - m = {} - # Requires an precision value for an addition ngram order. - p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] - m[-1] = p_n[0] + 1 - for i, p_i in enumerate(p_n): - p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 - m[i] = p_n[i] - return p_n - - def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 6: - Interpolates the maximum likelihood estimate of the precision *p_n* with - a prior estimate *pi0*. The prior is estimated by assuming that the ratio - between pn and pn−1 will be the same as that between pn−1 and pn−2; from - Gao and He (2013) Training MRF-Based Phrase Translation Models using - Gradient Ascent. In NAACL. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - # This smoothing only works when p_1 and p_2 is non-zero. - # Raise an error with an appropriate message when the input is too short - # to use this smoothing technique. - assert p_n[2], "This smoothing method requires non-zero precision for bigrams." - for i, p_i in enumerate(p_n): - if i in [0, 1]: # Skips the first 2 orders of ngrams. - continue - else: - pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] - # No. of ngrams in translation that matches the reference. - m = p_i.numerator - # No. of ngrams in translation. - ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1)) - # Calculates the interpolated precision. - p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha) - return p_n - - def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 7: - Interpolates methods 4 and 5. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - p_n = self.method4(p_n, references, hypothesis, hyp_len) - p_n = self.method5(p_n, references, hypothesis, hyp_len) - return p_n + return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n] diff --git a/codebleu/dataflow_match.py b/codebleu/dataflow_match.py index 30e3871..a4dd6d4 100644 --- a/codebleu/dataflow_match.py +++ b/codebleu/dataflow_match.py @@ -47,11 +47,11 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass diff --git a/codebleu/parser/build.py b/codebleu/parser/build.py index c1f1a3a..3e409e2 100644 --- a/codebleu/parser/build.py +++ b/codebleu/parser/build.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. # Copyright (c) 2023 Konstantin Chernyshev. # Licensed under the MIT license. - from tree_sitter import Language -Language.build_library( - "my-languages.so", - [ - "tree-sitter/go", - "tree-sitter/javascript", - "tree-sitter/python", - "tree-sitter/php", - "tree-sitter/java", - "tree-sitter/ruby", - "tree-sitter/c-sharp", - "tree-sitter/c", - "tree-sitter/cpp", - ], -) +if __name__ == "__main__": + Language.build_library( + "my-languages.so", + [ + "tree-sitter/go", + "tree-sitter/javascript", + "tree-sitter/python", + "tree-sitter/php", + "tree-sitter/java", + "tree-sitter/ruby", + "tree-sitter/c-sharp", + "tree-sitter/c", + "tree-sitter/cpp", + ], + ) diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 5f14e76..0050c1a 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -30,11 +30,11 @@ def calc_syntax_match(references, candidate, lang, lang_so_file): def corpus_syntax_match(references, candidates, lang, lang_so_file): - # print(os.listdir()) - JAVA_LANGUAGE = Language(lang_so_file, lang) + tree_sitter_language = Language(lang_so_file, lang) parser = Parser() - parser.set_language(JAVA_LANGUAGE) + parser.set_language(tree_sitter_language) match_count = 0 + match_count_candidate_to_reference = 0 total_count = 0 for i in range(len(candidates)): @@ -42,11 +42,11 @@ def corpus_syntax_match(references, candidates, lang, lang_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass @@ -69,15 +69,21 @@ def get_all_sub_trees(root_node): return sub_tree_sexp_list cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] - ref_sexps = get_all_sub_trees(reference_tree) + ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)] - # print(cand_sexps) - # print(ref_sexps) - - for sub_tree, depth in ref_sexps: + # TODO: fix, now we count number of reference subtrees matching candidate, + # but we should count number of candidate subtrees matching reference + # See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf + for sub_tree in ref_sexps: if sub_tree in cand_sexps: match_count += 1 - total_count += len(ref_sexps) + for sub_tree in cand_sexps: + if sub_tree in ref_sexps: + match_count_candidate_to_reference += 1 + + total_count += len(ref_sexps) + # print(f'match_count {match_count} / {total_count}') + # print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}') score = match_count / total_count return score diff --git a/codebleu/weighted_ngram_match.py b/codebleu/weighted_ngram_match.py index 507cb76..d651d7c 100644 --- a/codebleu/weighted_ngram_match.py +++ b/codebleu/weighted_ngram_match.py @@ -13,11 +13,8 @@ """BLEU score implementation.""" import math -import sys -import warnings from collections import Counter -from .bleu import modified_precision from .utils import ngrams @@ -156,8 +153,8 @@ def corpus_bleu( # For each order of ngram, calculate the numerator and # denominator for the corpus-level modified precision. for i, _ in enumerate(weights, start=1): - p_i_numeraotr, p_i_denominator = modified_recall(references, hypothesis, i) - p_numerators[i] += p_i_numeraotr + p_i_numerator, p_i_denominator = modified_recall(references, hypothesis, i) + p_numerators[i] += p_i_numerator p_denominators[i] += p_i_denominator # Calculate the hypothesis length and the closest reference length. @@ -400,133 +397,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5): self.alpha = alpha self.k = k - def method0(self, p_n, *args, **kwargs): - """ - No smoothing. - """ - p_n_new = [] - for i, p_i in enumerate(p_n): - if p_i[0] != 0: - p_n_new.append(p_i) - else: - _msg = str( - "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" - "Therefore the BLEU score evaluates to 0, independently of\n" - "how many N-gram overlaps of lower order it contains.\n" - "Consider using lower n-gram order or use " - "SmoothingFunction()" - ).format(i + 1) - warnings.warn(_msg) - # When numerator==0 where denonminator==0 or !=0, the result - # for the precision score should be equal to 0 or undefined. - # Due to BLEU geometric mean computation in logarithm space, - # we we need to take the return sys.float_info.min such that - # math.log(sys.float_info.min) returns a 0 precision score. - p_n_new.append(sys.float_info.min) - return p_n_new - def method1(self, p_n, *args, **kwargs): """ Smoothing method 1: Add *epsilon* counts to precision with 0 counts. """ return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n] - - def method2(self, p_n, *args, **kwargs): - """ - Smoothing method 2: Add 1 to both numerator and denominator from - Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of - machine translation quality using longest common subsequence and - skip-bigram statistics. In ACL04. - """ - return [(p_i[0] + 1, p_i[1] + 1) for p_i in p_n] - - def method3(self, p_n, *args, **kwargs): - """ - Smoothing method 3: NIST geometric sequence smoothing - The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each - precision score whose matching n-gram count is null. - k is 1 for the first 'n' value for which the n-gram match count is null/ - For example, if the text contains: - - one 2-gram match - - and (consequently) two 1-gram matches - the n-gram count for each individual precision score would be: - - n=1 => prec_count = 2 (two unigrams) - - n=2 => prec_count = 1 (one bigram) - - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) - - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) - """ - incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. - for i, p_i in enumerate(p_n): - if p_i.numerator == 0: - p_n[i] = 1 / (2**incvnt * p_i.denominator) - incvnt += 1 - return p_n - - def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 4: - Shorter translations may have inflated precision values due to having - smaller denominators; therefore, we give them proportionally - smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry - suggests dividing by 1/ln(len(T)), where T is the length of the translation. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - for i, p_i in enumerate(p_n): - if p_i.numerator == 0 and hyp_len != 0: - incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST. - p_n[i] = incvnt / p_i.denominator - return p_n - - def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 5: - The matched counts for similar values of n should be similar. To a - calculate the n-gram matched count, it averages the n−1, n and n+1 gram - matched counts. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - m = {} - # Requires an precision value for an addition ngram order. - p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] - m[-1] = p_n[0] + 1 - for i, p_i in enumerate(p_n): - p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 - m[i] = p_n[i] - return p_n - - def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 6: - Interpolates the maximum likelihood estimate of the precision *p_n* with - a prior estimate *pi0*. The prior is estimated by assuming that the ratio - between pn and pn−1 will be the same as that between pn−1 and pn−2; from - Gao and He (2013) Training MRF-Based Phrase Translation Models using - Gradient Ascent. In NAACL. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - # This smoothing only works when p_1 and p_2 is non-zero. - # Raise an error with an appropriate message when the input is too short - # to use this smoothing technique. - assert p_n[2], "This smoothing method requires non-zero precision for bigrams." - for i, p_i in enumerate(p_n): - if i in [0, 1]: # Skips the first 2 orders of ngrams. - continue - else: - pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] - # No. of ngrams in translation that matches the reference. - m = p_i.numerator - # No. of ngrams in translation. - ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1)) - # Calculates the interpolated precision. - p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha) - return p_n - - def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 7: - Interpolates methods 4 and 5. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - p_n = self.method4(p_n, references, hypothesis, hyp_len) - p_n = self.method5(p_n, references, hypothesis, hyp_len) - return p_n diff --git a/pyproject.toml b/pyproject.toml index ddc4a28..22d561c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,14 +40,12 @@ exclude = ["tests", "tests.*", "codebleu.parser.tree-sitter"] "*" = ["py.typed", "*.txt", "*.so", "*.dylib", "*.dll", "keywords/*"] - [project.scripts] codebleu = "codebleu.__main__:main" [project.urls] homepage = "https://github.com/k4black/codebleu" - [project.optional-dependencies] test = [ "pytest >=7.0.0,<8.0.0", @@ -60,10 +58,9 @@ test = [ "flake8 >=6.0.0,<7.0.0", "ruff >=0.0.275,<0.2.0", "isort >=5.0.0,<6.0.0", + "nltk >=3.0.0,<4.0.0", ] - - [tool.setuptools.dynamic] version = {file = "VERSION"} @@ -102,7 +99,7 @@ skip = ["build", "dist", ".venv", ".eggs", ".mypy_cache", ".pytest_cache", ".git [tool.black] line_length=120 -target_version=["py38","py39","py310","py311"] +target_version=["py38","py39","py310","py311", "py312"] [tool.ruff] line-length=120 diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 9bf4f32..2c41935 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -10,21 +10,17 @@ @pytest.mark.parametrize( ["predictions", "references", "codebleu"], [ - ( - ["some rannnndom words in length more than 3"], - ["def test ( ) :\n pass"], - 0.25, - ), # 'cause data_flow is 0 and considered as 1 - (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.4), - (["def foo ( x ) :\n return x * x"], ["def bar ( x ) :\n return x"], 0.6), - (["def bar ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 0.8), + (["some rannnndom words in length more than 3"], ["def test ( ) :\n pass"], 0.25), # cause data_flow=1 + (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.36), + (["def foo ( x ) :\n return x * x"], ["def bar ( x ) :\n return x"], 0.61), + (["def bar ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 0.85), (["def foo ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 1.0), ], ) def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None: result = calc_codebleu(references, predictions, "python") logging.debug(result) - assert result["codebleu"] == pytest.approx(codebleu, 0.1) + assert result["codebleu"] == pytest.approx(codebleu, 0.01) @pytest.mark.parametrize(["lang"], [(lang,) for lang in AVAILABLE_LANGS]) @@ -48,7 +44,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None: def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None: result = calc_codebleu(references, predictions, lang) logging.debug(result) - assert result["codebleu"] == pytest.approx(0.6, 0.1) + assert result["codebleu"] == pytest.approx(0.6, 0.05) def test_error_when_lang_not_supported() -> None: @@ -61,27 +57,58 @@ def test_error_when_input_length_mismatch() -> None: calc_codebleu(["def foo : pass"], ["def bar : pass", "def buz : pass"], "python") -# https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png @pytest.mark.parametrize( - ["predictions", "references", "codebleu"], + ["predictions", "references", "bleu", "syntax_match", "dataflow_match", "codebleu"], [ - # ( - # ['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'], - # ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'], - # 0.7238 # TODO: lol, not working at <3.12 - # ), - # ( - # ['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'], - # ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'], - # 0.8397 # TODO: check, lol, not working - # ), + # https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png + ( + ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7846, + 11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19 + 2/3, + 0.7019, # Should be 0.7238 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19 + ), + # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 + ( + ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7543, + 11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19 + 2/3, + 0.6873, # Should be 0.6973 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19 + ), + # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 + ( + ["public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7571, # Error in the Figure 4, text "Example 2" states 0.7571, not 0.6814, + 1.0, + 1.0, + 0.8804, # Error in the Figure 4, text "Example 2" states 0.8804, not 0.8397, + ), ], ) -def test_code_x_glue_readme_examples(predictions: List[Any], references: List[Any], codebleu: float) -> None: +def test_code_x_glue_readme_examples( + predictions: List[Any], + references: List[Any], + bleu: float, + syntax_match: float, + dataflow_match: float, + codebleu: float, +) -> None: result = calc_codebleu(references, predictions, "java") logging.debug(result) + + print(result) + + assert result["ngram_match_score"] == pytest.approx(bleu, 0.01) + assert result["syntax_match_score"] == pytest.approx(syntax_match, 0.01) + assert result["dataflow_match_score"] == pytest.approx(dataflow_match, 0.01) assert result["codebleu"] == pytest.approx(codebleu, 0.01) + # assert False + @pytest.mark.parametrize( ["predictions", "references", "codebleu"],