Skip to content

Commit

Permalink
Merge pull request #19 from DoodleBears/18-fix-newline-detect-error
Browse files Browse the repository at this point in the history
  • Loading branch information
DoodleBears authored Oct 3, 2024
2 parents 117ea2c + 27a33ba commit a765ae6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 45 deletions.
1 change: 1 addition & 0 deletions split_lang/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class LangSectionType(Enum):
ZH_JA = "zh_ja"
KO = "ko"
PUNCTUATION = "punctuation"
NEWLINE = "newline"
DIGIT = "digit"
OTHERS = "others"
ALL = "all"
Expand Down
66 changes: 48 additions & 18 deletions split_lang/split/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from ..config import DEFAULT_LANG, DEFAULT_LANG_MAP
from ..detect_lang.detector import (
detect_lang_combined,
possible_detection_list,
is_word_freq_higher_in_lang_b,
possible_detection_list,
)
from ..model import LangSectionType, SubString, SubStringSection
from .utils import PUNCTUATION, contains_hangul, contains_zh_ja, contains_ja
from .utils import PUNCTUATION, contains_hangul, contains_ja, contains_zh_ja

logging.basicConfig(
level=logging.WARNING,
Expand All @@ -23,13 +23,12 @@


class LangSplitter:

def __init__(
self,
lang_map: Dict = None,
default_lang: str = DEFAULT_LANG,
punctuation: str = PUNCTUATION,
not_merge_punctuation: str = "",
not_merge_punctuation: str = "\n",
merge_across_punctuation: bool = True,
merge_across_digit: bool = True,
debug: bool = False,
Expand Down Expand Up @@ -77,7 +76,6 @@ def split_by_lang(
substrings: List[SubString] = []
for section in sections:
substrings.extend(section.substrings)

substrings = self._merge_digit(substrings=substrings)

if self.merge_across_digit:
Expand Down Expand Up @@ -130,6 +128,17 @@ def _split(
length=section_len,
)
)
elif section.lang_section_type is LangSectionType.NEWLINE:
section.substrings.append(
SubString(
is_digit=False,
is_punctuation=False,
text=section.text,
lang="newline",
index=section_index,
length=section_len,
)
)
else:
substrings: List[str] = []
lang_section_type = LangSectionType.OTHERS
Expand Down Expand Up @@ -160,7 +169,11 @@ def _split(
# MARK: smart merge substring together
wtpsplit_section = pre_split_section
for section in wtpsplit_section:
if section.lang_section_type is LangSectionType.PUNCTUATION:
if (
section.lang_section_type is LangSectionType.PUNCTUATION
or section.lang_section_type is LangSectionType.NEWLINE
):
# print(section.text)
continue
smart_concat_result = self._smart_merge(
substr_list=section.substrings,
Expand Down Expand Up @@ -248,7 +261,8 @@ def add_substring(lang_section_type: LangSectionType):

for index, char in enumerate(text):
is_space = char.isspace()
if is_space is False:

if is_space is False: # Exclude newlines from processing
if contains_zh_ja(char):
if current_lang != LangSectionType.ZH_JA:
add_substring(current_lang)
Expand All @@ -271,6 +285,12 @@ def add_substring(lang_section_type: LangSectionType):
if current_lang != LangSectionType.OTHERS:
add_substring(current_lang)
current_lang = LangSectionType.OTHERS
else:
if current_lang != LangSectionType.NEWLINE:
# print(f"detect newline {char}")
add_substring(current_lang)
current_lang = LangSectionType.NEWLINE

current_text.append(char)

add_substring(current_lang)
Expand Down Expand Up @@ -400,7 +420,11 @@ def _merge_middle_substr_to_two_side(self, substrings: List[SubString]):
middle_block = substrings[index + 1]
right_block = substrings[index + 2]

if left_block.lang == right_block.lang and left_block.lang != "x":
if (
left_block.lang == right_block.lang
and left_block.lang != "x"
and left_block.lang != "newline"
):
# if different detectors results contains near block's language, then combine

if self._is_merge_middle_to_two_side(
Expand Down Expand Up @@ -554,6 +578,7 @@ def _merge_digit(

if (
left_block.lang == right_block.lang
and left_block.lang != "newline"
and left_block.is_digit
and middle_block.is_punctuation
):
Expand All @@ -570,19 +595,26 @@ def _merge_substring_across_digit(

for _, substring in enumerate(substrings):
if substring.is_digit:

if new_substrings and new_substrings[-1].lang != "punctuation":
if (
new_substrings
and new_substrings[-1].lang != "punctuation"
and new_substrings[-1].lang != "newline"
):
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
else:
new_substrings.append(substring)
else:
if new_substrings and new_substrings[-1].lang == "digit":
substring.text = new_substrings[-1].text + substring.text
substring.index = new_substrings[-1].index
substring.length = new_substrings[-1].length + substring.length

new_substrings.pop()
if (
new_substrings
and new_substrings[-1].lang == "digit"
and substring.lang != "newline"
):
temp = new_substrings.pop()
temp.text = temp.text + substring.text
temp.index = temp.index
temp.length = temp.length + substring.length
substring = temp
new_substrings.append(substring)

new_substrings = self._merge_substrings(substrings=new_substrings)
Expand Down Expand Up @@ -620,7 +652,6 @@ def _get_languages(
lang_text_list: List[SubString],
lang_section_type: LangSectionType,
):

if lang_section_type in [
LangSectionType.DIGIT,
LangSectionType.KO,
Expand All @@ -644,7 +675,6 @@ def _smart_concat_logic(
lang_text_list: List[SubString],
lang_section_type: LangSectionType,
):

lang_text_list = self._merge_substrings(lang_text_list)
lang_text_list = self._merge_middle_substr_to_two_side(lang_text_list)
lang_text_list = self._merge_substrings(lang_text_list)
Expand Down
8 changes: 8 additions & 0 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@
"Wir reisen nach Deutschland nous voyageons en Allemagne and we are excited.",
"Ich bin müde je suis fatigué and I need some rest.",
]


texts_with_newline = [
"""newline,
123.
abc
233""",
]
59 changes: 32 additions & 27 deletions tests/test_split_to_substrings.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,69 @@
from split_lang import split_by_lang
from split_lang.split.utils import DEFAULT_THRESHOLD
from tests.data.test_data import texts_de_fr_en, texts_with_digit, texts_zh_jp_ko_en
from split_lang import LangSplitter
from tests.data.test_data import (
texts_de_fr_en,
texts_with_digit,
texts_with_newline,
texts_zh_jp_ko_en,
)

lang_splitter = LangSplitter()


def test_split_to_substring():
for text in texts_zh_jp_ko_en:
substr = split_by_lang(
substr = lang_splitter.split_by_lang(
text=text,
verbose=False,
threshold=4.9e-5,
# threshold=DEFAULT_THRESHOLD,
# default_lang="en",
merge_across_punctuation=True,
)
for index, item in enumerate(substr):
for _, item in enumerate(substr):
print(item)
# print(f"{index}|{item.lang}:{item.text}")
print("----------------------")

for text in texts_de_fr_en:
substr = split_by_lang(
substr = lang_splitter.split_by_lang(
text=text,
verbose=False,
# lang_map=new_lang_map,
threshold=DEFAULT_THRESHOLD,
# default_lang="en",
)
for index, item in enumerate(substr):
for _, item in enumerate(substr):
print(item)
# print(f"{index}|{item.lang}:{item.text}")
print("----------------------")

lang_splitter.merge_across_digit = False
for text in texts_with_digit:
substr = split_by_lang(
substr = lang_splitter.split_by_lang(
text=text,
verbose=False,
threshold=4.9e-5,
# merge_across_punctuation=False,
merge_across_digit=False,
)
for index, item in enumerate(substr):
for _, item in enumerate(substr):
print(item)
# print(f"{index}|{item.lang}:{item.text}")
print("----------------------")

lang_splitter.merge_across_digit = True
lang_splitter.merge_across_punctuation = True
for text in texts_with_digit:
substr = split_by_lang(
substr = lang_splitter.split_by_lang(
text=text,
)
for _, item in enumerate(substr):
print(item)
# print(f"{index}|{item.lang}:{item.text}")
print("----------------------")


def test_split_to_substring_newline():
for text in texts_with_newline:
substr = lang_splitter.split_by_lang(
text=text,
verbose=False,
threshold=4.9e-5,
merge_across_punctuation=True,
)
for index, item in enumerate(substr):
for _, item in enumerate(substr):
print(item)
# print(f"{index}|{item.lang}:{item.text}")
print("----------------------")


def main():
test_split_to_substring()
test_split_to_substring_newline()


if __name__ == "__main__":
Expand Down

0 comments on commit a765ae6

Please sign in to comment.