Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Rust parser #31

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codebleu/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(
"--lang",
type=str,
required=True,
choices=["java", "js", "c_sharp", "php", "go", "python", "ruby"],
choices=["java", "js", "c_sharp", "php", "go", "python", "ruby", "rust"],
)
parser.add_argument("--params", type=str, default="0.25,0.25,0.25,0.25", help="alpha, beta and gamma")

Expand Down
13 changes: 12 additions & 1 deletion codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
from . import bleu, dataflow_match, syntax_match, weighted_ngram_match

PACKAGE_DIR = Path(__file__).parent
AVAILABLE_LANGS = ["java", "javascript", "c_sharp", "php", "c", "cpp", "python", "go", "ruby"] # keywords available
AVAILABLE_LANGS = [
"java",
"javascript",
"c_sharp",
"php",
"c",
"cpp",
"python",
"go",
"ruby",
"rust",
] # keywords available


def calc_codebleu(
Expand Down
2 changes: 2 additions & 0 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DFG_php,
DFG_python,
DFG_ruby,
DFG_rust,
index_to_code_token,
remove_comments_and_docstrings,
tree_to_token_index,
Expand All @@ -27,6 +28,7 @@
"c_sharp": DFG_csharp,
"c": DFG_csharp, # XLCoST uses C# parser for C
"cpp": DFG_csharp, # XLCoST uses C# parser for C++
"rust": DFG_rust,
}


Expand Down
71 changes: 71 additions & 0 deletions codebleu/keywords/rust.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
as
async
await
block
bool
break
char
const
continue
crate
default
dyn
else
enum
expr
extern
f32
f64
false
fn
for
i128
i16
i32
i64
i8
ident
if
impl
in
isize
item
let
lifetime
literal
loop
macro_rules!
match
meta
mod
move
mut
pat
path
pub
ref
return
self
static
stmt
str
struct
super
trait
true
tt
ty
type
u128
u16
u32
u64
u8
union
unsafe
use
usize
vis
where
while
yield
172 changes: 172 additions & 0 deletions codebleu/parser/DFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,3 +1213,175 @@ def DFG_javascript(root_node, index_to_code, states):
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states


def DFG_rust(root_node, index_to_code, states):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check pipeline failed, you need to apply black formatter (python -m black .)
Also pipeline if checking isort/ruff/mypy

assignment = ["assignment_expression", "compound_assignment_expr", "let_expression"]
def_statement = ["function_item"]
if_statement = ["if_expression", "if_let_expression", "match_expression", "else"]
for_statement = ["for_expression"]
enhanced_for_statement = ["for_each_statement"]
while_statement = ["while_expression", "while_let_expression", "loop_expression"]
do_first_statement = []
states = states.copy()
if (
len(root_node.children) == 0 or root_node.type in ["string_literal", "string", "character_literal"]
) and root_node.type != "comment":
idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
if root_node.type == code:
return [], states
elif code in states:
return [(code, idx, "comesFrom", [code], states[code].copy())], states
else:
if root_node.type == "identifier":
states[code] = [idx]
return [(code, idx, "comesFrom", [], [])], states
elif root_node.type in def_statement:
if len(root_node.children) >= 3:
name = root_node.children[1]
value = root_node.children[2]
else:
name = root_node.children[1]
value = None
DFG = []
if value is None:
indexs = tree_to_variable_index(name, index_to_code)
for index in indexs:
idx, code = index_to_code[index]
DFG.append((code, idx, "comesFrom", [], []))
states[code] = [idx]
return sorted(DFG, key=lambda x: x[1]), states
else:
name_indexs = tree_to_variable_index(name, index_to_code)
value_indexs = tree_to_variable_index(value, index_to_code)
temp, states = DFG_rust(value, index_to_code, states)
DFG += temp
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "comesFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in assignment:
left_nodes = root_node.child_by_field_name("left")
right_nodes = root_node.child_by_field_name("right")
DFG = []
temp, states = DFG_rust(right_nodes, index_to_code, states)
DFG += temp
name_indexs = tree_to_variable_index(left_nodes, index_to_code)
value_indexs = tree_to_variable_index(right_nodes, index_to_code)
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in if_statement:
DFG = []
current_states = states.copy()
others_states = []
flag = False
tag = False
if "else" in root_node.type:
tag = True
for child in root_node.children:
if "else" in child.type:
tag = True
if child.type not in if_statement and flag is False:
temp, current_states = DFG_rust(child, index_to_code, current_states)
DFG += temp
else:
flag = True
temp, new_states = DFG_rust(child, index_to_code, states)
DFG += temp
others_states.append(new_states)
others_states.append(current_states)
if tag is False:
others_states.append(states)
new_states = {}
for dic in others_states:
for key in dic:
if key not in new_states:
new_states[key] = dic[key].copy()
else:
new_states[key] += dic[key]
for key in new_states:
new_states[key] = sorted(list(set(new_states[key])))
return sorted(DFG, key=lambda x: x[1]), new_states
elif root_node.type in for_statement:
DFG = []
for child in root_node.children:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
flag = False
for child in root_node.children:
if flag:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
elif child.type == "local_variable_declaration":
flag = True
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in enhanced_for_statement:
name = root_node.child_by_field_name("left")
value = root_node.child_by_field_name("right")
body = root_node.child_by_field_name("body")
DFG = []
for i in range(2):
temp, states = DFG_rust(value, index_to_code, states)
DFG += temp
name_indexs = tree_to_variable_index(name, index_to_code)
value_indexs = tree_to_variable_index(value, index_to_code)
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
temp, states = DFG_rust(body, index_to_code, states)
DFG += temp
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in while_statement:
DFG = []
for i in range(2):
for child in root_node.children:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
else:
DFG = []
for child in root_node.children:
if child.type in do_first_statement:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
for child in root_node.children:
if child.type not in do_first_statement:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states
2 changes: 2 additions & 0 deletions codebleu/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DFG_php,
DFG_python,
DFG_ruby,
DFG_rust,
)
from .utils import (
index_to_code_token,
Expand All @@ -25,6 +26,7 @@
"DFG_php",
"DFG_python",
"DFG_ruby",
"DFG_rust",
"index_to_code_token",
"remove_comments_and_docstrings",
"tree_to_token_index",
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"c-sharp": "https://github.com/tree-sitter/tree-sitter-c-sharp/archive/refs/tags/v0.20.0.zip",
"c": "https://github.com/tree-sitter/tree-sitter-c/archive/refs/tags/v0.20.7.zip",
"cpp": "https://github.com/tree-sitter/tree-sitter-cpp/archive/refs/tags/v0.20.3.zip",
"rust": "https://github.com/tree-sitter/tree-sitter-rust/archive/refs/tags/v0.20.1.zip",
}


Expand Down
25 changes: 25 additions & 0 deletions use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from codebleu import calc_codebleu

#prediction = "def add ( a , b ) :\n return a + b"
#reference = "def sum ( first , second ) :\n return second + first"
#result = calc_codebleu([reference], [prediction], lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
#print(result)

# {
# 'codebleu': 0.5537,
# 'ngram_match_score': 0.1041,
# 'weighted_ngram_match_score': 0.1109,
# 'syntax_match_score': 1.0,
# 'dataflow_match_score': 1.0
# }

# prediction = "void add (int a ,int b ) {\n return a + b;}"
# reference = "void sum ( int first , int second ) {\n return second + first;}"
# result = calc_codebleu([reference], [prediction], lang="c", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
# print(result)

prediction = "fn add ( a , b )->i8 {\n a + b}"
reference = "fn sum ( first , second )->i8 {\n second + first}"
result = calc_codebleu([reference], [prediction], lang="rust", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
print(result)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please delete this file?

Loading