ollama source for Momentry Core verification
This commit is contained in:
367
tokenizer/bytepairencoding.go
Normal file
367
tokenizer/bytepairencoding.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
|
||||
// BPEOption configures BytePairEncoding behavior
|
||||
type BPEOption func(*BytePairEncoding)
|
||||
|
||||
// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding.
|
||||
func WithSentencePieceNormalizer() BPEOption {
|
||||
return func(bpe *BytePairEncoding) {
|
||||
bpe.spaceToSpmSep = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
return newBytePairEncoding(vocab, pretokenizer)
|
||||
}
|
||||
|
||||
func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||
bpe := newBytePairEncoding(vocab, pretokenizer, opts...)
|
||||
return bpe
|
||||
}
|
||||
|
||||
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||
bpe := BytePairEncoding{
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&bpe)
|
||||
}
|
||||
|
||||
if len(pretokenizer) == 0 && !bpe.spaceToSpmSep {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
bpe.regexps = slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizer {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return bpe
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[offset:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var normalized string
|
||||
if bpe.spaceToSpmSep {
|
||||
// SentencePiece-style: replace spaces with ▁
|
||||
normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep)
|
||||
} else {
|
||||
// GPT-2 byte-level: map bytes to shifted Unicode codepoints
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
normalized = sb.String()
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(normalized); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(normalized)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
} else if bpe.spaceToSpmSep {
|
||||
// SentencePiece byte fallback: encode each UTF-8 byte as <0xHH>
|
||||
for _, b := range []byte(string(merge.runes)) {
|
||||
if id := bpe.vocab.Encode(fmt.Sprintf("<0x%02X>", b)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
} else {
|
||||
slog.Debug("unknown byte token", "byte", b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type lazyIdsString struct {
|
||||
ids []int32
|
||||
}
|
||||
|
||||
func (l lazyIdsString) LogValue() slog.Value {
|
||||
return slog.AnyValue(fmt.Sprint(l.ids))
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// SentencePiece-style BPE stores true Unicode codepoints in the vocab
|
||||
// (plus ▁ as a whitespace marker), so decoding should pass runes through
|
||||
// directly instead of applying the GPT-2 byte-level reverse mapping.
|
||||
// Without this, codepoints in the 0x0100-0x0142 range (e.g. ą ę ć ł)
|
||||
// get mangled by the GPT-2 reversal into control characters.
|
||||
if bpe.spaceToSpmSep {
|
||||
for _, id := range ids {
|
||||
data := bpe.vocab.Decode(id)
|
||||
|
||||
// SentencePiece byte tokens: "<0xHH>" → raw byte
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
if b, err := strconv.ParseUint(data[3:5], 16, 8); err == nil {
|
||||
sb.WriteByte(byte(b))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range data {
|
||||
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
||||
sb.WriteByte(' ')
|
||||
} else {
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
||||
// range to represent bytes. Remap them back to actual bytes.
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
case r > 0x0143:
|
||||
// Non-GPT2 rune (e.g., SentencePiece-style BPE).
|
||||
// Handle ▁ as word separator, otherwise write the rune as-is.
|
||||
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
||||
sb.WriteByte(' ')
|
||||
} else {
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
639
tokenizer/bytepairencoding_test.go
Normal file
639
tokenizer/bytepairencoding_test.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]int32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// spmBPE builds a SentencePiece-style BPE tokenizer for testing.
|
||||
//
|
||||
// Models that use SentencePiece BPE differ from GPT-2 BPE in how they
|
||||
// handle spaces: the vocabulary stores ▁ (U+2581) instead of GPT-2's
|
||||
// shifted-byte encoding (0x0100–0x0143). Without WithSentencePieceNormalizer,
|
||||
// spaces are mapped through the GPT-2 byte table which produces wrong token
|
||||
// IDs for any vocabulary that uses ▁-prefixed tokens. The decode path has
|
||||
// the inverse problem: high codepoints like CJK characters and ▁ itself
|
||||
// would be mangled by the GPT-2 reverse mapping instead of being passed
|
||||
// through (or converted to spaces in the ▁ case).
|
||||
func spmBPE(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
tokens := []string{
|
||||
// Control tokens (low IDs, as in real SentencePiece vocabs)
|
||||
"<pad>", // 0
|
||||
"<eos>", // 1
|
||||
"<bos>", // 2
|
||||
"<|start>", // 3 - asymmetric open/close special tokens
|
||||
"<end|>", // 4
|
||||
"<|q>", // 5 - short special token (like <|"|>)
|
||||
|
||||
// ▁-prefixed word tokens (the core of what SPM BPE changes)
|
||||
"▁hello", // 6
|
||||
"▁world", // 7
|
||||
"hello", // 8
|
||||
"▁Run", // 9
|
||||
"▁a", // 10
|
||||
|
||||
// Punctuation and structure
|
||||
",", // 11
|
||||
"!", // 12
|
||||
":", // 13
|
||||
"{", // 14
|
||||
"}", // 15
|
||||
|
||||
// Whitespace separator
|
||||
"▁", // 16
|
||||
|
||||
// Subword tokens used in tool-declaration-like patterns
|
||||
"description", // 17
|
||||
"▁command", // 18
|
||||
"declaration", // 19
|
||||
|
||||
// Unicode token for decode passthrough testing (must be > U+0143
|
||||
// to exercise the SPM decode path rather than GPT-2 byte reversal)
|
||||
"▁中文", // 20
|
||||
|
||||
// Unicode tokens with codepoints in the GPT-2 byte range (0x0100-0x0142).
|
||||
// Without the SPM decode path, these get mangled by GPT-2 byte reversal.
|
||||
"ą", // 21 (U+0105) — would become 0x05 via GPT-2 reversal
|
||||
"ę", // 22 (U+0119) — would become 0x19
|
||||
"ć", // 23 (U+0107) — would become 0x07
|
||||
"ł", // 24 (U+0142) — would become 0xA0
|
||||
|
||||
// Byte fallback tokens (SentencePiece BYTE type)
|
||||
"<0x00>", // 25
|
||||
"<0x01>", // 26
|
||||
}
|
||||
|
||||
// Add all 256 byte tokens starting at index 27
|
||||
for b := 2; b < 256; b++ {
|
||||
tokens = append(tokens, fmt.Sprintf("<0x%02X>", b))
|
||||
}
|
||||
|
||||
types := make([]int32, len(tokens))
|
||||
for i := range types {
|
||||
types[i] = TOKEN_TYPE_NORMAL
|
||||
}
|
||||
types[0] = TOKEN_TYPE_CONTROL // <pad>
|
||||
types[1] = TOKEN_TYPE_CONTROL // <eos>
|
||||
types[2] = TOKEN_TYPE_CONTROL // <bos>
|
||||
types[3] = TOKEN_TYPE_USER_DEFINED // <|start>
|
||||
types[4] = TOKEN_TYPE_USER_DEFINED // <end|>
|
||||
types[5] = TOKEN_TYPE_USER_DEFINED // <|q>
|
||||
for i := 21; i < len(types); i++ {
|
||||
types[i] = TOKEN_TYPE_BYTE
|
||||
}
|
||||
|
||||
return NewBytePairEncodingWithOptions(
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
BOS: []int32{2},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
},
|
||||
// Empty pretokenizer list: falls back to the default pattern.
|
||||
// Real SentencePiece BPE models are configured this way.
|
||||
[]string{},
|
||||
WithSentencePieceNormalizer(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestSentencePieceBPE(t *testing.T) {
|
||||
tok := spmBPE(t)
|
||||
|
||||
// Test 1: Space-to-▁ normalization and roundtrip.
|
||||
//
|
||||
// SentencePiece BPE has no pretokenizer — the BPE merges handle word
|
||||
// boundaries via ▁ markers. With no merges in the test vocab, multi-char
|
||||
// tokens won't be found, but the roundtrip must still be lossless.
|
||||
t.Run("spm space normalization roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, input := range []string{
|
||||
"hello",
|
||||
" hello",
|
||||
"hello, world!",
|
||||
" leading spaces",
|
||||
"multiple spaces",
|
||||
} {
|
||||
ids, err := tok.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode(%q): %v", input, err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
t.Fatalf("Encode(%q) returned empty IDs", input)
|
||||
}
|
||||
|
||||
got, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode(%v): %v", ids, err)
|
||||
}
|
||||
if got != input {
|
||||
t.Errorf("roundtrip %q: Decode(Encode) = %q", input, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test 2: Special tokens interleaved with SPM-normalized text.
|
||||
//
|
||||
// This mimics tool declaration patterns like:
|
||||
// <|tool>declaration:bash{description:<|"|>Run a command<|"|>}<tool|>
|
||||
// where special tokens (<|tool>, <|"|>, <tool|>) must be extracted
|
||||
// first, then the remaining text fragments go through SPM normalization.
|
||||
t.Run("special tokens with spm text fragments", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "<|start>declaration:description:<|q> Run a command<|q>}<end|>"
|
||||
ids, err := tok.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Special tokens should be extracted as single IDs at the right positions.
|
||||
// The text between them is SPM-normalized and BPE-encoded (specific IDs
|
||||
// depend on merges, so we verify the special token positions + roundtrip).
|
||||
specialPositions := map[int32]bool{3: true, 4: true, 5: true} // <|start>, <end|>, <|q>
|
||||
foundSpecials := 0
|
||||
for _, id := range ids {
|
||||
if specialPositions[id] {
|
||||
foundSpecials++
|
||||
}
|
||||
}
|
||||
if foundSpecials != 4 { // <|start>, <|q>, <|q>, <end|>
|
||||
t.Errorf("expected 4 special tokens, found %d in %v", foundSpecials, ids)
|
||||
}
|
||||
|
||||
// First token must be <|start>(3), last must be <end|>(4)
|
||||
if ids[0] != 3 {
|
||||
t.Errorf("first token = %d, want 3 (<|start>)", ids[0])
|
||||
}
|
||||
if ids[len(ids)-1] != 4 {
|
||||
t.Errorf("last token = %d, want 4 (<end|>)", ids[len(ids)-1])
|
||||
}
|
||||
})
|
||||
|
||||
// Test 3: Byte fallback for characters not in the vocabulary.
|
||||
//
|
||||
// SentencePiece vocabs include <0xHH> byte tokens for every byte value.
|
||||
// When a character (e.g. "ą" = U+0105 = C4 85) isn't in the vocab as a
|
||||
// direct token, the encoder must fall back to its UTF-8 bytes:
|
||||
// <0xC4> <0x85>. Without this fallback, the character is silently dropped.
|
||||
// See: https://github.com/ollama/ollama/issues/15229
|
||||
t.Run("byte fallback for unknown chars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// "ą" is not in the vocab — should fall back to byte tokens
|
||||
ids, err := tok.Encode("ą", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode(ą): %v", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
t.Fatal("Encode(ą) returned empty IDs — character was silently dropped")
|
||||
}
|
||||
|
||||
got, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode: %v", err)
|
||||
}
|
||||
if got != "ą" {
|
||||
t.Errorf("roundtrip = %q, want %q", got, "ą")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 4: Byte fallback preserves known tokens around unknown chars.
|
||||
t.Run("byte fallback mixed with known tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// "hello" is in vocab, "é" is not
|
||||
ids, err := tok.Encode("helloé", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode: %v", err)
|
||||
}
|
||||
|
||||
got, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode: %v", err)
|
||||
}
|
||||
if got != "helloé" {
|
||||
t.Errorf("roundtrip = %q, want %q", got, "helloé")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 5: Decode doesn't mangle Unicode in the GPT-2 byte range.
|
||||
//
|
||||
// Characters like ą (U+0105), ę (U+0119), ć (U+0107), ł (U+0142) have
|
||||
// codepoints in the 0x0100-0x0142 range that GPT-2 byte reversal would
|
||||
// remap to control characters. SentencePiece decode must pass them through.
|
||||
t.Run("decode unicode in gpt2 byte range", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Token IDs 21-24 are ą, ę, ć, ł
|
||||
ids := []int32{21, 22, 23, 24}
|
||||
got, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode: %v", err)
|
||||
}
|
||||
if got != "ąęćł" {
|
||||
t.Errorf("Decode = %q, want %q", got, "ąęćł")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 6: Decode handles non-GPT2 Unicode correctly.
|
||||
//
|
||||
// GPT-2 BPE decode reverses the byte→codepoint shift for runes in
|
||||
// 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK,
|
||||
// accented chars, etc.) which have codepoints well above 0x0143.
|
||||
// Without the > 0x0143 passthrough in Decode, these would be mangled
|
||||
// by the GPT-2 reverse mapping (e.g., written as raw bytes instead
|
||||
// of the original characters).
|
||||
t.Run("decode non-gpt2 unicode passthrough", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
" 中文": {20}, // ▁→space, then CJK passes through as-is
|
||||
}
|
||||
|
||||
for want, ids := range cases {
|
||||
got, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode(%v): %v", ids, err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("Decode(%v) = %q, want %q", ids, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytePairEncodingSplitMultipleRegexpsPreservesOffsets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bpe := NewBytePairEncoding(
|
||||
nil,
|
||||
`(?:\r?\n)+(?!\r?\n)`,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
)
|
||||
|
||||
input := "One line\nTwo lines\n\nThree"
|
||||
got := slices.Collect(bpe.split(input))
|
||||
want := []string{"One", " line", "\n", "Two", " lines", "\n\n", "Three"}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("split mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytePairEncodingSplitRefactPreservesOffsets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bpe := NewBytePairEncoding(
|
||||
nil,
|
||||
`\p{N}`,
|
||||
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
|
||||
)
|
||||
|
||||
input := "One line\nTwo lines\n\nThree"
|
||||
got := slices.Collect(bpe.split(input))
|
||||
want := []string{"One", " line", "\n", "Two", " lines", "\n", "\n", "Three"}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("split mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytePairEncodingSplitDeepSeekV3PreservesOffsets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bpe := NewBytePairEncoding(
|
||||
nil,
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
|
||||
input := "One line\nTwo lines\n\nThree"
|
||||
got := slices.Collect(bpe.split(input))
|
||||
want := []string{"One", " line", "\n", "Two", " lines", "\n\n", "Three"}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("split mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
patterns,
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
patterns: []string{
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||
},
|
||||
{
|
||||
name: "individual digits",
|
||||
patterns: []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
249
tokenizer/sentencepiece.go
Normal file
249
tokenizer/sentencepiece.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
for cnt := range vocab.Types {
|
||||
switch vocab.Types[cnt] {
|
||||
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||
fallthrough
|
||||
default:
|
||||
counter[int(vocab.Types[cnt])] += 1
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||
|
||||
if id := spm.vocab.Encode(text); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
result = append(result, unknownID)
|
||||
} else {
|
||||
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, result...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
|
||||
// For tokenizer that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
172
tokenizer/sentencepiece_test.go
Normal file
172
tokenizer/sentencepiece_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var spm sentencepiece.ModelProto
|
||||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
v.Values = append(v.Values, piece.GetPiece())
|
||||
v.Scores = append(v.Scores, piece.GetScore())
|
||||
switch t := piece.GetType(); t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
v.Types = append(v.Types, int32(t))
|
||||
default:
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
// todo parse the special tokens file
|
||||
// - this will roundtrip correctly but the <start_of_turn> and
|
||||
// <end_of_turn> tokens aren't processed
|
||||
v.Types = append(v.Types, tt)
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
tokenizer := loadSentencePieceVocab(t)
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
"你好",
|
||||
"Hello 你好 world!",
|
||||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||||
"Numbers and symbols: 123456789 +- */",
|
||||
"Special tokens: <bos> text <eos>",
|
||||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special tokens", func(t *testing.T) {
|
||||
type candidate struct {
|
||||
token string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
cases := []candidate{
|
||||
{"<bos>", []int32{2}},
|
||||
{"<eos>", []int32{1}},
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want.token, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !slices.Equal(ids, want.ids) {
|
||||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
"<0xEA>",
|
||||
"<0x41>",
|
||||
"<0xC3>",
|
||||
"<0xA3>",
|
||||
},
|
||||
Types: []int32{
|
||||
TOKEN_TYPE_NORMAL,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
},
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single byte token",
|
||||
ids: []int32{1},
|
||||
expected: "\xea",
|
||||
},
|
||||
{
|
||||
name: "ASCII byte token",
|
||||
ids: []int32{2},
|
||||
expected: "A",
|
||||
},
|
||||
{
|
||||
name: "multiple byte tokens forming UTF-8 character",
|
||||
ids: []int32{3, 4},
|
||||
expected: "ã",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := spm.Decode(tt.ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
BIN
tokenizer/testdata/gemma2/tokenizer.model
vendored
Normal file
BIN
tokenizer/testdata/gemma2/tokenizer.model
vendored
Normal file
Binary file not shown.
128002
tokenizer/testdata/llama3.2/encoder.json
vendored
Normal file
128002
tokenizer/testdata/llama3.2/encoder.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
280147
tokenizer/testdata/llama3.2/vocab.bpe
vendored
Normal file
280147
tokenizer/testdata/llama3.2/vocab.bpe
vendored
Normal file
File diff suppressed because it is too large
Load Diff
63845
tokenizer/testdata/war-and-peace.txt
vendored
Normal file
63845
tokenizer/testdata/war-and-peace.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
17
tokenizer/tokenizer.go
Normal file
17
tokenizer/tokenizer.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package tokenizer
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
TOKEN_TYPE_CONTROL
|
||||
TOKEN_TYPE_USER_DEFINED
|
||||
TOKEN_TYPE_UNUSED
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type Tokenizer interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
112
tokenizer/vocabulary.go
Normal file
112
tokenizer/vocabulary.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Special int32
|
||||
|
||||
const (
|
||||
SpecialBOS Special = iota
|
||||
SpecialEOS
|
||||
)
|
||||
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Types []int32
|
||||
Scores []float32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS []int32
|
||||
AddBOS, AddEOS bool
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
|
||||
valuesOnce sync.Once
|
||||
values map[string]int32
|
||||
|
||||
mergeOnce sync.Once
|
||||
merge map[string]int32
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||
switch special {
|
||||
case SpecialBOS:
|
||||
return slices.Contains(v.BOS, id)
|
||||
case SpecialEOS:
|
||||
return slices.Contains(v.EOS, id)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
if v.AddBOS && len(v.BOS) > 0 {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||
ids = append([]int32{v.BOS[0]}, ids...)
|
||||
}
|
||||
|
||||
if v.AddEOS && len(v.EOS) > 0 {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||
ids = append(ids, v.EOS[0])
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Encode(s string) int32 {
|
||||
v.valuesOnce.Do(func() {
|
||||
v.values = make(map[string]int32, len(v.Values))
|
||||
for i, value := range v.Values {
|
||||
v.values[value] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.values[s]; ok {
|
||||
return id
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Decode(id int32) string {
|
||||
return v.Values[id]
|
||||
}
|
||||
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return v.special
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Merge(left, right string) int {
|
||||
v.mergeOnce.Do(func() {
|
||||
v.merge = make(map[string]int32, len(v.Merges))
|
||||
for i, merge := range v.Merges {
|
||||
v.merge[merge] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.merge[left+" "+right]; ok {
|
||||
return int(id)
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
107
tokenizer/vocabulary_test.go
Normal file
107
tokenizer/vocabulary_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestSpecialVocabulary(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||
}
|
||||
|
||||
specialVocab := vocab.SpecialVocabulary()
|
||||
|
||||
if len(specialVocab) != 4 {
|
||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddSpecialVocabulary(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
vocab *Vocabulary
|
||||
input []int32
|
||||
want []int32
|
||||
}{
|
||||
{
|
||||
name: "add bos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add bos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{0, 2, 3, 4},
|
||||
want: []int32{0, 0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "add eos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add eos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4, 1},
|
||||
want: []int32{2, 3, 4, 1, 1},
|
||||
},
|
||||
{
|
||||
name: "add both",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
name: "add bos to empty inputs",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{},
|
||||
want: []int32{0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.vocab.addSpecials(tt.input)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("no match (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
171
tokenizer/wordpiece.go
Normal file
171
tokenizer/wordpiece.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
lowercase bool
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
// this differs from original word piece which uses "##" to indicate subwords.
|
||||
const ggmlPrefix = "▁"
|
||||
|
||||
var wordPieceReplacer = strings.NewReplacer(
|
||||
" .", ".",
|
||||
" ?", "?",
|
||||
" !", "!",
|
||||
" ,", ",",
|
||||
" ' ", "'",
|
||||
" n't", "n't",
|
||||
" 'm", "'m",
|
||||
" do not", " don't",
|
||||
" 's", "'s",
|
||||
" 've", "'ve",
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements Tokenizer.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
return "", fmt.Errorf("invalid token id: %d", id)
|
||||
}
|
||||
|
||||
var separator string
|
||||
piece := wpm.vocab.Values[id]
|
||||
if i > 0 &&
|
||||
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||
separator = " "
|
||||
}
|
||||
|
||||
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF,
|
||||
r >= 0x3400 && r <= 0x4DBF,
|
||||
r >= 0x20000 && r <= 0x2A6DF,
|
||||
r >= 0x2A700 && r <= 0x2B73F,
|
||||
r >= 0x2B740 && r <= 0x2B81F,
|
||||
r >= 0x2B820 && r <= 0x2CEAF,
|
||||
r >= 0xF900 && r <= 0xFAFF,
|
||||
r >= 0x2F800 && r <= 0x2FA1F:
|
||||
runes = append(runes, ' ', r, ' ')
|
||||
default:
|
||||
runes = append(runes, r)
|
||||
}
|
||||
}
|
||||
|
||||
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||
// split on but keep punctuation
|
||||
var start int
|
||||
for start < len(w) {
|
||||
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||
if end < 0 {
|
||||
end = len(w) - start
|
||||
} else if end == 0 {
|
||||
end = 1
|
||||
}
|
||||
|
||||
if !yield(w[start : start+end]) {
|
||||
return
|
||||
}
|
||||
|
||||
start += end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements Tokenizer.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
unk := wpm.vocab.Encode("[UNK]")
|
||||
for word := range wpm.words(s) {
|
||||
var start int
|
||||
var pieces []int32
|
||||
for start < len(word) {
|
||||
end := len(word)
|
||||
|
||||
var piece int32
|
||||
for start < end {
|
||||
subword := word[start:end]
|
||||
if start == 0 {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
if wpm.lowercase {
|
||||
subword = strings.ToLower(subword)
|
||||
}
|
||||
piece = wpm.vocab.Encode(subword)
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
end--
|
||||
}
|
||||
|
||||
if piece < 0 {
|
||||
// Unknown token
|
||||
pieces = pieces[:0]
|
||||
break
|
||||
}
|
||||
|
||||
pieces = append(pieces, piece)
|
||||
start = end
|
||||
}
|
||||
|
||||
if len(pieces) > 0 {
|
||||
ids = append(ids, pieces...)
|
||||
} else {
|
||||
ids = append(ids, unk)
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements Tokenizer.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements Tokenizer.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
lowercase: lowercase,
|
||||
}
|
||||
}
|
||||
53
tokenizer/wordpiece_test.go
Normal file
53
tokenizer/wordpiece_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user