281 lines
7.2 KiB
Go
281 lines
7.2 KiB
Go
package gemma4
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"math/cmplx"
|
|
)
|
|
|
|
// Audio preprocessing constants.
|
|
const (
|
|
audioSampleRate = 16000
|
|
melBins = 128
|
|
frameLengthMs = 20.0
|
|
hopLengthMs = 10.0
|
|
minFrequency = 0.0
|
|
maxFrequency = 8000.0
|
|
melFloor = 1e-3
|
|
maxAudioSoftTokens = 750
|
|
)
|
|
|
|
// Computed from the above constants.
|
|
var (
|
|
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
|
|
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
|
|
)
|
|
|
|
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
|
|
func decodeWAV(data []byte) ([]float32, error) {
|
|
if len(data) < 12 {
|
|
return nil, fmt.Errorf("WAV file too short")
|
|
}
|
|
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
|
|
return nil, fmt.Errorf("not a WAV file")
|
|
}
|
|
|
|
var audioFormat uint16
|
|
var numChannels, sampleRate, bitsPerSample int
|
|
var audioData []byte
|
|
foundFmt := false
|
|
|
|
offset := 12
|
|
for offset+8 <= len(data) {
|
|
chunkID := string(data[offset : offset+4])
|
|
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
|
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
|
|
|
|
switch chunkID {
|
|
case "fmt ":
|
|
if len(chunkData) < 16 {
|
|
return nil, fmt.Errorf("fmt chunk too short")
|
|
}
|
|
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
|
|
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
|
|
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
|
|
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
|
|
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
|
|
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
|
|
}
|
|
foundFmt = true
|
|
case "data":
|
|
audioData = chunkData
|
|
}
|
|
|
|
offset += 8 + chunkSize
|
|
if chunkSize%2 != 0 {
|
|
offset++
|
|
}
|
|
}
|
|
|
|
if !foundFmt {
|
|
return nil, fmt.Errorf("no fmt chunk found in WAV file")
|
|
}
|
|
if audioFormat != 1 && audioFormat != 3 {
|
|
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
|
|
}
|
|
if audioData == nil {
|
|
return nil, fmt.Errorf("no data chunk found in WAV file")
|
|
}
|
|
|
|
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
|
|
if sampleRate != audioSampleRate {
|
|
samples = resampleLinear(samples, sampleRate, audioSampleRate)
|
|
}
|
|
return samples, nil
|
|
}
|
|
|
|
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
|
|
bytesPerSample := bits / 8
|
|
totalSamples := len(data) / (bytesPerSample * channels)
|
|
mono := make([]float32, totalSamples)
|
|
|
|
for i := range totalSamples {
|
|
var sum float64
|
|
for ch := range channels {
|
|
off := (i*channels + ch) * bytesPerSample
|
|
if off+bytesPerSample > len(data) {
|
|
break
|
|
}
|
|
switch {
|
|
case format == 1 && bits == 16:
|
|
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
|
|
sum += float64(v) / 32768.0
|
|
case format == 1 && bits == 32:
|
|
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
|
|
sum += float64(v) / 2147483648.0
|
|
case format == 1 && bits == 24:
|
|
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
|
|
if v&0x800000 != 0 {
|
|
v |= ^0xFFFFFF
|
|
}
|
|
sum += float64(v) / 8388608.0
|
|
case format == 3 && bits == 32:
|
|
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
|
|
sum += float64(v)
|
|
case format == 1 && bits == 8:
|
|
sum += (float64(data[off]) - 128.0) / 128.0
|
|
}
|
|
}
|
|
mono[i] = float32(sum / float64(channels))
|
|
}
|
|
return mono
|
|
}
|
|
|
|
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
|
|
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
|
|
out := make([]float32, n)
|
|
for i := range n {
|
|
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
|
|
idx := int(pos)
|
|
frac := float32(pos - float64(idx))
|
|
if idx+1 < len(samples) {
|
|
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
|
|
} else {
|
|
out[i] = samples[idx]
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
|
|
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
|
|
func computeMelSpectrogram(samples []float32) ([]float32, int) {
|
|
fftLen := 1
|
|
for fftLen < frameLength {
|
|
fftLen <<= 1
|
|
}
|
|
fftLen *= 2 // fft_overdrive=True
|
|
|
|
// Hanning-nonzero window.
|
|
window := make([]float64, frameLength)
|
|
arg := math.Pi * 2.0 / float64(frameLength)
|
|
for i := range frameLength {
|
|
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
|
|
}
|
|
|
|
numFreqBins := fftLen/2 + 1
|
|
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
|
|
|
|
frameSizeForUnfold := frameLength + 1
|
|
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
|
|
if numFrames <= 0 {
|
|
return nil, 0
|
|
}
|
|
|
|
result := make([]float32, numFrames*melBins)
|
|
fftInput := make([]complex128, fftLen)
|
|
|
|
for f := range numFrames {
|
|
start := f * hopLength
|
|
for i := range frameLength {
|
|
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
|
|
}
|
|
for i := frameLength; i < fftLen; i++ {
|
|
fftInput[i] = 0
|
|
}
|
|
|
|
fft(fftInput)
|
|
|
|
for m := range melBins {
|
|
var melVal float64
|
|
for k := range numFreqBins {
|
|
mag := cmplx.Abs(fftInput[k])
|
|
melVal += mag * float64(melFilters[k*melBins+m])
|
|
}
|
|
if melVal < melFloor {
|
|
melVal = melFloor
|
|
}
|
|
result[f*melBins+m] = float32(math.Log(melVal))
|
|
}
|
|
}
|
|
|
|
return result, numFrames
|
|
}
|
|
|
|
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
|
|
hzToMel := func(f float64) float64 {
|
|
return 2595.0 * math.Log10(1.0+f/700.0)
|
|
}
|
|
melToHz := func(m float64) float64 {
|
|
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
|
|
}
|
|
|
|
melMin := hzToMel(fMin)
|
|
melMax := hzToMel(fMax)
|
|
|
|
melPts := make([]float64, numMels+2)
|
|
for i := range melPts {
|
|
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
|
|
}
|
|
filterFreqs := make([]float64, numMels+2)
|
|
for i, m := range melPts {
|
|
filterFreqs[i] = melToHz(m)
|
|
}
|
|
|
|
fftFreqs := make([]float64, numFreqBins)
|
|
for i := range fftFreqs {
|
|
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
|
|
}
|
|
|
|
filters := make([]float32, numFreqBins*numMels)
|
|
for m := range numMels {
|
|
fLeft := filterFreqs[m]
|
|
fCenter := filterFreqs[m+1]
|
|
fRight := filterFreqs[m+2]
|
|
for k := range numFreqBins {
|
|
f := fftFreqs[k]
|
|
var v float64
|
|
if f >= fLeft && f <= fCenter && fCenter > fLeft {
|
|
v = (f - fLeft) / (fCenter - fLeft)
|
|
} else if f > fCenter && f <= fRight && fRight > fCenter {
|
|
v = (fRight - f) / (fRight - fCenter)
|
|
}
|
|
if v > 0 {
|
|
filters[k*numMels+m] = float32(v)
|
|
}
|
|
}
|
|
}
|
|
return filters
|
|
}
|
|
|
|
// fft performs an in-place Cooley-Tukey radix-2 FFT.
|
|
func fft(x []complex128) {
|
|
n := len(x)
|
|
if n <= 1 {
|
|
return
|
|
}
|
|
|
|
j := 0
|
|
for i := 1; i < n; i++ {
|
|
bit := n >> 1
|
|
for j&bit != 0 {
|
|
j ^= bit
|
|
bit >>= 1
|
|
}
|
|
j ^= bit
|
|
if i < j {
|
|
x[i], x[j] = x[j], x[i]
|
|
}
|
|
}
|
|
|
|
for size := 2; size <= n; size <<= 1 {
|
|
halfSize := size / 2
|
|
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
|
|
for start := 0; start < n; start += size {
|
|
wn := complex(1, 0)
|
|
for k := range halfSize {
|
|
t := wn * x[start+k+halfSize]
|
|
x[start+k+halfSize] = x[start+k] - t
|
|
x[start+k] = x[start+k] + t
|
|
wn *= w
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// isAudioData checks if the data starts with WAV magic bytes.
|
|
func isAudioData(data []byte) bool {
|
|
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
|
|
}
|