352 lines
7.1 KiB
Go
352 lines
7.1 KiB
Go
package mlxthread
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestDoRunsInOrder(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
var got []int
|
|
for i := 0; i < 5; i++ {
|
|
i := i
|
|
if err := thread.Do(context.Background(), func() error {
|
|
got = append(got, i)
|
|
return nil
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if want := []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDoPropagatesPanicToCaller(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
defer func() {
|
|
if got := recover(); got != "boom" {
|
|
t.Fatalf("got panic %v, want boom", got)
|
|
}
|
|
}()
|
|
|
|
_ = thread.Do(context.Background(), func() error {
|
|
panic("boom")
|
|
})
|
|
}
|
|
|
|
func TestDoCancelsBeforeJobStarts(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err = thread.Do(ctx, func() error {
|
|
t.Fatal("canceled job should not run")
|
|
return nil
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
|
|
close(release)
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestAlreadyCanceledContextDoesNotEnqueue(t *testing.T) {
|
|
t.Run("Do", func(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
ran := false
|
|
err = thread.Do(ctx, func() error {
|
|
ran = true
|
|
return nil
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
if ran {
|
|
t.Fatal("canceled job ran")
|
|
}
|
|
})
|
|
|
|
t.Run("Stop", func(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
cleaned := false
|
|
err = thread.Stop(ctx, func() {
|
|
cleaned = true
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
if cleaned {
|
|
t.Fatal("cleanup ran for canceled stop")
|
|
}
|
|
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
|
|
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCallReturnsValue(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
got, err := Call(context.Background(), thread, func() (int, error) {
|
|
return 42, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got != 42 {
|
|
t.Fatalf("got %d, want 42", got)
|
|
}
|
|
}
|
|
|
|
func TestDoRunsConcurrentlySubmittedWorkSerially(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
oldProcs := runtime.GOMAXPROCS(8)
|
|
defer runtime.GOMAXPROCS(oldProcs)
|
|
|
|
const goroutines = 16
|
|
const iterations = 64
|
|
|
|
var active atomic.Int32
|
|
var count atomic.Int64
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, goroutines)
|
|
|
|
for range goroutines {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
for range iterations {
|
|
if err := thread.Do(context.Background(), func() error {
|
|
if got := active.Add(1); got != 1 {
|
|
return errors.New("thread executed jobs concurrently")
|
|
}
|
|
runtime.Gosched()
|
|
count.Add(1)
|
|
if got := active.Add(-1); got != 0 {
|
|
return errors.New("thread active count did not return to zero")
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
for err := range errCh {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := count.Load(), int64(goroutines*iterations); got != want {
|
|
t.Fatalf("got %d jobs, want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestStopRunsCleanupAndRejectsWork(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cleaned := 0
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleaned++
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if cleaned != 1 {
|
|
t.Fatalf("cleanup ran %d times, want 1", cleaned)
|
|
}
|
|
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleaned++
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if cleaned != 1 {
|
|
t.Fatalf("cleanup ran %d times after second Stop, want 1", cleaned)
|
|
}
|
|
|
|
err = thread.Do(context.Background(), func() error {
|
|
t.Fatal("job should not run after stop")
|
|
return nil
|
|
})
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.Fatalf("got %v, want %v", err, ErrStopped)
|
|
}
|
|
}
|
|
|
|
func TestStopCanceledBeforeEnqueueCanBeRetried(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
|
defer cancel()
|
|
|
|
cleanupRan := false
|
|
err = thread.Stop(ctx, func() {
|
|
cleanupRan = true
|
|
})
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Fatalf("got %v, want %v", err, context.DeadlineExceeded)
|
|
}
|
|
if cleanupRan {
|
|
t.Fatal("cleanup ran even though stop was not enqueued")
|
|
}
|
|
|
|
close(release)
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
|
|
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
|
|
}
|
|
|
|
cleanupRan = false
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleanupRan = true
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !cleanupRan {
|
|
t.Fatal("cleanup did not run on retried Stop")
|
|
}
|
|
}
|
|
|
|
func TestStopWaitsForActiveWorkBeforeCleanup(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
jobErr := make(chan error, 1)
|
|
go func() {
|
|
jobErr <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
cleaned := make(chan struct{})
|
|
stopErr := make(chan error, 1)
|
|
go func() {
|
|
stopErr <- thread.Stop(context.Background(), func() {
|
|
close(cleaned)
|
|
})
|
|
}()
|
|
|
|
select {
|
|
case <-cleaned:
|
|
t.Fatal("cleanup ran before active job completed")
|
|
case <-time.After(10 * time.Millisecond):
|
|
}
|
|
|
|
err = thread.Do(context.Background(), func() error {
|
|
return errors.New("work should be rejected once Stop starts")
|
|
})
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.Fatalf("got %v, want %v", err, ErrStopped)
|
|
}
|
|
|
|
close(release)
|
|
if err := <-jobErr; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := <-stopErr; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
select {
|
|
case <-cleaned:
|
|
default:
|
|
t.Fatal("cleanup did not run")
|
|
}
|
|
}
|