diff --git a/pkg/provider/mock/mock.go b/pkg/provider/mock/mock.go index 4e3200f..f7ab012 100644 --- a/pkg/provider/mock/mock.go +++ b/pkg/provider/mock/mock.go @@ -82,9 +82,9 @@ func (_c *MockProvider_Events_Call) RunAndReturn(run func(context.Context) (<-ch return _c } -// List provides a mock function with given fields: ctx, name -func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error) { - ret := _m.Called(ctx, name) +// List provides a mock function with given fields: ctx, opts +func (_m *MockProvider) List(ctx context.Context, opts provider.ListOptions) ([]string, error) { + ret := _m.Called(ctx, opts) if len(ret) == 0 { panic("no return value specified for List") @@ -92,19 +92,19 @@ func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error) var r0 []string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { - return rf(ctx, name) + if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) ([]string, error)); ok { + return rf(ctx, opts) } - if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { - r0 = rf(ctx, name) + if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) []string); ok { + r0 = rf(ctx, opts) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]string) } } - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, name) + if rf, ok := ret.Get(1).(func(context.Context, provider.ListOptions) error); ok { + r1 = rf(ctx, opts) } else { r1 = ret.Error(1) } @@ -119,14 +119,14 @@ type MockProvider_List_Call struct { // List is a helper method to define mock.On call // - ctx context.Context -// - name string -func (_e *MockProvider_Expecter) List(ctx interface{}, name interface{}) *MockProvider_List_Call { - return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, name)} +// - opts provider.ListOptions +func (_e *MockProvider_Expecter) List(ctx interface{}, opts interface{}) *MockProvider_List_Call { + return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, opts)} } -func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, name string)) *MockProvider_List_Call { +func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, opts provider.ListOptions)) *MockProvider_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) + run(args[0].(context.Context), args[1].(provider.ListOptions)) }) return _c } @@ -136,7 +136,7 @@ func (_c *MockProvider_List_Call) Return(_a0 []string, _a1 error) *MockProvider_ return _c } -func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, string) ([]string, error)) *MockProvider_List_Call { +func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, provider.ListOptions) ([]string, error)) *MockProvider_List_Call { _c.Call.Return(run) return _c } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index e2b8f29..598969a 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -36,11 +36,16 @@ type StartOptions struct { ConsiderReadyAfter time.Duration } +type ListOptions struct { + // All list all instances whatever their status (up or down) + All bool +} + type Provider interface { Start(ctx context.Context, name string, opts StartOptions) error Stop(ctx context.Context, name string) error Status(ctx context.Context, name string) (bool, error) - List(ctx context.Context, name string) ([]string, error) + List(ctx context.Context, opts ListOptions) ([]string, error) Events(ctx context.Context) (<-chan Message, <-chan error) } diff --git a/pkg/sablier/autostop.go b/pkg/sablier/autostop.go deleted file mode 100644 index bb4cb38..0000000 --- a/pkg/sablier/autostop.go +++ /dev/null @@ -1 +0,0 @@ -package sablier diff --git a/pkg/sablier/sablier.go b/pkg/sablier/sablier.go index 349844e..d8207eb 100644 --- a/pkg/sablier/sablier.go +++ b/pkg/sablier/sablier.go @@ -2,12 +2,14 @@ package sablier import ( "context" + "sync" + "time" + "github.com/sablierapp/sablier/pkg/promise" "github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/tinykv" log "github.com/sirupsen/logrus" - "sync" - "time" + "golang.org/x/exp/maps" ) type Sablier struct { @@ -31,10 +33,12 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier { } delete(promises, k) }) + go func() { <-ctx.Done() expirations.Stop() }() + return &Sablier{ Provider: provider, promises: promises, @@ -42,3 +46,9 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier { lock: lock, } } + +func (s *Sablier) RegisteredInstances() []string { + s.lock.RLock() + defer s.lock.RUnlock() + return maps.Keys(s.promises) +} diff --git a/pkg/sablier/start_instance.go b/pkg/sablier/start_instance.go index 010cb47..499a390 100644 --- a/pkg/sablier/start_instance.go +++ b/pkg/sablier/start_instance.go @@ -2,11 +2,11 @@ package sablier import ( "context" - "github.com/sablierapp/sablier/pkg/provider" "log" "time" "github.com/sablierapp/sablier/pkg/promise" + "github.com/sablierapp/sablier/pkg/provider" ) type StartOptions struct { diff --git a/pkg/sablier/start_instance_test.go b/pkg/sablier/start_instance_test.go index 2821f88..022a174 100644 --- a/pkg/sablier/start_instance_test.go +++ b/pkg/sablier/start_instance_test.go @@ -4,14 +4,15 @@ import ( "context" "errors" "fmt" + "testing" + "time" + "github.com/sablierapp/sablier/pkg/promise" "github.com/sablierapp/sablier/pkg/provider" pmock "github.com/sablierapp/sablier/pkg/provider/mock" "github.com/sablierapp/sablier/pkg/sablier" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" - "time" ) func TestStartInstance(t *testing.T) { diff --git a/pkg/sablier/stop_all.go b/pkg/sablier/stop_all.go new file mode 100644 index 0000000..d4a2ddd --- /dev/null +++ b/pkg/sablier/stop_all.go @@ -0,0 +1,36 @@ +package sablier + +import ( + "context" + + "github.com/sablierapp/sablier/pkg/array" + "github.com/sablierapp/sablier/pkg/provider" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +func (s *Sablier) StopAllUnregistered(ctx context.Context) error { + instances, err := s.Provider.List(ctx, provider.ListOptions{ + All: false, + }) + if err != nil { + return err + } + + registered := s.RegisteredInstances() + unregistered := array.RemoveElements(instances, registered) + log.Tracef("Found %v unregistered instances ", len(unregistered)) + + waitGroup := errgroup.Group{} + + // Previously, the variables declared by a “for” loop were created once and updated by each iteration. + // In Go 1.22, each iteration of the loop creates new variables, to avoid accidental sharing bugs. + // The transition support tooling described in the proposal continues to work in the same way it did in Go 1.21. + for _, name := range unregistered { + waitGroup.Go(func() error { + return s.Provider.Stop(ctx, name) + }) + } + + return waitGroup.Wait() +} diff --git a/pkg/sablier/stop_all_test.go b/pkg/sablier/stop_all_test.go new file mode 100644 index 0000000..6afc6cd --- /dev/null +++ b/pkg/sablier/stop_all_test.go @@ -0,0 +1,59 @@ +package sablier_test + +import ( + "context" + "testing" + "time" + + "github.com/sablierapp/sablier/pkg/provider" + pmock "github.com/sablierapp/sablier/pkg/provider/mock" + "github.com/sablierapp/sablier/pkg/sablier" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestStopAllUnregistered(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + m := pmock.NewMockProvider(t) + s := sablier.NewSablier(ctx, m) + + m.EXPECT(). + List(ctx, provider.ListOptions{All: false}). + Return([]string{"instance1", "instance2"}, nil) + m.EXPECT().Stop(ctx, "instance1").Return(nil).Once() + m.EXPECT().Stop(ctx, "instance2").Return(nil).Once() + err := s.StopAllUnregistered(ctx) + assert.NoError(t, err) +} + +func TestStopAllUnregisteredWithAlreadyRegistered(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + name := "instance1" + opts := sablier.StartOptions{ + DesiredReplicas: 1, + ConsiderReadyAfter: 5 * time.Second, + Timeout: 30 * time.Second, + ExpiresAfter: 1 * time.Minute, + } + m := pmock.NewMockProvider(t) + s := sablier.NewSablier(ctx, m) + + m.EXPECT().Start(mock.Anything, name, provider.StartOptions{ + DesiredReplicas: opts.DesiredReplicas, + ConsiderReadyAfter: opts.ConsiderReadyAfter, + }).Return(nil).Once() + p := s.StartInstance(name, opts) + _, err := p.Await(ctx) + assert.NoError(t, err) + + m.EXPECT(). + List(ctx, provider.ListOptions{All: false}). + Return([]string{"instance1", "instance2"}, nil) + m.EXPECT().Stop(ctx, "instance2").Return(nil).Once() + err = s.StopAllUnregistered(ctx) + assert.NoError(t, err) +}