Skip to content

Commit

Permalink
Support (de-)projectivization with MaltParser
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Aug 28, 2023
1 parent 8a9fd99 commit 06ea307
Showing 1 changed file with 129 additions and 33 deletions.
162 changes: 129 additions & 33 deletions supar/models/dep/biaffine/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
import tempfile
from io import StringIO
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -283,11 +284,92 @@ def order(adjs, head):
return left + [head] + right
return [i for head in adjs[0] for i in order(adjs, head)]

@classmethod
def projectivize(cls, file: str, fproj: str, malt: str) -> str:
r"""
Projectivizes the non-projective input trees to pseudo-projective ones with MaltParser.
Args:
file (str):
Path to the input file containing non-projective trees that need to be handled.
fproj (str):
Path to the output file containing produced pseudo-projective trees.
malt (str):
Path to the MaltParser, which requires the Java execution environment.
Returns:
The name of the output file.
"""

import hashlib
import subprocess
file, fproj, malt = os.path.abspath(file), os.path.abspath(fproj), os.path.abspath(malt)
path, parser = os.path.dirname(malt), os.path.basename(malt)
cfg = hashlib.sha256(file.encode('ascii')).hexdigest()[:8]
subprocess.check_output([f"cd {path}; java -jar {parser} -c {cfg} -m proj -i {file} -o {fproj} -pp head"],
stderr=subprocess.STDOUT,
shell=True)
return fproj

@classmethod
def deprojectivize(
cls,
sentences: Iterable[Sentence],
arcs: Iterable,
rels: Iterable,
data: str,
malt: str
) -> Tuple[Iterable, Iterable]:
r"""
Recover the projectivized sentences to the orginal format with MaltParser.
Args:
sentences (Iterable[Sentence]):
Sentences in CoNLL-like format.
arcs (Iterable):
Sequences of arcs for pseudo projective trees.
rels (Iterable):
Sequences of dependency relations for pseudo projective trees.
data (str):
The data file used for projectivization, typically the training file.
malt (str):
Path to the MaltParser, which requires the Java execution environment.
Returns:
Recovered arcs and dependency relations.
"""

import hashlib
import subprocess
data, malt = os.path.abspath(data), os.path.abspath(malt)
path, parser = os.path.dirname(malt), os.path.basename(malt)
cfg = hashlib.sha256(data.encode('ascii')).hexdigest()[:8]
with tempfile.TemporaryDirectory() as tdir:
fproj, file = os.path.join(tdir, 'proj.conll'), os.path.join(tdir, 'nonproj.conll')
with open(fproj, 'w') as f:
f.write('\n'.join([s.conll_format(arcs[i], rels[i]) for i, s in enumerate(sentences)]))
subprocess.check_output([f"cd {path}; java -jar {parser} -c {cfg} -m deproj -i {fproj} -o {file}"],
stderr=subprocess.STDOUT,
shell=True)
arcs, rels, sent = [], [], []
with open(file) as f:
for line in f:
line = line.strip()
if len(line) == 0:
sent = [line for line in sent if line[0].isdigit()]
arcs.append([int(line[6]) for line in sent])
rels.append([line[7] for line in sent])
sent = []
else:
sent.append(line.split('\t'))
return arcs, rels

def load(
self,
data: Union[str, Iterable],
lang: Optional[str] = None,
proj: bool = False,
malt: str = None,
**kwargs
) -> Iterable[CoNLLSentence]:
r"""
Expand All @@ -302,7 +384,11 @@ def load(
``None`` if tokenization is not required.
Default: ``None``.
proj (bool):
If ``True``, discards all non-projective sentences. Default: ``False``.
If ``True``, discards all non-projective sentences.
Default: ``False``.
malt (bool):
If specified, projectivizes all the non-projective trees to pseudo-projective ones.
Default: ``None``.
Returns:
A list of :class:`CoNLLSentence` instances.
Expand All @@ -311,35 +397,38 @@ def load(
isconll = False
if lang is not None:
tokenizer = Tokenizer(lang)
if isinstance(data, str) and os.path.exists(data):
f = open(data)
if data.endswith('.txt'):
lines = (i
for s in f
if len(s) > 1
for i in StringIO(self.toconll(s.split() if lang is None else tokenizer(s)) + '\n'))
else:
lines, isconll = f, True
else:
if lang is not None:
data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
else:
data = [data] if isinstance(data[0], str) else data
lines = (i for s in data for i in StringIO(self.toconll(s) + '\n'))

index, sentence = 0, []
for line in lines:
line = line.strip()
if len(line) == 0:
sentence = CoNLLSentence(self, sentence, index)
if isconll and self.training and proj and not sentence.projective:
logger.warning(f"Sentence {index} is not projective. Discarding it!")
with tempfile.TemporaryDirectory() as tdir:
if isinstance(data, str) and os.path.exists(data):
f = open(data)
if data.endswith('.txt'):
lines = (i
for s in f
if len(s) > 1
for i in StringIO(self.toconll(s.split() if lang is None else tokenizer(s)) + '\n'))
else:
yield sentence
index += 1
sentence = []
if malt is not None:
f = open(CoNLL.projectivize(data, os.path.join(tdir, f"{os.path.basename(data)}.proj"), malt))
lines, isconll = f, True
else:
sentence.append(line)
if lang is not None:
data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
else:
data = [data] if isinstance(data[0], str) else data
lines = (i for s in data for i in StringIO(self.toconll(s) + '\n'))

index, sentence = 0, []
for line in lines:
line = line.strip()
if len(line) == 0:
sentence = CoNLLSentence(self, sentence, index)
if isconll and self.training and proj and not sentence.projective:
logger.warning(f"Sentence {index} is not projective. Discarding it!")
else:
yield sentence
index += 1
sentence = []
else:
sentence.append(line)


class CoNLLSentence(Sentence):
Expand Down Expand Up @@ -408,12 +497,19 @@ def __init__(self, transform: CoNLL, lines: Sequence[str], index: Optional[int]
self.values = list(zip(*self.values))

def __repr__(self):
# cover the raw lines
merged = {**self.annotations,
**{i: '\t'.join(map(str, line))
for i, line in enumerate(zip(*self.values))}}
return '\n'.join(merged.values()) + '\n'
return self.conll_format()

@property
def projective(self):
return CoNLL.isprojective(CoNLL.get_arcs(self.values[6]))

def conll_format(self, arcs: Iterable[int] = None, rels: Iterable[str] = None):
if arcs is None:
arcs = self.values[6]
if rels is None:
rels = self.values[7]
# cover the raw lines
merged = {**self.annotations,
**{i: '\t'.join(map(str, line))
for i, line in enumerate(zip(*self.values[:6], arcs, rels, *self.values[8:]))}}
return '\n'.join(merged.values()) + '\n'

0 comments on commit 06ea307

Please sign in to comment.