ollama source for Momentry Core verification
This commit is contained in:
85
x/imagegen/tokenizer/README.md
Normal file
85
x/imagegen/tokenizer/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Tokenizer
|
||||
|
||||
Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package.
|
||||
|
||||
## Features
|
||||
|
||||
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
|
||||
- **SentencePiece** - Gemma style with `▁` space handling
|
||||
- **WordPiece** - BERT style with `##` continuation tokens
|
||||
- **Parallel encoding** - Automatic parallelization for inputs >4KB
|
||||
- **HuggingFace compatible** - Loads `tokenizer.json` directly
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
// Load from HuggingFace model directory
|
||||
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Encode text to token IDs
|
||||
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
|
||||
|
||||
// Decode back to text
|
||||
text := tok.Decode(ids)
|
||||
|
||||
// Check special tokens
|
||||
if tok.IsEOS(ids[len(ids)-1]) {
|
||||
// End of sequence
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Benchmarks on Apple M3 Max:
|
||||
|
||||
| Input Size | Encode | Decode | Tokens |
|
||||
|------------|--------|--------|--------|
|
||||
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
|
||||
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
|
||||
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
|
||||
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
|
||||
|
||||
Comparison with other implementations (10 MB input):
|
||||
|
||||
| Implementation | Encode Speed | Notes |
|
||||
|----------------|--------------|-------|
|
||||
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
|
||||
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
|
||||
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
|
||||
|
||||
## Performance Opportunities
|
||||
|
||||
Potential optimizations not yet implemented:
|
||||
|
||||
| Optimization | Expected Gain | Complexity |
|
||||
|--------------|---------------|------------|
|
||||
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
|
||||
| Custom regex engine (like tiktoken) | 1.5-2x | High |
|
||||
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
|
||||
| Assembly BPE merge loop | 1.2-1.5x | High |
|
||||
| Memoization for repeated substrings | Variable | Low |
|
||||
|
||||
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
|
||||
|
||||
## Not Yet Implemented
|
||||
|
||||
| Feature | Used By | Notes |
|
||||
|---------|---------|-------|
|
||||
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
|
||||
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
|
||||
| Custom pretokenizers | Model-specific | Beyond standard patterns |
|
||||
|
||||
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `tokenizer.go` | Main implementation (~1000 lines) |
|
||||
| `tokenizer_test.go` | Tests and benchmarks |
|
||||
| `testdata/` | Mini tokenizer for unit tests |
|
||||
1
x/imagegen/tokenizer/testdata/mini_llama.json
vendored
Normal file
1
x/imagegen/tokenizer/testdata/mini_llama.json
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"model": {"type": "BPE", "vocab": {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "fé": 59958, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "r": 81, "q": 80, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "¡": 94, "¢": 95, "£": 96, "¤": 97, "¥": 98, "¦": 99, "§": 100, "¨": 101, "World": 10343, "©": 102, "ª": 103, "«": 104, "¬": 105, "®": 106, "world": 14957, "¯": 107, "°": 108, "±": 109, "²": 110, "³": 111, "´": 112, "µ": 113, "¶": 114, "·": 115, "¸": 116, "¹": 117, "º": 118, "»": 119, "¼": 120, "½": 121, "¾": 122, "¿": 123, "À": 124, "Á": 125, "Â": 126, "Ã": 127, "Ä": 128, "Å": 129, "Æ": 130, "Ç": 131, "È": 132, "É": 133, "Ê": 134, "Ë": 135, "Ì": 136, "Í": 137, "Î": 138, "Ï": 139, "Ð": 140, "Ñ": 141, "Ò": 142, "Ó": 143, "Ô": 144, "Õ": 145, "Ö": 146, "×": 147, "Ø": 148, "Ù": 149, "Ú": 150, "Û": 151, "Ü": 152, "Ý": 153, "Þ": 154, "ß": 155, "à": 156, "á": 157, "â": 158, "ã": 159, "ä": 160, "å": 161, "æ": 162, "ç": 163, "è": 164, "é": 165, "ê": 166, "ë": 167, "ì": 168, "Ġhello": 24748, "í": 169, "î": 170, "ï": 171, "ð": 172, "ñ": 173, "Hello": 9906, "ò": 174, "ó": 175, "ô": 176, "õ": 177, "ö": 178, "Ġ{}": 4792, "÷": 179, "ø": 180, "ù": 181, "ú": 182, "û": 183, "ü": 184, "ý": 185, "þ": 186, "ÿ": 187, "Ā": 188, "ā": 189, "Ă": 190, "ă": 191, "Ċ": 198, "Ą": 192, "ą": 193, "Ć": 194, "ć": 195, "Ĉ": 196, "ĉ": 197, "ċ": 199, "Č": 200, "č": 201, "Ď": 202, "ď": 203, "Đ": 204, "đ": 205, "Ē": 206, "ē": 207, "Ĕ": 208, "ĕ": 209, "Ė": 210, "ė": 211, "Ę": 212, "ę": 213, "Ġ": 220, "Ě": 214, "ě": 215, "Ĝ": 216, "ĝ": 217, "Ğ": 218, "ğ": 219, "ġ": 221, "Ģ": 222, "ģ": 223, "Ĥ": 224, "ĥ": 225, "Ħ": 226, "ħ": 227, "Ĩ": 228, "ĩ": 229, "Ī": 230, "ī": 231, "Ĭ": 232, "ĭ": 233, "Į": 234, "į": 235, "İ": 236, "ı": 237, "IJ": 238, "ij": 239, "Ĵ": 240, "ĵ": 241, "Ķ": 242, "ķ": 243, "ĸ": 244, "Ĺ": 245, "ĺ": 246, "Ļ": 247, "ļ": 248, "Ľ": 249, "ĠĠ": 256, "ľ": 250, "Ŀ": 251, "ŀ": 252, "Ł": 253, "rer": 38149, "ĠĠĠ": 262, "ł": 254, "Ń": 255, "'m": 2846, "'re": 2351, "can": 4919, "func": 2900, "()": 368, "Ġworld": 1917, "Ġmain": 1925, "00": 410, "123": 4513, "000": 931, "ca": 936, "'t": 956, "é": 978, "hello": 15339, "Ġw": 289, "orld": 1410, "Ġwor": 4191, "ld": 509, "main": 3902, "Ġm": 296, "ain": 467, "Ġma": 7643, "in": 258, "Ġmai": 17154, "re": 265, "'r": 97670, "unc": 1371, "fun": 12158, "fu": 33721, "nc": 1031, "ma": 1764, "mai": 77585, "wor": 50810, "or": 269, "Ġwo": 24670, "23": 1419, "12": 717, "{}": 6390, "Ġ{": 314, "an": 276, "ello": 4896, "Hel": 33813, "lo": 385, "Hell": 81394, "un": 359, "hel": 50222, "hell": 57195, "ai": 2192, "wo": 1146, "Ġh": 305, "Ġhel": 11591, "Ġhell": 15123, "el": 301, "He": 1548, "er": 261, "he": 383, "ell": 616, "ll": 657}, "merges": ["Ġ Ġ", "Ġ ĠĠ", "ĠĠ Ġ", "( )", "0 0", "0 00", "00 0", "c a", "' t", "à ©", "Ġ world", "Ġw orld", "Ġwor ld", "Ġ main", "Ġm ain", "Ġma in", "Ġmai n", "' re", "'r e", "' m", "f unc", "fun c", "fu nc", "m ain", "ma in", "mai n", "Ġ wor", "Ġw or", "Ġwo r", "1 23", "12 3", "Ġ {}", "Ġ{ }", "c an", "ca n", "{ }", "Ġ ma", "Ġm a", "H ello", "Hel lo", "Hell o", "W orld", "f un", "fu n", "w orld", "wor ld", "h ello", "hel lo", "hell o", "Ġ mai", "Ġm ai", "Ġma i", "Ġ wo", "Ġw o", "Ġ hello", "Ġh ello", "Ġhel lo", "Ġhell o", "f u", "H el", "He l", "r er", "re r", "h el", "he l", "w or", "wo r", "h ell", "he ll", "hel l", "f é", "m ai", "ma i", "H ell", "He ll", "Hel l", "' r"]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, "behavior": "Isolated", "invert": false}, {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": false}]}, "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true}, "added_tokens": [{"id": 128000, "content": "<|begin_of_text|>", "special": true}, {"id": 128001, "content": "<|end_of_text|>", "special": true}]}
|
||||
1171
x/imagegen/tokenizer/tokenizer.go
Normal file
1171
x/imagegen/tokenizer/tokenizer.go
Normal file
File diff suppressed because it is too large
Load Diff
783
x/imagegen/tokenizer/tokenizer_test.go
Normal file
783
x/imagegen/tokenizer/tokenizer_test.go
Normal file
@@ -0,0 +1,783 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPatternCompilation validates that HuggingFace pretokenizer patterns
|
||||
// can be rewritten for Go's RE2 regexp engine and compiled successfully.
|
||||
func TestPatternCompilation(t *testing.T) {
|
||||
patterns := []struct {
|
||||
name string
|
||||
pattern string
|
||||
}{
|
||||
{"llama3", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"qwen2", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt4o", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt2", `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`},
|
||||
{"deepseek_cjk", `[一-龥\x{3040}-ゟ゠-ヿ]+`},
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
rewritten := rewritePatternForRE2(p.pattern)
|
||||
if _, err := regexp.Compile(rewritten); err != nil {
|
||||
t.Errorf("failed to compile pattern: %v\noriginal: %s\nrewritten: %s", err, p.pattern, rewritten)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundtrip verifies the fundamental property: encode(text) -> decode -> text
|
||||
// This is the key invariant from tiktoken's test suite.
|
||||
func TestRoundtrip(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Test cases covering key edge cases from tiktoken
|
||||
inputs := []string{
|
||||
// Empty and simple
|
||||
"",
|
||||
"a",
|
||||
"hello",
|
||||
"hello world",
|
||||
|
||||
// Whitespace edge cases
|
||||
" ",
|
||||
" ",
|
||||
" ",
|
||||
" hello",
|
||||
"hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"hello world",
|
||||
"\t",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"hello\nworld",
|
||||
"hello\n\nworld",
|
||||
|
||||
// Contractions
|
||||
"don't",
|
||||
"I'm",
|
||||
"we'll",
|
||||
"they're",
|
||||
"it's",
|
||||
"DON'T", // uppercase
|
||||
|
||||
// Numbers
|
||||
"123",
|
||||
"1234567890",
|
||||
"3.14159",
|
||||
"$100",
|
||||
"50%",
|
||||
|
||||
// Unicode
|
||||
"こんにちは", // Japanese
|
||||
"你好", // Chinese
|
||||
"مرحبا", // Arabic (RTL)
|
||||
"🎉", // Emoji
|
||||
"Hello 世界", // Mixed
|
||||
"café", // Accented
|
||||
"naïve", // Diaeresis
|
||||
"Ω≈ç√∫", // Math symbols
|
||||
|
||||
// Code
|
||||
"func main() {}",
|
||||
"if (x == 0) { return; }",
|
||||
"import \"fmt\"",
|
||||
"x := 42",
|
||||
"// comment",
|
||||
"/* block */",
|
||||
|
||||
// Repetition (tiktoken specifically tests this)
|
||||
"aaaa",
|
||||
"aaaaaaaaaaaa",
|
||||
strings.Repeat("a", 100),
|
||||
strings.Repeat("hello ", 50),
|
||||
|
||||
// Punctuation
|
||||
"...",
|
||||
"!!!",
|
||||
"???",
|
||||
"hello, world!",
|
||||
"(parentheses)",
|
||||
"[brackets]",
|
||||
"{braces}",
|
||||
|
||||
// Mixed complexity
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
|
||||
"func TestRoundtrip(t *testing.T) { t.Run(\"test\", func(t *testing.T) {}) }",
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
name := input
|
||||
if len(name) > 30 {
|
||||
name = name[:30] + "..."
|
||||
}
|
||||
if name == "" {
|
||||
name = "<empty>"
|
||||
}
|
||||
name = strings.ReplaceAll(name, "\n", "\\n")
|
||||
name = strings.ReplaceAll(name, "\t", "\\t")
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n tokens: %v\n decoded: %q", input, tokens, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpecialTokens verifies that special tokens are handled correctly
|
||||
func TestSpecialTokens(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Special tokens should be preserved through encode/decode
|
||||
t.Run("bos_preserved", func(t *testing.T) {
|
||||
if tok.BOS() < 0 {
|
||||
t.Skip("no BOS token")
|
||||
}
|
||||
tokens := tok.Encode("hello", true)
|
||||
if len(tokens) == 0 || tokens[0] != tok.BOS() {
|
||||
t.Errorf("BOS not prepended: got %v, want first token to be %d", tokens, tok.BOS())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_token_split", func(t *testing.T) {
|
||||
// If we have special tokens, verify they're split correctly
|
||||
for tokenStr, tokenID := range tok.specialTokens {
|
||||
input := "before" + tokenStr + "after"
|
||||
tokens := tok.Encode(input, false)
|
||||
|
||||
found := false
|
||||
for _, id := range tokens {
|
||||
if id == tokenID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("special token %q (id=%d) not found in encoding of %q: %v",
|
||||
tokenStr, tokenID, input, tokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrency verifies thread-safe encoding
|
||||
func TestConcurrency(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "The quick brown fox jumps over the lazy dog."
|
||||
expected := tok.Encode(input, false)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
got := tok.Encode(input, false)
|
||||
if len(got) != len(expected) {
|
||||
errors <- nil // just signal error
|
||||
return
|
||||
}
|
||||
for j := range got {
|
||||
if got[j] != expected[j] {
|
||||
errors <- nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("concurrent encoding produced inconsistent results")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration runs against real model directories, comparing with Python transformers.
|
||||
// Skips if model weights are not available.
|
||||
func TestIntegration(t *testing.T) {
|
||||
models := []string{
|
||||
"../weights/Llama-3.2-1B",
|
||||
"../weights/gemma-3-1b-it",
|
||||
"../weights/gpt-oss-20b",
|
||||
}
|
||||
|
||||
// Test inputs covering various edge cases
|
||||
inputs := []string{
|
||||
"Hello, world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"こんにちは世界",
|
||||
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
||||
"1234567890",
|
||||
" spaces ",
|
||||
"don't won't can't",
|
||||
}
|
||||
|
||||
for _, modelPath := range models {
|
||||
modelName := filepath.Base(modelPath)
|
||||
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
|
||||
if _, err := os.Stat(tokenizerPath); err != nil {
|
||||
t.Skipf("skipping: %s not found", tokenizerPath)
|
||||
}
|
||||
|
||||
tok, err := Load(tokenizerPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
t.Run(truncate(input, 20), func(t *testing.T) {
|
||||
// Test roundtrip
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n decoded: %q", input, decoded)
|
||||
}
|
||||
|
||||
// Compare with Python if available
|
||||
if pythonTokens, err := pythonEncode(modelPath, input); err == nil {
|
||||
if !equalInt32Slice(tokens, pythonTokens) {
|
||||
t.Errorf("mismatch with Python:\n go: %v\n python: %v", tokens, pythonTokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// pythonEncode calls Python transformers to encode text, for comparison
|
||||
func pythonEncode(modelPath, text string) ([]int32, error) {
|
||||
script := `
|
||||
import sys, json
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(sys.argv[1])
|
||||
tokens = tok.encode(sys.argv[2], add_special_tokens=False)
|
||||
print(json.dumps(tokens))
|
||||
`
|
||||
cmd := exec.Command("python3", "-c", script, modelPath, text)
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse JSON array
|
||||
var tokens []int32
|
||||
output := strings.TrimSpace(out.String())
|
||||
if output == "" || output == "[]" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
// Simple parsing for [1, 2, 3] format
|
||||
output = strings.Trim(output, "[]")
|
||||
if output == "" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(output, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
var v int32
|
||||
if _, err := parseIntSimple(s, &v); err == nil {
|
||||
tokens = append(tokens, v)
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func parseIntSimple(s string, v *int32) (bool, error) {
|
||||
var n int64
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int64(c-'0')
|
||||
}
|
||||
}
|
||||
*v = int32(n)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func equalInt32Slice(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// TestBPEPretokenizer verifies BPE pretokenizer splits text correctly
|
||||
// using the GPT-2 style regex pattern (no dependency on tokenizer files)
|
||||
func TestBPEPretokenizer(t *testing.T) {
|
||||
pattern := `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
re := regexp.MustCompile(rewritePatternForRE2(pattern))
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", " world"}},
|
||||
{"Hello, world!", []string{"Hello", ",", " world", "!"}},
|
||||
{"don't", []string{"don", "'t"}},
|
||||
{"I'm", []string{"I", "'m"}},
|
||||
{"123", []string{"123"}},
|
||||
{"12345", []string{"12345"}}, // GPT-2 pattern matches any digit sequence
|
||||
{"a b", []string{"a", " ", " b"}}, // whitespace boundary: last space prepends to word
|
||||
{" ", []string{" "}}, // pure whitespace stays together
|
||||
{"\n\n", []string{"\n\n"}}, // newlines stay together
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// Get regex matches
|
||||
matches := re.FindAllStringIndex(tt.input, -1)
|
||||
var chunks []string
|
||||
for _, m := range matches {
|
||||
chunks = append(chunks, tt.input[m[0]:m[1]])
|
||||
}
|
||||
|
||||
// Apply whitespace boundary fix (same logic as Encode)
|
||||
for i := 0; i < len(chunks)-1; i++ {
|
||||
if isNonNewlineWhitespace(chunks[i]) && len(chunks[i+1]) > 0 {
|
||||
r, _ := []rune(chunks[i+1])[0], 0
|
||||
if r >= 'A' && r <= 'z' { // simplified letter check
|
||||
// Move last space to next chunk
|
||||
if len(chunks[i]) > 0 {
|
||||
lastSpace := chunks[i][len(chunks[i])-1:]
|
||||
chunks[i] = chunks[i][:len(chunks[i])-1]
|
||||
chunks[i+1] = lastSpace + chunks[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter empty chunks
|
||||
var result []string
|
||||
for _, c := range chunks {
|
||||
if c != "" {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("chunk %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSentencePiecePretokenizer verifies SentencePiece doesn't use pretokenizer
|
||||
// and correctly replaces spaces with ▁ (no dependency on tokenizer files)
|
||||
func TestSentencePiecePretokenizer(t *testing.T) {
|
||||
// SentencePiece has no pretokenizer - whole text is one chunk
|
||||
// Spaces are replaced with ▁ during encoding
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string // after space replacement
|
||||
}{
|
||||
{"Hello", "Hello"},
|
||||
{"Hello world", "Hello▁world"},
|
||||
{"Hello, world!", "Hello,▁world!"},
|
||||
{" spaces ", "▁▁▁spaces▁▁▁"},
|
||||
{" Hello", "▁Hello"},
|
||||
{"Hello ", "Hello▁"},
|
||||
{"a b c", "a▁b▁c"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// SentencePiece encoding: replace space with ▁
|
||||
result := strings.ReplaceAll(tt.input, " ", "▁")
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWordPiecePretokenizer verifies WordPiece (BERT) pretokenizer splits correctly
|
||||
// BertPreTokenizer splits on whitespace and punctuation
|
||||
func TestWordPiecePretokenizer(t *testing.T) {
|
||||
// BertPreTokenizer behavior: split on whitespace and punctuation
|
||||
// Whitespace is stripped, punctuation becomes separate tokens
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", "world"}}, // whitespace stripped
|
||||
{"Hello, world!", []string{"Hello", ",", "world", "!"}}, // punct separate
|
||||
{"don't", []string{"don", "'", "t"}}, // apostrophe separate (unlike BPE)
|
||||
{" spaces ", []string{"spaces"}}, // whitespace stripped
|
||||
{"Hello.World", []string{"Hello", ".", "World"}}, // punct splits
|
||||
{"test@email.com", []string{"test", "@", "email", ".", "com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := splitBertStyle(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("token %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// splitBertStyle mimics BertPreTokenizer: split on whitespace and punctuation
|
||||
func splitBertStyle(s string) []string {
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
|
||||
for _, r := range s {
|
||||
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||||
// Whitespace: flush current token, don't add whitespace
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else if isPunct(r) {
|
||||
// Punctuation: flush current, add punct as separate token
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
result = append(result, string(r))
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isPunct(r rune) bool {
|
||||
// Common ASCII punctuation
|
||||
return (r >= '!' && r <= '/') || (r >= ':' && r <= '@') ||
|
||||
(r >= '[' && r <= '`') || (r >= '{' && r <= '~')
|
||||
}
|
||||
|
||||
// TestRepeatedDigits verifies correct tokenization of repeated digit sequences.
|
||||
// Llama-style tokenizers split digits in groups of 1-3 due to the \p{N}{1,3} pattern.
|
||||
func TestRepeatedDigits(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
// Pattern: 1 digit, 2 digits, 3 digits, then repeats
|
||||
// "0" -> [single], "00" -> [double], "000" -> [triple]
|
||||
// "0000" -> [triple, single], etc.
|
||||
tests := []struct {
|
||||
input string
|
||||
count int // expected token count
|
||||
}{
|
||||
{"0", 1},
|
||||
{"00", 1},
|
||||
{"000", 1},
|
||||
{"0000", 2}, // 3 + 1
|
||||
{"00000", 2}, // 3 + 2
|
||||
{"000000", 2}, // 3 + 3
|
||||
{"0000000", 3},
|
||||
{"00000000", 3},
|
||||
{"000000000", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
ids := tok.Encode(tt.input, false)
|
||||
if len(ids) != tt.count {
|
||||
t.Errorf("Encode(%q) = %d tokens, want %d", tt.input, len(ids), tt.count)
|
||||
}
|
||||
// Verify roundtrip
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != tt.input {
|
||||
t.Errorf("Decode(Encode(%q)) = %q", tt.input, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNullByte verifies that null bytes roundtrip correctly
|
||||
func TestNullByte(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
ids := tok.Encode("\x00", false)
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != "\x00" {
|
||||
t.Errorf("null byte roundtrip failed: got %q, want %q", decoded, "\x00")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenizerTypeDetection verifies correct detection of tokenizer types
|
||||
func TestTokenizerTypeDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decoder string
|
||||
expected TokenizerType
|
||||
}{
|
||||
{
|
||||
name: "ByteLevel decoder (BPE)",
|
||||
decoder: `{"type": "ByteLevel"}`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
{
|
||||
name: "Sequence with Replace ▁ (SentencePiece)",
|
||||
decoder: `{
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "}
|
||||
]
|
||||
}`,
|
||||
expected: TokenizerSentencePiece,
|
||||
},
|
||||
{
|
||||
name: "null decoder (BPE default)",
|
||||
decoder: `null`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isSPM := detectSentencePiece([]byte(tt.decoder))
|
||||
var got TokenizerType
|
||||
if isSPM {
|
||||
got = TokenizerSentencePiece
|
||||
} else {
|
||||
got = TokenizerBPE
|
||||
}
|
||||
if got != tt.expected {
|
||||
t.Errorf("got %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenDefault verifies PAD() returns -1 when not configured
|
||||
func TestPADTokenDefault(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// mini_llama.json has no PAD token configured, should return -1
|
||||
if got := tok.PAD(); got != -1 {
|
||||
t.Errorf("PAD() = %d, want -1 (not configured)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromConfig verifies PAD token is loaded from tokenizer_config.json
|
||||
func TestPADTokenFromConfig(t *testing.T) {
|
||||
// Create temp directory with tokenizer files
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<|endoftext|>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<|endoftext|>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token
|
||||
configJSON := `{
|
||||
"pad_token": "<|endoftext|>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<|endoftext|>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromSpecialTokensMap verifies PAD falls back to special_tokens_map.json
|
||||
func TestPADTokenFromSpecialTokensMap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<pad>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<pad>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write special_tokens_map.json with pad_token
|
||||
mapJSON := `{
|
||||
"pad_token": "<pad>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "special_tokens_map.json"), []byte(mapJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write special_tokens_map.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<pad>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenWithContentObject verifies PAD token works with {"content": "..."} format
|
||||
func TestPADTokenWithContentObject(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"[PAD]": 0, "hello": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "[PAD]", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token as object (HuggingFace format)
|
||||
configJSON := `{
|
||||
"pad_token": {"content": "[PAD]", "lstrip": false, "normalized": false}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 ([PAD])", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks
|
||||
|
||||
func BenchmarkEncode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{"short", "Hello, world!"},
|
||||
{"medium", "The quick brown fox jumps over the lazy dog. " + strings.Repeat("This is a test. ", 10)},
|
||||
{"long", strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)
|
||||
tokens := tok.Encode(text, false)
|
||||
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Decode(tokens)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user