diff --git a/codebleu/__main__.py b/codebleu/__main__.py index 92af515..ee5859e 100644 --- a/codebleu/__main__.py +++ b/codebleu/__main__.py @@ -57,7 +57,7 @@ def main( "--lang", type=str, required=True, - choices=["java", "js", "c_sharp", "php", "go", "python", "ruby"], + choices=["java", "js", "c_sharp", "php", "go", "python", "ruby", "rust"], ) parser.add_argument("--params", type=str, default="0.25,0.25,0.25,0.25", help="alpha, beta and gamma") diff --git a/codebleu/codebleu.py b/codebleu/codebleu.py index f1beb86..65d8959 100644 --- a/codebleu/codebleu.py +++ b/codebleu/codebleu.py @@ -7,7 +7,18 @@ from . import bleu, dataflow_match, syntax_match, weighted_ngram_match PACKAGE_DIR = Path(__file__).parent -AVAILABLE_LANGS = ["java", "javascript", "c_sharp", "php", "c", "cpp", "python", "go", "ruby"] # keywords available +AVAILABLE_LANGS = [ + "java", + "javascript", + "c_sharp", + "php", + "c", + "cpp", + "python", + "go", + "ruby", + "rust", +] # keywords available def calc_codebleu( diff --git a/codebleu/dataflow_match.py b/codebleu/dataflow_match.py index a4dd6d4..f110a91 100644 --- a/codebleu/dataflow_match.py +++ b/codebleu/dataflow_match.py @@ -12,6 +12,7 @@ DFG_php, DFG_python, DFG_ruby, + DFG_rust, index_to_code_token, remove_comments_and_docstrings, tree_to_token_index, @@ -27,6 +28,7 @@ "c_sharp": DFG_csharp, "c": DFG_csharp, # XLCoST uses C# parser for C "cpp": DFG_csharp, # XLCoST uses C# parser for C++ + "rust": DFG_rust, } diff --git a/codebleu/keywords/rust.txt b/codebleu/keywords/rust.txt new file mode 100644 index 0000000..49eeaee --- /dev/null +++ b/codebleu/keywords/rust.txt @@ -0,0 +1,71 @@ +as +async +await +block +bool +break +char +const +continue +crate +default +dyn +else +enum +expr +extern +f32 +f64 +false +fn +for +i128 +i16 +i32 +i64 +i8 +ident +if +impl +in +isize +item +let +lifetime +literal +loop +macro_rules! +match +meta +mod +move +mut +pat +path +pub +ref +return +self +static +stmt +str +struct +super +trait +true +tt +ty +type +u128 +u16 +u32 +u64 +u8 +union +unsafe +use +usize +vis +where +while +yield diff --git a/codebleu/parser/DFG.py b/codebleu/parser/DFG.py index 146eb63..74e4b0b 100644 --- a/codebleu/parser/DFG.py +++ b/codebleu/parser/DFG.py @@ -1213,3 +1213,175 @@ def DFG_javascript(root_node, index_to_code, states): DFG += temp return sorted(DFG, key=lambda x: x[1]), states + + +def DFG_rust(root_node, index_to_code, states): + assignment = ["assignment_expression", "compound_assignment_expr", "let_expression"] + def_statement = ["function_item"] + if_statement = ["if_expression", "if_let_expression", "match_expression", "else"] + for_statement = ["for_expression"] + enhanced_for_statement = ["for_each_statement"] + while_statement = ["while_expression", "while_let_expression", "loop_expression"] + do_first_statement = [] + states = states.copy() + if ( + len(root_node.children) == 0 or root_node.type in ["string_literal", "string", "character_literal"] + ) and root_node.type != "comment": + idx, code = index_to_code[(root_node.start_point, root_node.end_point)] + if root_node.type == code: + return [], states + elif code in states: + return [(code, idx, "comesFrom", [code], states[code].copy())], states + else: + if root_node.type == "identifier": + states[code] = [idx] + return [(code, idx, "comesFrom", [], [])], states + elif root_node.type in def_statement: + if len(root_node.children) >= 3: + name = root_node.children[1] + value = root_node.children[2] + else: + name = root_node.children[1] + value = None + DFG = [] + if value is None: + indexs = tree_to_variable_index(name, index_to_code) + for index in indexs: + idx, code = index_to_code[index] + DFG.append((code, idx, "comesFrom", [], [])) + states[code] = [idx] + return sorted(DFG, key=lambda x: x[1]), states + else: + name_indexs = tree_to_variable_index(name, index_to_code) + value_indexs = tree_to_variable_index(value, index_to_code) + temp, states = DFG_rust(value, index_to_code, states) + DFG += temp + for index1 in name_indexs: + idx1, code1 = index_to_code[index1] + for index2 in value_indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "comesFrom", [code2], [idx2])) + states[code1] = [idx1] + return sorted(DFG, key=lambda x: x[1]), states + elif root_node.type in assignment: + left_nodes = root_node.child_by_field_name("left") + right_nodes = root_node.child_by_field_name("right") + DFG = [] + temp, states = DFG_rust(right_nodes, index_to_code, states) + DFG += temp + name_indexs = tree_to_variable_index(left_nodes, index_to_code) + value_indexs = tree_to_variable_index(right_nodes, index_to_code) + for index1 in name_indexs: + idx1, code1 = index_to_code[index1] + for index2 in value_indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "computedFrom", [code2], [idx2])) + states[code1] = [idx1] + return sorted(DFG, key=lambda x: x[1]), states + elif root_node.type in if_statement: + DFG = [] + current_states = states.copy() + others_states = [] + flag = False + tag = False + if "else" in root_node.type: + tag = True + for child in root_node.children: + if "else" in child.type: + tag = True + if child.type not in if_statement and flag is False: + temp, current_states = DFG_rust(child, index_to_code, current_states) + DFG += temp + else: + flag = True + temp, new_states = DFG_rust(child, index_to_code, states) + DFG += temp + others_states.append(new_states) + others_states.append(current_states) + if tag is False: + others_states.append(states) + new_states = {} + for dic in others_states: + for key in dic: + if key not in new_states: + new_states[key] = dic[key].copy() + else: + new_states[key] += dic[key] + for key in new_states: + new_states[key] = sorted(list(set(new_states[key]))) + return sorted(DFG, key=lambda x: x[1]), new_states + elif root_node.type in for_statement: + DFG = [] + for child in root_node.children: + temp, states = DFG_rust(child, index_to_code, states) + DFG += temp + flag = False + for child in root_node.children: + if flag: + temp, states = DFG_rust(child, index_to_code, states) + DFG += temp + elif child.type == "local_variable_declaration": + flag = True + dic = {} + for x in DFG: + if (x[0], x[1], x[2]) not in dic: + dic[(x[0], x[1], x[2])] = [x[3], x[4]] + else: + dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) + dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) + DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] + return sorted(DFG, key=lambda x: x[1]), states + elif root_node.type in enhanced_for_statement: + name = root_node.child_by_field_name("left") + value = root_node.child_by_field_name("right") + body = root_node.child_by_field_name("body") + DFG = [] + for i in range(2): + temp, states = DFG_rust(value, index_to_code, states) + DFG += temp + name_indexs = tree_to_variable_index(name, index_to_code) + value_indexs = tree_to_variable_index(value, index_to_code) + for index1 in name_indexs: + idx1, code1 = index_to_code[index1] + for index2 in value_indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "computedFrom", [code2], [idx2])) + states[code1] = [idx1] + temp, states = DFG_rust(body, index_to_code, states) + DFG += temp + dic = {} + for x in DFG: + if (x[0], x[1], x[2]) not in dic: + dic[(x[0], x[1], x[2])] = [x[3], x[4]] + else: + dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) + dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) + DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] + return sorted(DFG, key=lambda x: x[1]), states + elif root_node.type in while_statement: + DFG = [] + for i in range(2): + for child in root_node.children: + temp, states = DFG_rust(child, index_to_code, states) + DFG += temp + dic = {} + for x in DFG: + if (x[0], x[1], x[2]) not in dic: + dic[(x[0], x[1], x[2])] = [x[3], x[4]] + else: + dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) + dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) + DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] + return sorted(DFG, key=lambda x: x[1]), states + else: + DFG = [] + for child in root_node.children: + if child.type in do_first_statement: + temp, states = DFG_rust(child, index_to_code, states) + DFG += temp + for child in root_node.children: + if child.type not in do_first_statement: + temp, states = DFG_rust(child, index_to_code, states) + DFG += temp + + return sorted(DFG, key=lambda x: x[1]), states diff --git a/codebleu/parser/__init__.py b/codebleu/parser/__init__.py index 7d41751..2b4f1a5 100644 --- a/codebleu/parser/__init__.py +++ b/codebleu/parser/__init__.py @@ -9,6 +9,7 @@ DFG_php, DFG_python, DFG_ruby, + DFG_rust, ) from .utils import ( index_to_code_token, @@ -25,6 +26,7 @@ "DFG_php", "DFG_python", "DFG_ruby", + "DFG_rust", "index_to_code_token", "remove_comments_and_docstrings", "tree_to_token_index", diff --git a/setup.py b/setup.py index 9613e44..e4d6aac 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ "c-sharp": "https://github.com/tree-sitter/tree-sitter-c-sharp/archive/refs/tags/v0.20.0.zip", "c": "https://github.com/tree-sitter/tree-sitter-c/archive/refs/tags/v0.20.7.zip", "cpp": "https://github.com/tree-sitter/tree-sitter-cpp/archive/refs/tags/v0.20.3.zip", + "rust": "https://github.com/tree-sitter/tree-sitter-rust/archive/refs/tags/v0.20.1.zip", } diff --git a/use.py b/use.py new file mode 100644 index 0000000..42f1552 --- /dev/null +++ b/use.py @@ -0,0 +1,25 @@ +from codebleu import calc_codebleu + +#prediction = "def add ( a , b ) :\n return a + b" +#reference = "def sum ( first , second ) :\n return second + first" +#result = calc_codebleu([reference], [prediction], lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +#print(result) + +# { +# 'codebleu': 0.5537, +# 'ngram_match_score': 0.1041, +# 'weighted_ngram_match_score': 0.1109, +# 'syntax_match_score': 1.0, +# 'dataflow_match_score': 1.0 +# } + +# prediction = "void add (int a ,int b ) {\n return a + b;}" +# reference = "void sum ( int first , int second ) {\n return second + first;}" +# result = calc_codebleu([reference], [prediction], lang="c", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +# print(result) + +prediction = "fn add ( a , b )->i8 {\n a + b}" +reference = "fn sum ( first , second )->i8 {\n second + first}" +result = calc_codebleu([reference], [prediction], lang="rust", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +print(result) +