Skip to content

Commit

Permalink
fix(splitter): when merge across punctuation on section stage the lan…
Browse files Browse the repository at this point in the history
…g of substring inside is wrong
  • Loading branch information
DoodleBears committed Oct 17, 2024
1 parent 6639274 commit bd94849
Showing 1 changed file with 57 additions and 20 deletions.
77 changes: 57 additions & 20 deletions split_lang/split/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,37 @@ def _special_merge_for_zh_ja(
new_substrings = self._merge_substrings(substrings=new_substrings)
return new_substrings

# MARK: _merge_substrings_across_newline
def _merge_substrings_across_newline(
self,
substrings: List[SubString],
) -> List[SubString]:
new_substrings: List[SubString] = []
last_lang = ""

for _, substring in enumerate(substrings):
if new_substrings:
if substring.lang == "newline":
# If the last substring is also a newline, merge them
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
else:
if substring.lang == last_lang or last_lang == "":
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
new_substrings[-1].lang = (
substring.lang
if new_substrings[-1].lang == "newline"
else new_substrings[-1].lang
)
else:
new_substrings.append(substring)
last_lang = substring.lang
else:
new_substrings.append(substring)

return new_substrings

# MARK: _merge_substrings_across_newline_based_on_sections
def _merge_substrings_across_newline_based_on_sections(
self,
Expand Down Expand Up @@ -769,6 +800,11 @@ def _merge_substrings_across_newline_based_on_sections(
section.substrings[substr_index - 1].index
+ section.substrings[substr_index - 1].length
)
# NOTE: 合并 sections 中的 substrings 里面的 text
for section in new_sections_merged:
section.substrings = self._merge_substrings_across_newline(
substrings=section.substrings
)
if self.debug:
logger.debug(
"---------------------------------after_merge_newline_sections:"
Expand All @@ -788,12 +824,8 @@ def _merge_substrings_across_digit(
for _, substring in enumerate(substrings):
if new_substrings:
if substring.lang == "digit":
if new_substrings[-1].lang == "digit":
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
else:
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
else:
if substring.lang == last_lang or last_lang == "":
new_substrings[-1].text += substring.text
Expand Down Expand Up @@ -907,24 +939,29 @@ def _merge_substrings_across_punctuation(
substrings: List[SubString],
) -> List[SubString]:
new_substrings: List[SubString] = []
lang = ""
for substring in substrings:
if (
substring.lang == "punctuation"
and substring.text.strip() not in self.not_merge_punctuation
):
if new_substrings and new_substrings[-1].lang == lang:
last_lang = "" # Changed from 'lang' to 'last_lang' for consistency

for _, substring in enumerate(substrings):
if new_substrings:
if substring.lang == "punctuation":
# If the last substring is also a punctuation, merge them
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
else:
new_substrings.append(substring)
if substring.lang == last_lang or last_lang == "":
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
new_substrings[-1].lang = (
substring.lang
if new_substrings[-1].lang == "punctuation"
else new_substrings[-1].lang
)
else:
new_substrings.append(substring)
last_lang = substring.lang
else:
if substring.lang != lang:
new_substrings.append(substring)
else:
new_substrings[-1].text += substring.text
new_substrings[-1].length += substring.length
lang = substring.lang if substring.lang != "punctuation" else lang
new_substrings.append(substring)

return new_substrings

# MARK: _merge_substrings_across_punctuation based on sections
Expand Down

0 comments on commit bd94849

Please sign in to comment.