184 lines
3.3 KiB
Go
184 lines
3.3 KiB
Go
package mlxthread
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"runtime"
|
|
"sync/atomic"
|
|
)
|
|
|
|
var ErrStopped = errors.New("mlx thread stopped")
|
|
|
|
type Thread struct {
|
|
name string
|
|
|
|
jobs chan job
|
|
done chan struct{}
|
|
stopping atomic.Bool
|
|
}
|
|
|
|
type job struct {
|
|
fn func() error
|
|
result chan result
|
|
stop bool
|
|
}
|
|
|
|
type result struct {
|
|
err error
|
|
panicValue any
|
|
}
|
|
|
|
// Start creates a long-lived worker goroutine locked to one OS thread.
|
|
func Start(name string, init func() error) (*Thread, error) {
|
|
t := &Thread{
|
|
name: name,
|
|
jobs: make(chan job),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
initResult := make(chan result, 1)
|
|
go t.loop(init, initResult)
|
|
|
|
res := <-initResult
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
if res.err != nil {
|
|
return nil, res.err
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// Do runs fn on the locked OS thread.
|
|
//
|
|
// Context cancellation only applies while the work is queued. Once the worker
|
|
// accepts a job, the job runs until fn returns or reaches its own cancellation
|
|
// checks.
|
|
func (t *Thread) Do(ctx context.Context, fn func() error) error {
|
|
res, err := t.enqueue(ctx, fn, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
return res.err
|
|
}
|
|
|
|
func Call[T any](ctx context.Context, t *Thread, fn func() (T, error)) (T, error) {
|
|
var value T
|
|
err := t.Do(ctx, func() error {
|
|
var err error
|
|
value, err = fn()
|
|
return err
|
|
})
|
|
return value, err
|
|
}
|
|
|
|
// Stop runs cleanup on the locked OS thread and then shuts the worker down.
|
|
func (t *Thread) Stop(ctx context.Context, cleanup func()) error {
|
|
ctx = contextOrBackground(ctx)
|
|
|
|
if !t.stopping.CompareAndSwap(false, true) {
|
|
select {
|
|
case <-t.done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
res, err := t.enqueue(ctx, func() error {
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
return nil
|
|
}, true, true)
|
|
if err != nil {
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.stopping.Store(false)
|
|
}
|
|
return err
|
|
}
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
if res.err != nil {
|
|
return res.err
|
|
}
|
|
|
|
select {
|
|
case <-t.done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (t *Thread) loop(init func() error, initResult chan<- result) {
|
|
runtime.LockOSThread()
|
|
// Deliberately do not unlock. MLX thread-local state belongs to this worker
|
|
// until shutdown so it cannot leak back to arbitrary Go goroutines.
|
|
|
|
res := run(init)
|
|
initResult <- res
|
|
if res.err != nil || res.panicValue != nil {
|
|
close(t.done)
|
|
return
|
|
}
|
|
|
|
for {
|
|
j := <-t.jobs
|
|
res := run(j.fn)
|
|
j.result <- res
|
|
if j.stop {
|
|
close(t.done)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *Thread) enqueue(ctx context.Context, fn func() error, stop, allowStopping bool) (result, error) {
|
|
ctx = contextOrBackground(ctx)
|
|
if err := ctx.Err(); err != nil {
|
|
return result{}, err
|
|
}
|
|
|
|
if !allowStopping && t.stopping.Load() {
|
|
return result{}, ErrStopped
|
|
}
|
|
|
|
resultCh := make(chan result, 1)
|
|
j := job{fn: fn, result: resultCh, stop: stop}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return result{}, ctx.Err()
|
|
case <-t.done:
|
|
return result{}, ErrStopped
|
|
case t.jobs <- j:
|
|
}
|
|
|
|
return <-resultCh, nil
|
|
}
|
|
|
|
func run(fn func() error) (res result) {
|
|
defer func() {
|
|
if v := recover(); v != nil {
|
|
res.panicValue = v
|
|
}
|
|
}()
|
|
if fn != nil {
|
|
res.err = fn()
|
|
}
|
|
return res
|
|
}
|
|
|
|
func contextOrBackground(ctx context.Context) context.Context {
|
|
if ctx != nil {
|
|
return ctx
|
|
}
|
|
return context.Background()
|
|
}
|