From 17411d4b94f18e0abd3e1b8adaeac3483c49c3bd Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 11:25:37 +0100 Subject: [PATCH] style: apply black and add some comments --- codebleu/bleu.py | 5 ++++- codebleu/dataflow_match.py | 4 ++-- codebleu/parser/build.py | 30 +++++++++++++++--------------- codebleu/syntax_match.py | 25 +++++++++++++++---------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/codebleu/bleu.py b/codebleu/bleu.py index be0738c..f528746 100644 --- a/codebleu/bleu.py +++ b/codebleu/bleu.py @@ -18,8 +18,11 @@ from .utils import ngrams -# _normalize=False was removed in 3.12, add custom class for back-compatibility class Fraction(_Fraction): + """Fraction class with _normalize=False support. + _normalize=False was removed in 3.12, add custom class for back-compatibility + """ + # We're immutable, so use __new__ not __init__ def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction": if sys.version_info >= (3, 12): diff --git a/codebleu/dataflow_match.py b/codebleu/dataflow_match.py index 30e3871..a4dd6d4 100644 --- a/codebleu/dataflow_match.py +++ b/codebleu/dataflow_match.py @@ -47,11 +47,11 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass diff --git a/codebleu/parser/build.py b/codebleu/parser/build.py index c1f1a3a..3e409e2 100644 --- a/codebleu/parser/build.py +++ b/codebleu/parser/build.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. # Copyright (c) 2023 Konstantin Chernyshev. # Licensed under the MIT license. - from tree_sitter import Language -Language.build_library( - "my-languages.so", - [ - "tree-sitter/go", - "tree-sitter/javascript", - "tree-sitter/python", - "tree-sitter/php", - "tree-sitter/java", - "tree-sitter/ruby", - "tree-sitter/c-sharp", - "tree-sitter/c", - "tree-sitter/cpp", - ], -) +if __name__ == "__main__": + Language.build_library( + "my-languages.so", + [ + "tree-sitter/go", + "tree-sitter/javascript", + "tree-sitter/python", + "tree-sitter/php", + "tree-sitter/java", + "tree-sitter/ruby", + "tree-sitter/c-sharp", + "tree-sitter/c", + "tree-sitter/cpp", + ], + ) diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 5f14e76..92dd6db 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -30,11 +30,11 @@ def calc_syntax_match(references, candidate, lang, lang_so_file): def corpus_syntax_match(references, candidates, lang, lang_so_file): - # print(os.listdir()) - JAVA_LANGUAGE = Language(lang_so_file, lang) + tree_sitter_language = Language(lang_so_file, lang) parser = Parser() - parser.set_language(JAVA_LANGUAGE) + parser.set_language(tree_sitter_language) match_count = 0 + match_count_candidate_to_reference = 0 total_count = 0 for i in range(len(candidates)): @@ -42,11 +42,11 @@ def corpus_syntax_match(references, candidates, lang, lang_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass @@ -69,14 +69,19 @@ def get_all_sub_trees(root_node): return sub_tree_sexp_list cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] - ref_sexps = get_all_sub_trees(reference_tree) + ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)] - # print(cand_sexps) - # print(ref_sexps) - - for sub_tree, depth in ref_sexps: + # TODO: fix, now we count number of reference subtrees matching candidate, + # but we should count number of candidate subtrees matching reference + # See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf + for sub_tree in ref_sexps: if sub_tree in cand_sexps: match_count += 1 + + for sub_tree in cand_sexps: + if sub_tree in ref_sexps: + match_count_candidate_to_reference += 1 + total_count += len(ref_sexps) score = match_count / total_count