From b72c37a85a7974a5971a37b33e18440c71f11806 Mon Sep 17 00:00:00 2001 From: Alexis Couvreur Date: Sun, 9 Mar 2025 01:29:06 -0500 Subject: [PATCH] refactor: remove discovery package (#553) --- app/discovery/autostop_test.go | 77 --------------------- app/discovery/types.go | 16 ----- app/sablier.go | 3 +- app/storage/file.go | 58 ---------------- go.mod | 1 + go.sum | 8 ++- pkg/provider/docker/container_list.go | 15 ++-- pkg/provider/dockerswarm/service_list.go | 15 ++-- pkg/provider/kubernetes/deployment_list.go | 21 +++--- pkg/provider/kubernetes/statefulset_list.go | 21 +++--- pkg/provider/types.go | 3 +- {app/discovery => pkg/sablier}/autostop.go | 24 +++---- pkg/sablier/autostop_test.go | 69 ++++++++++++++++++ pkg/sablier/sablier.go | 1 + pkg/sablier/sablier_test.go | 21 ++++++ pkg/sablier/sabliertest/mocks_sablier.go | 14 ++++ pkg/sablier/session_request_test.go | 22 ++---- 17 files changed, 163 insertions(+), 226 deletions(-) delete mode 100644 app/discovery/autostop_test.go delete mode 100644 app/discovery/types.go delete mode 100644 app/storage/file.go rename {app/discovery => pkg/sablier}/autostop.go (50%) create mode 100644 pkg/sablier/autostop_test.go create mode 100644 pkg/sablier/sablier_test.go diff --git a/app/discovery/autostop_test.go b/app/discovery/autostop_test.go deleted file mode 100644 index dea395b..0000000 --- a/app/discovery/autostop_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package discovery_test - -import ( - "errors" - "github.com/neilotoole/slogt" - "github.com/sablierapp/sablier/app/discovery" - "github.com/sablierapp/sablier/pkg/provider" - "github.com/sablierapp/sablier/pkg/provider/providertest" - "github.com/sablierapp/sablier/pkg/sablier" - "github.com/sablierapp/sablier/pkg/store/inmemory" - gomock "go.uber.org/mock/gomock" - "gotest.tools/v3/assert" - "testing" - "time" -) - -func TestStopAllUnregisteredInstances(t *testing.T) { - ctrl := gomock.NewController(t) - p := providertest.NewMockProvider(ctrl) - - ctx := t.Context() - - // Define instances and registered instances - instances := []sablier.InstanceConfiguration{ - {Name: "instance1"}, - {Name: "instance2"}, - {Name: "instance3"}, - } - store := inmemory.NewInMemory() - err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute) - assert.NilError(t, err) - - // Set up expectations for InstanceList - p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{ - All: false, - Labels: []string{discovery.LabelEnable}, - }).Return(instances, nil) - - // Set up expectations for InstanceStop - p.EXPECT().InstanceStop(ctx, "instance2").Return(nil) - p.EXPECT().InstanceStop(ctx, "instance3").Return(nil) - - // Call the function under test - err = discovery.StopAllUnregisteredInstances(ctx, p, store, slogt.New(t)) - assert.NilError(t, err) -} - -func TestStopAllUnregisteredInstances_WithError(t *testing.T) { - ctrl := gomock.NewController(t) - p := providertest.NewMockProvider(ctrl) - - ctx := t.Context() - - // Define instances and registered instances - instances := []sablier.InstanceConfiguration{ - {Name: "instance1"}, - {Name: "instance2"}, - {Name: "instance3"}, - } - store := inmemory.NewInMemory() - err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute) - assert.NilError(t, err) - - // Set up expectations for InstanceList - p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{ - All: false, - Labels: []string{discovery.LabelEnable}, - }).Return(instances, nil) - - // Set up expectations for InstanceStop with error - p.EXPECT().InstanceStop(ctx, "instance2").Return(errors.New("stop error")) - p.EXPECT().InstanceStop(ctx, "instance3").Return(nil) - - // Call the function under test - err = discovery.StopAllUnregisteredInstances(ctx, p, store, slogt.New(t)) - assert.Error(t, err, "stop error") -} diff --git a/app/discovery/types.go b/app/discovery/types.go deleted file mode 100644 index deb7b44..0000000 --- a/app/discovery/types.go +++ /dev/null @@ -1,16 +0,0 @@ -package discovery - -const ( - LabelEnable = "sablier.enable" - LabelGroup = "sablier.group" - LabelGroupDefaultValue = "default" -) - -type Group struct { - Name string - Instances []Instance -} - -type Instance struct { - Name string -} diff --git a/app/sablier.go b/app/sablier.go index f32616f..df57214 100644 --- a/app/sablier.go +++ b/app/sablier.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "github.com/docker/docker/client" - "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/pkg/provider/docker" "github.com/sablierapp/sablier/pkg/provider/dockerswarm" @@ -74,7 +73,7 @@ func Start(ctx context.Context, conf config.Config) error { }() if conf.Provider.AutoStopOnStartup { - err := discovery.StopAllUnregisteredInstances(ctx, provider, store, logger) + err := s.StopAllUnregisteredInstances(ctx) if err != nil { logger.ErrorContext(ctx, "unable to stop unregistered instances", slog.Any("reason", err)) } diff --git a/app/storage/file.go b/app/storage/file.go deleted file mode 100644 index c06725a..0000000 --- a/app/storage/file.go +++ /dev/null @@ -1,58 +0,0 @@ -package storage - -import ( - "fmt" - "io" - "log/slog" - "os" - - "github.com/sablierapp/sablier/config" -) - -type Storage interface { - Reader() (io.ReadCloser, error) - Writer() (io.WriteCloser, error) -} - -type FileStorage struct { - file string - l *slog.Logger -} - -func NewFileStorage(config config.Storage, logger *slog.Logger) (Storage, error) { - logger = logger.With(slog.String("file", config.File)) - storage := &FileStorage{ - file: config.File, - } - - file, err := os.OpenFile(config.File, os.O_RDWR|os.O_CREATE, 0755) - if err != nil { - return nil, fmt.Errorf("unable to open file: %w", err) - } - defer file.Close() - - stats, err := file.Stat() - if err != nil { - return nil, fmt.Errorf("unable to read file info: %w", err) - } - - // Initialize file to an empty JSON3 - if stats.Size() == 0 { - _, err := file.WriteString("{}") - if err != nil { - return nil, fmt.Errorf("unable to initialize file to valid json: %w", err) - } - } - - logger.Info("storage successfully initialized") - - return storage, nil -} - -func (fs *FileStorage) Reader() (io.ReadCloser, error) { - return os.OpenFile(fs.file, os.O_RDWR|os.O_CREATE, 0755) -} - -func (fs *FileStorage) Writer() (io.WriteCloser, error) { - return os.OpenFile(fs.file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) -} diff --git a/go.mod b/go.mod index f57859c..5da35a9 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 // indirect github.com/ajg/form v1.5.1 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/bytedance/sonic v1.12.8 // indirect diff --git a/go.sum b/go.sum index 8ac27d7..f705904 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/acouvreur/httpexpect/v2 v2.16.0 h1:FGXaR9jt6IQMXxpqbM8YpX7EEvyERU0Lps3ooEc/gk8= -github.com/acouvreur/httpexpect/v2 v2.16.0/go.mod h1:7myOP3A3VyS4+qnA4cm8DAad8zMN+7zxDB80W9f8yIc= +github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 h1:ZBbLwSJqkHBuFDA6DUhhse0IGJ7T5bemHyNILUjvOq4= +github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2/go.mod h1:VSw57q4QFiWDbRnjdX8Cb3Ow0SFncRw+bA/ofY6Q83w= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= @@ -61,6 +61,8 @@ github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/gavv/httpexpect/v2 v2.17.0 h1:nIJqt5v5e4P7/0jODpX2gtSw+pHXUqdP28YcjqwDZmE= +github.com/gavv/httpexpect/v2 v2.17.0/go.mod h1:E8ENFlT9MZ3Si2sfM6c6ONdwXV2noBCGkhA+lkJgkP0= github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= @@ -121,6 +123,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjw github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f h1:7LYC+Yfkj3CTRcShK0KOL/w6iTiKyqqBA9a41Wnggw8= +github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f/go.mod h1:pFlLw2CfqZiIBOx6BuCeRLCrfxBJipTY0nIOF/VbGcI= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= diff --git a/pkg/provider/docker/container_list.go b/pkg/provider/docker/container_list.go index 2332672..6955f18 100644 --- a/pkg/provider/docker/container_list.go +++ b/pkg/provider/docker/container_list.go @@ -6,7 +6,6 @@ import ( dockertypes "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" - "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/sablier" "strings" @@ -14,7 +13,7 @@ import ( func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) { args := filters.NewArgs() - args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) + args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) containers, err := p.Client.ContainerList(ctx, container.ListOptions{ All: options.All, @@ -36,11 +35,11 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration { var group string - if _, ok := c.Labels[discovery.LabelEnable]; ok { - if g, ok := c.Labels[discovery.LabelGroup]; ok { + if _, ok := c.Labels["sablier.enable"]; ok { + if g, ok := c.Labels["sablier.group"]; ok { group = g } else { - group = discovery.LabelGroupDefaultValue + group = "default" } } @@ -52,7 +51,7 @@ func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { args := filters.NewArgs() - args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) + args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) containers, err := p.Client.ContainerList(ctx, container.ListOptions{ All: true, @@ -65,9 +64,9 @@ func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string] groups := make(map[string][]string) for _, c := range containers { - groupName := c.Labels[discovery.LabelGroup] + groupName := c.Labels["sablier.group"] if len(groupName) == 0 { - groupName = discovery.LabelGroupDefaultValue + groupName = "default" } group := groups[groupName] group = append(group, strings.TrimPrefix(c.Names[0], "/")) diff --git a/pkg/provider/dockerswarm/service_list.go b/pkg/provider/dockerswarm/service_list.go index 01b53a5..fff9e56 100644 --- a/pkg/provider/dockerswarm/service_list.go +++ b/pkg/provider/dockerswarm/service_list.go @@ -6,14 +6,13 @@ import ( dockertypes "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/swarm" - "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/sablier" ) func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) { args := filters.NewArgs() - args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) + args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) args.Add("mode", "replicated") services, err := p.Client.ServiceList(ctx, dockertypes.ServiceListOptions{ @@ -36,11 +35,11 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) { var group string - if _, ok := s.Spec.Labels[discovery.LabelEnable]; ok { - if g, ok := s.Spec.Labels[discovery.LabelGroup]; ok { + if _, ok := s.Spec.Labels["sablier.enable"]; ok { + if g, ok := s.Spec.Labels["sablier.group"]; ok { group = g } else { - group = discovery.LabelGroupDefaultValue + group = "default" } } @@ -52,7 +51,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.Inst func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { f := filters.NewArgs() - f.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) + f.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) services, err := p.Client.ServiceList(ctx, dockertypes.ServiceListOptions{ Filters: f, @@ -64,9 +63,9 @@ func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][] groups := make(map[string][]string) for _, service := range services { - groupName := service.Spec.Labels[discovery.LabelGroup] + groupName := service.Spec.Labels["sablier.group"] if len(groupName) == 0 { - groupName = discovery.LabelGroupDefaultValue + groupName = "default" } group := groups[groupName] diff --git a/pkg/provider/kubernetes/deployment_list.go b/pkg/provider/kubernetes/deployment_list.go index c911614..8b0362c 100644 --- a/pkg/provider/kubernetes/deployment_list.go +++ b/pkg/provider/kubernetes/deployment_list.go @@ -2,20 +2,19 @@ package kubernetes import ( "context" - "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/pkg/sablier" v1 "k8s.io/api/apps/v1" - core_v1 "k8s.io/api/core/v1" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) { labelSelector := metav1.LabelSelector{ MatchLabels: map[string]string{ - discovery.LabelEnable: "true", + "sablier.enable": "true", }, } - deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + deployments, err := p.Client.AppsV1().Deployments(corev1.NamespaceAll).List(ctx, metav1.ListOptions{ LabelSelector: metav1.FormatLabelSelector(&labelSelector), }) if err != nil { @@ -34,11 +33,11 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.Inst func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration { var group string - if _, ok := d.Labels[discovery.LabelEnable]; ok { - if g, ok := d.Labels[discovery.LabelGroup]; ok { + if _, ok := d.Labels["sablier.enable"]; ok { + if g, ok := d.Labels["sablier.group"]; ok { group = g } else { - group = discovery.LabelGroupDefaultValue + group = "default" } } @@ -53,10 +52,10 @@ func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.Inst func (p *KubernetesProvider) DeploymentGroups(ctx context.Context) (map[string][]string, error) { labelSelector := metav1.LabelSelector{ MatchLabels: map[string]string{ - discovery.LabelEnable: "true", + "sablier.enable": "true", }, } - deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + deployments, err := p.Client.AppsV1().Deployments(corev1.NamespaceAll).List(ctx, metav1.ListOptions{ LabelSelector: metav1.FormatLabelSelector(&labelSelector), }) @@ -66,9 +65,9 @@ func (p *KubernetesProvider) DeploymentGroups(ctx context.Context) (map[string][ groups := make(map[string][]string) for _, deployment := range deployments.Items { - groupName := deployment.Labels[discovery.LabelGroup] + groupName := deployment.Labels["sablier.group"] if len(groupName) == 0 { - groupName = discovery.LabelGroupDefaultValue + groupName = "default" } group := groups[groupName] diff --git a/pkg/provider/kubernetes/statefulset_list.go b/pkg/provider/kubernetes/statefulset_list.go index 1fa1620..801b274 100644 --- a/pkg/provider/kubernetes/statefulset_list.go +++ b/pkg/provider/kubernetes/statefulset_list.go @@ -2,20 +2,19 @@ package kubernetes import ( "context" - "github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/pkg/sablier" v1 "k8s.io/api/apps/v1" - core_v1 "k8s.io/api/core/v1" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) { labelSelector := metav1.LabelSelector{ MatchLabels: map[string]string{ - discovery.LabelEnable: "true", + "sablier.enable": "true", }, } - statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + statefulSets, err := p.Client.AppsV1().StatefulSets(corev1.NamespaceAll).List(ctx, metav1.ListOptions{ LabelSelector: metav1.FormatLabelSelector(&labelSelector), }) if err != nil { @@ -34,11 +33,11 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.Ins func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration { var group string - if _, ok := ss.Labels[discovery.LabelEnable]; ok { - if g, ok := ss.Labels[discovery.LabelGroup]; ok { + if _, ok := ss.Labels["sablier.enable"]; ok { + if g, ok := ss.Labels["sablier.group"]; ok { group = g } else { - group = discovery.LabelGroupDefaultValue + group = "default" } } @@ -53,10 +52,10 @@ func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.I func (p *KubernetesProvider) StatefulSetGroups(ctx context.Context) (map[string][]string, error) { labelSelector := metav1.LabelSelector{ MatchLabels: map[string]string{ - discovery.LabelEnable: "true", + "sablier.enable": "true", }, } - statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{ + statefulSets, err := p.Client.AppsV1().StatefulSets(corev1.NamespaceAll).List(ctx, metav1.ListOptions{ LabelSelector: metav1.FormatLabelSelector(&labelSelector), }) if err != nil { @@ -65,9 +64,9 @@ func (p *KubernetesProvider) StatefulSetGroups(ctx context.Context) (map[string] groups := make(map[string][]string) for _, ss := range statefulSets.Items { - groupName := ss.Labels[discovery.LabelGroup] + groupName := ss.Labels["sablier.group"] if len(groupName) == 0 { - groupName = discovery.LabelGroupDefaultValue + groupName = "default" } group := groups[groupName] diff --git a/pkg/provider/types.go b/pkg/provider/types.go index 2964369..05ce74b 100644 --- a/pkg/provider/types.go +++ b/pkg/provider/types.go @@ -1,6 +1,5 @@ package provider type InstanceListOptions struct { - All bool - Labels []string + All bool } diff --git a/app/discovery/autostop.go b/pkg/sablier/autostop.go similarity index 50% rename from app/discovery/autostop.go rename to pkg/sablier/autostop.go index be8fe05..56e3b2e 100644 --- a/app/discovery/autostop.go +++ b/pkg/sablier/autostop.go @@ -1,10 +1,9 @@ -package discovery +package sablier import ( "context" "errors" "github.com/sablierapp/sablier/pkg/provider" - "github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/store" "golang.org/x/sync/errgroup" "log/slog" @@ -14,10 +13,9 @@ import ( // as running instances by Sablier. // By default, Sablier does not stop all already running instances. Meaning that you need to make an // initial request in order to trigger the scaling to zero. -func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sablier.Store, logger *slog.Logger) error { - instances, err := p.InstanceList(ctx, provider.InstanceListOptions{ - All: false, // Only running containers - Labels: []string{LabelEnable}, +func (s *sablier) StopAllUnregisteredInstances(ctx context.Context) error { + instances, err := s.provider.InstanceList(ctx, provider.InstanceListOptions{ + All: false, // Only running instances }) if err != nil { return err @@ -25,31 +23,31 @@ func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sab unregistered := make([]string, 0) for _, instance := range instances { - _, err = s.Get(ctx, instance.Name) + _, err = s.sessions.Get(ctx, instance.Name) if errors.Is(err, store.ErrKeyNotFound) { unregistered = append(unregistered, instance.Name) } } - logger.DebugContext(ctx, "found instances to stop", slog.Any("instances", unregistered)) + s.l.DebugContext(ctx, "found instances to stop", slog.Any("instances", unregistered)) waitGroup := errgroup.Group{} for _, name := range unregistered { - waitGroup.Go(stopFunc(ctx, name, p, logger)) + waitGroup.Go(s.stopFunc(ctx, name)) } return waitGroup.Wait() } -func stopFunc(ctx context.Context, name string, p sablier.Provider, logger *slog.Logger) func() error { +func (s *sablier) stopFunc(ctx context.Context, name string) func() error { return func() error { - err := p.InstanceStop(ctx, name) + err := s.provider.InstanceStop(ctx, name) if err != nil { - logger.ErrorContext(ctx, "failed to stop instance", slog.String("instance", name), slog.Any("error", err)) + s.l.ErrorContext(ctx, "failed to stop instance", slog.String("instance", name), slog.Any("error", err)) return err } - logger.InfoContext(ctx, "stopped unregistered instance", slog.String("instance", name), slog.String("reason", "instance is enabled but not started by Sablier")) + s.l.InfoContext(ctx, "stopped unregistered instance", slog.String("instance", name), slog.String("reason", "instance is enabled but not started by Sablier")) return nil } } diff --git a/pkg/sablier/autostop_test.go b/pkg/sablier/autostop_test.go new file mode 100644 index 0000000..896c245 --- /dev/null +++ b/pkg/sablier/autostop_test.go @@ -0,0 +1,69 @@ +package sablier_test + +import ( + "errors" + "github.com/sablierapp/sablier/pkg/provider" + "github.com/sablierapp/sablier/pkg/sablier" + "github.com/sablierapp/sablier/pkg/store" + "gotest.tools/v3/assert" + "testing" +) + +func TestStopAllUnregisteredInstances(t *testing.T) { + s, sessions, p := setupSablier(t) + + ctx := t.Context() + + // Define instances and registered instances + instances := []sablier.InstanceConfiguration{ + {Name: "instance1"}, + {Name: "instance2"}, + } + + sessions.EXPECT().Get(ctx, "instance1").Return(sablier.InstanceInfo{}, store.ErrKeyNotFound) + sessions.EXPECT().Get(ctx, "instance2").Return(sablier.InstanceInfo{ + Name: "instance2", + Status: sablier.InstanceStatusReady, + }, nil) + + // Set up expectations for InstanceList + p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{ + All: false, + }).Return(instances, nil) + + // Set up expectations for InstanceStop + p.EXPECT().InstanceStop(ctx, "instance1").Return(nil) + + // Call the function under test + err := s.StopAllUnregisteredInstances(ctx) + assert.NilError(t, err) +} + +func TestStopAllUnregisteredInstances_WithError(t *testing.T) { + s, sessions, p := setupSablier(t) + ctx := t.Context() + + // Define instances and registered instances + instances := []sablier.InstanceConfiguration{ + {Name: "instance1"}, + {Name: "instance2"}, + } + + sessions.EXPECT().Get(ctx, "instance1").Return(sablier.InstanceInfo{}, store.ErrKeyNotFound) + sessions.EXPECT().Get(ctx, "instance2").Return(sablier.InstanceInfo{ + Name: "instance2", + Status: sablier.InstanceStatusReady, + }, nil) + + // Set up expectations for InstanceList + p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{ + All: false, + }).Return(instances, nil) + + // Set up expectations for InstanceStop with error + p.EXPECT().InstanceStop(ctx, "instance1").Return(errors.New("stop error")) + + // Call the function under test + err := s.StopAllUnregisteredInstances(ctx) + assert.Error(t, err, "stop error") +} diff --git a/pkg/sablier/sablier.go b/pkg/sablier/sablier.go index a5c83dc..04e866f 100644 --- a/pkg/sablier/sablier.go +++ b/pkg/sablier/sablier.go @@ -17,6 +17,7 @@ type Sablier interface { RemoveInstance(ctx context.Context, name string) error SetGroups(groups map[string][]string) + StopAllUnregisteredInstances(ctx context.Context) error } type sablier struct { diff --git a/pkg/sablier/sablier_test.go b/pkg/sablier/sablier_test.go new file mode 100644 index 0000000..30b8725 --- /dev/null +++ b/pkg/sablier/sablier_test.go @@ -0,0 +1,21 @@ +package sablier_test + +import ( + "github.com/neilotoole/slogt" + "github.com/sablierapp/sablier/pkg/provider/providertest" + "github.com/sablierapp/sablier/pkg/sablier" + "github.com/sablierapp/sablier/pkg/store/storetest" + "go.uber.org/mock/gomock" + "testing" +) + +func setupSablier(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) { + t.Helper() + ctrl := gomock.NewController(t) + + p := providertest.NewMockProvider(ctrl) + s := storetest.NewMockStore(ctrl) + + m := sablier.New(slogt.New(t), s, p) + return m, s, p +} diff --git a/pkg/sablier/sabliertest/mocks_sablier.go b/pkg/sablier/sabliertest/mocks_sablier.go index 6dedac3..ac06a81 100644 --- a/pkg/sablier/sabliertest/mocks_sablier.go +++ b/pkg/sablier/sabliertest/mocks_sablier.go @@ -127,3 +127,17 @@ func (mr *MockSablierMockRecorder) SetGroups(groups any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockSablier)(nil).SetGroups), groups) } + +// StopAllUnregisteredInstances mocks base method. +func (m *MockSablier) StopAllUnregisteredInstances(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopAllUnregisteredInstances", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopAllUnregisteredInstances indicates an expected call of StopAllUnregisteredInstances. +func (mr *MockSablierMockRecorder) StopAllUnregisteredInstances(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopAllUnregisteredInstances", reflect.TypeOf((*MockSablier)(nil).StopAllUnregisteredInstances), ctx) +} diff --git a/pkg/sablier/session_request_test.go b/pkg/sablier/session_request_test.go index 2327b09..0e55d9e 100644 --- a/pkg/sablier/session_request_test.go +++ b/pkg/sablier/session_request_test.go @@ -2,10 +2,7 @@ package sablier_test import ( "context" - "github.com/neilotoole/slogt" - "github.com/sablierapp/sablier/pkg/provider/providertest" "github.com/sablierapp/sablier/pkg/sablier" - "github.com/sablierapp/sablier/pkg/store/storetest" "go.uber.org/mock/gomock" "testing" "time" @@ -86,20 +83,9 @@ func createMap(instances []sablier.InstanceInfo) map[string]sablier.InstanceInfo return states } -func setupSessionManager(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) { - t.Helper() - ctrl := gomock.NewController(t) - - p := providertest.NewMockProvider(ctrl) - s := storetest.NewMockStore(ctrl) - - m := sablier.New(slogt.New(t), s, p) - return m, s, p -} - func TestSessionsManager(t *testing.T) { t.Run("RemoveInstance", func(t *testing.T) { - manager, store, _ := setupSessionManager(t) + manager, store, _ := setupSablier(t) store.EXPECT().Delete(gomock.Any(), "test") err := manager.RemoveInstance(t.Context(), "test") assert.NilError(t, err) @@ -109,7 +95,7 @@ func TestSessionsManager(t *testing.T) { func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) { t.Run("request ready session is cancelled by user", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - manager, store, provider := setupSessionManager(t) + manager, store, provider := setupSablier(t) store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -131,7 +117,7 @@ func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) { func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) { t.Run("request ready session is cancelled by timeout", func(t *testing.T) { - manager, store, provider := setupSessionManager(t) + manager, store, provider := setupSablier(t) store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -150,7 +136,7 @@ func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) { func TestSessionsManager_RequestReadySession(t *testing.T) { t.Run("request ready session is ready", func(t *testing.T) { - manager, store, _ := setupSessionManager(t) + manager, store, _ := setupSablier(t) store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusReady}, nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()