mirror of
https://github.com/sablierapp/sablier.git
synced 2025-12-25 14:59:16 +01:00
add stop_all
This commit is contained in:
@@ -82,9 +82,9 @@ func (_c *MockProvider_Events_Call) RunAndReturn(run func(context.Context) (<-ch
|
||||
return _c
|
||||
}
|
||||
|
||||
// List provides a mock function with given fields: ctx, name
|
||||
func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error) {
|
||||
ret := _m.Called(ctx, name)
|
||||
// List provides a mock function with given fields: ctx, opts
|
||||
func (_m *MockProvider) List(ctx context.Context, opts provider.ListOptions) ([]string, error) {
|
||||
ret := _m.Called(ctx, opts)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for List")
|
||||
@@ -92,19 +92,19 @@ func (_m *MockProvider) List(ctx context.Context, name string) ([]string, error)
|
||||
|
||||
var r0 []string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok {
|
||||
return rf(ctx, name)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) ([]string, error)); ok {
|
||||
return rf(ctx, opts)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok {
|
||||
r0 = rf(ctx, name)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, provider.ListOptions) []string); ok {
|
||||
r0 = rf(ctx, opts)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, name)
|
||||
if rf, ok := ret.Get(1).(func(context.Context, provider.ListOptions) error); ok {
|
||||
r1 = rf(ctx, opts)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -119,14 +119,14 @@ type MockProvider_List_Call struct {
|
||||
|
||||
// List is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - name string
|
||||
func (_e *MockProvider_Expecter) List(ctx interface{}, name interface{}) *MockProvider_List_Call {
|
||||
return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, name)}
|
||||
// - opts provider.ListOptions
|
||||
func (_e *MockProvider_Expecter) List(ctx interface{}, opts interface{}) *MockProvider_List_Call {
|
||||
return &MockProvider_List_Call{Call: _e.mock.On("List", ctx, opts)}
|
||||
}
|
||||
|
||||
func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, name string)) *MockProvider_List_Call {
|
||||
func (_c *MockProvider_List_Call) Run(run func(ctx context.Context, opts provider.ListOptions)) *MockProvider_List_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
run(args[0].(context.Context), args[1].(provider.ListOptions))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (_c *MockProvider_List_Call) Return(_a0 []string, _a1 error) *MockProvider_
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, string) ([]string, error)) *MockProvider_List_Call {
|
||||
func (_c *MockProvider_List_Call) RunAndReturn(run func(context.Context, provider.ListOptions) ([]string, error)) *MockProvider_List_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -36,11 +36,16 @@ type StartOptions struct {
|
||||
ConsiderReadyAfter time.Duration
|
||||
}
|
||||
|
||||
type ListOptions struct {
|
||||
// All list all instances whatever their status (up or down)
|
||||
All bool
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
Start(ctx context.Context, name string, opts StartOptions) error
|
||||
Stop(ctx context.Context, name string) error
|
||||
Status(ctx context.Context, name string) (bool, error)
|
||||
List(ctx context.Context, name string) ([]string, error)
|
||||
List(ctx context.Context, opts ListOptions) ([]string, error)
|
||||
|
||||
Events(ctx context.Context) (<-chan Message, <-chan error)
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
package sablier
|
||||
@@ -2,12 +2,14 @@ package sablier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sablierapp/sablier/pkg/promise"
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
"github.com/sablierapp/sablier/pkg/tinykv"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
"time"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type Sablier struct {
|
||||
@@ -31,10 +33,12 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier {
|
||||
}
|
||||
delete(promises, k)
|
||||
})
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
expirations.Stop()
|
||||
}()
|
||||
|
||||
return &Sablier{
|
||||
Provider: provider,
|
||||
promises: promises,
|
||||
@@ -42,3 +46,9 @@ func NewSablier(ctx context.Context, provider provider.Provider) *Sablier {
|
||||
lock: lock,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sablier) RegisteredInstances() []string {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return maps.Keys(s.promises)
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package sablier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/sablierapp/sablier/pkg/promise"
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
)
|
||||
|
||||
type StartOptions struct {
|
||||
|
||||
@@ -4,14 +4,15 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sablierapp/sablier/pkg/promise"
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
pmock "github.com/sablierapp/sablier/pkg/provider/mock"
|
||||
"github.com/sablierapp/sablier/pkg/sablier"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStartInstance(t *testing.T) {
|
||||
|
||||
36
pkg/sablier/stop_all.go
Normal file
36
pkg/sablier/stop_all.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package sablier
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sablierapp/sablier/pkg/array"
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func (s *Sablier) StopAllUnregistered(ctx context.Context) error {
|
||||
instances, err := s.Provider.List(ctx, provider.ListOptions{
|
||||
All: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
registered := s.RegisteredInstances()
|
||||
unregistered := array.RemoveElements(instances, registered)
|
||||
log.Tracef("Found %v unregistered instances ", len(unregistered))
|
||||
|
||||
waitGroup := errgroup.Group{}
|
||||
|
||||
// Previously, the variables declared by a “for” loop were created once and updated by each iteration.
|
||||
// In Go 1.22, each iteration of the loop creates new variables, to avoid accidental sharing bugs.
|
||||
// The transition support tooling described in the proposal continues to work in the same way it did in Go 1.21.
|
||||
for _, name := range unregistered {
|
||||
waitGroup.Go(func() error {
|
||||
return s.Provider.Stop(ctx, name)
|
||||
})
|
||||
}
|
||||
|
||||
return waitGroup.Wait()
|
||||
}
|
||||
59
pkg/sablier/stop_all_test.go
Normal file
59
pkg/sablier/stop_all_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package sablier_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sablierapp/sablier/pkg/provider"
|
||||
pmock "github.com/sablierapp/sablier/pkg/provider/mock"
|
||||
"github.com/sablierapp/sablier/pkg/sablier"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestStopAllUnregistered(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
m := pmock.NewMockProvider(t)
|
||||
s := sablier.NewSablier(ctx, m)
|
||||
|
||||
m.EXPECT().
|
||||
List(ctx, provider.ListOptions{All: false}).
|
||||
Return([]string{"instance1", "instance2"}, nil)
|
||||
m.EXPECT().Stop(ctx, "instance1").Return(nil).Once()
|
||||
m.EXPECT().Stop(ctx, "instance2").Return(nil).Once()
|
||||
err := s.StopAllUnregistered(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStopAllUnregisteredWithAlreadyRegistered(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
name := "instance1"
|
||||
opts := sablier.StartOptions{
|
||||
DesiredReplicas: 1,
|
||||
ConsiderReadyAfter: 5 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
ExpiresAfter: 1 * time.Minute,
|
||||
}
|
||||
m := pmock.NewMockProvider(t)
|
||||
s := sablier.NewSablier(ctx, m)
|
||||
|
||||
m.EXPECT().Start(mock.Anything, name, provider.StartOptions{
|
||||
DesiredReplicas: opts.DesiredReplicas,
|
||||
ConsiderReadyAfter: opts.ConsiderReadyAfter,
|
||||
}).Return(nil).Once()
|
||||
p := s.StartInstance(name, opts)
|
||||
_, err := p.Await(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m.EXPECT().
|
||||
List(ctx, provider.ListOptions{All: false}).
|
||||
Return([]string{"instance1", "instance2"}, nil)
|
||||
m.EXPECT().Stop(ctx, "instance2").Return(nil).Once()
|
||||
err = s.StopAllUnregistered(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user