Skip to content

Commit

Permalink
feat: add Rust parser (#31)
Browse files Browse the repository at this point in the history
* adding Rust parser

* fix formatting issues
  • Loading branch information
yijunyu authored Mar 1, 2024
1 parent 8fe0915 commit a30cb9f
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 2 deletions.
2 changes: 1 addition & 1 deletion codebleu/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(
"--lang",
type=str,
required=True,
choices=["java", "js", "c_sharp", "php", "go", "python", "ruby"],
choices=["java", "js", "c_sharp", "php", "go", "python", "ruby", "rust"],
)
parser.add_argument("--params", type=str, default="0.25,0.25,0.25,0.25", help="alpha, beta and gamma")

Expand Down
13 changes: 12 additions & 1 deletion codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
from . import bleu, dataflow_match, syntax_match, weighted_ngram_match

PACKAGE_DIR = Path(__file__).parent
AVAILABLE_LANGS = ["java", "javascript", "c_sharp", "php", "c", "cpp", "python", "go", "ruby"] # keywords available
AVAILABLE_LANGS = [
"java",
"javascript",
"c_sharp",
"php",
"c",
"cpp",
"python",
"go",
"ruby",
"rust",
] # keywords available


def calc_codebleu(
Expand Down
2 changes: 2 additions & 0 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DFG_php,
DFG_python,
DFG_ruby,
DFG_rust,
index_to_code_token,
remove_comments_and_docstrings,
tree_to_token_index,
Expand All @@ -27,6 +28,7 @@
"c_sharp": DFG_csharp,
"c": DFG_csharp, # XLCoST uses C# parser for C
"cpp": DFG_csharp, # XLCoST uses C# parser for C++
"rust": DFG_rust,
}


Expand Down
71 changes: 71 additions & 0 deletions codebleu/keywords/rust.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
as
async
await
block
bool
break
char
const
continue
crate
default
dyn
else
enum
expr
extern
f32
f64
false
fn
for
i128
i16
i32
i64
i8
ident
if
impl
in
isize
item
let
lifetime
literal
loop
macro_rules!
match
meta
mod
move
mut
pat
path
pub
ref
return
self
static
stmt
str
struct
super
trait
true
tt
ty
type
u128
u16
u32
u64
u8
union
unsafe
use
usize
vis
where
while
yield
172 changes: 172 additions & 0 deletions codebleu/parser/DFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,3 +1213,175 @@ def DFG_javascript(root_node, index_to_code, states):
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states


def DFG_rust(root_node, index_to_code, states):
assignment = ["assignment_expression", "compound_assignment_expr", "let_expression"]
def_statement = ["function_item"]
if_statement = ["if_expression", "if_let_expression", "match_expression", "else"]
for_statement = ["for_expression"]
enhanced_for_statement = ["for_each_statement"]
while_statement = ["while_expression", "while_let_expression", "loop_expression"]
do_first_statement = []
states = states.copy()
if (
len(root_node.children) == 0 or root_node.type in ["string_literal", "string", "character_literal"]
) and root_node.type != "comment":
idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
if root_node.type == code:
return [], states
elif code in states:
return [(code, idx, "comesFrom", [code], states[code].copy())], states
else:
if root_node.type == "identifier":
states[code] = [idx]
return [(code, idx, "comesFrom", [], [])], states
elif root_node.type in def_statement:
if len(root_node.children) >= 3:
name = root_node.children[1]
value = root_node.children[2]
else:
name = root_node.children[1]
value = None
DFG = []
if value is None:
indexs = tree_to_variable_index(name, index_to_code)
for index in indexs:
idx, code = index_to_code[index]
DFG.append((code, idx, "comesFrom", [], []))
states[code] = [idx]
return sorted(DFG, key=lambda x: x[1]), states
else:
name_indexs = tree_to_variable_index(name, index_to_code)
value_indexs = tree_to_variable_index(value, index_to_code)
temp, states = DFG_rust(value, index_to_code, states)
DFG += temp
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "comesFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in assignment:
left_nodes = root_node.child_by_field_name("left")
right_nodes = root_node.child_by_field_name("right")
DFG = []
temp, states = DFG_rust(right_nodes, index_to_code, states)
DFG += temp
name_indexs = tree_to_variable_index(left_nodes, index_to_code)
value_indexs = tree_to_variable_index(right_nodes, index_to_code)
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in if_statement:
DFG = []
current_states = states.copy()
others_states = []
flag = False
tag = False
if "else" in root_node.type:
tag = True
for child in root_node.children:
if "else" in child.type:
tag = True
if child.type not in if_statement and flag is False:
temp, current_states = DFG_rust(child, index_to_code, current_states)
DFG += temp
else:
flag = True
temp, new_states = DFG_rust(child, index_to_code, states)
DFG += temp
others_states.append(new_states)
others_states.append(current_states)
if tag is False:
others_states.append(states)
new_states = {}
for dic in others_states:
for key in dic:
if key not in new_states:
new_states[key] = dic[key].copy()
else:
new_states[key] += dic[key]
for key in new_states:
new_states[key] = sorted(list(set(new_states[key])))
return sorted(DFG, key=lambda x: x[1]), new_states
elif root_node.type in for_statement:
DFG = []
for child in root_node.children:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
flag = False
for child in root_node.children:
if flag:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
elif child.type == "local_variable_declaration":
flag = True
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in enhanced_for_statement:
name = root_node.child_by_field_name("left")
value = root_node.child_by_field_name("right")
body = root_node.child_by_field_name("body")
DFG = []
for i in range(2):
temp, states = DFG_rust(value, index_to_code, states)
DFG += temp
name_indexs = tree_to_variable_index(name, index_to_code)
value_indexs = tree_to_variable_index(value, index_to_code)
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
temp, states = DFG_rust(body, index_to_code, states)
DFG += temp
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
elif root_node.type in while_statement:
DFG = []
for i in range(2):
for child in root_node.children:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states
else:
DFG = []
for child in root_node.children:
if child.type in do_first_statement:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp
for child in root_node.children:
if child.type not in do_first_statement:
temp, states = DFG_rust(child, index_to_code, states)
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states
2 changes: 2 additions & 0 deletions codebleu/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DFG_php,
DFG_python,
DFG_ruby,
DFG_rust,
)
from .utils import (
index_to_code_token,
Expand All @@ -25,6 +26,7 @@
"DFG_php",
"DFG_python",
"DFG_ruby",
"DFG_rust",
"index_to_code_token",
"remove_comments_and_docstrings",
"tree_to_token_index",
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"c-sharp": "https://github.com/tree-sitter/tree-sitter-c-sharp/archive/refs/tags/v0.20.0.zip",
"c": "https://github.com/tree-sitter/tree-sitter-c/archive/refs/tags/v0.20.7.zip",
"cpp": "https://github.com/tree-sitter/tree-sitter-cpp/archive/refs/tags/v0.20.3.zip",
"rust": "https://github.com/tree-sitter/tree-sitter-rust/archive/refs/tags/v0.20.1.zip",
}


Expand Down
25 changes: 25 additions & 0 deletions use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from codebleu import calc_codebleu

#prediction = "def add ( a , b ) :\n return a + b"
#reference = "def sum ( first , second ) :\n return second + first"
#result = calc_codebleu([reference], [prediction], lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
#print(result)

# {
# 'codebleu': 0.5537,
# 'ngram_match_score': 0.1041,
# 'weighted_ngram_match_score': 0.1109,
# 'syntax_match_score': 1.0,
# 'dataflow_match_score': 1.0
# }

# prediction = "void add (int a ,int b ) {\n return a + b;}"
# reference = "void sum ( int first , int second ) {\n return second + first;}"
# result = calc_codebleu([reference], [prediction], lang="c", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
# print(result)

prediction = "fn add ( a , b )->i8 {\n a + b}"
reference = "fn sum ( first , second )->i8 {\n second + first}"
result = calc_codebleu([reference], [prediction], lang="rust", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
print(result)

0 comments on commit a30cb9f

Please sign in to comment.