Skip to content

Commit

Permalink
fix: to work with 3.12 and add examples tests (#10)
Browse files Browse the repository at this point in the history
* test: recalculate exapmes tests

* style: apply black and add some comments

* ci: fast test on 3.12

* test: add AST and data to examples

* fix: add debug print in syntax match

* test: fix test output with new version of tree-sitter

* test: debug fail first test

* ci: tmp true fast tests

* refactor: bleu file to work with 3.12

* style: fix ruff
  • Loading branch information
k4black authored Nov 16, 2023
1 parent 12cbf9b commit 8a944ea
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 334 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
Expand Down
156 changes: 9 additions & 147 deletions codebleu/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,11 @@

"""BLEU score implementation."""
import math
import sys
import warnings
from collections import Counter
from fractions import Fraction as _Fraction
from typing import Any

from .utils import ngrams


# _normalize=False was removed in 3.12, add custom class for back-compatibility
class Fraction(_Fraction):
# We're immutable, so use __new__ not __init__
def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction":
if sys.version_info >= (3, 12):
return super(Fraction, cls).__new__(cls, numerator, denominator)
else:
return super(Fraction, cls).__new__(cls, numerator, denominator, _normalize=False)


def sentence_bleu(
references,
hypothesis,
Expand Down Expand Up @@ -163,9 +149,9 @@ def corpus_bleu(
# For each order of ngram, calculate the numerator and
# denominator for the corpus-level modified precision.
for i, _ in enumerate(weights, start=1):
p_i = modified_precision(references, hypothesis, i)
p_numerators[i] += p_i.numerator
p_denominators[i] += p_i.denominator
p_i_numerator, p_i_denominator = modified_precision(references, hypothesis, i)
p_numerators[i] += p_i_numerator
p_denominators[i] += p_i_denominator

# Calculate the hypothesis length and the closest reference length.
# Adds them to the corpus-level hypothesis and reference counts.
Expand All @@ -182,8 +168,8 @@ def corpus_bleu(
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
weights = (1 / hyp_lengths,) * hyp_lengths

# Collects the various precision values for the different ngram orders.
p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) for i, _ in enumerate(weights, start=1)]
# Collects the various recall values for the different ngram orders.
p_n = [(p_numerators[i], p_denominators[i]) for i, _ in enumerate(weights, start=1)]

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
Expand All @@ -199,7 +185,7 @@ def corpus_bleu(
# it tries to retain the Fraction object as much as the
# smoothing method allows.
p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths)
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n))
s = bp * math.exp(math.fsum(s))
return s

Expand Down Expand Up @@ -295,7 +281,8 @@ def modified_precision(references, hypothesis, n):
# Usually this happens when the ngram order is > len(reference).
denominator = max(1, sum(counts.values()))

return Fraction(numerator, denominator, _normalize=False)
# return Fraction(numerator, denominator, _normalize=False)
return numerator, denominator


def closest_ref_length(references, hyp_len):
Expand Down Expand Up @@ -444,133 +431,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5):
self.alpha = alpha
self.k = k

def method0(self, p_n, *args, **kwargs):
"""
No smoothing.
"""
p_n_new = []
for i, p_i in enumerate(p_n):
if p_i.numerator != 0:
p_n_new.append(p_i)
else:
_msg = str(
"\nThe hypothesis contains 0 counts of {}-gram overlaps.\n"
"Therefore the BLEU score evaluates to 0, independently of\n"
"how many N-gram overlaps of lower order it contains.\n"
"Consider using lower n-gram order or use "
"SmoothingFunction()"
).format(i + 1)
warnings.warn(_msg)
# When numerator==0 where denonminator==0 or !=0, the result
# for the precision score should be equal to 0 or undefined.
# Due to BLEU geometric mean computation in logarithm space,
# we we need to take the return sys.float_info.min such that
# math.log(sys.float_info.min) returns a 0 precision score.
p_n_new.append(sys.float_info.min)
return p_n_new

def method1(self, p_n, *args, **kwargs):
"""
Smoothing method 1: Add *epsilon* counts to precision with 0 counts.
"""
return [(p_i.numerator + self.epsilon) / p_i.denominator if p_i.numerator == 0 else p_i for p_i in p_n]

def method2(self, p_n, *args, **kwargs):
"""
Smoothing method 2: Add 1 to both numerator and denominator from
Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of
machine translation quality using longest common subsequence and
skip-bigram statistics. In ACL04.
"""
return [Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) for p_i in p_n]

def method3(self, p_n, *args, **kwargs):
"""
Smoothing method 3: NIST geometric sequence smoothing
The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each
precision score whose matching n-gram count is null.
k is 1 for the first 'n' value for which the n-gram match count is null/
For example, if the text contains:
- one 2-gram match
- and (consequently) two 1-gram matches
the n-gram count for each individual precision score would be:
- n=1 => prec_count = 2 (two unigrams)
- n=2 => prec_count = 1 (one bigram)
- n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1)
- n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2)
"""
incvnt = 1 # From the mteval-v13a.pl, it's referred to as k.
for i, p_i in enumerate(p_n):
if p_i.numerator == 0:
p_n[i] = 1 / (2**incvnt * p_i.denominator)
incvnt += 1
return p_n

def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 4:
Shorter translations may have inflated precision values due to having
smaller denominators; therefore, we give them proportionally
smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry
suggests dividing by 1/ln(len(T)), where T is the length of the translation.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
for i, p_i in enumerate(p_n):
if p_i.numerator == 0 and hyp_len != 0:
incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST.
p_n[i] = incvnt / p_i.denominator
return p_n

def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 5:
The matched counts for similar values of n should be similar. To a
calculate the n-gram matched count, it averages the n−1, n and n+1 gram
matched counts.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
m = {}
# Requires an precision value for an addition ngram order.
p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)]
m[-1] = p_n[0] + 1
for i, p_i in enumerate(p_n):
p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3
m[i] = p_n[i]
return p_n

def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 6:
Interpolates the maximum likelihood estimate of the precision *p_n* with
a prior estimate *pi0*. The prior is estimated by assuming that the ratio
between pn and pn−1 will be the same as that between pn−1 and pn−2; from
Gao and He (2013) Training MRF-Based Phrase Translation Models using
Gradient Ascent. In NAACL.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
# This smoothing only works when p_1 and p_2 is non-zero.
# Raise an error with an appropriate message when the input is too short
# to use this smoothing technique.
assert p_n[2], "This smoothing method requires non-zero precision for bigrams."
for i, p_i in enumerate(p_n):
if i in [0, 1]: # Skips the first 2 orders of ngrams.
continue
else:
pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2]
# No. of ngrams in translation that matches the reference.
m = p_i.numerator
# No. of ngrams in translation.
ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1))
# Calculates the interpolated precision.
p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha)
return p_n

def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
"""
Smoothing method 7:
Interpolates methods 4 and 5.
"""
hyp_len = hyp_len if hyp_len else len(hypothesis)
p_n = self.method4(p_n, references, hypothesis, hyp_len)
p_n = self.method5(p_n, references, hypothesis, hyp_len)
return p_n
return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n]
4 changes: 2 additions & 2 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file):
candidate = candidates[i]
for reference in references_sample:
try:
candidate = remove_comments_and_docstrings(candidate, "java")
candidate = remove_comments_and_docstrings(candidate, lang)
except Exception:
pass
try:
reference = remove_comments_and_docstrings(reference, "java")
reference = remove_comments_and_docstrings(reference, lang)
except Exception:
pass

Expand Down
30 changes: 15 additions & 15 deletions codebleu/parser/build.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Copyright (c) 2023 Konstantin Chernyshev.
# Licensed under the MIT license.

from tree_sitter import Language

Language.build_library(
"my-languages.so",
[
"tree-sitter/go",
"tree-sitter/javascript",
"tree-sitter/python",
"tree-sitter/php",
"tree-sitter/java",
"tree-sitter/ruby",
"tree-sitter/c-sharp",
"tree-sitter/c",
"tree-sitter/cpp",
],
)
if __name__ == "__main__":
Language.build_library(
"my-languages.so",
[
"tree-sitter/go",
"tree-sitter/javascript",
"tree-sitter/python",
"tree-sitter/php",
"tree-sitter/java",
"tree-sitter/ruby",
"tree-sitter/c-sharp",
"tree-sitter/c",
"tree-sitter/cpp",
],
)
28 changes: 17 additions & 11 deletions codebleu/syntax_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ def calc_syntax_match(references, candidate, lang, lang_so_file):


def corpus_syntax_match(references, candidates, lang, lang_so_file):
# print(os.listdir())
JAVA_LANGUAGE = Language(lang_so_file, lang)
tree_sitter_language = Language(lang_so_file, lang)
parser = Parser()
parser.set_language(JAVA_LANGUAGE)
parser.set_language(tree_sitter_language)
match_count = 0
match_count_candidate_to_reference = 0
total_count = 0

for i in range(len(candidates)):
references_sample = references[i]
candidate = candidates[i]
for reference in references_sample:
try:
candidate = remove_comments_and_docstrings(candidate, "java")
candidate = remove_comments_and_docstrings(candidate, lang)
except Exception:
pass
try:
reference = remove_comments_and_docstrings(reference, "java")
reference = remove_comments_and_docstrings(reference, lang)
except Exception:
pass

Expand All @@ -69,15 +69,21 @@ def get_all_sub_trees(root_node):
return sub_tree_sexp_list

cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
ref_sexps = get_all_sub_trees(reference_tree)
ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)]

# print(cand_sexps)
# print(ref_sexps)

for sub_tree, depth in ref_sexps:
# TODO: fix, now we count number of reference subtrees matching candidate,
# but we should count number of candidate subtrees matching reference
# See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf
for sub_tree in ref_sexps:
if sub_tree in cand_sexps:
match_count += 1
total_count += len(ref_sexps)

for sub_tree in cand_sexps:
if sub_tree in ref_sexps:
match_count_candidate_to_reference += 1

total_count += len(ref_sexps)
# print(f'match_count {match_count} / {total_count}')
# print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}')
score = match_count / total_count
return score
Loading

0 comments on commit 8a944ea

Please sign in to comment.