Skip to content

Commit

Permalink
Update TextCatBOW to use the fixed SparseLinear layer (explosion#…
Browse files Browse the repository at this point in the history
…13149)

* Update `TextCatBOW` to use the fixed `SparseLinear` layer

A while ago, we fixed the `SparseLinear` layer to use all available
parameters: explosion/thinc#754

This change updates `TextCatBOW` to `v3` which uses the new
`SparseLinear_v2` layer. This results in a sizeable improvement on a
text categorization task that was tested.

While at it, this `spacy.TextCatBOW.v3` also adds the `length_exponent`
option to make it possible to change the hidden size. Ideally, we'd just
have an option called `length`. But the way that `TextCatBOW` uses
hashes results in a non-uniform distribution of parameters when the
length is not a power of two.

* Replace TexCatBOW `length_exponent` parameter by `length`

We now round up the length to the next power of two if it isn't
a power of two.

* Remove some tests for TextCatBOW.v2

* Fix missing import
  • Loading branch information
danieldk authored Nov 29, 2023
1 parent bf7c2ea commit da7ad97
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 50 deletions.
17 changes: 11 additions & 6 deletions spacy/cli/templates/quickstart_training.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ grad_factor = 1.0
@layers = "reduce_mean.v1"

[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false

Expand Down Expand Up @@ -308,8 +309,9 @@ grad_factor = 1.0
@layers = "reduce_mean.v1"

[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false

Expand Down Expand Up @@ -542,14 +544,15 @@ nO = null
width = ${components.tok2vec.model.encode.width}

[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false

{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
Expand All @@ -570,15 +573,17 @@ nO = null
width = ${components.tok2vec.model.encode.width}

[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false

{% else -%}
[components.textcat_multilabel.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false
{%- endif %}
Expand Down
1 change: 1 addition & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ class Errors(metaclass=ErrorsWithCodes):
"predicted docs when training {component}.")
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
"but only callbacks with one or three parameters are supported")
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
46 changes: 43 additions & 3 deletions spacy/ml/models/textcat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Optional, cast
from typing import List, Optional, Tuple, cast

from thinc.api import (
Dropout,
Expand All @@ -12,6 +12,7 @@
Relu,
Softmax,
SparseLinear,
SparseLinear_v2,
chain,
clone,
concatenate,
Expand All @@ -25,9 +26,10 @@
)
from thinc.layers.chain import init as init_chain
from thinc.layers.resizable import resize_linear_weighted, resize_model
from thinc.types import Floats2d
from thinc.types import ArrayXd, Floats2d

from ...attrs import ORTH
from ...errors import Errors
from ...tokens import Doc
from ...util import registry
from ..extract_ngrams import extract_ngrams
Expand Down Expand Up @@ -95,10 +97,48 @@ def build_bow_text_classifier(
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
return _build_bow_text_classifier(
exclusive_classes=exclusive_classes,
ngram_size=ngram_size,
no_output_layer=no_output_layer,
nO=nO,
sparse_linear=SparseLinear(nO=nO),
)


@registry.architectures("spacy.TextCatBOW.v3")
def build_bow_text_classifier_v3(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
length: int = 262144,
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
if length < 1:
raise ValueError(Errors.E1056.format(length=length))

# Find k such that 2**(k-1) < length <= 2**k.
length = 2 ** (length - 1).bit_length()

return _build_bow_text_classifier(
exclusive_classes=exclusive_classes,
ngram_size=ngram_size,
no_output_layer=no_output_layer,
nO=nO,
sparse_linear=SparseLinear_v2(nO=nO, length=length),
)


def _build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
sparse_linear: Model[Tuple[ArrayXd, ArrayXd, ArrayXd], ArrayXd],
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
fill_defaults = {"b": 0, "W": 0}
with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO=nO)
output_layer = None
if not no_output_layer:
fill_defaults["b"] = NEG_VALUE
Expand Down
6 changes: 4 additions & 2 deletions spacy/pipeline/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@
depth = 2
[model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
DEFAULT_SINGLE_TEXTCAT_MODEL = Config().from_str(single_label_default_config)["model"]

single_label_bow_config = """
[model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
Expand Down
5 changes: 3 additions & 2 deletions spacy/pipeline/textcat_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@
depth = 2
[model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false
"""
DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"]

multi_label_bow_config = """
[model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
Expand Down
2 changes: 1 addition & 1 deletion spacy/tests/pipeline/test_pipe_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_pipe_class_component_model():
"@architectures": "spacy.TextCatEnsemble.v2",
"tok2vec": DEFAULT_TOK2VEC_MODEL,
"linear_model": {
"@architectures": "spacy.TextCatBOW.v2",
"@architectures": "spacy.TextCatBOW.v3",
"exclusive_classes": False,
"ngram_size": 1,
"no_output_layer": False,
Expand Down
31 changes: 18 additions & 13 deletions spacy/tests/pipeline/test_textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_implicit_label(name, get_examples):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
# BOW V1
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
Expand Down Expand Up @@ -451,11 +451,11 @@ def test_no_resize(name, textcat_config):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# BOW V3
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# CNN
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand All @@ -480,11 +480,11 @@ def test_resize(name, textcat_config):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# BOW v3
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# CNN
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand Down Expand Up @@ -693,9 +693,14 @@ def test_overfitting_IO_multi():
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
# BOW V3
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
# ENSEMBLE V2
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
# CNN V2
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/test_cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_project_push_pull(project_dir):

def test_find_function_valid():
# example of architecture in main code base
function = "spacy.TextCatBOW.v2"
function = "spacy.TextCatBOW.v3"
result = CliRunner().invoke(app, ["find-function", function, "-r", "architectures"])
assert f"Found registered function '{function}'" in result.stdout
assert "textcat.py" in result.stdout
Expand All @@ -257,7 +257,7 @@ def test_find_function_valid():

def test_find_function_invalid():
# invalid registry
function = "spacy.TextCatBOW.v2"
function = "spacy.TextCatBOW.v3"
registry = "foobar"
result = CliRunner().invoke(
app, ["find-function", function, "--registry", registry]
Expand Down
3 changes: 2 additions & 1 deletion spacy/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ def test_util_dot_section():
factory = "textcat"
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
Expand Down
Loading

0 comments on commit da7ad97

Please sign in to comment.