Skip to content

Commit

Permalink
Merge master into v4 (#11034)
Browse files Browse the repository at this point in the history
* Add "Aim-spaCy" to spaCy Universe (#10943)

* Add Aim-spaCy to spaCy universe

* Update Aim thumbnail

* Fix author links

Co-authored-by: Paul O'Leary McCann <[email protected]>

* Auto-format code with black (#10945)

Co-authored-by: explosion-bot <[email protected]>

* precomputable_biaffine: avoid concatenation (#10911)

The `forward` of `precomputable_biaffine` performs matrix multiplication
and then `vstack`s the result with padding. This creates a temporary
array used for the output of matrix concatenation.

This change avoids the temporary by pre-allocating an array that is
large enough for the output of matrix multiplication plus padding and
fills the array in-place.

This gave me a small speedup (a bit over 100 WPS) on de_core_news_lg on
M1 Max (after changing thinc-apple-ops to support in-place gemm as BLIS
does).

* Add failing test: `test_matcher_extension_in_set_predicate` (#10948)

* vectors: remove use of float as row number (#10955)

The float -1 was returned rather than the integer -1 as the row for
unknown keys. This doesn't introduce a realy bug, since such floats
cast (without issues) to int in the conversion to NumPy arrays. Still,
it's nice to to do the correct thing :).

* Update for CBlas changes in Thinc 8.1.0.dev2 (#10970)

* Workaround for Typer optional default values with Python calls (#10788)

* Workaround for Typer optional default values with Python calls: added test and workaround.

* @rmitsch Workaround for Typer optional default values with Python calls: reverting some black formatting changes.

Co-authored-by: Sofie Van Landeghem <[email protected]>

* @rmitsch Workaround for Typer optional default values with Python calls: removing return type hint.

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Workaround for Typer optional default values with Python calls: fixed imports, added GitHub issue marker.

* Workaround for Typer optional default values with Python calls: removed forcing of default values for optional arguments in init_config_cli(). Added default values for init_config(). Synchronized default values for init_config_cli() and init_config().

* Workaround for Typer optional default values with Python calls: removed unused import.

* Workaround for Typer optional default values with Python calls: fixed usage of optimize in init_config_cli().

* Workaround for Typer optional default values with Pythhon calls: remove output_file from InitDefaultValues.

* Workaround for Typer optional default values with Python calls: rename class for default init values.

* Workaround for Typer optional default values with Python calls: remove newline.

* remove introduced newlines

* Remove test_init_config_from_python_without_optional_args().

* remove leftover import

* reformat import

* remove duplicate

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Made _initialize_X() methods private. (#10978)

* Auto-format code with black (#10977)

Co-authored-by: explosion-bot <[email protected]>

* account for NER labels with a hyphen in the name (#10960)

* account for NER labels with a hyphen in the name

* cleanup

* fix docstring

* add return type to helper method

* shorter method and few more occurrences

* user helper method across repo

* fix circular import

* partial revert to avoid circular import

* `enable` argument for spacy.load() (#10784)

* Enable flag on spacy.load: foundation for include, enable arguments.

* Enable flag on spacy.load: fixed tests.

* Enable flag on spacy.load: switched from pretrained model to empty model with added pipes for tests.

* Enable flag on spacy.load: switched to more consistent error on misspecification of component activity. Test refactoring. Added  to default config.

* Enable flag on spacy.load: added support for fields not in pipeline.

* Enable flag on spacy.load: removed serialization fields from supported fields.

* Enable flag on spacy.load: removed 'enable' from config again.

* Enable flag on spacy.load: relaxed checks in _resolve_component_activation_status() to allow non-standard pipes.

* Enable flag on spacy.load: fixed relaxed checks for _resolve_component_activation_status() to allow non-standard pipes. Extended tests.

* Enable flag on spacy.load: comments w.r.t. resolution workarounds.

* Enable flag on spacy.load: remove include fields. Update website docs.

* Enable flag on spacy.load: updates w.r.t. changes in master.

* Implement Doc.from_json(): update docstrings.

Co-authored-by: Adriane Boyd <[email protected]>

* Implement Doc.from_json(): remove newline.

Co-authored-by: Adriane Boyd <[email protected]>

* Implement Doc.from_json(): change error message for E1038.

Co-authored-by: Adriane Boyd <[email protected]>

* Enable flag on spacy.load: wrapped docstring for _resolve_component_status() at 80 chars.

* Enable flag on spacy.load: changed exmples for enable flag.

* Remove newline.

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Fix docstring for Language._resolve_component_status().

* Rename E1038 to E1042.

Co-authored-by: Adriane Boyd <[email protected]>
Co-authored-by: Sofie Van Landeghem <[email protected]>

* add counts to verbose list of NER labels (#10957)

* Update linguistic-features.md (#10993)

Change link for downloading fasttext word vectors

* Use thinc-apple-ops>=0.1.0.dev0 with `apple` extras (#10904)

* Use thinc-apple-ops>=0.1.0.dev0 with `apple` extras

Also test with thinc-apple-ops that is at least 0.1.0.dev0.

* Check thinc-apple-ops on macOS with Python 3.10

Co-authored-by: Adriane Boyd <[email protected]>

* Use `pip install --pre` for installing thinc-apple-ops in CI

Co-authored-by: Adriane Boyd <[email protected]>

Co-authored-by: Gor Arakelyan <[email protected]>
Co-authored-by: Paul O'Leary McCann <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: explosion-bot <[email protected]>
Co-authored-by: Madeesh Kannan <[email protected]>
Co-authored-by: Raphael Mitsch <[email protected]>
Co-authored-by: Sofie Van Landeghem <[email protected]>
Co-authored-by: Adriane Boyd <[email protected]>
Co-authored-by: Victoria <[email protected]>
  • Loading branch information
10 people committed Jun 27, 2022
1 parent 7f3842f commit 3f76bc1
Show file tree
Hide file tree
Showing 31 changed files with 300 additions and 69 deletions.
4 changes: 2 additions & 2 deletions .github/azure-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ steps:
condition: eq(variables['python_version'], '3.8')
- script: |
${{ parameters.prefix }} python -m pip install thinc-apple-ops
${{ parameters.prefix }} python -m pip install --pre thinc-apple-ops
${{ parameters.prefix }} python -m pytest --pyargs spacy
displayName: "Run CPU tests with thinc-apple-ops"
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9'))
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.10'))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.1.0.dev0,<8.2.0",
"thinc>=8.1.0.dev2,<8.2.0",
"pathy",
"numpy>=1.15.0",
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.1.0
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
install_requires =
# Our libraries
spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
wasabi>=0.9.1,<1.1.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0
Expand Down Expand Up @@ -104,7 +104,7 @@ cuda114 =
cuda115 =
cupy-cuda115>=5.0.0b4,<11.0.0
apple =
thinc-apple-ops>=0.0.4,<1.0.0
thinc-apple-ops>=0.1.0.dev0,<1.0.0
# Language tokenizers with external dependencies
ja =
sudachipy>=0.5.2,!=0.6.1
Expand Down
10 changes: 9 additions & 1 deletion spacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def load(
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = util.SimpleFrozenList(),
enable: Iterable[str] = util.SimpleFrozenList(),
exclude: Iterable[str] = util.SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language:
Expand All @@ -42,14 +43,21 @@ def load(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
enable (Iterable[str]): Names of pipeline components to enable. All other
pipes will be disabled (but can be enabled later using nlp.enable_pipe).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
return util.load_model(
name, vocab=vocab, disable=disable, exclude=exclude, config=config
name,
vocab=vocab,
disable=disable,
enable=enable,
exclude=exclude,
config=config,
)


Expand Down
10 changes: 5 additions & 5 deletions spacy/cli/debug_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ..training import Example
from ..training import Example, remove_bilu_prefix
from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj
Expand Down Expand Up @@ -361,7 +361,7 @@ def debug_data(
if label != "-"
]
labels_with_counts = _format_labels(labels_with_counts, counts=True)
msg.text(f"Labels in train data: {_format_labels(labels)}", show=verbose)
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
missing_labels = model_labels - labels
if missing_labels:
msg.warn(
Expand Down Expand Up @@ -758,9 +758,9 @@ def _compile_gold(
# "Illegal" whitespace entity
data["ws_ents"] += 1
if label.startswith(("B-", "U-")):
combined_label = label.split("-")[1]
combined_label = remove_bilu_prefix(label)
data["ner"][combined_label] += 1
if sent_starts[i] == True and label.startswith(("I-", "L-")):
if sent_starts[i] and label.startswith(("I-", "L-")):
data["boundary_cross_ents"] += 1
elif label == "-":
data["ner"]["-"] += 1
Expand Down Expand Up @@ -908,7 +908,7 @@ def _get_examples_without_label(
for eg in data:
if component == "ner":
labels = [
label.split("-")[1]
remove_bilu_prefix(label)
for label in eg.get_aligned_ner()
if label not in ("O", "-", None)
]
Expand Down
37 changes: 26 additions & 11 deletions spacy/cli/init_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .. import util
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
from ..schemas import RecommendationSchema
from ..util import SimpleFrozenList
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
from ._util import string_to_list, import_code

Expand All @@ -24,16 +25,30 @@ class Optimizations(str, Enum):
accuracy = "accuracy"


class InitValues:
"""
Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and
init_config(), i.e. initialization calls via CLI respectively Python.
"""

lang = "en"
pipeline = SimpleFrozenList(["tagger", "parser", "ner"])
optimize = Optimizations.efficiency
gpu = False
pretraining = False
force_overwrite = False


@init_cli.command("config")
def init_config_cli(
# fmt: off
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
lang: str = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(False, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"),
lang: str = Opt(InitValues.lang, "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: str = Opt(",".join(InitValues.pipeline), "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(InitValues.optimize, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(InitValues.gpu, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(InitValues.pretraining, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
force_overwrite: bool = Opt(InitValues.force_overwrite, "--force", "-F", help="Force overwriting the output file"),
# fmt: on
):
"""
Expand Down Expand Up @@ -133,11 +148,11 @@ def fill_config(

def init_config(
*,
lang: str,
pipeline: List[str],
optimize: str,
gpu: bool,
pretraining: bool = False,
lang: str = InitValues.lang,
pipeline: List[str] = InitValues.pipeline,
optimize: str = InitValues.optimize,
gpu: bool = InitValues.gpu,
pretraining: bool = InitValues.pretraining,
silent: bool = True,
) -> Config:
msg = Printer(no_print=silent)
Expand Down
2 changes: 2 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
"Some tokens do not contain annotation for: {partial_attrs}")
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
"`{arg2}`={arg2_values} but these arguments are conflicting.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
22 changes: 11 additions & 11 deletions spacy/kb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ cdef class KnowledgeBase:
self.vocab = vocab
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])

def initialize_entities(self, int64_t nr_entities):
def _initialize_entities(self, int64_t nr_entities):
self._entry_index = PreshMap(nr_entities + 1)
self._entries = entry_vec(nr_entities + 1)

def initialize_vectors(self, int64_t nr_entities):
def _initialize_vectors(self, int64_t nr_entities):
self._vectors_table = float_matrix(nr_entities + 1)

def initialize_aliases(self, int64_t nr_aliases):
def _initialize_aliases(self, int64_t nr_aliases):
self._alias_index = PreshMap(nr_aliases + 1)
self._aliases_table = alias_vec(nr_aliases + 1)

Expand Down Expand Up @@ -155,8 +155,8 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E140)

nr_entities = len(set(entity_list))
self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
self._initialize_entities(nr_entities)
self._initialize_vectors(nr_entities)

i = 0
cdef KBEntryC entry
Expand Down Expand Up @@ -388,9 +388,9 @@ cdef class KnowledgeBase:
nr_entities = header[0]
nr_aliases = header[1]
entity_vector_length = header[2]
self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
self.initialize_aliases(nr_aliases)
self._initialize_entities(nr_entities)
self._initialize_vectors(nr_entities)
self._initialize_aliases(nr_aliases)
self.entity_vector_length = entity_vector_length

def deserialize_vectors(b):
Expand Down Expand Up @@ -512,8 +512,8 @@ cdef class KnowledgeBase:
cdef int64_t entity_vector_length
reader.read_header(&nr_entities, &entity_vector_length)

self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
self._initialize_entities(nr_entities)
self._initialize_vectors(nr_entities)
self.entity_vector_length = entity_vector_length

# STEP 1: load entity vectors
Expand Down Expand Up @@ -552,7 +552,7 @@ cdef class KnowledgeBase:
# STEP 3: load aliases
cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases)
self.initialize_aliases(nr_aliases)
self._initialize_aliases(nr_aliases)

cdef int64_t nr_candidates
cdef vector[int64_t] entry_indices
Expand Down
50 changes: 48 additions & 2 deletions spacy/language.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection
from typing import Union, Tuple, List, Set, Pattern, Sequence
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload

Expand Down Expand Up @@ -1694,6 +1694,7 @@ def from_config(
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
meta: Dict[str, Any] = SimpleFrozenDict(),
auto_fill: bool = True,
Expand All @@ -1708,6 +1709,8 @@ def from_config(
disable (Iterable[str]): Names of pipeline components to disable.
Disabled pipes will be loaded but they won't be run unless you
explicitly enable them by calling nlp.enable_pipe.
enable (Iterable[str]): Names of pipeline components to enable. All other
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude.
Excluded components won't be loaded.
meta (Dict[str, Any]): Meta overrides for nlp.meta.
Expand Down Expand Up @@ -1861,8 +1864,15 @@ def from_config(
# Restore the original vocab after sourcing if necessary
if vocab_b is not None:
nlp.vocab.from_bytes(vocab_b)
disabled_pipes = [*config["nlp"]["disabled"], *disable]

# Resolve disabled/enabled settings.
disabled_pipes = cls._resolve_component_status(
[*config["nlp"]["disabled"], *disable],
[*config["nlp"].get("enabled", []), *enable],
config["nlp"]["pipeline"],
)
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)

nlp.batch_size = config["nlp"]["batch_size"]
nlp.config = filled if auto_fill else config
if after_pipeline_creation is not None:
Expand Down Expand Up @@ -2014,6 +2024,42 @@ def to_disk(
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude)

@staticmethod
def _resolve_component_status(
disable: Iterable[str], enable: Iterable[str], pipe_names: Collection[str]
) -> Tuple[str, ...]:
"""Derives whether (1) `disable` and `enable` values are consistent and (2)
resolves those to a single set of disabled components. Raises an error in
case of inconsistency.
disable (Iterable[str]): Names of components or serialization fields to disable.
enable (Iterable[str]): Names of pipeline components to enable.
pipe_names (Iterable[str]): Names of all pipeline components.
RETURNS (Tuple[str, ...]): Names of components to exclude from pipeline w.r.t.
specified includes and excludes.
"""

if disable is not None and isinstance(disable, str):
disable = [disable]
to_disable = disable

if enable:
to_disable = [
pipe_name for pipe_name in pipe_names if pipe_name not in enable
]
if disable and disable != to_disable:
raise ValueError(
Errors.E1042.format(
arg1="enable",
arg2="disable",
arg1_values=enable,
arg2_values=disable,
)
)

return tuple(to_disable)

def from_disk(
self,
path: Union[str, Path],
Expand Down
6 changes: 4 additions & 2 deletions spacy/ml/_precomputable_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def forward(model, X, is_train):
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.get_param("W")
Yf = model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True)
# Preallocate array for layer output, including padding.
Yf = model.ops.alloc2f(X.shape[0] + 1, nF * nO * nP, zeros=False)
model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True, out=Yf[1:])
Yf = Yf.reshape((Yf.shape[0], nF, nO, nP))
Yf = model.ops.xp.vstack((model.get_param("pad"), Yf))
Yf[0] = model.get_param("pad")

def backward(dY_ids):
# This backprop is particularly tricky, because we get back a different
Expand Down
5 changes: 3 additions & 2 deletions spacy/ml/parser_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from libc.math cimport exp
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from thinc.backends.linalg cimport Vec, VecVec
from thinc.backends.cblas cimport saxpy, sgemm

import numpy
import numpy.random
Expand Down Expand Up @@ -112,7 +113,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, A.scores, n.classes)
Expand Down Expand Up @@ -147,7 +148,7 @@ cdef void sum_state_features(CBlas cblas, float* output,
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F


Expand Down
3 changes: 2 additions & 1 deletion spacy/pipeline/_parser_internals/arc_eager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from ...strings cimport hash_string
from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token cimport MISSING_DEP
from ...training import split_bilu_label
from ...training.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
Expand Down Expand Up @@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
return self.c[name_or_id]
name = name_or_id
if '-' in name:
move_str, label_str = name.split('-', 1)
move_str, label_str = split_bilu_label(name)
label = self.strings[label_str]
else:
move_str = name
Expand Down
Loading

0 comments on commit 3f76bc1

Please sign in to comment.