Utilize CLS-Token of transformers in textcat component #7178
Replies: 4 comments
-
Yes, you can either build the whole model yourself in PyTorch (including using the Here's how to do it the latter way. This turned out to be a great question because it pointed to a few gaps in the layers we provide. I've drafted the solution, but I haven't run it yet. Getting one [CLS] per DocThe Using this overlapping span-based approach, you'll have multiple In case it's not practical to do this, you'll need to define the model such that it can handle multiple Model definition and configNext you'll need to define a model for your textcat. Here's a model that connects to a shared transformer, gets the class tokens for each span, mean pools them, and and passes the result through a linear layer with softmax activation to predict the class probabilities. Model definitionfrom typing import List
from thinc.api import Model, Softmax, chain, reduce_mean, list2ragged
from thinc.types import Floats2d
from spacy_transformers.data_classes import TransformerData
from spacy_transformers.layers import TransformerListener
@registry.architectures.register("TransformerListenerClassTokenTextcat.v1")
def transformer_listener_class_tok2vec_v1(
tensor_index: int,
class_index: int,
) -> Model[List[Doc], Floats2d]:
# I'm assuming that we can have more than one span per doc, and we're going
# to average the class vectors for the spans if there are multiple.
return chain(
TransformerListener(upstream_name="*"), # List[Doc] -> List[TransformerData]
trf2tensor(tensor_index), # List[TransformerData] -> List[Floats3d]
foreach(
# This does array[:, class_index].
# i.e. we're getting the class array for each span.
array_getitem((slice(0, None), class_index)) # Floats3d -> Floats2d
), # List[Floats3d] -> List[Floats2d]
list2ragged(), # List[Floats2d] -> Ragged
reduce_mean(), # Ragged -> Floats2d
Softmax() # Floats2d -> Floats2d
) Config for transformer and textcat components
trf2tensor layer (we should add this to spacy-transformers)from typing import List, TypeVar
from thinc.api import Model
from thinc.types import FloatsXd
from spacy_transformers.data_classes import TransformerData
OutT = TypeVar("OutT", bound=FloatsXd)
def trf2tensor(index: int) -> Model[List[TransformerData], OutT]:
"""Extract just one tensor from each TransformerData."""
return Model(
"trf2tensor",
forward,
attrs={"index": index}
)
def forward(model: Model, Xs: List[TransformerData], is_train: bool) -> Tuple[OutT, Callable]:
index = model.attrs["index"]
Ys = [x.tensors[index] for x in Xs]
def backprop_trfs2tensor(dYs: List[OutT]) -> List[TransformerData]:
dXs = []
for X, dY in zip(Xs, dYs):
d_tensors = []
for j, tensor in enumerate(X.tensors):
if j == index:
d_tensors.append(dY)
else:
d_tensors.append(model.ops.alloc(tensor.shape, dtype=tensor.dtype))
dXs.append(
TransformerData(
tensors=d_tensors,
wordpieces=X.wordpieces,
align=X.align
)
)
return dXs foreach layer (should go in Thinc)InT = TypeVar("InT")
OutT = TypeVar("OutT")
# I could've sworn I implemented this in thinc already =/. Maybe it was in a branch
# that got abandoned or something?
# In any case, it maps a layer across a list.
def foreach(layer: Model[InT, OutT]) -> Model[List[InT], List[OutT]]:
return Model("foreach", forward_foreach, layers=[layer])
def forward_foreach(
model: Model[List[InT], List[OutT]],
Xs: List[InT],
is_train: bool
) -> Tuple[List[OutT], Callable[[List[OutT]], List[InT]]]:
layer = model.layers[0]
Ys = []
callbacks = []
for X in Xs:
Y, get_dX = layer(X, is_train)
Ys.append(Y)
callbacks.append(get_dX)
def backprop_foreach(dYs: List[OutT]) -> List[InT]:
return [callback(dY) for callback, dY in zip(callbacks, dYs)]
return Ys, backprop_foreach Other tipsYou would need to ensure that the code that registers your |
Beta Was this translation helpful? Give feedback.
-
There's a couple of follow-up tasks someone could help with here:
|
Beta Was this translation helpful? Give feedback.
-
This is so helpful. Thank you! |
Beta Was this translation helpful? Give feedback.
-
The "foreach" layer is now in thinc v8.0.2 as |
Beta Was this translation helpful? Give feedback.
-
I was going over the spacy 3.0 and building a classification model with transformer + textcat components. I just realized that textcat inputs are the aligned outputs of the transformer tokens. Since CLS token encodes document-wise information, is there a way to utilize the CLS-token of the transformers and pass it to the textcat component. Or is it already being utilized and I am missing something?
Environment
Beta Was this translation helpful? Give feedback.
All reactions