From e5c78e8b6d7228437947a4f4dccc16a04a8ee2ff Mon Sep 17 00:00:00 2001 From: marco-nicola Date: Sat, 12 Dec 2020 16:22:11 +0100 Subject: [PATCH] Create PreTokenizedString.IntoEncoding --- pretokenizedstring/pretokenizedstring.go | 66 ++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/pretokenizedstring/pretokenizedstring.go b/pretokenizedstring/pretokenizedstring.go index b8782ad..49c31d3 100644 --- a/pretokenizedstring/pretokenizedstring.go +++ b/pretokenizedstring/pretokenizedstring.go @@ -5,6 +5,8 @@ package pretokenizedstring import ( + "fmt" + "github.com/nlpodyssey/gotokenizers/encodings" "github.com/nlpodyssey/gotokenizers/models" "github.com/nlpodyssey/gotokenizers/normalizedstring" "github.com/nlpodyssey/gotokenizers/strutils" @@ -156,3 +158,67 @@ func (p *PreTokenizedString) GetNormalizedByteSplits() []NormalizedByteSplit { func (p *PreTokenizedString) Splits() []Split { return p.splits } + +// IntoEncoding transforms the current PreTokenizedString into an +// encodings.Encoding. +// +// If a wordIndex is provided (i.e. >= 0), any word in the generated Encoding +// will be set to this value. This is generally used with pre-tokenized +// input, that does not need the PreTokenizedString to generate word ids. +// +// This method will fail if some splits do not have associated Token. +// +// Offset indices are based on bytes (not runes). +func (p *PreTokenizedString) IntoEncoding(wordIndex int, typeID int) (*encodings.Encoding, error) { + if len(p.splits) == 0 { + return encodings.NewDefaultEncoding(), nil + } + if !p.allSplitsHaveTokens() { + return nil, fmt.Errorf("splits have not been tokenized, call `PreTokenizedString.Tokenize` first") + } + + sequence := make([]encodings.EncodableToken, 0) + + for splitIndex, split := range p.splits { + nsOffsets := split.NormalizedString.OriginalOffsets() + + actualWordIndex := wordIndex + if actualWordIndex < 0 { + actualWordIndex = splitIndex + } + + for _, token := range *split.Tokens { + var offsets strutils.ByteOffsets + + tokenOrigRange, ok := split.NormalizedString.CoerceRangeToOriginal( + normalizedstring.NewNormalizedRange(token.Offsets.Start, token.Offsets.End)) + if ok { + offsets = strutils.ByteOffsets{ + Start: nsOffsets.Start + tokenOrigRange.Start(), + End: nsOffsets.Start + tokenOrigRange.End(), + } + } else { + offsets = token.Offsets + } + + sequence = append(sequence, encodings.EncodableToken{ + ID: token.ID, + Token: token.Value, + Offsets: offsets, + WordIndex: actualWordIndex, + TypeID: typeID, + }) + } + } + + return encodings.EncodingFromEncodableTokens(sequence), nil +} + +func (p *PreTokenizedString) allSplitsHaveTokens() bool { + for _, split := range p.splits { + if split.Tokens == nil { + return false + } + } + return true +}