refactor: remove discovery package (#553)

This commit is contained in:
Alexis Couvreur
2025-03-09 01:29:06 -05:00
committed by GitHub
parent 8e5d5758a9
commit b72c37a85a
17 changed files with 163 additions and 226 deletions

View File

@@ -1,77 +0,0 @@
package discovery_test
import (
"errors"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory"
gomock "go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
"testing"
"time"
)
func TestStopAllUnregisteredInstances(t *testing.T) {
ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl)
ctx := t.Context()
// Define instances and registered instances
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
{Name: "instance3"},
}
store := inmemory.NewInMemory()
err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err)
// Set up expectations for InstanceList
p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{
All: false,
Labels: []string{discovery.LabelEnable},
}).Return(instances, nil)
// Set up expectations for InstanceStop
p.EXPECT().InstanceStop(ctx, "instance2").Return(nil)
p.EXPECT().InstanceStop(ctx, "instance3").Return(nil)
// Call the function under test
err = discovery.StopAllUnregisteredInstances(ctx, p, store, slogt.New(t))
assert.NilError(t, err)
}
func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl)
ctx := t.Context()
// Define instances and registered instances
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
{Name: "instance3"},
}
store := inmemory.NewInMemory()
err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err)
// Set up expectations for InstanceList
p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{
All: false,
Labels: []string{discovery.LabelEnable},
}).Return(instances, nil)
// Set up expectations for InstanceStop with error
p.EXPECT().InstanceStop(ctx, "instance2").Return(errors.New("stop error"))
p.EXPECT().InstanceStop(ctx, "instance3").Return(nil)
// Call the function under test
err = discovery.StopAllUnregisteredInstances(ctx, p, store, slogt.New(t))
assert.Error(t, err, "stop error")
}

View File

@@ -1,16 +0,0 @@
package discovery
const (
LabelEnable = "sablier.enable"
LabelGroup = "sablier.group"
LabelGroupDefaultValue = "default"
)
type Group struct {
Name string
Instances []Instance
}
type Instance struct {
Name string
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"github.com/docker/docker/client"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
@@ -74,7 +73,7 @@ func Start(ctx context.Context, conf config.Config) error {
}()
if conf.Provider.AutoStopOnStartup {
err := discovery.StopAllUnregisteredInstances(ctx, provider, store, logger)
err := s.StopAllUnregisteredInstances(ctx)
if err != nil {
logger.ErrorContext(ctx, "unable to stop unregistered instances", slog.Any("reason", err))
}

View File

@@ -1,58 +0,0 @@
package storage
import (
"fmt"
"io"
"log/slog"
"os"
"github.com/sablierapp/sablier/config"
)
type Storage interface {
Reader() (io.ReadCloser, error)
Writer() (io.WriteCloser, error)
}
type FileStorage struct {
file string
l *slog.Logger
}
func NewFileStorage(config config.Storage, logger *slog.Logger) (Storage, error) {
logger = logger.With(slog.String("file", config.File))
storage := &FileStorage{
file: config.File,
}
file, err := os.OpenFile(config.File, os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
return nil, fmt.Errorf("unable to open file: %w", err)
}
defer file.Close()
stats, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("unable to read file info: %w", err)
}
// Initialize file to an empty JSON3
if stats.Size() == 0 {
_, err := file.WriteString("{}")
if err != nil {
return nil, fmt.Errorf("unable to initialize file to valid json: %w", err)
}
}
logger.Info("storage successfully initialized")
return storage, nil
}
func (fs *FileStorage) Reader() (io.ReadCloser, error) {
return os.OpenFile(fs.file, os.O_RDWR|os.O_CREATE, 0755)
}
func (fs *FileStorage) Writer() (io.WriteCloser, error) {
return os.OpenFile(fs.file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755)
}

1
go.mod
View File

@@ -35,6 +35,7 @@ require (
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/bytedance/sonic v1.12.8 // indirect

8
go.sum
View File

@@ -6,8 +6,8 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/acouvreur/httpexpect/v2 v2.16.0 h1:FGXaR9jt6IQMXxpqbM8YpX7EEvyERU0Lps3ooEc/gk8=
github.com/acouvreur/httpexpect/v2 v2.16.0/go.mod h1:7myOP3A3VyS4+qnA4cm8DAad8zMN+7zxDB80W9f8yIc=
github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 h1:ZBbLwSJqkHBuFDA6DUhhse0IGJ7T5bemHyNILUjvOq4=
github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2/go.mod h1:VSw57q4QFiWDbRnjdX8Cb3Ow0SFncRw+bA/ofY6Q83w=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
@@ -61,6 +61,8 @@ github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gavv/httpexpect/v2 v2.17.0 h1:nIJqt5v5e4P7/0jODpX2gtSw+pHXUqdP28YcjqwDZmE=
github.com/gavv/httpexpect/v2 v2.17.0/go.mod h1:E8ENFlT9MZ3Si2sfM6c6ONdwXV2noBCGkhA+lkJgkP0=
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
@@ -121,6 +123,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjw
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f h1:7LYC+Yfkj3CTRcShK0KOL/w6iTiKyqqBA9a41Wnggw8=
github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f/go.mod h1:pFlLw2CfqZiIBOx6BuCeRLCrfxBJipTY0nIOF/VbGcI=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=

View File

@@ -6,7 +6,6 @@ import (
dockertypes "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"strings"
@@ -14,7 +13,7 @@ import (
func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
containers, err := p.Client.ContainerList(ctx, container.ListOptions{
All: options.All,
@@ -36,11 +35,11 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid
func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration {
var group string
if _, ok := c.Labels[discovery.LabelEnable]; ok {
if g, ok := c.Labels[discovery.LabelGroup]; ok {
if _, ok := c.Labels["sablier.enable"]; ok {
if g, ok := c.Labels["sablier.group"]; ok {
group = g
} else {
group = discovery.LabelGroupDefaultValue
group = "default"
}
}
@@ -52,7 +51,7 @@ func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration
func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
containers, err := p.Client.ContainerList(ctx, container.ListOptions{
All: true,
@@ -65,9 +64,9 @@ func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string]
groups := make(map[string][]string)
for _, c := range containers {
groupName := c.Labels[discovery.LabelGroup]
groupName := c.Labels["sablier.group"]
if len(groupName) == 0 {
groupName = discovery.LabelGroupDefaultValue
groupName = "default"
}
group := groups[groupName]
group = append(group, strings.TrimPrefix(c.Names[0], "/"))

View File

@@ -6,14 +6,13 @@ import (
dockertypes "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
)
func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
args.Add("mode", "replicated")
services, err := p.Client.ServiceList(ctx, dockertypes.ServiceListOptions{
@@ -36,11 +35,11 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) {
var group string
if _, ok := s.Spec.Labels[discovery.LabelEnable]; ok {
if g, ok := s.Spec.Labels[discovery.LabelGroup]; ok {
if _, ok := s.Spec.Labels["sablier.enable"]; ok {
if g, ok := s.Spec.Labels["sablier.group"]; ok {
group = g
} else {
group = discovery.LabelGroupDefaultValue
group = "default"
}
}
@@ -52,7 +51,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.Inst
func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) {
f := filters.NewArgs()
f.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
f.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
services, err := p.Client.ServiceList(ctx, dockertypes.ServiceListOptions{
Filters: f,
@@ -64,9 +63,9 @@ func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][]
groups := make(map[string][]string)
for _, service := range services {
groupName := service.Spec.Labels[discovery.LabelGroup]
groupName := service.Spec.Labels["sablier.group"]
if len(groupName) == 0 {
groupName = discovery.LabelGroupDefaultValue
groupName = "default"
}
group := groups[groupName]

View File

@@ -2,20 +2,19 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
"sablier.enable": "true",
},
}
deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
deployments, err := p.Client.AppsV1().Deployments(corev1.NamespaceAll).List(ctx, metav1.ListOptions{
LabelSelector: metav1.FormatLabelSelector(&labelSelector),
})
if err != nil {
@@ -34,11 +33,11 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.Inst
func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration {
var group string
if _, ok := d.Labels[discovery.LabelEnable]; ok {
if g, ok := d.Labels[discovery.LabelGroup]; ok {
if _, ok := d.Labels["sablier.enable"]; ok {
if g, ok := d.Labels["sablier.group"]; ok {
group = g
} else {
group = discovery.LabelGroupDefaultValue
group = "default"
}
}
@@ -53,10 +52,10 @@ func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.Inst
func (p *KubernetesProvider) DeploymentGroups(ctx context.Context) (map[string][]string, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
"sablier.enable": "true",
},
}
deployments, err := p.Client.AppsV1().Deployments(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
deployments, err := p.Client.AppsV1().Deployments(corev1.NamespaceAll).List(ctx, metav1.ListOptions{
LabelSelector: metav1.FormatLabelSelector(&labelSelector),
})
@@ -66,9 +65,9 @@ func (p *KubernetesProvider) DeploymentGroups(ctx context.Context) (map[string][
groups := make(map[string][]string)
for _, deployment := range deployments.Items {
groupName := deployment.Labels[discovery.LabelGroup]
groupName := deployment.Labels["sablier.group"]
if len(groupName) == 0 {
groupName = discovery.LabelGroupDefaultValue
groupName = "default"
}
group := groups[groupName]

View File

@@ -2,20 +2,19 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
"sablier.enable": "true",
},
}
statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
statefulSets, err := p.Client.AppsV1().StatefulSets(corev1.NamespaceAll).List(ctx, metav1.ListOptions{
LabelSelector: metav1.FormatLabelSelector(&labelSelector),
})
if err != nil {
@@ -34,11 +33,11 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.Ins
func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration {
var group string
if _, ok := ss.Labels[discovery.LabelEnable]; ok {
if g, ok := ss.Labels[discovery.LabelGroup]; ok {
if _, ok := ss.Labels["sablier.enable"]; ok {
if g, ok := ss.Labels["sablier.group"]; ok {
group = g
} else {
group = discovery.LabelGroupDefaultValue
group = "default"
}
}
@@ -53,10 +52,10 @@ func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.I
func (p *KubernetesProvider) StatefulSetGroups(ctx context.Context) (map[string][]string, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
"sablier.enable": "true",
},
}
statefulSets, err := p.Client.AppsV1().StatefulSets(core_v1.NamespaceAll).List(ctx, metav1.ListOptions{
statefulSets, err := p.Client.AppsV1().StatefulSets(corev1.NamespaceAll).List(ctx, metav1.ListOptions{
LabelSelector: metav1.FormatLabelSelector(&labelSelector),
})
if err != nil {
@@ -65,9 +64,9 @@ func (p *KubernetesProvider) StatefulSetGroups(ctx context.Context) (map[string]
groups := make(map[string][]string)
for _, ss := range statefulSets.Items {
groupName := ss.Labels[discovery.LabelGroup]
groupName := ss.Labels["sablier.group"]
if len(groupName) == 0 {
groupName = discovery.LabelGroupDefaultValue
groupName = "default"
}
group := groups[groupName]

View File

@@ -2,5 +2,4 @@ package provider
type InstanceListOptions struct {
All bool
Labels []string
}

View File

@@ -1,10 +1,9 @@
package discovery
package sablier
import (
"context"
"errors"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"golang.org/x/sync/errgroup"
"log/slog"
@@ -14,10 +13,9 @@ import (
// as running instances by Sablier.
// By default, Sablier does not stop all already running instances. Meaning that you need to make an
// initial request in order to trigger the scaling to zero.
func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sablier.Store, logger *slog.Logger) error {
instances, err := p.InstanceList(ctx, provider.InstanceListOptions{
All: false, // Only running containers
Labels: []string{LabelEnable},
func (s *sablier) StopAllUnregisteredInstances(ctx context.Context) error {
instances, err := s.provider.InstanceList(ctx, provider.InstanceListOptions{
All: false, // Only running instances
})
if err != nil {
return err
@@ -25,31 +23,31 @@ func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sab
unregistered := make([]string, 0)
for _, instance := range instances {
_, err = s.Get(ctx, instance.Name)
_, err = s.sessions.Get(ctx, instance.Name)
if errors.Is(err, store.ErrKeyNotFound) {
unregistered = append(unregistered, instance.Name)
}
}
logger.DebugContext(ctx, "found instances to stop", slog.Any("instances", unregistered))
s.l.DebugContext(ctx, "found instances to stop", slog.Any("instances", unregistered))
waitGroup := errgroup.Group{}
for _, name := range unregistered {
waitGroup.Go(stopFunc(ctx, name, p, logger))
waitGroup.Go(s.stopFunc(ctx, name))
}
return waitGroup.Wait()
}
func stopFunc(ctx context.Context, name string, p sablier.Provider, logger *slog.Logger) func() error {
func (s *sablier) stopFunc(ctx context.Context, name string) func() error {
return func() error {
err := p.InstanceStop(ctx, name)
err := s.provider.InstanceStop(ctx, name)
if err != nil {
logger.ErrorContext(ctx, "failed to stop instance", slog.String("instance", name), slog.Any("error", err))
s.l.ErrorContext(ctx, "failed to stop instance", slog.String("instance", name), slog.Any("error", err))
return err
}
logger.InfoContext(ctx, "stopped unregistered instance", slog.String("instance", name), slog.String("reason", "instance is enabled but not started by Sablier"))
s.l.InfoContext(ctx, "stopped unregistered instance", slog.String("instance", name), slog.String("reason", "instance is enabled but not started by Sablier"))
return nil
}
}

View File

@@ -0,0 +1,69 @@
package sablier_test
import (
"errors"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"gotest.tools/v3/assert"
"testing"
)
func TestStopAllUnregisteredInstances(t *testing.T) {
s, sessions, p := setupSablier(t)
ctx := t.Context()
// Define instances and registered instances
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
}
sessions.EXPECT().Get(ctx, "instance1").Return(sablier.InstanceInfo{}, store.ErrKeyNotFound)
sessions.EXPECT().Get(ctx, "instance2").Return(sablier.InstanceInfo{
Name: "instance2",
Status: sablier.InstanceStatusReady,
}, nil)
// Set up expectations for InstanceList
p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{
All: false,
}).Return(instances, nil)
// Set up expectations for InstanceStop
p.EXPECT().InstanceStop(ctx, "instance1").Return(nil)
// Call the function under test
err := s.StopAllUnregisteredInstances(ctx)
assert.NilError(t, err)
}
func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
s, sessions, p := setupSablier(t)
ctx := t.Context()
// Define instances and registered instances
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
}
sessions.EXPECT().Get(ctx, "instance1").Return(sablier.InstanceInfo{}, store.ErrKeyNotFound)
sessions.EXPECT().Get(ctx, "instance2").Return(sablier.InstanceInfo{
Name: "instance2",
Status: sablier.InstanceStatusReady,
}, nil)
// Set up expectations for InstanceList
p.EXPECT().InstanceList(ctx, provider.InstanceListOptions{
All: false,
}).Return(instances, nil)
// Set up expectations for InstanceStop with error
p.EXPECT().InstanceStop(ctx, "instance1").Return(errors.New("stop error"))
// Call the function under test
err := s.StopAllUnregisteredInstances(ctx)
assert.Error(t, err, "stop error")
}

View File

@@ -17,6 +17,7 @@ type Sablier interface {
RemoveInstance(ctx context.Context, name string) error
SetGroups(groups map[string][]string)
StopAllUnregisteredInstances(ctx context.Context) error
}
type sablier struct {

View File

@@ -0,0 +1,21 @@
package sablier_test
import (
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/storetest"
"go.uber.org/mock/gomock"
"testing"
)
func setupSablier(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) {
t.Helper()
ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl)
s := storetest.NewMockStore(ctrl)
m := sablier.New(slogt.New(t), s, p)
return m, s, p
}

View File

@@ -127,3 +127,17 @@ func (mr *MockSablierMockRecorder) SetGroups(groups any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockSablier)(nil).SetGroups), groups)
}
// StopAllUnregisteredInstances mocks base method.
func (m *MockSablier) StopAllUnregisteredInstances(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopAllUnregisteredInstances", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// StopAllUnregisteredInstances indicates an expected call of StopAllUnregisteredInstances.
func (mr *MockSablierMockRecorder) StopAllUnregisteredInstances(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopAllUnregisteredInstances", reflect.TypeOf((*MockSablier)(nil).StopAllUnregisteredInstances), ctx)
}

View File

@@ -2,10 +2,7 @@ package sablier_test
import (
"context"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/storetest"
"go.uber.org/mock/gomock"
"testing"
"time"
@@ -86,20 +83,9 @@ func createMap(instances []sablier.InstanceInfo) map[string]sablier.InstanceInfo
return states
}
func setupSessionManager(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) {
t.Helper()
ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl)
s := storetest.NewMockStore(ctrl)
m := sablier.New(slogt.New(t), s, p)
return m, s, p
}
func TestSessionsManager(t *testing.T) {
t.Run("RemoveInstance", func(t *testing.T) {
manager, store, _ := setupSessionManager(t)
manager, store, _ := setupSablier(t)
store.EXPECT().Delete(gomock.Any(), "test")
err := manager.RemoveInstance(t.Context(), "test")
assert.NilError(t, err)
@@ -109,7 +95,7 @@ func TestSessionsManager(t *testing.T) {
func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) {
t.Run("request ready session is cancelled by user", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
manager, store, provider := setupSessionManager(t)
manager, store, provider := setupSablier(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
@@ -131,7 +117,7 @@ func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) {
func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) {
t.Run("request ready session is cancelled by timeout", func(t *testing.T) {
manager, store, provider := setupSessionManager(t)
manager, store, provider := setupSablier(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
@@ -150,7 +136,7 @@ func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) {
func TestSessionsManager_RequestReadySession(t *testing.T) {
t.Run("request ready session is ready", func(t *testing.T) {
manager, store, _ := setupSessionManager(t)
manager, store, _ := setupSablier(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()