ollama source for Momentry Core verification
This commit is contained in:
183
x/internal/mlxthread/thread.go
Normal file
183
x/internal/mlxthread/thread.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mlxthread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var ErrStopped = errors.New("mlx thread stopped")
|
||||
|
||||
type Thread struct {
|
||||
name string
|
||||
|
||||
jobs chan job
|
||||
done chan struct{}
|
||||
stopping atomic.Bool
|
||||
}
|
||||
|
||||
type job struct {
|
||||
fn func() error
|
||||
result chan result
|
||||
stop bool
|
||||
}
|
||||
|
||||
type result struct {
|
||||
err error
|
||||
panicValue any
|
||||
}
|
||||
|
||||
// Start creates a long-lived worker goroutine locked to one OS thread.
|
||||
func Start(name string, init func() error) (*Thread, error) {
|
||||
t := &Thread{
|
||||
name: name,
|
||||
jobs: make(chan job),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
initResult := make(chan result, 1)
|
||||
go t.loop(init, initResult)
|
||||
|
||||
res := <-initResult
|
||||
if res.panicValue != nil {
|
||||
panic(res.panicValue)
|
||||
}
|
||||
if res.err != nil {
|
||||
return nil, res.err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Do runs fn on the locked OS thread.
|
||||
//
|
||||
// Context cancellation only applies while the work is queued. Once the worker
|
||||
// accepts a job, the job runs until fn returns or reaches its own cancellation
|
||||
// checks.
|
||||
func (t *Thread) Do(ctx context.Context, fn func() error) error {
|
||||
res, err := t.enqueue(ctx, fn, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.panicValue != nil {
|
||||
panic(res.panicValue)
|
||||
}
|
||||
return res.err
|
||||
}
|
||||
|
||||
func Call[T any](ctx context.Context, t *Thread, fn func() (T, error)) (T, error) {
|
||||
var value T
|
||||
err := t.Do(ctx, func() error {
|
||||
var err error
|
||||
value, err = fn()
|
||||
return err
|
||||
})
|
||||
return value, err
|
||||
}
|
||||
|
||||
// Stop runs cleanup on the locked OS thread and then shuts the worker down.
|
||||
func (t *Thread) Stop(ctx context.Context, cleanup func()) error {
|
||||
ctx = contextOrBackground(ctx)
|
||||
|
||||
if !t.stopping.CompareAndSwap(false, true) {
|
||||
select {
|
||||
case <-t.done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
res, err := t.enqueue(ctx, func() error {
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
return nil
|
||||
}, true, true)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrStopped) {
|
||||
t.stopping.Store(false)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if res.panicValue != nil {
|
||||
panic(res.panicValue)
|
||||
}
|
||||
if res.err != nil {
|
||||
return res.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-t.done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) loop(init func() error, initResult chan<- result) {
|
||||
runtime.LockOSThread()
|
||||
// Deliberately do not unlock. MLX thread-local state belongs to this worker
|
||||
// until shutdown so it cannot leak back to arbitrary Go goroutines.
|
||||
|
||||
res := run(init)
|
||||
initResult <- res
|
||||
if res.err != nil || res.panicValue != nil {
|
||||
close(t.done)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
j := <-t.jobs
|
||||
res := run(j.fn)
|
||||
j.result <- res
|
||||
if j.stop {
|
||||
close(t.done)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) enqueue(ctx context.Context, fn func() error, stop, allowStopping bool) (result, error) {
|
||||
ctx = contextOrBackground(ctx)
|
||||
if err := ctx.Err(); err != nil {
|
||||
return result{}, err
|
||||
}
|
||||
|
||||
if !allowStopping && t.stopping.Load() {
|
||||
return result{}, ErrStopped
|
||||
}
|
||||
|
||||
resultCh := make(chan result, 1)
|
||||
j := job{fn: fn, result: resultCh, stop: stop}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return result{}, ctx.Err()
|
||||
case <-t.done:
|
||||
return result{}, ErrStopped
|
||||
case t.jobs <- j:
|
||||
}
|
||||
|
||||
return <-resultCh, nil
|
||||
}
|
||||
|
||||
func run(fn func() error) (res result) {
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
res.panicValue = v
|
||||
}
|
||||
}()
|
||||
if fn != nil {
|
||||
res.err = fn()
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func contextOrBackground(ctx context.Context) context.Context {
|
||||
if ctx != nil {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
Reference in New Issue
Block a user