diff --git a/app/discovery/autostop.go b/app/discovery/autostop.go index 98b5337..5f98617 100644 --- a/app/discovery/autostop.go +++ b/app/discovery/autostop.go @@ -43,7 +43,7 @@ func StopAllUnregisteredInstances(ctx context.Context, p provider.Provider, s st func stopFunc(ctx context.Context, name string, p provider.Provider, logger *slog.Logger) func() error { return func() error { - err := p.Stop(ctx, name) + err := p.InstanceStop(ctx, name) if err != nil { logger.ErrorContext(ctx, "failed to stop instance", slog.String("instance", name), slog.Any("error", err)) return err diff --git a/app/discovery/autostop_test.go b/app/discovery/autostop_test.go index e1e368a..1fcf31c 100644 --- a/app/discovery/autostop_test.go +++ b/app/discovery/autostop_test.go @@ -35,9 +35,9 @@ func TestStopAllUnregisteredInstances(t *testing.T) { Labels: []string{discovery.LabelEnable}, }).Return(instances, nil) - // Set up expectations for Stop - mockProvider.On("Stop", ctx, "instance2").Return(nil) - mockProvider.On("Stop", ctx, "instance3").Return(nil) + // Set up expectations for InstanceStop + mockProvider.On("InstanceStop", ctx, "instance2").Return(nil) + mockProvider.On("InstanceStop", ctx, "instance3").Return(nil) // Call the function under test err = discovery.StopAllUnregisteredInstances(ctx, mockProvider, store, slogt.New(t)) @@ -67,9 +67,9 @@ func TestStopAllUnregisteredInstances_WithError(t *testing.T) { Labels: []string{discovery.LabelEnable}, }).Return(instances, nil) - // Set up expectations for Stop with error - mockProvider.On("Stop", ctx, "instance2").Return(errors.New("stop error")) - mockProvider.On("Stop", ctx, "instance3").Return(nil) + // Set up expectations for InstanceStop with error + mockProvider.On("InstanceStop", ctx, "instance2").Return(errors.New("stop error")) + mockProvider.On("InstanceStop", ctx, "instance3").Return(nil) // Call the function under test err = discovery.StopAllUnregisteredInstances(ctx, mockProvider, store, slogt.New(t)) diff --git a/app/sablier.go b/app/sablier.go index 65a3a74..53dfc8d 100644 --- a/app/sablier.go +++ b/app/sablier.go @@ -12,6 +12,9 @@ import ( "github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/store/inmemory" "github.com/sablierapp/sablier/pkg/theme" + k8s "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "log/slog" "os" "os/signal" @@ -56,7 +59,7 @@ func Start(ctx context.Context, conf config.Config) error { loadSessions(storage, sessionsManager, logger) } - groups, err := provider.GetGroups(ctx) + groups, err := provider.InstanceGroups(ctx) if err != nil { logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err)) } else { @@ -133,7 +136,7 @@ func onSessionExpires(ctx context.Context, provider provider.Provider, logger *s return func(_key string) { go func(key string) { logger.InfoContext(ctx, "instance expired", slog.String("instance", key)) - err := provider.Stop(ctx, key) + err := provider.InstanceStop(ctx, key) if err != nil { logger.ErrorContext(ctx, "instance expired could not be stopped from provider", slog.String("instance", key), slog.Any("error", err)) } @@ -185,7 +188,18 @@ func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provide } return docker.NewDockerClassicProvider(ctx, cli, logger) case "kubernetes": - return kubernetes.NewKubernetesProvider(ctx, logger, config.Kubernetes) + kubeclientConfig, err := rest.InClusterConfig() + if err != nil { + return nil, err + } + kubeclientConfig.QPS = config.Kubernetes.QPS + kubeclientConfig.Burst = config.Kubernetes.Burst + + cli, err := k8s.NewForConfig(kubeclientConfig) + if err != nil { + return nil, err + } + return kubernetes.NewKubernetesProvider(ctx, cli, logger, config.Kubernetes) } return nil, fmt.Errorf("unimplemented provider %s", config.Name) } @@ -197,7 +211,7 @@ func WatchGroups(ctx context.Context, provider provider.Provider, frequency time case <-ctx.Done(): return case <-ticker.C: - groups, err := provider.GetGroups(ctx) + groups, err := provider.InstanceGroups(ctx) if err != nil { logger.Error("cannot retrieve group from provider", slog.Any("reason", err)) } else if groups != nil { diff --git a/app/sessions/mocks/provider_mock.go b/app/sessions/mocks/provider_mock.go index c43ad29..6c61b8a 100644 --- a/app/sessions/mocks/provider_mock.go +++ b/app/sessions/mocks/provider_mock.go @@ -48,12 +48,12 @@ func (provider *ProviderMock) Wait() { provider.wg.Wait() } -func (provider *ProviderMock) GetState(ctx context.Context, name string) (instance.State, error) { +func (provider *ProviderMock) InstanceInspect(ctx context.Context, name string) (instance.State, error) { args := provider.Mock.Called(name) return args.Get(0).(instance.State), args.Error(1) } -func (provider *ProviderMock) GetGroups(ctx context.Context) (map[string][]string, error) { +func (provider *ProviderMock) InstanceGroups(ctx context.Context) (map[string][]string, error) { return make(map[string][]string), nil } diff --git a/app/sessions/sessions_manager.go b/app/sessions/sessions_manager.go index 6b305a3..70bb36b 100644 --- a/app/sessions/sessions_manager.go +++ b/app/sessions/sessions_manager.go @@ -181,12 +181,12 @@ func (s *SessionsManager) requestInstance(ctx context.Context, name string, dura if errors.Is(err, store.ErrKeyNotFound) { s.l.DebugContext(ctx, "request to start instance received", slog.String("instance", name)) - err := s.provider.Start(ctx, name) + err := s.provider.InstanceStart(ctx, name) if err != nil { return instance.State{}, err } - state, err = s.provider.GetState(ctx, name) + state, err = s.provider.InstanceInspect(ctx, name) if err != nil { return instance.State{}, err } @@ -196,7 +196,7 @@ func (s *SessionsManager) requestInstance(ctx context.Context, name string, dura return instance.State{}, fmt.Errorf("cannot retrieve instance from store: %w", err) } else if state.Status != instance.Ready { s.l.DebugContext(ctx, "request to check instance status received", slog.String("instance", name), slog.String("current_status", state.Status)) - state, err = s.provider.GetState(ctx, name) + state, err = s.provider.InstanceInspect(ctx, name) if err != nil { return instance.State{}, err } diff --git a/app/sessions/sessions_manager_test.go b/app/sessions/sessions_manager_test.go index 994b34e..34c0571 100644 --- a/app/sessions/sessions_manager_test.go +++ b/app/sessions/sessions_manager_test.go @@ -114,7 +114,7 @@ func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) { store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - provider.On("GetState", mock.Anything).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) + provider.On("InstanceInspect", mock.Anything).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) errchan := make(chan error) go func() { @@ -136,7 +136,7 @@ func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) { store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - provider.On("GetState", mock.Anything).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) + provider.On("InstanceInspect", mock.Anything).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) errchan := make(chan error) go func() { diff --git a/cmd/root_test.go b/cmd/root_test.go index aa227bd..a0e5fee 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -178,7 +178,7 @@ func unsetEnvsFromFile(path string) { func mockStartCommand() *cobra.Command { cmd := &cobra.Command{ Use: "start", - Short: "Start the Sablier server", + Short: "InstanceStart the Sablier server", Run: func(cmd *cobra.Command, args []string) { viper.Unmarshal(&conf) diff --git a/cmd/start.go b/cmd/start.go index c0b085d..517aa7a 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -9,7 +9,7 @@ import ( var newStartCommand = func() *cobra.Command { return &cobra.Command{ Use: "start", - Short: "Start the Sablier server", + Short: "InstanceStart the Sablier server", Run: func(cmd *cobra.Command, args []string) { viper.Unmarshal(&conf) diff --git a/go.mod b/go.mod index 4f50c91..a374554 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.35.0 + github.com/testcontainers/testcontainers-go/modules/k3s v0.35.0 github.com/testcontainers/testcontainers-go/modules/valkey v0.35.0 github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f github.com/valkey-io/valkey-go v1.0.55 diff --git a/go.sum b/go.sum index 4ab4034..bf9f287 100644 --- a/go.sum +++ b/go.sum @@ -263,6 +263,8 @@ github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSW github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= github.com/testcontainers/testcontainers-go v0.35.0 h1:uADsZpTKFAtp8SLK+hMwSaa+X+JiERHtd4sQAFmXeMo= github.com/testcontainers/testcontainers-go v0.35.0/go.mod h1:oEVBj5zrfJTrgjwONs1SsRbnBtH9OKl+IGl3UMcr2B4= +github.com/testcontainers/testcontainers-go/modules/k3s v0.35.0 h1:zEfdO1Dz7sA2jNpf1PVCOI6FND1t/mDpaeDCguaLRXw= +github.com/testcontainers/testcontainers-go/modules/k3s v0.35.0/go.mod h1:YWc+Yph4EvIXHsjRAwPezJEvQGoOFP1AEbfhrYrylAM= github.com/testcontainers/testcontainers-go/modules/valkey v0.35.0 h1:0cX9txu8oW4NVXzaGMLBEOX/BBmWmQtd1X55JILNb6E= github.com/testcontainers/testcontainers-go/modules/valkey v0.35.0/go.mod h1:Bro7Md5b9MoFzM1bs/NWEwazdePpYBy96thih94pYxs= github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= diff --git a/go.work.sum b/go.work.sum index a27c866..57c25b7 100644 --- a/go.work.sum +++ b/go.work.sum @@ -608,6 +608,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20240312041847-bd984b5ce465 h1:KwWnWVW github.com/ianlancetaylor/demangle v0.0.0-20240312041847-bd984b5ce465/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk= +github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg= github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab h1:HqW4xhhynfjrtEiiSGcQUd6vrK23iMam1FO8rI7mwig= github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= diff --git a/pkg/provider/docker/container_inspect.go b/pkg/provider/docker/container_inspect.go index ab2a010..e149cf9 100644 --- a/pkg/provider/docker/container_inspect.go +++ b/pkg/provider/docker/container_inspect.go @@ -7,12 +7,12 @@ import ( "log/slog" ) -func (p *DockerClassicProvider) GetState(ctx context.Context, name string) (instance.State, error) { +func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { spec, err := p.Client.ContainerInspect(ctx, name) if err != nil { return instance.State{}, fmt.Errorf("cannot inspect container: %w", err) } - + // "created", "running", "paused", "restarting", "removing", "exited", or "dead" switch spec.State.Status { case "created", "paused", "restarting", "removing": diff --git a/pkg/provider/docker/container_inspect_test.go b/pkg/provider/docker/container_inspect_test.go index e13ed55..b2b42f0 100644 --- a/pkg/provider/docker/container_inspect_test.go +++ b/pkg/provider/docker/container_inspect_test.go @@ -269,9 +269,9 @@ func TestDockerClassicProvider_GetState(t *testing.T) { assert.NilError(t, err) tt.want.Name = name - got, err := p.GetState(ctx, name) + got, err := p.InstanceInspect(ctx, name) if !cmp.Equal(err, tt.wantErr) { - t.Errorf("DockerClassicProvider.GetState() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("DockerClassicProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) return } assert.DeepEqual(t, got, tt.want) diff --git a/pkg/provider/docker/container_list.go b/pkg/provider/docker/container_list.go index 0b1a364..e6f777a 100644 --- a/pkg/provider/docker/container_list.go +++ b/pkg/provider/docker/container_list.go @@ -50,7 +50,7 @@ func containerToInstance(c dockertypes.Container) types.Instance { } } -func (p *DockerClassicProvider) GetGroups(ctx context.Context) (map[string][]string, error) { +func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { args := filters.NewArgs() args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) diff --git a/pkg/provider/docker/container_list_test.go b/pkg/provider/docker/container_list_test.go index 099a277..5a8d77d 100644 --- a/pkg/provider/docker/container_list_test.go +++ b/pkg/provider/docker/container_list_test.go @@ -103,7 +103,7 @@ func TestDockerClassicProvider_GetGroups(t *testing.T) { i2, err := dind.client.ContainerInspect(ctx, c2.ID) assert.NilError(t, err) - got, err := p.GetGroups(ctx) + got, err := p.InstanceGroups(ctx) assert.NilError(t, err) want := map[string][]string{ diff --git a/pkg/provider/docker/container_start.go b/pkg/provider/docker/container_start.go index 45bc4c5..5378002 100644 --- a/pkg/provider/docker/container_start.go +++ b/pkg/provider/docker/container_start.go @@ -6,8 +6,8 @@ import ( "github.com/docker/docker/api/types/container" ) -func (p *DockerClassicProvider) Start(ctx context.Context, name string) error { - // TODO: Start should block until the container is ready. +func (p *DockerClassicProvider) InstanceStart(ctx context.Context, name string) error { + // TODO: InstanceStart should block until the container is ready. err := p.Client.ContainerStart(ctx, name, container.StartOptions{}) if err != nil { return fmt.Errorf("cannot start container %s: %w", name, err) diff --git a/pkg/provider/docker/container_start_test.go b/pkg/provider/docker/container_start_test.go index 89addfc..ef8d2fe 100644 --- a/pkg/provider/docker/container_start_test.go +++ b/pkg/provider/docker/container_start_test.go @@ -53,7 +53,7 @@ func TestDockerClassicProvider_Start(t *testing.T) { name, err := tt.args.do(c) assert.NilError(t, err) - err = p.Start(t.Context(), name) + err = p.InstanceStart(t.Context(), name) if tt.err != nil { assert.Error(t, err, tt.err.Error()) } else { diff --git a/pkg/provider/docker/container_stop.go b/pkg/provider/docker/container_stop.go index 2330fde..212c242 100644 --- a/pkg/provider/docker/container_stop.go +++ b/pkg/provider/docker/container_stop.go @@ -7,7 +7,7 @@ import ( "log/slog" ) -func (p *DockerClassicProvider) Stop(ctx context.Context, name string) error { +func (p *DockerClassicProvider) InstanceStop(ctx context.Context, name string) error { p.l.DebugContext(ctx, "stopping container", slog.String("name", name)) err := p.Client.ContainerStop(ctx, name, container.StopOptions{}) if err != nil { diff --git a/pkg/provider/docker/container_stop_test.go b/pkg/provider/docker/container_stop_test.go index dfafd16..d5a9374 100644 --- a/pkg/provider/docker/container_stop_test.go +++ b/pkg/provider/docker/container_stop_test.go @@ -62,7 +62,7 @@ func TestDockerClassicProvider_Stop(t *testing.T) { name, err := tt.args.do(c) assert.NilError(t, err) - err = p.Stop(t.Context(), name) + err = p.InstanceStop(t.Context(), name) if tt.err != nil { assert.Error(t, err, tt.err.Error()) } else { diff --git a/pkg/provider/dockerswarm/docker_swarm.go b/pkg/provider/dockerswarm/docker_swarm.go index 3c6ae6f..adde3a5 100644 --- a/pkg/provider/dockerswarm/docker_swarm.go +++ b/pkg/provider/dockerswarm/docker_swarm.go @@ -44,7 +44,7 @@ func NewDockerSwarmProvider(ctx context.Context, cli *client.Client, logger *slo } -func (p *DockerSwarmProvider) scale(ctx context.Context, name string, replicas uint64) error { +func (p *DockerSwarmProvider) ServiceUpdateReplicas(ctx context.Context, name string, replicas uint64) error { service, err := p.getServiceByName(name, ctx) if err != nil { return err diff --git a/pkg/provider/dockerswarm/service_inspect.go b/pkg/provider/dockerswarm/service_inspect.go index ee3fbad..81ed1b7 100644 --- a/pkg/provider/dockerswarm/service_inspect.go +++ b/pkg/provider/dockerswarm/service_inspect.go @@ -10,7 +10,7 @@ import ( "github.com/sablierapp/sablier/app/instance" ) -func (p *DockerSwarmProvider) GetState(ctx context.Context, name string) (instance.State, error) { +func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { service, err := p.getServiceByName(name, ctx) if err != nil { return instance.State{}, err diff --git a/pkg/provider/dockerswarm/service_inspect_test.go b/pkg/provider/dockerswarm/service_inspect_test.go index 01ce488..a8304ff 100644 --- a/pkg/provider/dockerswarm/service_inspect_test.go +++ b/pkg/provider/dockerswarm/service_inspect_test.go @@ -133,9 +133,9 @@ func TestDockerSwarmProvider_GetState(t *testing.T) { assert.NilError(t, err) tt.want.Name = name - got, err := p.GetState(ctx, name) + got, err := p.InstanceInspect(ctx, name) if !cmp.Equal(err, tt.wantErr) { - t.Errorf("DockerSwarmProvider.GetState() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("DockerSwarmProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) return } assert.DeepEqual(t, got, tt.want) diff --git a/pkg/provider/dockerswarm/service_list.go b/pkg/provider/dockerswarm/service_list.go index 7f98e05..32e6d71 100644 --- a/pkg/provider/dockerswarm/service_list.go +++ b/pkg/provider/dockerswarm/service_list.go @@ -50,7 +50,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i types.Instan } } -func (p *DockerSwarmProvider) GetGroups(ctx context.Context) (map[string][]string, error) { +func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { f := filters.NewArgs() f.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) diff --git a/pkg/provider/dockerswarm/service_list_test.go b/pkg/provider/dockerswarm/service_list_test.go index 3cb3619..05ba204 100644 --- a/pkg/provider/dockerswarm/service_list_test.go +++ b/pkg/provider/dockerswarm/service_list_test.go @@ -101,7 +101,7 @@ func TestDockerClassicProvider_GetGroups(t *testing.T) { i2, _, err := dind.client.ServiceInspectWithRaw(ctx, s2.ID, dockertypes.ServiceInspectOptions{}) assert.NilError(t, err) - got, err := p.GetGroups(ctx) + got, err := p.InstanceGroups(ctx) assert.NilError(t, err) want := map[string][]string{ diff --git a/pkg/provider/dockerswarm/service_start.go b/pkg/provider/dockerswarm/service_start.go index bb97030..41451c9 100644 --- a/pkg/provider/dockerswarm/service_start.go +++ b/pkg/provider/dockerswarm/service_start.go @@ -2,6 +2,6 @@ package dockerswarm import "context" -func (p *DockerSwarmProvider) Start(ctx context.Context, name string) error { - return p.scale(ctx, name, uint64(p.desiredReplicas)) +func (p *DockerSwarmProvider) InstanceStart(ctx context.Context, name string) error { + return p.ServiceUpdateReplicas(ctx, name, uint64(p.desiredReplicas)) } diff --git a/pkg/provider/dockerswarm/service_start_test.go b/pkg/provider/dockerswarm/service_start_test.go index e586098..d7ac215 100644 --- a/pkg/provider/dockerswarm/service_start_test.go +++ b/pkg/provider/dockerswarm/service_start_test.go @@ -132,9 +132,9 @@ func TestDockerSwarmProvider_Start(t *testing.T) { assert.NilError(t, err) tt.want.Name = name - err = p.Start(ctx, name) + err = p.InstanceStart(ctx, name) if !cmp.Equal(err, tt.wantErr) { - t.Errorf("DockerSwarmProvider.Stop() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("DockerSwarmProvider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr) return } diff --git a/pkg/provider/dockerswarm/service_stop.go b/pkg/provider/dockerswarm/service_stop.go index a05a474..2061255 100644 --- a/pkg/provider/dockerswarm/service_stop.go +++ b/pkg/provider/dockerswarm/service_stop.go @@ -2,6 +2,6 @@ package dockerswarm import "context" -func (p *DockerSwarmProvider) Stop(ctx context.Context, name string) error { - return p.scale(ctx, name, 0) +func (p *DockerSwarmProvider) InstanceStop(ctx context.Context, name string) error { + return p.ServiceUpdateReplicas(ctx, name, 0) } diff --git a/pkg/provider/dockerswarm/service_stop_test.go b/pkg/provider/dockerswarm/service_stop_test.go index 4f79034..099a67f 100644 --- a/pkg/provider/dockerswarm/service_stop_test.go +++ b/pkg/provider/dockerswarm/service_stop_test.go @@ -100,9 +100,9 @@ func TestDockerSwarmProvider_Stop(t *testing.T) { assert.NilError(t, err) tt.want.Name = name - err = p.Stop(ctx, name) + err = p.InstanceStop(ctx, name) if !cmp.Equal(err, tt.wantErr) { - t.Errorf("DockerSwarmProvider.Stop() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("DockerSwarmProvider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr) return } diff --git a/pkg/provider/kubernetes/deployment_events.go b/pkg/provider/kubernetes/deployment_events.go new file mode 100644 index 0000000..e6eb7de --- /dev/null +++ b/pkg/provider/kubernetes/deployment_events.go @@ -0,0 +1,37 @@ +package kubernetes + +import ( + appsv1 "k8s.io/api/apps/v1" + core_v1 "k8s.io/api/core/v1" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" + "time" +) + +func (p *KubernetesProvider) watchDeployents(instance chan<- string) cache.SharedIndexInformer { + handler := cache.ResourceEventHandlerFuncs{ + UpdateFunc: func(old, new interface{}) { + newDeployment := new.(*appsv1.Deployment) + oldDeployment := old.(*appsv1.Deployment) + + if newDeployment.ObjectMeta.ResourceVersion == oldDeployment.ObjectMeta.ResourceVersion { + return + } + + if *newDeployment.Spec.Replicas == 0 { + parsed := DeploymentName(newDeployment, ParseOptions{Delimiter: p.delimiter}) + instance <- parsed.Original + } + }, + DeleteFunc: func(obj interface{}) { + deletedDeployment := obj.(*appsv1.Deployment) + parsed := DeploymentName(deletedDeployment, ParseOptions{Delimiter: p.delimiter}) + instance <- parsed.Original + }, + } + factory := informers.NewSharedInformerFactoryWithOptions(p.Client, 2*time.Second, informers.WithNamespace(core_v1.NamespaceAll)) + informer := factory.Apps().V1().Deployments().Informer() + + informer.AddEventHandler(handler) + return informer +} diff --git a/pkg/provider/kubernetes/deployment_inspect.go b/pkg/provider/kubernetes/deployment_inspect.go new file mode 100644 index 0000000..d355d5e --- /dev/null +++ b/pkg/provider/kubernetes/deployment_inspect.go @@ -0,0 +1,22 @@ +package kubernetes + +import ( + "context" + "fmt" + "github.com/sablierapp/sablier/app/instance" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (instance.State, error) { + d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) + if err != nil { + return instance.State{}, fmt.Errorf("error getting deployment: %w", err) + } + + // TODO: Should add option to set ready as soon as one replica is ready + if *d.Spec.Replicas != 0 && *d.Spec.Replicas == d.Status.ReadyReplicas { + return instance.ReadyInstanceState(config.Original, config.Replicas), nil + } + + return instance.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil +} diff --git a/pkg/provider/kubernetes/deployment_inspect_test.go b/pkg/provider/kubernetes/deployment_inspect_test.go new file mode 100644 index 0000000..cc7a3af --- /dev/null +++ b/pkg/provider/kubernetes/deployment_inspect_test.go @@ -0,0 +1,135 @@ +package kubernetes_test + +import ( + "context" + "fmt" + "github.com/google/go-cmp/cmp" + "github.com/neilotoole/slogt" + "github.com/sablierapp/sablier/app/instance" + "github.com/sablierapp/sablier/config" + "github.com/sablierapp/sablier/pkg/provider/kubernetes" + "gotest.tools/v3/assert" + autoscalingv1 "k8s.io/api/autoscaling/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "testing" +) + +func TestKubernetesProvider_DeploymentInspect(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + ctx := context.Background() + type args struct { + do func(dind *kindContainer) (string, error) + } + tests := []struct { + name string + args args + want instance.State + wantErr error + }{ + { + name: "deployment with 1/1 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicDeployment(ctx, MimicOptions{ + Cmd: []string{"/mimic"}, + Healthcheck: nil, + }) + if err != nil { + return "", err + } + + if err = WaitForDeploymentReady(ctx, dind.client, "default", d.Name); err != nil { + return "", fmt.Errorf("error waiting for deployment: %w", err) + } + + return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 1, + DesiredReplicas: 1, + Status: instance.Ready, + }, + wantErr: nil, + }, + { + name: "deployment with 0/1 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicDeployment(ctx, MimicOptions{ + Cmd: []string{"/mimic", "-running-after=1ms", "-healthy=false", "-healthy-after=10s"}, + Healthcheck: &corev1.Probe{}, + }) + if err != nil { + return "", err + } + + return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 0, + DesiredReplicas: 1, + Status: instance.NotReady, + }, + wantErr: nil, + }, + { + name: "deployment with 0/0 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicDeployment(ctx, MimicOptions{}) + if err != nil { + return "", err + } + + _, err = dind.client.AppsV1().Deployments(d.Namespace).UpdateScale(ctx, d.Name, &autoscalingv1.Scale{ + ObjectMeta: metav1.ObjectMeta{ + Name: d.Name, + }, + Spec: autoscalingv1.ScaleSpec{ + Replicas: 0, + }, + }, metav1.UpdateOptions{}) + if err != nil { + return "", err + } + + if err = WaitForDeploymentScale(ctx, dind.client, "default", d.Name, 0); err != nil { + return "", fmt.Errorf("error waiting for deployment: %w", err) + } + + return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 0, + DesiredReplicas: 1, + Status: instance.NotReady, + }, + wantErr: nil, + }, + } + c := setupKinD(t, ctx) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p, err := kubernetes.NewKubernetesProvider(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes) + + name, err := tt.args.do(c) + assert.NilError(t, err) + + tt.want.Name = name + got, err := p.InstanceInspect(ctx, name) + if !cmp.Equal(err, tt.wantErr) { + t.Errorf("KubernetesProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.DeepEqual(t, got, tt.want) + }) + } +} diff --git a/pkg/provider/kubernetes/deployment_list.go b/pkg/provider/kubernetes/deployment_list.go new file mode 100644 index 0000000..f5fc2f8 --- /dev/null +++ b/pkg/provider/kubernetes/deployment_list.go @@ -0,0 +1,49 @@ +package kubernetes + +import ( + "context" + "github.com/sablierapp/sablier/app/discovery" + "github.com/sablierapp/sablier/app/types" + "github.com/sablierapp/sablier/pkg/provider" + v1 "k8s.io/api/apps/v1" + core_v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "strings" +) + +func (p *KubernetesProvider) DeploymentList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { + deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + LabelSelector: strings.Join(options.Labels, ","), + }) + + if err != nil { + return nil, err + } + + instances := make([]types.Instance, 0, len(deployments.Items)) + for _, d := range deployments.Items { + instance := p.deploymentToInstance(&d) + instances = append(instances, instance) + } + + return instances, nil +} + +func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) types.Instance { + var group string + + if _, ok := d.Labels[discovery.LabelEnable]; ok { + if g, ok := d.Labels[discovery.LabelGroup]; ok { + group = g + } else { + group = discovery.LabelGroupDefaultValue + } + } + + parsed := DeploymentName(d, ParseOptions{Delimiter: p.delimiter}) + + return types.Instance{ + Name: parsed.Original, + Group: group, + } +} diff --git a/pkg/provider/kubernetes/instance_events.go b/pkg/provider/kubernetes/instance_events.go new file mode 100644 index 0000000..0bdb81f --- /dev/null +++ b/pkg/provider/kubernetes/instance_events.go @@ -0,0 +1,10 @@ +package kubernetes + +import "context" + +func (p *KubernetesProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { + informer := p.watchDeployents(instance) + go informer.Run(ctx.Done()) + informer = p.watchStatefulSets(instance) + go informer.Run(ctx.Done()) +} diff --git a/pkg/provider/kubernetes/instance_inspect.go b/pkg/provider/kubernetes/instance_inspect.go new file mode 100644 index 0000000..626f2e4 --- /dev/null +++ b/pkg/provider/kubernetes/instance_inspect.go @@ -0,0 +1,23 @@ +package kubernetes + +import ( + "context" + "fmt" + "github.com/sablierapp/sablier/app/instance" +) + +func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { + parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) + if err != nil { + return instance.State{}, err + } + + switch parsed.Kind { + case "deployment": + return p.DeploymentInspect(ctx, parsed) + case "statefulset": + return p.StatefulSetInspect(ctx, parsed) + default: + return instance.State{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind) + } +} diff --git a/pkg/provider/kubernetes/instance_list.go b/pkg/provider/kubernetes/instance_list.go new file mode 100644 index 0000000..f4eb9f6 --- /dev/null +++ b/pkg/provider/kubernetes/instance_list.go @@ -0,0 +1,21 @@ +package kubernetes + +import ( + "context" + "github.com/sablierapp/sablier/app/types" + "github.com/sablierapp/sablier/pkg/provider" +) + +func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { + deployments, err := p.DeploymentList(ctx, options) + if err != nil { + return nil, err + } + + statefulSets, err := p.StatefulSetList(ctx, options) + if err != nil { + return nil, err + } + + return append(deployments, statefulSets...), nil +} diff --git a/pkg/provider/kubernetes/kubernetes.go b/pkg/provider/kubernetes/kubernetes.go index a3c8a78..c95949f 100644 --- a/pkg/provider/kubernetes/kubernetes.go +++ b/pkg/provider/kubernetes/kubernetes.go @@ -2,54 +2,28 @@ package kubernetes import ( "context" - "fmt" "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/pkg/provider" - "log/slog" - "time" - - appsv1 "k8s.io/api/apps/v1" core_v1 "k8s.io/api/core/v1" + "log/slog" - "github.com/sablierapp/sablier/app/instance" providerConfig "github.com/sablierapp/sablier/config" - autoscalingv1 "k8s.io/api/autoscaling/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" - "k8s.io/client-go/tools/cache" ) // Interface guard var _ provider.Provider = (*KubernetesProvider)(nil) -type Workload interface { - GetScale(ctx context.Context, workloadName string, options metav1.GetOptions) (*autoscalingv1.Scale, error) - UpdateScale(ctx context.Context, workloadName string, scale *autoscalingv1.Scale, opts metav1.UpdateOptions) (*autoscalingv1.Scale, error) -} - type KubernetesProvider struct { Client kubernetes.Interface delimiter string l *slog.Logger } -func NewKubernetesProvider(ctx context.Context, logger *slog.Logger, providerConfig providerConfig.Kubernetes) (*KubernetesProvider, error) { +func NewKubernetesProvider(ctx context.Context, client *kubernetes.Clientset, logger *slog.Logger, kubeclientConfig providerConfig.Kubernetes) (*KubernetesProvider, error) { logger = logger.With(slog.String("provider", "kubernetes")) - kubeclientConfig, err := rest.InClusterConfig() - if err != nil { - return nil, err - } - kubeclientConfig.QPS = providerConfig.QPS - kubeclientConfig.Burst = providerConfig.Burst - - client, err := kubernetes.NewForConfig(kubeclientConfig) - if err != nil { - return nil, err - } - info, err := client.ServerVersion() if err != nil { return nil, err @@ -63,13 +37,13 @@ func NewKubernetesProvider(ctx context.Context, logger *slog.Logger, providerCon return &KubernetesProvider{ Client: client, - delimiter: providerConfig.Delimiter, + delimiter: kubeclientConfig.Delimiter, l: logger, }, nil } -func (p *KubernetesProvider) Start(ctx context.Context, name string) error { +func (p *KubernetesProvider) InstanceStart(ctx context.Context, name string) error { parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) if err != nil { return err @@ -78,17 +52,16 @@ func (p *KubernetesProvider) Start(ctx context.Context, name string) error { return p.scale(ctx, parsed, parsed.Replicas) } -func (p *KubernetesProvider) Stop(ctx context.Context, name string) error { +func (p *KubernetesProvider) InstanceStop(ctx context.Context, name string) error { parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) if err != nil { return err } return p.scale(ctx, parsed, 0) - } -func (p *KubernetesProvider) GetGroups(ctx context.Context) (map[string][]string, error) { +func (p *KubernetesProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ LabelSelector: discovery.LabelEnable, }) @@ -105,7 +78,7 @@ func (p *KubernetesProvider) GetGroups(ctx context.Context) (map[string][]string } group := groups[groupName] - parsed := DeploymentName(deployment, ParseOptions{Delimiter: p.delimiter}) + parsed := DeploymentName(&deployment, ParseOptions{Delimiter: p.delimiter}) group = append(group, parsed.Original) groups[groupName] = group } @@ -125,139 +98,10 @@ func (p *KubernetesProvider) GetGroups(ctx context.Context) (map[string][]string } group := groups[groupName] - parsed := StatefulSetName(statefulSet, ParseOptions{Delimiter: p.delimiter}) + parsed := StatefulSetName(&statefulSet, ParseOptions{Delimiter: p.delimiter}) group = append(group, parsed.Original) groups[groupName] = group } return groups, nil } - -func (p *KubernetesProvider) scale(ctx context.Context, config ParsedName, replicas int32) error { - var workload Workload - - switch config.Kind { - case "deployment": - workload = p.Client.AppsV1().Deployments(config.Namespace) - case "statefulset": - workload = p.Client.AppsV1().StatefulSets(config.Namespace) - default: - return fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", config.Kind) - } - - s, err := workload.GetScale(ctx, config.Name, metav1.GetOptions{}) - if err != nil { - return err - } - - s.Spec.Replicas = replicas - _, err = workload.UpdateScale(ctx, config.Name, s, metav1.UpdateOptions{}) - - return err -} - -func (p *KubernetesProvider) GetState(ctx context.Context, name string) (instance.State, error) { - parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) - if err != nil { - return instance.State{}, err - } - - switch parsed.Kind { - case "deployment": - return p.getDeploymentState(ctx, parsed) - case "statefulset": - return p.getStatefulsetState(ctx, parsed) - default: - return instance.State{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind) - } -} - -func (p *KubernetesProvider) getDeploymentState(ctx context.Context, config ParsedName) (instance.State, error) { - d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) - if err != nil { - return instance.State{}, err - } - - if *d.Spec.Replicas == d.Status.ReadyReplicas { - return instance.ReadyInstanceState(config.Original, config.Replicas), nil - } - - return instance.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil -} - -func (p *KubernetesProvider) getStatefulsetState(ctx context.Context, config ParsedName) (instance.State, error) { - ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) - if err != nil { - return instance.State{}, err - } - - if *ss.Spec.Replicas == ss.Status.ReadyReplicas { - return instance.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil - } - - return instance.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, *ss.Spec.Replicas), nil -} - -func (p *KubernetesProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { - - informer := p.watchDeployents(instance) - go informer.Run(ctx.Done()) - informer = p.watchStatefulSets(instance) - go informer.Run(ctx.Done()) -} - -func (p *KubernetesProvider) watchDeployents(instance chan<- string) cache.SharedIndexInformer { - handler := cache.ResourceEventHandlerFuncs{ - UpdateFunc: func(old, new interface{}) { - newDeployment := new.(*appsv1.Deployment) - oldDeployment := old.(*appsv1.Deployment) - - if newDeployment.ObjectMeta.ResourceVersion == oldDeployment.ObjectMeta.ResourceVersion { - return - } - - if *newDeployment.Spec.Replicas == 0 { - parsed := DeploymentName(*newDeployment, ParseOptions{Delimiter: p.delimiter}) - instance <- parsed.Original - } - }, - DeleteFunc: func(obj interface{}) { - deletedDeployment := obj.(*appsv1.Deployment) - parsed := DeploymentName(*deletedDeployment, ParseOptions{Delimiter: p.delimiter}) - instance <- parsed.Original - }, - } - factory := informers.NewSharedInformerFactoryWithOptions(p.Client, 2*time.Second, informers.WithNamespace(core_v1.NamespaceAll)) - informer := factory.Apps().V1().Deployments().Informer() - - informer.AddEventHandler(handler) - return informer -} - -func (p *KubernetesProvider) watchStatefulSets(instance chan<- string) cache.SharedIndexInformer { - handler := cache.ResourceEventHandlerFuncs{ - UpdateFunc: func(old, new interface{}) { - newStatefulSet := new.(*appsv1.StatefulSet) - oldStatefulSet := old.(*appsv1.StatefulSet) - - if newStatefulSet.ObjectMeta.ResourceVersion == oldStatefulSet.ObjectMeta.ResourceVersion { - return - } - - if *newStatefulSet.Spec.Replicas == 0 { - parsed := StatefulSetName(*newStatefulSet, ParseOptions{Delimiter: p.delimiter}) - instance <- parsed.Original - } - }, - DeleteFunc: func(obj interface{}) { - deletedStatefulSet := obj.(*appsv1.StatefulSet) - parsed := StatefulSetName(*deletedStatefulSet, ParseOptions{Delimiter: p.delimiter}) - instance <- parsed.Original - }, - } - factory := informers.NewSharedInformerFactoryWithOptions(p.Client, 2*time.Second, informers.WithNamespace(core_v1.NamespaceAll)) - informer := factory.Apps().V1().StatefulSets().Informer() - - informer.AddEventHandler(handler) - return informer -} diff --git a/pkg/provider/kubernetes/kubernetes_test.go b/pkg/provider/kubernetes/kubernetes_test.go deleted file mode 100644 index 1eed2ef..0000000 --- a/pkg/provider/kubernetes/kubernetes_test.go +++ /dev/null @@ -1,290 +0,0 @@ -package kubernetes - -import ( - "context" - "github.com/neilotoole/slogt" - "github.com/sablierapp/sablier/pkg/provider/mocks" - "k8s.io/client-go/kubernetes" - "reflect" - "testing" - - "github.com/sablierapp/sablier/app/instance" - "github.com/stretchr/testify/mock" - v1 "k8s.io/api/apps/v1" - autoscalingv1 "k8s.io/api/autoscaling/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -func setupProvider(t *testing.T, client kubernetes.Interface) *KubernetesProvider { - t.Helper() - return &KubernetesProvider{ - Client: client, - delimiter: "_", - l: slogt.New(t), - } -} - -func TestKubernetesProvider_Start(t *testing.T) { - type data struct { - name string - get *autoscalingv1.Scale - update *autoscalingv1.Scale - } - type args struct { - name string - } - tests := []struct { - name string - args args - want instance.State - data data - wantErr bool - }{ - { - name: "scale nginx deployment to 2 replicas", - args: args{ - name: "deployment_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(0), - update: mocks.V1Scale(2), - }, - wantErr: false, - }, - { - name: "scale nginx statefulset to 2 replicas", - args: args{ - name: "statefulset_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(0), - update: mocks.V1Scale(2), - }, - wantErr: false, - }, - { - name: "scale unsupported kind", - args: args{ - name: "gateway_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(0), - update: mocks.V1Scale(0), - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - deploymentAPI := mocks.DeploymentMock{} - statefulsetAPI := mocks.StatefulSetsMock{} - provider := setupProvider(t, mocks.NewKubernetesAPIClientMock(&deploymentAPI, &statefulsetAPI)) - - deploymentAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil) - deploymentAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil) - - statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil) - statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil) - - err := provider.Start(context.Background(), tt.args.name) - if (err != nil) != tt.wantErr { - t.Errorf("KubernetesProvider.Start() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func TestKubernetesProvider_Stop(t *testing.T) { - type data struct { - name string - get *autoscalingv1.Scale - update *autoscalingv1.Scale - } - type args struct { - name string - } - tests := []struct { - name string - args args - want instance.State - data data - wantErr bool - }{ - { - name: "scale nginx deployment to 2 replicas", - args: args{ - name: "deployment_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(2), - update: mocks.V1Scale(0), - }, - wantErr: false, - }, - { - name: "scale nginx statefulset to 2 replicas", - args: args{ - name: "statefulset_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(2), - update: mocks.V1Scale(0), - }, - wantErr: false, - }, - { - name: "scale unsupported kind", - args: args{ - name: "gateway_default_nginx_2", - }, - data: data{ - name: "nginx", - get: mocks.V1Scale(0), - update: mocks.V1Scale(0), - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - deploymentAPI := mocks.DeploymentMock{} - statefulsetAPI := mocks.StatefulSetsMock{} - provider := setupProvider(t, mocks.NewKubernetesAPIClientMock(&deploymentAPI, &statefulsetAPI)) - - deploymentAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil) - deploymentAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil) - - statefulsetAPI.On("GetScale", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.get, nil) - statefulsetAPI.On("UpdateScale", mock.Anything, tt.data.name, tt.data.update, metav1.UpdateOptions{}).Return(nil, nil) - - err := provider.Stop(context.Background(), tt.args.name) - if (err != nil) != tt.wantErr { - t.Errorf("KubernetesProvider.Stop() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func TestKubernetesProvider_GetState(t *testing.T) { - type data struct { - name string - getDeployment *v1.Deployment - getStatefulSet *v1.StatefulSet - } - type args struct { - name string - } - tests := []struct { - name string - args args - want instance.State - data data - wantErr bool - }{ - { - name: "ready nginx deployment with 2 ready replicas", - args: args{ - name: "deployment_default_nginx_2", - }, - want: instance.State{ - Name: "deployment_default_nginx_2", - CurrentReplicas: 2, - DesiredReplicas: 2, - Status: instance.Ready, - }, - data: data{ - name: "nginx", - getDeployment: mocks.V1Deployment(2, 2), - }, - wantErr: false, - }, - { - name: "not ready nginx deployment with 1 ready replica out of 2", - args: args{ - name: "deployment_default_nginx_2", - }, - want: instance.State{ - Name: "deployment_default_nginx_2", - CurrentReplicas: 1, - DesiredReplicas: 2, - Status: instance.NotReady, - }, - data: data{ - name: "nginx", - getDeployment: mocks.V1Deployment(2, 1), - }, - wantErr: false, - }, - { - name: "ready nginx statefulset to 2 replicas", - args: args{ - name: "statefulset_default_nginx_2", - }, - want: instance.State{ - Name: "statefulset_default_nginx_2", - CurrentReplicas: 2, - DesiredReplicas: 2, - Status: instance.Ready, - }, - data: data{ - name: "nginx", - getStatefulSet: mocks.V1StatefulSet(2, 2), - }, - wantErr: false, - }, - { - name: "not ready nginx statefulset to 1 ready replica out of 2", - args: args{ - name: "statefulset_default_nginx_2", - }, - want: instance.State{ - Name: "statefulset_default_nginx_2", - CurrentReplicas: 1, - DesiredReplicas: 2, - Status: instance.NotReady, - }, - data: data{ - name: "nginx", - getStatefulSet: mocks.V1StatefulSet(2, 1), - }, - wantErr: false, - }, - { - name: "scale unsupported kind", - args: args{ - name: "gateway_default_nginx_2", - }, - want: instance.State{}, - data: data{ - name: "nginx", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - deploymentAPI := mocks.DeploymentMock{} - statefulsetAPI := mocks.StatefulSetsMock{} - provider := setupProvider(t, mocks.NewKubernetesAPIClientMock(&deploymentAPI, &statefulsetAPI)) - - deploymentAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getDeployment, nil) - statefulsetAPI.On("Get", mock.Anything, tt.data.name, metav1.GetOptions{}).Return(tt.data.getStatefulSet, nil) - - got, err := provider.GetState(context.Background(), tt.args.name) - if (err != nil) != tt.wantErr { - t.Errorf("KubernetesProvider.GetState() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("KubernetesProvider.GetState() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/provider/kubernetes/list.go b/pkg/provider/kubernetes/list.go deleted file mode 100644 index 6bb899f..0000000 --- a/pkg/provider/kubernetes/list.go +++ /dev/null @@ -1,100 +0,0 @@ -package kubernetes - -import ( - "context" - "github.com/sablierapp/sablier/app/discovery" - "github.com/sablierapp/sablier/app/types" - "github.com/sablierapp/sablier/pkg/provider" - v1 "k8s.io/api/apps/v1" - core_v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "strings" -) - -func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { - deployments, err := p.deploymentList(ctx, options) - if err != nil { - return nil, err - } - - statefulSets, err := p.statefulSetList(ctx, options) - if err != nil { - return nil, err - } - - return append(deployments, statefulSets...), nil -} - -func (p *KubernetesProvider) deploymentList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { - deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ - LabelSelector: strings.Join(options.Labels, ","), - }) - - if err != nil { - return nil, err - } - - instances := make([]types.Instance, 0, len(deployments.Items)) - for _, d := range deployments.Items { - instance := p.deploymentToInstance(d) - instances = append(instances, instance) - } - - return instances, nil -} - -func (p *KubernetesProvider) deploymentToInstance(d v1.Deployment) types.Instance { - var group string - - if _, ok := d.Labels[discovery.LabelEnable]; ok { - if g, ok := d.Labels[discovery.LabelGroup]; ok { - group = g - } else { - group = discovery.LabelGroupDefaultValue - } - } - - parsed := DeploymentName(d, ParseOptions{Delimiter: p.delimiter}) - - return types.Instance{ - Name: parsed.Original, - Group: group, - } -} - -func (p *KubernetesProvider) statefulSetList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { - statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ - LabelSelector: strings.Join(options.Labels, ","), - }) - - if err != nil { - return nil, err - } - - instances := make([]types.Instance, 0, len(statefulSets.Items)) - for _, ss := range statefulSets.Items { - instance := p.statefulSetToInstance(ss) - instances = append(instances, instance) - } - - return instances, nil -} - -func (p *KubernetesProvider) statefulSetToInstance(ss v1.StatefulSet) types.Instance { - var group string - - if _, ok := ss.Labels[discovery.LabelEnable]; ok { - if g, ok := ss.Labels[discovery.LabelGroup]; ok { - group = g - } else { - group = discovery.LabelGroupDefaultValue - } - } - - parsed := StatefulSetName(ss, ParseOptions{Delimiter: p.delimiter}) - - return types.Instance{ - Name: parsed.Original, - Group: group, - } -} diff --git a/pkg/provider/kubernetes/parse_name.go b/pkg/provider/kubernetes/parse_name.go index 90df1ab..f83c070 100644 --- a/pkg/provider/kubernetes/parse_name.go +++ b/pkg/provider/kubernetes/parse_name.go @@ -41,7 +41,7 @@ func ParseName(name string, opts ParseOptions) (ParsedName, error) { }, nil } -func DeploymentName(deployment v1.Deployment, opts ParseOptions) ParsedName { +func DeploymentName(deployment *v1.Deployment, opts ParseOptions) ParsedName { kind := "deployment" namespace := deployment.Namespace name := deployment.Name @@ -57,7 +57,7 @@ func DeploymentName(deployment v1.Deployment, opts ParseOptions) ParsedName { } } -func StatefulSetName(statefulSet v1.StatefulSet, opts ParseOptions) ParsedName { +func StatefulSetName(statefulSet *v1.StatefulSet, opts ParseOptions) ParsedName { kind := "statefulset" namespace := statefulSet.Namespace name := statefulSet.Name diff --git a/pkg/provider/kubernetes/parse_name_test.go b/pkg/provider/kubernetes/parse_name_test.go index e96fba2..58b81e4 100644 --- a/pkg/provider/kubernetes/parse_name_test.go +++ b/pkg/provider/kubernetes/parse_name_test.go @@ -79,7 +79,7 @@ func TestDeploymentName(t *testing.T) { Replicas: 1, } - result := DeploymentName(deployment, opts) + result := DeploymentName(&deployment, opts) if result != expected { t.Errorf("expected %v but got %v", expected, result) } @@ -101,7 +101,7 @@ func TestStatefulSetName(t *testing.T) { Replicas: 1, } - result := StatefulSetName(statefulSet, opts) + result := StatefulSetName(&statefulSet, opts) if result != expected { t.Errorf("expected %v but got %v", expected, result) } diff --git a/pkg/provider/kubernetes/statefulset_events.go b/pkg/provider/kubernetes/statefulset_events.go new file mode 100644 index 0000000..57e931d --- /dev/null +++ b/pkg/provider/kubernetes/statefulset_events.go @@ -0,0 +1,37 @@ +package kubernetes + +import ( + appsv1 "k8s.io/api/apps/v1" + core_v1 "k8s.io/api/core/v1" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" + "time" +) + +func (p *KubernetesProvider) watchStatefulSets(instance chan<- string) cache.SharedIndexInformer { + handler := cache.ResourceEventHandlerFuncs{ + UpdateFunc: func(old, new interface{}) { + newStatefulSet := new.(*appsv1.StatefulSet) + oldStatefulSet := old.(*appsv1.StatefulSet) + + if newStatefulSet.ObjectMeta.ResourceVersion == oldStatefulSet.ObjectMeta.ResourceVersion { + return + } + + if *newStatefulSet.Spec.Replicas == 0 { + parsed := StatefulSetName(newStatefulSet, ParseOptions{Delimiter: p.delimiter}) + instance <- parsed.Original + } + }, + DeleteFunc: func(obj interface{}) { + deletedStatefulSet := obj.(*appsv1.StatefulSet) + parsed := StatefulSetName(deletedStatefulSet, ParseOptions{Delimiter: p.delimiter}) + instance <- parsed.Original + }, + } + factory := informers.NewSharedInformerFactoryWithOptions(p.Client, 2*time.Second, informers.WithNamespace(core_v1.NamespaceAll)) + informer := factory.Apps().V1().StatefulSets().Informer() + + informer.AddEventHandler(handler) + return informer +} diff --git a/pkg/provider/kubernetes/statefulset_inspect.go b/pkg/provider/kubernetes/statefulset_inspect.go new file mode 100644 index 0000000..19e3131 --- /dev/null +++ b/pkg/provider/kubernetes/statefulset_inspect.go @@ -0,0 +1,20 @@ +package kubernetes + +import ( + "context" + "github.com/sablierapp/sablier/app/instance" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (instance.State, error) { + ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) + if err != nil { + return instance.State{}, err + } + + if *ss.Spec.Replicas != 0 && *ss.Spec.Replicas == ss.Status.ReadyReplicas { + return instance.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil + } + + return instance.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, config.Replicas), nil +} diff --git a/pkg/provider/kubernetes/statefulset_inspect_test.go b/pkg/provider/kubernetes/statefulset_inspect_test.go new file mode 100644 index 0000000..7fc6e1e --- /dev/null +++ b/pkg/provider/kubernetes/statefulset_inspect_test.go @@ -0,0 +1,135 @@ +package kubernetes_test + +import ( + "context" + "fmt" + "github.com/google/go-cmp/cmp" + "github.com/neilotoole/slogt" + "github.com/sablierapp/sablier/app/instance" + "github.com/sablierapp/sablier/config" + "github.com/sablierapp/sablier/pkg/provider/kubernetes" + "gotest.tools/v3/assert" + autoscalingv1 "k8s.io/api/autoscaling/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "testing" +) + +func TestKubernetesProvider_InspectStatefulSet(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + ctx := context.Background() + type args struct { + do func(dind *kindContainer) (string, error) + } + tests := []struct { + name string + args args + want instance.State + wantErr error + }{ + { + name: "statefulSet with 1/1 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicStatefulSet(ctx, MimicOptions{ + Cmd: []string{"/mimic"}, + Healthcheck: nil, + }) + if err != nil { + return "", err + } + + if err = WaitForStatefulSetReady(ctx, dind.client, "default", d.Name); err != nil { + return "", fmt.Errorf("error waiting for statefulSet: %w", err) + } + + return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 1, + DesiredReplicas: 1, + Status: instance.Ready, + }, + wantErr: nil, + }, + { + name: "statefulSet with 0/1 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicStatefulSet(ctx, MimicOptions{ + Cmd: []string{"/mimic", "-running-after=1ms", "-healthy=false", "-healthy-after=10s"}, + Healthcheck: &corev1.Probe{}, + }) + if err != nil { + return "", err + } + + return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 0, + DesiredReplicas: 1, + Status: instance.NotReady, + }, + wantErr: nil, + }, + { + name: "statefulSet with 0/0 replicas", + args: args{ + do: func(dind *kindContainer) (string, error) { + d, err := dind.CreateMimicStatefulSet(ctx, MimicOptions{}) + if err != nil { + return "", err + } + + _, err = dind.client.AppsV1().StatefulSets(d.Namespace).UpdateScale(ctx, d.Name, &autoscalingv1.Scale{ + ObjectMeta: metav1.ObjectMeta{ + Name: d.Name, + }, + Spec: autoscalingv1.ScaleSpec{ + Replicas: 0, + }, + }, metav1.UpdateOptions{}) + if err != nil { + return "", err + } + + if err = WaitForStatefulSetScale(ctx, dind.client, "default", d.Name, 0); err != nil { + return "", fmt.Errorf("error waiting for statefulSet: %w", err) + } + + return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil + }, + }, + want: instance.State{ + CurrentReplicas: 0, + DesiredReplicas: 1, + Status: instance.NotReady, + }, + wantErr: nil, + }, + } + c := setupKinD(t, ctx) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p, err := kubernetes.NewKubernetesProvider(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes) + + name, err := tt.args.do(c) + assert.NilError(t, err) + + tt.want.Name = name + got, err := p.InstanceInspect(ctx, name) + if !cmp.Equal(err, tt.wantErr) { + t.Errorf("DockerSwarmProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.DeepEqual(t, got, tt.want) + }) + } +} diff --git a/pkg/provider/kubernetes/statefulset_list.go b/pkg/provider/kubernetes/statefulset_list.go new file mode 100644 index 0000000..c607717 --- /dev/null +++ b/pkg/provider/kubernetes/statefulset_list.go @@ -0,0 +1,49 @@ +package kubernetes + +import ( + "context" + "github.com/sablierapp/sablier/app/discovery" + "github.com/sablierapp/sablier/app/types" + "github.com/sablierapp/sablier/pkg/provider" + v1 "k8s.io/api/apps/v1" + core_v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "strings" +) + +func (p *KubernetesProvider) StatefulSetList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { + statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + LabelSelector: strings.Join(options.Labels, ","), + }) + + if err != nil { + return nil, err + } + + instances := make([]types.Instance, 0, len(statefulSets.Items)) + for _, ss := range statefulSets.Items { + instance := p.statefulSetToInstance(&ss) + instances = append(instances, instance) + } + + return instances, nil +} + +func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) types.Instance { + var group string + + if _, ok := ss.Labels[discovery.LabelEnable]; ok { + if g, ok := ss.Labels[discovery.LabelGroup]; ok { + group = g + } else { + group = discovery.LabelGroupDefaultValue + } + } + + parsed := StatefulSetName(ss, ParseOptions{Delimiter: p.delimiter}) + + return types.Instance{ + Name: parsed.Original, + Group: group, + } +} diff --git a/pkg/provider/kubernetes/statefulset_list_test.go b/pkg/provider/kubernetes/statefulset_list_test.go new file mode 100644 index 0000000..bdeb4e5 --- /dev/null +++ b/pkg/provider/kubernetes/statefulset_list_test.go @@ -0,0 +1 @@ +package kubernetes_test diff --git a/pkg/provider/kubernetes/testcontainers_test.go b/pkg/provider/kubernetes/testcontainers_test.go new file mode 100644 index 0000000..edddb21 --- /dev/null +++ b/pkg/provider/kubernetes/testcontainers_test.go @@ -0,0 +1,246 @@ +package kubernetes_test + +import ( + "context" + "fmt" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/k3s" + "gotest.tools/v3/assert" + v1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" + "math/rand" + "sync" + "testing" + "time" +) + +var r = rand.New(rand.NewSource(time.Now().UnixNano())) +var mu sync.Mutex // r is not safe for concurrent use + +type kindContainer struct { + testcontainers.Container + client *kubernetes.Clientset + t *testing.T +} + +type MimicOptions struct { + Cmd []string + Healthcheck *corev1.Probe + Labels map[string]string +} + +func (d *kindContainer) CreateMimicDeployment(ctx context.Context, opts MimicOptions) (*v1.Deployment, error) { + if len(opts.Cmd) == 0 { + opts.Cmd = []string{"/mimic", "-running", "-running-after=1s", "-healthy=false"} + } + + name := generateRandomName() + // Add the app label to the deployment for matching the selector + if opts.Labels == nil { + opts.Labels = make(map[string]string) + } + opts.Labels["app"] = name + d.t.Log("Creating mimic deployment with options", opts) + replicas := int32(1) + return d.client.AppsV1().Deployments("default").Create(ctx, &v1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + Spec: v1.DeploymentSpec{ + Replicas: &replicas, + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": name, + }, + }, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mimic", + Image: "sablierapp/mimic:v0.3.1", + Command: opts.Cmd, + // ReadinessProbe: opts.Healthcheck, + }, + }, + }, + ObjectMeta: metav1.ObjectMeta{ + Labels: opts.Labels, + }, + }, + }, + }, metav1.CreateOptions{}) +} + +func (d *kindContainer) CreateMimicStatefulSet(ctx context.Context, opts MimicOptions) (*v1.StatefulSet, error) { + if len(opts.Cmd) == 0 { + opts.Cmd = []string{"/mimic", "-running", "-running-after=1s", "-healthy=false"} + } + + name := generateRandomName() + // Add the app label to the deployment for matching the selector + if opts.Labels == nil { + opts.Labels = make(map[string]string) + } + opts.Labels["app"] = name + d.t.Log("Creating mimic deployment with options", opts) + replicas := int32(1) + return d.client.AppsV1().StatefulSets("default").Create(ctx, &v1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + Spec: v1.StatefulSetSpec{ + Replicas: &replicas, + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app": name, + }, + }, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mimic", + Image: "sablierapp/mimic:v0.3.1", + Command: opts.Cmd, + // ReadinessProbe: opts.Healthcheck, + }, + }, + }, + ObjectMeta: metav1.ObjectMeta{ + Labels: opts.Labels, + }, + }, + }, + }, metav1.CreateOptions{}) +} + +func setupKinD(t *testing.T, ctx context.Context) *kindContainer { + t.Helper() + + kind, err := k3s.Run(ctx, "rancher/k3s:v1.27.1-k3s1") + testcontainers.CleanupContainer(t, kind) + assert.NilError(t, err) + + kubeConfigYaml, err := kind.GetKubeConfig(ctx) + assert.NilError(t, err) + + restcfg, err := clientcmd.RESTConfigFromKubeConfig(kubeConfigYaml) + assert.NilError(t, err) + + provider, err := testcontainers.ProviderDocker.GetProvider() + assert.NilError(t, err) + + err = provider.PullImage(ctx, "sablierapp/mimic:v0.3.1") + assert.NilError(t, err) + + err = kind.LoadImages(ctx, "sablierapp/mimic:v0.3.1") + assert.NilError(t, err) + + k8s, err := kubernetes.NewForConfig(restcfg) + assert.NilError(t, err) + + return &kindContainer{ + Container: kind, + client: k8s, + t: t, + } +} + +func generateRandomName() string { + mu.Lock() + defer mu.Unlock() + letters := []rune("abcdefghijklmnopqrstuvwxyz") + name := make([]rune, 10) + for i := range name { + name[i] = letters[r.Intn(len(letters))] + } + return string(name) +} + +func WaitForDeploymentReady(ctx context.Context, client kubernetes.Interface, namespace, name string) error { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context canceled while waiting for deployment %s/%s", namespace, name) + case <-ticker.C: + deployment, err := client.AppsV1().Deployments(namespace).Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("error getting deployment: %w", err) + } + + if deployment.Status.ReadyReplicas == *deployment.Spec.Replicas { + return nil + } + } + } +} + +func WaitForStatefulSetReady(ctx context.Context, client kubernetes.Interface, namespace, name string) error { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context canceled while waiting for statefulSet %s/%s", namespace, name) + case <-ticker.C: + statefulSet, err := client.AppsV1().StatefulSets(namespace).Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("error getting statefulSet: %w", err) + } + + if statefulSet.Status.ReadyReplicas == *statefulSet.Spec.Replicas { + return nil + } + } + } +} + +func WaitForDeploymentScale(ctx context.Context, client kubernetes.Interface, namespace, name string, replicas int32) error { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context canceled while waiting for deployment %s/%s scale", namespace, name) + case <-ticker.C: + deployment, err := client.AppsV1().Deployments(namespace).Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("error getting deployment: %w", err) + } + + if *deployment.Spec.Replicas == replicas && deployment.Status.Replicas == replicas { + return nil + } + } + } +} + +func WaitForStatefulSetScale(ctx context.Context, client kubernetes.Interface, namespace, name string, replicas int32) error { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context canceled while waiting for statefulSet %s/%s scale", namespace, name) + case <-ticker.C: + statefulSet, err := client.AppsV1().StatefulSets(namespace).Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("error getting statefulSet: %w", err) + } + + if *statefulSet.Spec.Replicas == replicas && statefulSet.Status.Replicas == replicas { + return nil + } + } + } +} diff --git a/pkg/provider/kubernetes/workload_scale.go b/pkg/provider/kubernetes/workload_scale.go new file mode 100644 index 0000000..2243069 --- /dev/null +++ b/pkg/provider/kubernetes/workload_scale.go @@ -0,0 +1,36 @@ +package kubernetes + +import ( + "context" + "fmt" + autoscalingv1 "k8s.io/api/autoscaling/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type Workload interface { + GetScale(ctx context.Context, workloadName string, options metav1.GetOptions) (*autoscalingv1.Scale, error) + UpdateScale(ctx context.Context, workloadName string, scale *autoscalingv1.Scale, opts metav1.UpdateOptions) (*autoscalingv1.Scale, error) +} + +func (p *KubernetesProvider) scale(ctx context.Context, config ParsedName, replicas int32) error { + var workload Workload + + switch config.Kind { + case "deployment": + workload = p.Client.AppsV1().Deployments(config.Namespace) + case "statefulset": + workload = p.Client.AppsV1().StatefulSets(config.Namespace) + default: + return fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", config.Kind) + } + + s, err := workload.GetScale(ctx, config.Name, metav1.GetOptions{}) + if err != nil { + return err + } + + s.Spec.Replicas = replicas + _, err = workload.UpdateScale(ctx, config.Name, s, metav1.UpdateOptions{}) + + return err +} diff --git a/pkg/provider/mock/mock.go b/pkg/provider/mock/mock.go index 2f658db..ce91766 100644 --- a/pkg/provider/mock/mock.go +++ b/pkg/provider/mock/mock.go @@ -16,19 +16,19 @@ type ProviderMock struct { mock.Mock } -func (m *ProviderMock) Start(ctx context.Context, name string) error { +func (m *ProviderMock) InstanceStart(ctx context.Context, name string) error { args := m.Called(ctx, name) return args.Error(0) } -func (m *ProviderMock) Stop(ctx context.Context, name string) error { +func (m *ProviderMock) InstanceStop(ctx context.Context, name string) error { args := m.Called(ctx, name) return args.Error(0) } -func (m *ProviderMock) GetState(ctx context.Context, name string) (instance.State, error) { +func (m *ProviderMock) InstanceInspect(ctx context.Context, name string) (instance.State, error) { args := m.Called(ctx, name) return args.Get(0).(instance.State), args.Error(1) } -func (m *ProviderMock) GetGroups(ctx context.Context) (map[string][]string, error) { +func (m *ProviderMock) InstanceGroups(ctx context.Context) (map[string][]string, error) { args := m.Called(ctx) return args.Get(0).(map[string][]string), args.Error(1) } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 3b57d5a..790ca5a 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -7,13 +7,13 @@ import ( "github.com/sablierapp/sablier/app/instance" ) -//go:generate mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go * +//go:generate go tool mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go * type Provider interface { - Start(ctx context.Context, name string) error - Stop(ctx context.Context, name string) error - GetState(ctx context.Context, name string) (instance.State, error) - GetGroups(ctx context.Context) (map[string][]string, error) + InstanceStart(ctx context.Context, name string) error + InstanceStop(ctx context.Context, name string) error + InstanceInspect(ctx context.Context, name string) (instance.State, error) + InstanceGroups(ctx context.Context) (map[string][]string, error) InstanceList(ctx context.Context, options InstanceListOptions) ([]types.Instance, error) NotifyInstanceStopped(ctx context.Context, instance chan<- string) diff --git a/pkg/provider/providertest/mock_provider.go b/pkg/provider/providertest/mock_provider.go index 869cb02..76f71a4 100644 --- a/pkg/provider/providertest/mock_provider.go +++ b/pkg/provider/providertest/mock_provider.go @@ -43,34 +43,34 @@ func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } -// GetGroups mocks base method. -func (m *MockProvider) GetGroups(ctx context.Context) (map[string][]string, error) { +// InstanceGroups mocks base method. +func (m *MockProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGroups", ctx) + ret := m.ctrl.Call(m, "InstanceGroups", ctx) ret0, _ := ret[0].(map[string][]string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetGroups indicates an expected call of GetGroups. -func (mr *MockProviderMockRecorder) GetGroups(ctx any) *gomock.Call { +// InstanceGroups indicates an expected call of InstanceGroups. +func (mr *MockProviderMockRecorder) InstanceGroups(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockProvider)(nil).GetGroups), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceGroups", reflect.TypeOf((*MockProvider)(nil).InstanceGroups), ctx) } -// GetState mocks base method. -func (m *MockProvider) GetState(ctx context.Context, name string) (instance.State, error) { +// InstanceInspect mocks base method. +func (m *MockProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetState", ctx, name) + ret := m.ctrl.Call(m, "InstanceInspect", ctx, name) ret0, _ := ret[0].(instance.State) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetState indicates an expected call of GetState. -func (mr *MockProviderMockRecorder) GetState(ctx, name any) *gomock.Call { +// InstanceInspect indicates an expected call of InstanceInspect. +func (mr *MockProviderMockRecorder) InstanceInspect(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MockProvider)(nil).GetState), ctx, name) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceInspect", reflect.TypeOf((*MockProvider)(nil).InstanceInspect), ctx, name) } // InstanceList mocks base method. @@ -88,6 +88,34 @@ func (mr *MockProviderMockRecorder) InstanceList(ctx, options any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceList", reflect.TypeOf((*MockProvider)(nil).InstanceList), ctx, options) } +// InstanceStart mocks base method. +func (m *MockProvider) InstanceStart(ctx context.Context, name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InstanceStart", ctx, name) + ret0, _ := ret[0].(error) + return ret0 +} + +// InstanceStart indicates an expected call of InstanceStart. +func (mr *MockProviderMockRecorder) InstanceStart(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceStart", reflect.TypeOf((*MockProvider)(nil).InstanceStart), ctx, name) +} + +// InstanceStop mocks base method. +func (m *MockProvider) InstanceStop(ctx context.Context, name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InstanceStop", ctx, name) + ret0, _ := ret[0].(error) + return ret0 +} + +// InstanceStop indicates an expected call of InstanceStop. +func (mr *MockProviderMockRecorder) InstanceStop(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceStop", reflect.TypeOf((*MockProvider)(nil).InstanceStop), ctx, name) +} + // NotifyInstanceStopped mocks base method. func (m *MockProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { m.ctrl.T.Helper() @@ -99,31 +127,3 @@ func (mr *MockProviderMockRecorder) NotifyInstanceStopped(ctx, instance any) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyInstanceStopped", reflect.TypeOf((*MockProvider)(nil).NotifyInstanceStopped), ctx, instance) } - -// Start mocks base method. -func (m *MockProvider) Start(ctx context.Context, name string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start", ctx, name) - ret0, _ := ret[0].(error) - return ret0 -} - -// Start indicates an expected call of Start. -func (mr *MockProviderMockRecorder) Start(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockProvider)(nil).Start), ctx, name) -} - -// Stop mocks base method. -func (m *MockProvider) Stop(ctx context.Context, name string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stop", ctx, name) - ret0, _ := ret[0].(error) - return ret0 -} - -// Stop indicates an expected call of Stop. -func (mr *MockProviderMockRecorder) Stop(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockProvider)(nil).Stop), ctx, name) -}