add stop_all

This commit is contained in:
Alexis Couvreur
2024-11-18 17:35:54 -05:00
parent 72895c200f
commit cd454bbf83
8 changed files with 132 additions and 22 deletions

View File

@@ -82,9 +82,9 @@ func (_c *MockProvider_Events_Call) RunAndReturn(run func(context.Context) (<-ch
return _c
}
// List provides a mock function with given fields: ctx, name
func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error) {
ret := _m.Called(ctx, name)
// List provides a mock function with given fields: ctx, opts
func (_m *MockProvider) List(ctx context.Context, opts provider.ListOptions) ([]string, error) {
ret := _m.Called(ctx, opts)
if len(ret) == 0 {
panic("no return value specified for List")
@@ -92,19 +92,19 @@ func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error)
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok {
return rf(ctx, name)
if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) ([]string, error)); ok {
return rf(ctx, opts)
}
if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok {
r0 = rf(ctx, name)
if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) []string); ok {
r0 = rf(ctx, opts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, name)
if rf, ok := ret.Get(1).(func(context.Context, provider.ListOptions) error); ok {
r1 = rf(ctx, opts)
} else {
r1 = ret.Error(1)
}
@@ -119,14 +119,14 @@ type MockProvider_List_Call struct {
// List is a helper method to define mock.On call
// - ctx context.Context
// - name string
func (_e *MockProvider_Expecter) List(ctx interface{}, name interface{}) *MockProvider_List_Call {
return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, name)}
// - opts provider.ListOptions
func (_e *MockProvider_Expecter) List(ctx interface{}, opts interface{}) *MockProvider_List_Call {
return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, opts)}
}
func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, name string)) *MockProvider_List_Call {
func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, opts provider.ListOptions)) *MockProvider_List_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
run(args[0].(context.Context), args[1].(provider.ListOptions))
})
return _c
}
@@ -136,7 +136,7 @@ func (_c *MockProvider_List_Call) Return(_a0 []string, _a1 error) *MockProvider_
return _c
}
func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, string) ([]string, error)) *MockProvider_List_Call {
func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, provider.ListOptions) ([]string, error)) *MockProvider_List_Call {
_c.Call.Return(run)
return _c
}

View File

@@ -36,11 +36,16 @@ type StartOptions struct {
ConsiderReadyAfter time.Duration
}
type ListOptions struct {
// All list all instances whatever their status (up or down)
All bool
}
type Provider interface {
Start(ctx context.Context, name string, opts StartOptions) error
Stop(ctx context.Context, name string) error
Status(ctx context.Context, name string) (bool, error)
List(ctx context.Context, name string) ([]string, error)
List(ctx context.Context, opts ListOptions) ([]string, error)
Events(ctx context.Context) (<-chan Message, <-chan error)
}

View File

@@ -1 +0,0 @@
package sablier

View File

@@ -2,12 +2,14 @@ package sablier
import (
"context"
"sync"
"time"
"github.com/sablierapp/sablier/pkg/promise"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/tinykv"
log "github.com/sirupsen/logrus"
"sync"
"time"
"golang.org/x/exp/maps"
)
type Sablier struct {
@@ -31,10 +33,12 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier {
}
delete(promises, k)
})
go func() {
<-ctx.Done()
expirations.Stop()
}()
return &Sablier{
Provider: provider,
promises: promises,
@@ -42,3 +46,9 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier {
lock: lock,
}
}
func (s *Sablier) RegisteredInstances() []string {
s.lock.RLock()
defer s.lock.RUnlock()
return maps.Keys(s.promises)
}

View File

@@ -2,11 +2,11 @@ package sablier
import (
"context"
"github.com/sablierapp/sablier/pkg/provider"
"log"
"time"
"github.com/sablierapp/sablier/pkg/promise"
"github.com/sablierapp/sablier/pkg/provider"
)
type StartOptions struct {

View File

@@ -4,14 +4,15 @@ import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/sablierapp/sablier/pkg/promise"
"github.com/sablierapp/sablier/pkg/provider"
pmock "github.com/sablierapp/sablier/pkg/provider/mock"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
"time"
)
func TestStartInstance(t *testing.T) {

36
pkg/sablier/stop_all.go Normal file
View File

@@ -0,0 +1,36 @@
package sablier
import (
"context"
"github.com/sablierapp/sablier/pkg/array"
"github.com/sablierapp/sablier/pkg/provider"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
func (s *Sablier) StopAllUnregistered(ctx context.Context) error {
instances, err := s.Provider.List(ctx, provider.ListOptions{
All: false,
})
if err != nil {
return err
}
registered := s.RegisteredInstances()
unregistered := array.RemoveElements(instances, registered)
log.Tracef("Found %v unregistered instances ", len(unregistered))
waitGroup := errgroup.Group{}
// Previously, the variables declared by a “for” loop were created once and updated by each iteration.
// In Go 1.22, each iteration of the loop creates new variables, to avoid accidental sharing bugs.
// The transition support tooling described in the proposal continues to work in the same way it did in Go 1.21.
for _, name := range unregistered {
waitGroup.Go(func() error {
return s.Provider.Stop(ctx, name)
})
}
return waitGroup.Wait()
}

View File

@@ -0,0 +1,59 @@
package sablier_test
import (
"context"
"testing"
"time"
"github.com/sablierapp/sablier/pkg/provider"
pmock "github.com/sablierapp/sablier/pkg/provider/mock"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestStopAllUnregistered(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
m := pmock.NewMockProvider(t)
s := sablier.NewSablier(ctx, m)
m.EXPECT().
List(ctx, provider.ListOptions{All: false}).
Return([]string{"instance1", "instance2"}, nil)
m.EXPECT().Stop(ctx, "instance1").Return(nil).Once()
m.EXPECT().Stop(ctx, "instance2").Return(nil).Once()
err := s.StopAllUnregistered(ctx)
assert.NoError(t, err)
}
func TestStopAllUnregisteredWithAlreadyRegistered(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
name := "instance1"
opts := sablier.StartOptions{
DesiredReplicas: 1,
ConsiderReadyAfter: 5 * time.Second,
Timeout: 30 * time.Second,
ExpiresAfter: 1 * time.Minute,
}
m := pmock.NewMockProvider(t)
s := sablier.NewSablier(ctx, m)
m.EXPECT().Start(mock.Anything, name, provider.StartOptions{
DesiredReplicas: opts.DesiredReplicas,
ConsiderReadyAfter: opts.ConsiderReadyAfter,
}).Return(nil).Once()
p := s.StartInstance(name, opts)
_, err := p.Await(ctx)
assert.NoError(t, err)
m.EXPECT().
List(ctx, provider.ListOptions{All: false}).
Return([]string{"instance1", "instance2"}, nil)
m.EXPECT().Stop(ctx, "instance2").Return(nil).Once()
err = s.StopAllUnregistered(ctx)
assert.NoError(t, err)
}