ollama source for Momentry Core verification
This commit is contained in:
783
x/imagegen/tokenizer/tokenizer_test.go
Normal file
783
x/imagegen/tokenizer/tokenizer_test.go
Normal file
@@ -0,0 +1,783 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPatternCompilation validates that HuggingFace pretokenizer patterns
|
||||
// can be rewritten for Go's RE2 regexp engine and compiled successfully.
|
||||
func TestPatternCompilation(t *testing.T) {
|
||||
patterns := []struct {
|
||||
name string
|
||||
pattern string
|
||||
}{
|
||||
{"llama3", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"qwen2", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt4o", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt2", `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`},
|
||||
{"deepseek_cjk", `[一-龥\x{3040}-ゟ゠-ヿ]+`},
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
rewritten := rewritePatternForRE2(p.pattern)
|
||||
if _, err := regexp.Compile(rewritten); err != nil {
|
||||
t.Errorf("failed to compile pattern: %v\noriginal: %s\nrewritten: %s", err, p.pattern, rewritten)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundtrip verifies the fundamental property: encode(text) -> decode -> text
|
||||
// This is the key invariant from tiktoken's test suite.
|
||||
func TestRoundtrip(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Test cases covering key edge cases from tiktoken
|
||||
inputs := []string{
|
||||
// Empty and simple
|
||||
"",
|
||||
"a",
|
||||
"hello",
|
||||
"hello world",
|
||||
|
||||
// Whitespace edge cases
|
||||
" ",
|
||||
" ",
|
||||
" ",
|
||||
" hello",
|
||||
"hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"hello world",
|
||||
"\t",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"hello\nworld",
|
||||
"hello\n\nworld",
|
||||
|
||||
// Contractions
|
||||
"don't",
|
||||
"I'm",
|
||||
"we'll",
|
||||
"they're",
|
||||
"it's",
|
||||
"DON'T", // uppercase
|
||||
|
||||
// Numbers
|
||||
"123",
|
||||
"1234567890",
|
||||
"3.14159",
|
||||
"$100",
|
||||
"50%",
|
||||
|
||||
// Unicode
|
||||
"こんにちは", // Japanese
|
||||
"你好", // Chinese
|
||||
"مرحبا", // Arabic (RTL)
|
||||
"🎉", // Emoji
|
||||
"Hello 世界", // Mixed
|
||||
"café", // Accented
|
||||
"naïve", // Diaeresis
|
||||
"Ω≈ç√∫", // Math symbols
|
||||
|
||||
// Code
|
||||
"func main() {}",
|
||||
"if (x == 0) { return; }",
|
||||
"import \"fmt\"",
|
||||
"x := 42",
|
||||
"// comment",
|
||||
"/* block */",
|
||||
|
||||
// Repetition (tiktoken specifically tests this)
|
||||
"aaaa",
|
||||
"aaaaaaaaaaaa",
|
||||
strings.Repeat("a", 100),
|
||||
strings.Repeat("hello ", 50),
|
||||
|
||||
// Punctuation
|
||||
"...",
|
||||
"!!!",
|
||||
"???",
|
||||
"hello, world!",
|
||||
"(parentheses)",
|
||||
"[brackets]",
|
||||
"{braces}",
|
||||
|
||||
// Mixed complexity
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
|
||||
"func TestRoundtrip(t *testing.T) { t.Run(\"test\", func(t *testing.T) {}) }",
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
name := input
|
||||
if len(name) > 30 {
|
||||
name = name[:30] + "..."
|
||||
}
|
||||
if name == "" {
|
||||
name = "<empty>"
|
||||
}
|
||||
name = strings.ReplaceAll(name, "\n", "\\n")
|
||||
name = strings.ReplaceAll(name, "\t", "\\t")
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n tokens: %v\n decoded: %q", input, tokens, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpecialTokens verifies that special tokens are handled correctly
|
||||
func TestSpecialTokens(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Special tokens should be preserved through encode/decode
|
||||
t.Run("bos_preserved", func(t *testing.T) {
|
||||
if tok.BOS() < 0 {
|
||||
t.Skip("no BOS token")
|
||||
}
|
||||
tokens := tok.Encode("hello", true)
|
||||
if len(tokens) == 0 || tokens[0] != tok.BOS() {
|
||||
t.Errorf("BOS not prepended: got %v, want first token to be %d", tokens, tok.BOS())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_token_split", func(t *testing.T) {
|
||||
// If we have special tokens, verify they're split correctly
|
||||
for tokenStr, tokenID := range tok.specialTokens {
|
||||
input := "before" + tokenStr + "after"
|
||||
tokens := tok.Encode(input, false)
|
||||
|
||||
found := false
|
||||
for _, id := range tokens {
|
||||
if id == tokenID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("special token %q (id=%d) not found in encoding of %q: %v",
|
||||
tokenStr, tokenID, input, tokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrency verifies thread-safe encoding
|
||||
func TestConcurrency(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "The quick brown fox jumps over the lazy dog."
|
||||
expected := tok.Encode(input, false)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
got := tok.Encode(input, false)
|
||||
if len(got) != len(expected) {
|
||||
errors <- nil // just signal error
|
||||
return
|
||||
}
|
||||
for j := range got {
|
||||
if got[j] != expected[j] {
|
||||
errors <- nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("concurrent encoding produced inconsistent results")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration runs against real model directories, comparing with Python transformers.
|
||||
// Skips if model weights are not available.
|
||||
func TestIntegration(t *testing.T) {
|
||||
models := []string{
|
||||
"../weights/Llama-3.2-1B",
|
||||
"../weights/gemma-3-1b-it",
|
||||
"../weights/gpt-oss-20b",
|
||||
}
|
||||
|
||||
// Test inputs covering various edge cases
|
||||
inputs := []string{
|
||||
"Hello, world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"こんにちは世界",
|
||||
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
||||
"1234567890",
|
||||
" spaces ",
|
||||
"don't won't can't",
|
||||
}
|
||||
|
||||
for _, modelPath := range models {
|
||||
modelName := filepath.Base(modelPath)
|
||||
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
|
||||
if _, err := os.Stat(tokenizerPath); err != nil {
|
||||
t.Skipf("skipping: %s not found", tokenizerPath)
|
||||
}
|
||||
|
||||
tok, err := Load(tokenizerPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
t.Run(truncate(input, 20), func(t *testing.T) {
|
||||
// Test roundtrip
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n decoded: %q", input, decoded)
|
||||
}
|
||||
|
||||
// Compare with Python if available
|
||||
if pythonTokens, err := pythonEncode(modelPath, input); err == nil {
|
||||
if !equalInt32Slice(tokens, pythonTokens) {
|
||||
t.Errorf("mismatch with Python:\n go: %v\n python: %v", tokens, pythonTokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// pythonEncode calls Python transformers to encode text, for comparison
|
||||
func pythonEncode(modelPath, text string) ([]int32, error) {
|
||||
script := `
|
||||
import sys, json
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(sys.argv[1])
|
||||
tokens = tok.encode(sys.argv[2], add_special_tokens=False)
|
||||
print(json.dumps(tokens))
|
||||
`
|
||||
cmd := exec.Command("python3", "-c", script, modelPath, text)
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse JSON array
|
||||
var tokens []int32
|
||||
output := strings.TrimSpace(out.String())
|
||||
if output == "" || output == "[]" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
// Simple parsing for [1, 2, 3] format
|
||||
output = strings.Trim(output, "[]")
|
||||
if output == "" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(output, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
var v int32
|
||||
if _, err := parseIntSimple(s, &v); err == nil {
|
||||
tokens = append(tokens, v)
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func parseIntSimple(s string, v *int32) (bool, error) {
|
||||
var n int64
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int64(c-'0')
|
||||
}
|
||||
}
|
||||
*v = int32(n)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func equalInt32Slice(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// TestBPEPretokenizer verifies BPE pretokenizer splits text correctly
|
||||
// using the GPT-2 style regex pattern (no dependency on tokenizer files)
|
||||
func TestBPEPretokenizer(t *testing.T) {
|
||||
pattern := `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
re := regexp.MustCompile(rewritePatternForRE2(pattern))
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", " world"}},
|
||||
{"Hello, world!", []string{"Hello", ",", " world", "!"}},
|
||||
{"don't", []string{"don", "'t"}},
|
||||
{"I'm", []string{"I", "'m"}},
|
||||
{"123", []string{"123"}},
|
||||
{"12345", []string{"12345"}}, // GPT-2 pattern matches any digit sequence
|
||||
{"a b", []string{"a", " ", " b"}}, // whitespace boundary: last space prepends to word
|
||||
{" ", []string{" "}}, // pure whitespace stays together
|
||||
{"\n\n", []string{"\n\n"}}, // newlines stay together
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// Get regex matches
|
||||
matches := re.FindAllStringIndex(tt.input, -1)
|
||||
var chunks []string
|
||||
for _, m := range matches {
|
||||
chunks = append(chunks, tt.input[m[0]:m[1]])
|
||||
}
|
||||
|
||||
// Apply whitespace boundary fix (same logic as Encode)
|
||||
for i := 0; i < len(chunks)-1; i++ {
|
||||
if isNonNewlineWhitespace(chunks[i]) && len(chunks[i+1]) > 0 {
|
||||
r, _ := []rune(chunks[i+1])[0], 0
|
||||
if r >= 'A' && r <= 'z' { // simplified letter check
|
||||
// Move last space to next chunk
|
||||
if len(chunks[i]) > 0 {
|
||||
lastSpace := chunks[i][len(chunks[i])-1:]
|
||||
chunks[i] = chunks[i][:len(chunks[i])-1]
|
||||
chunks[i+1] = lastSpace + chunks[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter empty chunks
|
||||
var result []string
|
||||
for _, c := range chunks {
|
||||
if c != "" {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("chunk %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSentencePiecePretokenizer verifies SentencePiece doesn't use pretokenizer
|
||||
// and correctly replaces spaces with ▁ (no dependency on tokenizer files)
|
||||
func TestSentencePiecePretokenizer(t *testing.T) {
|
||||
// SentencePiece has no pretokenizer - whole text is one chunk
|
||||
// Spaces are replaced with ▁ during encoding
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string // after space replacement
|
||||
}{
|
||||
{"Hello", "Hello"},
|
||||
{"Hello world", "Hello▁world"},
|
||||
{"Hello, world!", "Hello,▁world!"},
|
||||
{" spaces ", "▁▁▁spaces▁▁▁"},
|
||||
{" Hello", "▁Hello"},
|
||||
{"Hello ", "Hello▁"},
|
||||
{"a b c", "a▁b▁c"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// SentencePiece encoding: replace space with ▁
|
||||
result := strings.ReplaceAll(tt.input, " ", "▁")
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWordPiecePretokenizer verifies WordPiece (BERT) pretokenizer splits correctly
|
||||
// BertPreTokenizer splits on whitespace and punctuation
|
||||
func TestWordPiecePretokenizer(t *testing.T) {
|
||||
// BertPreTokenizer behavior: split on whitespace and punctuation
|
||||
// Whitespace is stripped, punctuation becomes separate tokens
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", "world"}}, // whitespace stripped
|
||||
{"Hello, world!", []string{"Hello", ",", "world", "!"}}, // punct separate
|
||||
{"don't", []string{"don", "'", "t"}}, // apostrophe separate (unlike BPE)
|
||||
{" spaces ", []string{"spaces"}}, // whitespace stripped
|
||||
{"Hello.World", []string{"Hello", ".", "World"}}, // punct splits
|
||||
{"test@email.com", []string{"test", "@", "email", ".", "com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := splitBertStyle(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("token %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// splitBertStyle mimics BertPreTokenizer: split on whitespace and punctuation
|
||||
func splitBertStyle(s string) []string {
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
|
||||
for _, r := range s {
|
||||
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||||
// Whitespace: flush current token, don't add whitespace
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else if isPunct(r) {
|
||||
// Punctuation: flush current, add punct as separate token
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
result = append(result, string(r))
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isPunct(r rune) bool {
|
||||
// Common ASCII punctuation
|
||||
return (r >= '!' && r <= '/') || (r >= ':' && r <= '@') ||
|
||||
(r >= '[' && r <= '`') || (r >= '{' && r <= '~')
|
||||
}
|
||||
|
||||
// TestRepeatedDigits verifies correct tokenization of repeated digit sequences.
|
||||
// Llama-style tokenizers split digits in groups of 1-3 due to the \p{N}{1,3} pattern.
|
||||
func TestRepeatedDigits(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
// Pattern: 1 digit, 2 digits, 3 digits, then repeats
|
||||
// "0" -> [single], "00" -> [double], "000" -> [triple]
|
||||
// "0000" -> [triple, single], etc.
|
||||
tests := []struct {
|
||||
input string
|
||||
count int // expected token count
|
||||
}{
|
||||
{"0", 1},
|
||||
{"00", 1},
|
||||
{"000", 1},
|
||||
{"0000", 2}, // 3 + 1
|
||||
{"00000", 2}, // 3 + 2
|
||||
{"000000", 2}, // 3 + 3
|
||||
{"0000000", 3},
|
||||
{"00000000", 3},
|
||||
{"000000000", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
ids := tok.Encode(tt.input, false)
|
||||
if len(ids) != tt.count {
|
||||
t.Errorf("Encode(%q) = %d tokens, want %d", tt.input, len(ids), tt.count)
|
||||
}
|
||||
// Verify roundtrip
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != tt.input {
|
||||
t.Errorf("Decode(Encode(%q)) = %q", tt.input, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNullByte verifies that null bytes roundtrip correctly
|
||||
func TestNullByte(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
ids := tok.Encode("\x00", false)
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != "\x00" {
|
||||
t.Errorf("null byte roundtrip failed: got %q, want %q", decoded, "\x00")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenizerTypeDetection verifies correct detection of tokenizer types
|
||||
func TestTokenizerTypeDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decoder string
|
||||
expected TokenizerType
|
||||
}{
|
||||
{
|
||||
name: "ByteLevel decoder (BPE)",
|
||||
decoder: `{"type": "ByteLevel"}`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
{
|
||||
name: "Sequence with Replace ▁ (SentencePiece)",
|
||||
decoder: `{
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "}
|
||||
]
|
||||
}`,
|
||||
expected: TokenizerSentencePiece,
|
||||
},
|
||||
{
|
||||
name: "null decoder (BPE default)",
|
||||
decoder: `null`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isSPM := detectSentencePiece([]byte(tt.decoder))
|
||||
var got TokenizerType
|
||||
if isSPM {
|
||||
got = TokenizerSentencePiece
|
||||
} else {
|
||||
got = TokenizerBPE
|
||||
}
|
||||
if got != tt.expected {
|
||||
t.Errorf("got %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenDefault verifies PAD() returns -1 when not configured
|
||||
func TestPADTokenDefault(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// mini_llama.json has no PAD token configured, should return -1
|
||||
if got := tok.PAD(); got != -1 {
|
||||
t.Errorf("PAD() = %d, want -1 (not configured)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromConfig verifies PAD token is loaded from tokenizer_config.json
|
||||
func TestPADTokenFromConfig(t *testing.T) {
|
||||
// Create temp directory with tokenizer files
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<|endoftext|>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<|endoftext|>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token
|
||||
configJSON := `{
|
||||
"pad_token": "<|endoftext|>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<|endoftext|>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromSpecialTokensMap verifies PAD falls back to special_tokens_map.json
|
||||
func TestPADTokenFromSpecialTokensMap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<pad>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<pad>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write special_tokens_map.json with pad_token
|
||||
mapJSON := `{
|
||||
"pad_token": "<pad>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "special_tokens_map.json"), []byte(mapJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write special_tokens_map.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<pad>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenWithContentObject verifies PAD token works with {"content": "..."} format
|
||||
func TestPADTokenWithContentObject(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"[PAD]": 0, "hello": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "[PAD]", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token as object (HuggingFace format)
|
||||
configJSON := `{
|
||||
"pad_token": {"content": "[PAD]", "lstrip": false, "normalized": false}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 ([PAD])", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks
|
||||
|
||||
func BenchmarkEncode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{"short", "Hello, world!"},
|
||||
{"medium", "The quick brown fox jumps over the lazy dog. " + strings.Repeat("This is a test. ", 10)},
|
||||
{"long", strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)
|
||||
tokens := tok.Encode(text, false)
|
||||
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Decode(tokens)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user