mirror of
https://github.com/sablierapp/sablier.git
synced 2026-01-03 19:44:59 +01:00
refactor(provider): pass context.Context down to all operations
This means that with more work, a canceled request would cancel to underlying request.
This commit is contained in:
@@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
|
||||
"github.com/acouvreur/sablier/app/instance"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/events"
|
||||
@@ -34,9 +35,7 @@ func NewDockerClassicProvider() (*DockerClassicProvider, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (provider *DockerClassicProvider) GetGroups() (map[string][]string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *DockerClassicProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
|
||||
filters := filters.NewArgs()
|
||||
filters.Add("label", fmt.Sprintf("%s=true", enableLabel))
|
||||
|
||||
@@ -65,9 +64,7 @@ func (provider *DockerClassicProvider) GetGroups() (map[string][]string, error)
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (provider *DockerClassicProvider) Start(name string) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *DockerClassicProvider) Start(ctx context.Context, name string) (instance.State, error) {
|
||||
err := provider.Client.ContainerStart(ctx, name, types.ContainerStartOptions{})
|
||||
|
||||
if err != nil {
|
||||
@@ -82,10 +79,7 @@ func (provider *DockerClassicProvider) Start(name string) (instance.State, error
|
||||
}, err
|
||||
}
|
||||
|
||||
func (provider *DockerClassicProvider) Stop(name string) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// TODO: Allow to specify a termination timeout
|
||||
func (provider *DockerClassicProvider) Stop(ctx context.Context, name string) (instance.State, error) {
|
||||
err := provider.Client.ContainerStop(ctx, name, container.StopOptions{})
|
||||
|
||||
if err != nil {
|
||||
@@ -100,9 +94,7 @@ func (provider *DockerClassicProvider) Stop(name string) (instance.State, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (provider *DockerClassicProvider) GetState(name string) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *DockerClassicProvider) GetState(ctx context.Context, name string) (instance.State, error) {
|
||||
spec, err := provider.Client.ContainerInspect(ctx, name)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -248,7 +248,7 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
|
||||
|
||||
tt.fields.Client.On("ContainerInspect", mock.Anything, mock.Anything).Return(tt.containerSpec, tt.err)
|
||||
|
||||
got, err := provider.GetState(tt.args.name)
|
||||
got, err := provider.GetState(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerClassicProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -320,7 +320,7 @@ func TestDockerClassicProvider_Stop(t *testing.T) {
|
||||
|
||||
tt.fields.Client.On("ContainerStop", mock.Anything, mock.Anything, mock.Anything).Return(tt.err)
|
||||
|
||||
got, err := provider.Stop(tt.args.name)
|
||||
got, err := provider.Stop(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerClassicProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -392,7 +392,7 @@ func TestDockerClassicProvider_Start(t *testing.T) {
|
||||
|
||||
tt.fields.Client.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(tt.err)
|
||||
|
||||
got, err := provider.Start(tt.args.name)
|
||||
got, err := provider.Start(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerClassicProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/acouvreur/sablier/app/instance"
|
||||
"github.com/docker/docker/api/types"
|
||||
@@ -18,8 +17,6 @@ import (
|
||||
|
||||
type DockerSwarmProvider struct {
|
||||
Client client.APIClient
|
||||
updateGroups chan any
|
||||
groups *sync.Map
|
||||
desiredReplicas int
|
||||
}
|
||||
|
||||
@@ -31,22 +28,19 @@ func NewDockerSwarmProvider() (*DockerSwarmProvider, error) {
|
||||
return &DockerSwarmProvider{
|
||||
Client: cli,
|
||||
desiredReplicas: 1,
|
||||
updateGroups: make(chan any, 1),
|
||||
groups: &sync.Map{},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) Start(name string) (instance.State, error) {
|
||||
return provider.scale(name, uint64(provider.desiredReplicas))
|
||||
func (provider *DockerSwarmProvider) Start(ctx context.Context, name string) (instance.State, error) {
|
||||
return provider.scale(ctx, name, uint64(provider.desiredReplicas))
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) Stop(name string) (instance.State, error) {
|
||||
return provider.scale(name, 0)
|
||||
func (provider *DockerSwarmProvider) Stop(ctx context.Context, name string) (instance.State, error) {
|
||||
return provider.scale(ctx, name, 0)
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) scale(name string, replicas uint64) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
func (provider *DockerSwarmProvider) scale(ctx context.Context, name string, replicas uint64) (instance.State, error) {
|
||||
service, err := provider.getServiceByName(name, ctx)
|
||||
|
||||
if err != nil {
|
||||
@@ -74,9 +68,7 @@ func (provider *DockerSwarmProvider) scale(name string, replicas uint64) (instan
|
||||
return instance.NotReadyInstanceState(foundName, 0, provider.desiredReplicas)
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) GetGroups() (map[string][]string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *DockerSwarmProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
|
||||
filters := filters.NewArgs()
|
||||
filters.Add("label", fmt.Sprintf("%s=true", enableLabel))
|
||||
|
||||
@@ -103,16 +95,7 @@ func (provider *DockerSwarmProvider) GetGroups() (map[string][]string, error) {
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) GetGroup(group string) []string {
|
||||
containers, ok := provider.groups.Load(group)
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
return containers.([]string)
|
||||
}
|
||||
|
||||
func (provider *DockerSwarmProvider) GetState(name string) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
func (provider *DockerSwarmProvider) GetState(ctx context.Context, name string) (instance.State, error) {
|
||||
|
||||
service, err := provider.getServiceByName(name, ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
|
||||
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
|
||||
clientMock.On("ServiceUpdate", mock.Anything, tt.wantService.ID, tt.wantService.Meta.Version, tt.wantService.Spec, mock.Anything).Return(tt.response, nil)
|
||||
|
||||
got, err := provider.Start(tt.args.name)
|
||||
got, err := provider.Start(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerSwarmProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -201,7 +201,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
|
||||
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
|
||||
clientMock.On("ServiceUpdate", mock.Anything, tt.wantService.ID, tt.wantService.Meta.Version, tt.wantService.Spec, mock.Anything).Return(tt.response, nil)
|
||||
|
||||
got, err := provider.Stop(tt.args.name)
|
||||
got, err := provider.Stop(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerSwarmProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -284,7 +284,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
|
||||
|
||||
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
|
||||
|
||||
got, err := provider.GetState(tt.args.name)
|
||||
got, err := provider.GetState(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DockerSwarmProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
||||
@@ -76,28 +76,26 @@ func NewKubernetesProvider() (*KubernetesProvider, error) {
|
||||
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) Start(name string) (instance.State, error) {
|
||||
func (provider *KubernetesProvider) Start(ctx context.Context, name string) (instance.State, error) {
|
||||
config, err := convertName(name)
|
||||
if err != nil {
|
||||
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
|
||||
}
|
||||
|
||||
return provider.scale(config, config.Replicas)
|
||||
return provider.scale(ctx, config, config.Replicas)
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) Stop(name string) (instance.State, error) {
|
||||
func (provider *KubernetesProvider) Stop(ctx context.Context, name string) (instance.State, error) {
|
||||
config, err := convertName(name)
|
||||
if err != nil {
|
||||
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
|
||||
}
|
||||
|
||||
return provider.scale(config, 0)
|
||||
return provider.scale(ctx, config, 0)
|
||||
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) GetGroups() (map[string][]string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *KubernetesProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
|
||||
deployments, err := provider.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
|
||||
LabelSelector: enableLabel,
|
||||
})
|
||||
@@ -123,9 +121,7 @@ func (provider *KubernetesProvider) GetGroups() (map[string][]string, error) {
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) scale(config *Config, replicas int32) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *KubernetesProvider) scale(ctx context.Context, config *Config, replicas int32) (instance.State, error) {
|
||||
var workload Workload
|
||||
|
||||
switch config.Kind {
|
||||
@@ -152,7 +148,7 @@ func (provider *KubernetesProvider) scale(config *Config, replicas int32) (insta
|
||||
return instance.NotReadyInstanceState(config.OriginalName, 0, int(config.Replicas))
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) GetState(name string) (instance.State, error) {
|
||||
func (provider *KubernetesProvider) GetState(ctx context.Context, name string) (instance.State, error) {
|
||||
config, err := convertName(name)
|
||||
if err != nil {
|
||||
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
|
||||
@@ -160,17 +156,15 @@ func (provider *KubernetesProvider) GetState(name string) (instance.State, error
|
||||
|
||||
switch config.Kind {
|
||||
case "deployment":
|
||||
return provider.getDeploymentState(config)
|
||||
return provider.getDeploymentState(ctx, config)
|
||||
case "statefulset":
|
||||
return provider.getStatefulsetState(config)
|
||||
return provider.getStatefulsetState(ctx, config)
|
||||
default:
|
||||
return instance.UnrecoverableInstanceState(config.OriginalName, fmt.Sprintf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", config.Kind), int(config.Replicas))
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) getDeploymentState(config *Config) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *KubernetesProvider) getDeploymentState(ctx context.Context, config *Config) (instance.State, error) {
|
||||
d, err := provider.Client.AppsV1().Deployments(config.Namespace).
|
||||
Get(ctx, config.Name, metav1.GetOptions{})
|
||||
|
||||
@@ -185,9 +179,7 @@ func (provider *KubernetesProvider) getDeploymentState(config *Config) (instance
|
||||
return instance.NotReadyInstanceState(config.OriginalName, int(d.Status.ReadyReplicas), int(config.Replicas))
|
||||
}
|
||||
|
||||
func (provider *KubernetesProvider) getStatefulsetState(config *Config) (instance.State, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
func (provider *KubernetesProvider) getStatefulsetState(ctx context.Context, config *Config) (instance.State, error) {
|
||||
ss, err := provider.Client.AppsV1().StatefulSets(config.Namespace).
|
||||
Get(ctx, config.Name, metav1.GetOptions{})
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -98,7 +99,7 @@ func TestKubernetesProvider_Start(t *testing.T) {
|
||||
statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil)
|
||||
statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil)
|
||||
|
||||
got, err := provider.Start(tt.args.name)
|
||||
got, err := provider.Start(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KubernetesProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -196,7 +197,7 @@ func TestKubernetesProvider_Stop(t *testing.T) {
|
||||
statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil)
|
||||
statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil)
|
||||
|
||||
got, err := provider.Stop(tt.args.name)
|
||||
got, err := provider.Stop(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KubernetesProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -321,7 +322,7 @@ func TestKubernetesProvider_GetState(t *testing.T) {
|
||||
deploymentAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getDeployment, nil)
|
||||
statefulsetAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getStatefulSet, nil)
|
||||
|
||||
got, err := provider.GetState(tt.args.name)
|
||||
got, err := provider.GetState(context.Background(), tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KubernetesProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
||||
@@ -13,10 +13,10 @@ const groupLabel = "sablier.group"
|
||||
const defaultGroupValue = "default"
|
||||
|
||||
type Provider interface {
|
||||
Start(name string) (instance.State, error)
|
||||
Stop(name string) (instance.State, error)
|
||||
GetState(name string) (instance.State, error)
|
||||
GetGroups() (map[string][]string, error)
|
||||
Start(ctx context.Context, name string) (instance.State, error)
|
||||
Stop(ctx context.Context, name string) (instance.State, error)
|
||||
GetState(ctx context.Context, name string) (instance.State, error)
|
||||
GetGroups(ctx context.Context) (map[string][]string, error)
|
||||
|
||||
NotifyInstanceStopped(ctx context.Context, instance chan<- string)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/acouvreur/sablier/app/http"
|
||||
"github.com/acouvreur/sablier/app/instance"
|
||||
"github.com/acouvreur/sablier/app/providers"
|
||||
@@ -56,7 +58,7 @@ func onSessionExpires(provider providers.Provider) func(key string, instance ins
|
||||
return func(_key string, _instance instance.State) {
|
||||
go func(key string, instance instance.State) {
|
||||
log.Debugf("stopping %s...", key)
|
||||
_, err := provider.Stop(key)
|
||||
_, err := provider.Stop(context.Background(), key)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("error stopping %s: %s", key, err.Error())
|
||||
|
||||
@@ -16,7 +16,7 @@ func watchGroups(ctx context.Context, provider providers.Provider, frequency tim
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
groups, err := provider.GetGroups()
|
||||
groups, err := provider.GetGroups(ctx)
|
||||
if err != nil {
|
||||
log.Warn("could not get groups", err)
|
||||
} else {
|
||||
|
||||
@@ -48,12 +48,12 @@ func (provider *ProviderMock) Wait() {
|
||||
provider.wg.Wait()
|
||||
}
|
||||
|
||||
func (provider *ProviderMock) GetState(name string) (instance.State, error) {
|
||||
func (provider *ProviderMock) GetState(ctx context.Context, name string) (instance.State, error) {
|
||||
args := provider.Mock.Called(name)
|
||||
return args.Get(0).(instance.State), args.Error(1)
|
||||
}
|
||||
|
||||
func (provider *ProviderMock) GetGroups() (map[string][]string, error) {
|
||||
func (provider *ProviderMock) GetGroups(ctx context.Context) (map[string][]string, error) {
|
||||
return make(map[string][]string), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ type SessionsManager struct {
|
||||
func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Provider) Manager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
groups, err := provider.GetGroups()
|
||||
groups, err := provider.GetGroups(ctx)
|
||||
if err != nil {
|
||||
groups = make(map[string][]string)
|
||||
log.Warn("could not get groups", err)
|
||||
@@ -184,7 +184,7 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
|
||||
if !exists {
|
||||
log.Debugf("starting %s...", name)
|
||||
|
||||
state, err := s.provider.Start(name)
|
||||
state, err := s.provider.Start(s.ctx, name)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("an error occurred starting %s: %s", name, err.Error())
|
||||
@@ -199,7 +199,7 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
|
||||
} else if requestState.Status != instance.Ready {
|
||||
log.Debugf("checking %s...", name)
|
||||
|
||||
state, err := s.provider.GetState(name)
|
||||
state, err := s.provider.GetState(s.ctx, name)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("an error occurred checking state %s: %s", name, err.Error())
|
||||
|
||||
Reference in New Issue
Block a user