refactor(provider): pass context.Context down to all operations

This means that with more work, a canceled request would cancel to underlying request.
This commit is contained in:
Alexis Couvreur
2023-09-14 09:29:13 -04:00
parent 526f188ade
commit 72ea3b3645
11 changed files with 47 additions and 77 deletions

View File

@@ -4,10 +4,11 @@ import (
"context"
"errors"
"fmt"
"github.com/docker/docker/api/types/container"
"io"
"strings"
"github.com/docker/docker/api/types/container"
"github.com/acouvreur/sablier/app/instance"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/events"
@@ -34,9 +35,7 @@ func NewDockerClassicProvider() (*DockerClassicProvider, error) {
}, nil
}
func (provider *DockerClassicProvider) GetGroups() (map[string][]string, error) {
ctx := context.Background()
func (provider *DockerClassicProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
filters := filters.NewArgs()
filters.Add("label", fmt.Sprintf("%s=true", enableLabel))
@@ -65,9 +64,7 @@ func (provider *DockerClassicProvider) GetGroups() (map[string][]string, error)
return groups, nil
}
func (provider *DockerClassicProvider) Start(name string) (instance.State, error) {
ctx := context.Background()
func (provider *DockerClassicProvider) Start(ctx context.Context, name string) (instance.State, error) {
err := provider.Client.ContainerStart(ctx, name, types.ContainerStartOptions{})
if err != nil {
@@ -82,10 +79,7 @@ func (provider *DockerClassicProvider) Start(name string) (instance.State, error
}, err
}
func (provider *DockerClassicProvider) Stop(name string) (instance.State, error) {
ctx := context.Background()
// TODO: Allow to specify a termination timeout
func (provider *DockerClassicProvider) Stop(ctx context.Context, name string) (instance.State, error) {
err := provider.Client.ContainerStop(ctx, name, container.StopOptions{})
if err != nil {
@@ -100,9 +94,7 @@ func (provider *DockerClassicProvider) Stop(name string) (instance.State, error)
}, nil
}
func (provider *DockerClassicProvider) GetState(name string) (instance.State, error) {
ctx := context.Background()
func (provider *DockerClassicProvider) GetState(ctx context.Context, name string) (instance.State, error) {
spec, err := provider.Client.ContainerInspect(ctx, name)
if err != nil {

View File

@@ -248,7 +248,7 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
tt.fields.Client.On("ContainerInspect", mock.Anything, mock.Anything).Return(tt.containerSpec, tt.err)
got, err := provider.GetState(tt.args.name)
got, err := provider.GetState(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerClassicProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -320,7 +320,7 @@ func TestDockerClassicProvider_Stop(t *testing.T) {
tt.fields.Client.On("ContainerStop", mock.Anything, mock.Anything, mock.Anything).Return(tt.err)
got, err := provider.Stop(tt.args.name)
got, err := provider.Stop(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerClassicProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -392,7 +392,7 @@ func TestDockerClassicProvider_Start(t *testing.T) {
tt.fields.Client.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(tt.err)
got, err := provider.Start(tt.args.name)
got, err := provider.Start(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerClassicProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"strings"
"sync"
"github.com/acouvreur/sablier/app/instance"
"github.com/docker/docker/api/types"
@@ -18,8 +17,6 @@ import (
type DockerSwarmProvider struct {
Client client.APIClient
updateGroups chan any
groups *sync.Map
desiredReplicas int
}
@@ -31,22 +28,19 @@ func NewDockerSwarmProvider() (*DockerSwarmProvider, error) {
return &DockerSwarmProvider{
Client: cli,
desiredReplicas: 1,
updateGroups: make(chan any, 1),
groups: &sync.Map{},
}, nil
}
func (provider *DockerSwarmProvider) Start(name string) (instance.State, error) {
return provider.scale(name, uint64(provider.desiredReplicas))
func (provider *DockerSwarmProvider) Start(ctx context.Context, name string) (instance.State, error) {
return provider.scale(ctx, name, uint64(provider.desiredReplicas))
}
func (provider *DockerSwarmProvider) Stop(name string) (instance.State, error) {
return provider.scale(name, 0)
func (provider *DockerSwarmProvider) Stop(ctx context.Context, name string) (instance.State, error) {
return provider.scale(ctx, name, 0)
}
func (provider *DockerSwarmProvider) scale(name string, replicas uint64) (instance.State, error) {
ctx := context.Background()
func (provider *DockerSwarmProvider) scale(ctx context.Context, name string, replicas uint64) (instance.State, error) {
service, err := provider.getServiceByName(name, ctx)
if err != nil {
@@ -74,9 +68,7 @@ func (provider *DockerSwarmProvider) scale(name string, replicas uint64) (instan
return instance.NotReadyInstanceState(foundName, 0, provider.desiredReplicas)
}
func (provider *DockerSwarmProvider) GetGroups() (map[string][]string, error) {
ctx := context.Background()
func (provider *DockerSwarmProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
filters := filters.NewArgs()
filters.Add("label", fmt.Sprintf("%s=true", enableLabel))
@@ -103,16 +95,7 @@ func (provider *DockerSwarmProvider) GetGroups() (map[string][]string, error) {
return groups, nil
}
func (provider *DockerSwarmProvider) GetGroup(group string) []string {
containers, ok := provider.groups.Load(group)
if !ok {
return []string{}
}
return containers.([]string)
}
func (provider *DockerSwarmProvider) GetState(name string) (instance.State, error) {
ctx := context.Background()
func (provider *DockerSwarmProvider) GetState(ctx context.Context, name string) (instance.State, error) {
service, err := provider.getServiceByName(name, ctx)
if err != nil {

View File

@@ -101,7 +101,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
clientMock.On("ServiceUpdate", mock.Anything, tt.wantService.ID, tt.wantService.Meta.Version, tt.wantService.Spec, mock.Anything).Return(tt.response, nil)
got, err := provider.Start(tt.args.name)
got, err := provider.Start(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerSwarmProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -201,7 +201,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
clientMock.On("ServiceUpdate", mock.Anything, tt.wantService.ID, tt.wantService.Meta.Version, tt.wantService.Spec, mock.Anything).Return(tt.response, nil)
got, err := provider.Stop(tt.args.name)
got, err := provider.Stop(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerSwarmProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -284,7 +284,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
clientMock.On("ServiceList", mock.Anything, mock.Anything).Return(tt.serviceList, nil)
got, err := provider.GetState(tt.args.name)
got, err := provider.GetState(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("DockerSwarmProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@@ -76,28 +76,26 @@ func NewKubernetesProvider() (*KubernetesProvider, error) {
}
func (provider *KubernetesProvider) Start(name string) (instance.State, error) {
func (provider *KubernetesProvider) Start(ctx context.Context, name string) (instance.State, error) {
config, err := convertName(name)
if err != nil {
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
}
return provider.scale(config, config.Replicas)
return provider.scale(ctx, config, config.Replicas)
}
func (provider *KubernetesProvider) Stop(name string) (instance.State, error) {
func (provider *KubernetesProvider) Stop(ctx context.Context, name string) (instance.State, error) {
config, err := convertName(name)
if err != nil {
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
}
return provider.scale(config, 0)
return provider.scale(ctx, config, 0)
}
func (provider *KubernetesProvider) GetGroups() (map[string][]string, error) {
ctx := context.Background()
func (provider *KubernetesProvider) GetGroups(ctx context.Context) (map[string][]string, error) {
deployments, err := provider.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
LabelSelector: enableLabel,
})
@@ -123,9 +121,7 @@ func (provider *KubernetesProvider) GetGroups() (map[string][]string, error) {
return groups, nil
}
func (provider *KubernetesProvider) scale(config *Config, replicas int32) (instance.State, error) {
ctx := context.Background()
func (provider *KubernetesProvider) scale(ctx context.Context, config *Config, replicas int32) (instance.State, error) {
var workload Workload
switch config.Kind {
@@ -152,7 +148,7 @@ func (provider *KubernetesProvider) scale(config *Config, replicas int32) (insta
return instance.NotReadyInstanceState(config.OriginalName, 0, int(config.Replicas))
}
func (provider *KubernetesProvider) GetState(name string) (instance.State, error) {
func (provider *KubernetesProvider) GetState(ctx context.Context, name string) (instance.State, error) {
config, err := convertName(name)
if err != nil {
return instance.UnrecoverableInstanceState(name, err.Error(), int(config.Replicas))
@@ -160,17 +156,15 @@ func (provider *KubernetesProvider) GetState(name string) (instance.State, error
switch config.Kind {
case "deployment":
return provider.getDeploymentState(config)
return provider.getDeploymentState(ctx, config)
case "statefulset":
return provider.getStatefulsetState(config)
return provider.getStatefulsetState(ctx, config)
default:
return instance.UnrecoverableInstanceState(config.OriginalName, fmt.Sprintf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", config.Kind), int(config.Replicas))
}
}
func (provider *KubernetesProvider) getDeploymentState(config *Config) (instance.State, error) {
ctx := context.Background()
func (provider *KubernetesProvider) getDeploymentState(ctx context.Context, config *Config) (instance.State, error) {
d, err := provider.Client.AppsV1().Deployments(config.Namespace).
Get(ctx, config.Name, metav1.GetOptions{})
@@ -185,9 +179,7 @@ func (provider *KubernetesProvider) getDeploymentState(config *Config) (instance
return instance.NotReadyInstanceState(config.OriginalName, int(d.Status.ReadyReplicas), int(config.Replicas))
}
func (provider *KubernetesProvider) getStatefulsetState(config *Config) (instance.State, error) {
ctx := context.Background()
func (provider *KubernetesProvider) getStatefulsetState(ctx context.Context, config *Config) (instance.State, error) {
ss, err := provider.Client.AppsV1().StatefulSets(config.Namespace).
Get(ctx, config.Name, metav1.GetOptions{})

View File

@@ -1,6 +1,7 @@
package providers
import (
"context"
"reflect"
"testing"
@@ -98,7 +99,7 @@ func TestKubernetesProvider_Start(t *testing.T) {
statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil)
statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil)
got, err := provider.Start(tt.args.name)
got, err := provider.Start(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("KubernetesProvider.Start() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -196,7 +197,7 @@ func TestKubernetesProvider_Stop(t *testing.T) {
statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil)
statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil)
got, err := provider.Stop(tt.args.name)
got, err := provider.Stop(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("KubernetesProvider.Stop() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -321,7 +322,7 @@ func TestKubernetesProvider_GetState(t *testing.T) {
deploymentAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getDeployment, nil)
statefulsetAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getStatefulSet, nil)
got, err := provider.GetState(tt.args.name)
got, err := provider.GetState(context.Background(), tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("KubernetesProvider.GetState() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@@ -13,10 +13,10 @@ const groupLabel = "sablier.group"
const defaultGroupValue = "default"
type Provider interface {
Start(name string) (instance.State, error)
Stop(name string) (instance.State, error)
GetState(name string) (instance.State, error)
GetGroups() (map[string][]string, error)
Start(ctx context.Context, name string) (instance.State, error)
Stop(ctx context.Context, name string) (instance.State, error)
GetState(ctx context.Context, name string) (instance.State, error)
GetGroups(ctx context.Context) (map[string][]string, error)
NotifyInstanceStopped(ctx context.Context, instance chan<- string)
}

View File

@@ -1,6 +1,8 @@
package app
import (
"context"
"github.com/acouvreur/sablier/app/http"
"github.com/acouvreur/sablier/app/instance"
"github.com/acouvreur/sablier/app/providers"
@@ -56,7 +58,7 @@ func onSessionExpires(provider providers.Provider) func(key string, instance ins
return func(_key string, _instance instance.State) {
go func(key string, instance instance.State) {
log.Debugf("stopping %s...", key)
_, err := provider.Stop(key)
_, err := provider.Stop(context.Background(), key)
if err != nil {
log.Warnf("error stopping %s: %s", key, err.Error())

View File

@@ -16,7 +16,7 @@ func watchGroups(ctx context.Context, provider providers.Provider, frequency tim
case <-ctx.Done():
return
case <-ticker.C:
groups, err := provider.GetGroups()
groups, err := provider.GetGroups(ctx)
if err != nil {
log.Warn("could not get groups", err)
} else {

View File

@@ -48,12 +48,12 @@ func (provider *ProviderMock) Wait() {
provider.wg.Wait()
}
func (provider *ProviderMock) GetState(name string) (instance.State, error) {
func (provider *ProviderMock) GetState(ctx context.Context, name string) (instance.State, error) {
args := provider.Mock.Called(name)
return args.Get(0).(instance.State), args.Error(1)
}
func (provider *ProviderMock) GetGroups() (map[string][]string, error) {
func (provider *ProviderMock) GetGroups(ctx context.Context) (map[string][]string, error) {
return make(map[string][]string), nil
}

View File

@@ -40,7 +40,7 @@ type SessionsManager struct {
func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Provider) Manager {
ctx, cancel := context.WithCancel(context.Background())
groups, err := provider.GetGroups()
groups, err := provider.GetGroups(ctx)
if err != nil {
groups = make(map[string][]string)
log.Warn("could not get groups", err)
@@ -184,7 +184,7 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
if !exists {
log.Debugf("starting %s...", name)
state, err := s.provider.Start(name)
state, err := s.provider.Start(s.ctx, name)
if err != nil {
log.Errorf("an error occurred starting %s: %s", name, err.Error())
@@ -199,7 +199,7 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
} else if requestState.Status != instance.Ready {
log.Debugf("checking %s...", name)
state, err := s.provider.GetState(name)
state, err := s.provider.GetState(s.ctx, name)
if err != nil {
log.Errorf("an error occurred checking state %s: %s", name, err.Error())