Skip to content

Commit

Permalink
fix: Clean up code (#1)
Browse files Browse the repository at this point in the history
* Clean up code

* Fix tests + logging

* update
  • Loading branch information
maximus12793 authored Oct 3, 2023
1 parent dda27ee commit d79e2aa
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 28 deletions.
15 changes: 2 additions & 13 deletions codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import bleu, dataflow_match, syntax_match, weighted_ngram_match

PACKAGE_DIR = Path(__file__).parent
# AVAILABLE_LANGS = ['java', 'javascript', 'c_sharp', 'php', 'go', 'python', 'ruby']
AVAILABLE_LANGS = ["java", "javascript", "c_sharp", "php", "c", "cpp", "python"] # keywords available


Expand Down Expand Up @@ -56,7 +55,8 @@ def tokenizer(s):
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)

# calculate weighted ngram match
keywords = [x.strip() for x in open(keywords_dir / (lang + ".txt"), "r", encoding="utf-8").readlines()]
with open(keywords_dir / (lang + ".txt"), "r", encoding="utf-8") as f:
keywords = [x.strip() for x in f.readlines()]

def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens}
Expand All @@ -74,15 +74,6 @@ def make_weights(reference_tokens, key_word_list):
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, lang_so_file)

# print(
# "ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}".format(
# ngram_match_score,
# weighted_ngram_match_score,
# syntax_match_score,
# dataflow_match_score,
# )
# )

alpha, beta, gamma, theta = weights
code_bleu_score = (
alpha * ngram_match_score
Expand All @@ -91,8 +82,6 @@ def make_weights(reference_tokens, key_word_list):
+ theta * (dataflow_match_score or 1)
)

# print("CodeBLEU score: ", code_bleu_score)

return {
"codebleu": code_bleu_score,
"ngram_match_score": ngram_match_score,
Expand Down
3 changes: 2 additions & 1 deletion codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging

from tree_sitter import Language, Parser

Expand Down Expand Up @@ -67,7 +68,7 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file):
match_count += 1
normalized_cand_dfg.remove(dataflow)
if total_count == 0:
print(
logging.warning(
"WARNING: There is no reference data-flows extracted from the whole corpus, "
"and the data-flow match score degenerates to 0. Please consider ignoring this score."
)
Expand Down
2 changes: 0 additions & 2 deletions codebleu/weighted_ngram_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def corpus_bleu(
# it tries to retain the Fraction object as much as the
# smoothing method allows.
p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths)
# pdb.set_trace()
s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n))
s = bp * math.exp(math.fsum(s))
return s
Expand All @@ -212,7 +211,6 @@ def modified_recall(references, hypothesis, n):
"""
# Extracts all ngrams in hypothesis
# Set an empty Counter if hypothesis is empty.
# pdb.set_trace()
numerator = 0
denominator = 0

Expand Down
2 changes: 1 addition & 1 deletion evaluate_app/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _info(self):
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
# workarounds as this file have to be named codebleu (evaluate library requirement)
self.codebleu_package = importlib.import_module('codebleu')
self.codebleu_package = importlib.import_module("codebleu")
pass

def _compute(self, predictions, references, lang, weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None):
Expand Down
24 changes: 13 additions & 11 deletions tests/test_codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
from typing import Any, List

import pytest
import logging

from codebleu.codebleu import AVAILABLE_LANGS, calc_codebleu


@pytest.mark.parametrize(['predictions', 'references', 'codebleu'], [
(['some rannnndom words in length more than 3'], ['def test ( ) :\n pass'], 0.25), # 'cause data_flow is 0 and considered as 1
(['some rannnndom words in length more than 3'],
['def test ( ) :\n pass'], 0.25), # 'cause data_flow is 0 and considered as 1
(['def bar ( y , x ) :\n a = x * x\n return a'], ['def foo ( x ) :\n return x'], 0.4),
(['def foo ( x ) :\n return x * x'], ['def bar ( x ) :\n return x'], 0.6),
(['def bar ( x ) :\n return x'], ['def foo ( x ) :\n return x'], 0.8),
(['def foo ( x ) :\n return x'], ['def foo ( x ) :\n return x'], 1.0),
])
def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None:
result = calc_codebleu(references, predictions, 'python')
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(codebleu, 0.1)


@pytest.mark.parametrize(['lang'], [(l,) for l in AVAILABLE_LANGS])
@pytest.mark.parametrize(['lang'], [(lang,) for lang in AVAILABLE_LANGS])
def test_exact_match_works_for_all_langs(lang: str) -> None:
predictions = references = ['some matching string a couple of times']
assert calc_codebleu(references, predictions, lang)['codebleu'] == 1.0
Expand All @@ -36,7 +38,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None:
])
def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None:
result = calc_codebleu(references, predictions, lang)
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(0.6, 0.1)


Expand All @@ -54,17 +56,17 @@ def test_error_when_input_length_mismatch() -> None:
(
['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'],
['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
0.7238
0.7019
),
(
['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'],
['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
0.8804
),
# (
# ['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'],
# ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
# 0.8397
# ),
])
def test_code_x_glue_readme_examples(predictions: List[Any], references: List[Any], codebleu: float) -> None:
result = calc_codebleu(references, predictions, 'java')
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(codebleu, 0.01)


Expand Down

0 comments on commit d79e2aa

Please sign in to comment.