Skip to content

Commit

Permalink
test: fix test output with new version of tree-sitter
Browse files Browse the repository at this point in the history
  • Loading branch information
k4black committed Nov 16, 2023
1 parent 8d17961 commit eceb842
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
11 changes: 2 additions & 9 deletions codebleu/syntax_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ def get_all_sub_trees(root_node):
cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)]

print('cand_sexps')
for tree, depth in get_all_sub_trees(candidate_tree):
print(' ', depth, tree)
print('ref_sexps')
for tree, depth in get_all_sub_trees(reference_tree):
print(' ', depth, tree)

# TODO: fix, now we count number of reference subtrees matching candidate,
# but we should count number of candidate subtrees matching reference
# See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf
Expand All @@ -90,7 +83,7 @@ def get_all_sub_trees(root_node):
match_count_candidate_to_reference += 1

total_count += len(ref_sexps)
print(f'match_count {match_count} / {total_count}')
print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}')
# print(f'match_count {match_count} / {total_count}')
# print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}')
score = match_count / total_count
return score
12 changes: 8 additions & 4 deletions tests/test_codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,18 @@ def test_error_when_input_length_mismatch() -> None:
["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"],
["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"],
0.7846,
14/21,
11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19
2/3,
0.7238, # TODO: lol, not working at <3.12
0.7019, # Should be 0.7238 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19
),
# https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4
(
["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"],
["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"],
0.7543,
14/21,
11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19
2/3,
0.7091, # Should be 0.6973 if AST=13/21, however at the moment tee-sitter AST is 14/21
0.6873, # Should be 0.6973 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19
),
# https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4
(
Expand All @@ -104,11 +104,15 @@ def test_code_x_glue_readme_examples(
result = calc_codebleu(references, predictions, "java")
logging.debug(result)

print(result)

assert result["ngram_match_score"] == pytest.approx(bleu, 0.01)
assert result["syntax_match_score"] == pytest.approx(syntax_match, 0.01)
assert result["dataflow_match_score"] == pytest.approx(dataflow_match, 0.01)
assert result["codebleu"] == pytest.approx(codebleu, 0.01)

# assert False


@pytest.mark.parametrize(
["predictions", "references", "codebleu"],
Expand Down

0 comments on commit eceb842

Please sign in to comment.