refactor: remove session manager

The session manager is now simply Sablier
This commit is contained in:
Alexis Couvreur
2025-03-08 12:17:12 -05:00
parent 1298d86c61
commit ce7de13ade
55 changed files with 745 additions and 832 deletions

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store" "github.com/sablierapp/sablier/pkg/store"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"log/slog" "log/slog"
@@ -13,7 +14,7 @@ import (
// as running instances by Sablier. // as running instances by Sablier.
// By default, Sablier does not stop all already running instances. Meaning that you need to make an // By default, Sablier does not stop all already running instances. Meaning that you need to make an
// initial request in order to trigger the scaling to zero. // initial request in order to trigger the scaling to zero.
func StopAllUnregisteredInstances(ctx context.Context, p provider.Provider, s store.Store, logger *slog.Logger) error { func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sablier.Store, logger *slog.Logger) error {
instances, err := p.InstanceList(ctx, provider.InstanceListOptions{ instances, err := p.InstanceList(ctx, provider.InstanceListOptions{
All: false, // Only running containers All: false, // Only running containers
Labels: []string{LabelEnable}, Labels: []string{LabelEnable},
@@ -41,7 +42,7 @@ func StopAllUnregisteredInstances(ctx context.Context, p provider.Provider, s st
return waitGroup.Wait() return waitGroup.Wait()
} }
func stopFunc(ctx context.Context, name string, p provider.Provider, logger *slog.Logger) func() error { func stopFunc(ctx context.Context, name string, p sablier.Provider, logger *slog.Logger) func() error {
return func() error { return func() error {
err := p.InstanceStop(ctx, name) err := p.InstanceStop(ctx, name)
if err != nil { if err != nil {

View File

@@ -4,10 +4,9 @@ import (
"errors" "errors"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/providertest" "github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory" "github.com/sablierapp/sablier/pkg/store/inmemory"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -22,13 +21,13 @@ func TestStopAllUnregisteredInstances(t *testing.T) {
ctx := t.Context() ctx := t.Context()
// Define instances and registered instances // Define instances and registered instances
instances := []types.Instance{ instances := []sablier.InstanceConfiguration{
{Name: "instance1"}, {Name: "instance1"},
{Name: "instance2"}, {Name: "instance2"},
{Name: "instance3"}, {Name: "instance3"},
} }
store := inmemory.NewInMemory() store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute) err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err) assert.NilError(t, err)
// Set up expectations for InstanceList // Set up expectations for InstanceList
@@ -53,13 +52,13 @@ func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
ctx := t.Context() ctx := t.Context()
// Define instances and registered instances // Define instances and registered instances
instances := []types.Instance{ instances := []sablier.InstanceConfiguration{
{Name: "instance1"}, {Name: "instance1"},
{Name: "instance2"}, {Name: "instance2"},
{Name: "instance3"}, {Name: "instance3"},
} }
store := inmemory.NewInMemory() store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute) err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err) assert.NilError(t, err)
// Set up expectations for InstanceList // Set up expectations for InstanceList

View File

@@ -1,15 +1,15 @@
package routes package routes
import ( import (
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
) )
type ServeStrategy struct { type ServeStrategy struct {
Theme *theme.Themes Theme *theme.Themes
SessionsManager sessions.Manager SessionsManager sablier.Sablier
StrategyConfig config.Strategy StrategyConfig config.Strategy
SessionsConfig config.Sessions SessionsConfig config.Sessions
} }

View File

@@ -1,55 +0,0 @@
package instance
var Ready = "ready"
var NotReady = "not-ready"
var Unrecoverable = "unrecoverable"
type State struct {
Name string `json:"name"`
CurrentReplicas int32 `json:"currentReplicas"`
DesiredReplicas int32 `json:"desiredReplicas"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
func (instance State) IsReady() bool {
return instance.Status == Ready
}
func ErrorInstanceState(name string, err error, desiredReplicas int32) (State, error) {
return State{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: Unrecoverable,
Message: err.Error(),
}, err
}
func UnrecoverableInstanceState(name string, message string, desiredReplicas int32) State {
return State{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: Unrecoverable,
Message: message,
}
}
func ReadyInstanceState(name string, replicas int32) State {
return State{
Name: name,
CurrentReplicas: replicas,
DesiredReplicas: replicas,
Status: Ready,
}
}
func NotReadyInstanceState(name string, currentReplicas int32, desiredReplicas int32) State {
return State{
Name: name,
CurrentReplicas: currentReplicas,
DesiredReplicas: desiredReplicas,
Status: NotReady,
}
}

View File

@@ -6,10 +6,10 @@ import (
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/docker" "github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm" "github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory" "github.com/sablierapp/sablier/pkg/store/inmemory"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
k8s "k8s.io/client-go/kubernetes" k8s "k8s.io/client-go/kubernetes"
@@ -21,8 +21,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/app/storage"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/internal/server" "github.com/sablierapp/sablier/internal/server"
"github.com/sablierapp/sablier/version" "github.com/sablierapp/sablier/version"
@@ -47,30 +45,20 @@ func Start(ctx context.Context, conf config.Config) error {
return err return err
} }
sessionsManager := sessions.NewSessionsManager(logger, store, provider) s := sablier.New(logger, store, provider)
if conf.Storage.File != "" {
storage, err := storage.NewFileStorage(conf.Storage, logger)
if err != nil {
return err
}
defer saveSessions(storage, sessionsManager, logger)
loadSessions(storage, sessionsManager, logger)
}
groups, err := provider.InstanceGroups(ctx) groups, err := provider.InstanceGroups(ctx)
if err != nil { if err != nil {
logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err)) logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err))
} else { } else {
sessionsManager.SetGroups(groups) s.SetGroups(groups)
} }
updateGroups := make(chan map[string][]string) updateGroups := make(chan map[string][]string)
go WatchGroups(ctx, provider, 2*time.Second, updateGroups, logger) go WatchGroups(ctx, provider, 2*time.Second, updateGroups, logger)
go func() { go func() {
for groups := range updateGroups { for groups := range updateGroups {
sessionsManager.SetGroups(groups) s.SetGroups(groups)
} }
}() }()
@@ -78,7 +66,7 @@ func Start(ctx context.Context, conf config.Config) error {
go provider.NotifyInstanceStopped(ctx, instanceStopped) go provider.NotifyInstanceStopped(ctx, instanceStopped)
go func() { go func() {
for stopped := range instanceStopped { for stopped := range instanceStopped {
err := sessionsManager.RemoveInstance(stopped) err := s.RemoveInstance(ctx, stopped)
if err != nil { if err != nil {
logger.Warn("could not remove instance", slog.Any("error", err)) logger.Warn("could not remove instance", slog.Any("error", err))
} }
@@ -111,7 +99,7 @@ func Start(ctx context.Context, conf config.Config) error {
strategy := &routes.ServeStrategy{ strategy := &routes.ServeStrategy{
Theme: t, Theme: t,
SessionsManager: sessionsManager, SessionsManager: s,
StrategyConfig: conf.Strategy, StrategyConfig: conf.Strategy,
SessionsConfig: conf.Sessions, SessionsConfig: conf.Sessions,
} }
@@ -132,7 +120,7 @@ func Start(ctx context.Context, conf config.Config) error {
return nil return nil
} }
func onSessionExpires(ctx context.Context, provider provider.Provider, logger *slog.Logger) func(key string) { func onSessionExpires(ctx context.Context, provider sablier.Provider, logger *slog.Logger) func(key string) {
return func(_key string) { return func(_key string) {
go func(key string) { go func(key string) {
logger.InfoContext(ctx, "instance expired", slog.String("instance", key)) logger.InfoContext(ctx, "instance expired", slog.String("instance", key))
@@ -144,32 +132,7 @@ func onSessionExpires(ctx context.Context, provider provider.Provider, logger *s
} }
} }
func loadSessions(storage storage.Storage, sessions sessions.Manager, logger *slog.Logger) { func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (sablier.Provider, error) {
logger.Info("loading sessions from storage")
reader, err := storage.Reader()
if err != nil {
logger.Error("error loading sessions from storage", slog.Any("reason", err))
}
err = sessions.LoadSessions(reader)
if err != nil {
logger.Error("error loading sessions into Sablier", slog.Any("reason", err))
}
}
func saveSessions(storage storage.Storage, sessions sessions.Manager, logger *slog.Logger) {
logger.Info("writing sessions to storage")
writer, err := storage.Writer()
if err != nil {
logger.Error("error saving sessions to storage", slog.Any("reason", err))
return
}
err = sessions.SaveSessions(writer)
if err != nil {
logger.Error("error saving sessions from Sablier", slog.Any("reason", err))
}
}
func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (provider.Provider, error) {
if err := config.IsValid(); err != nil { if err := config.IsValid(); err != nil {
return nil, err return nil, err
} }
@@ -204,7 +167,7 @@ func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provide
return nil, fmt.Errorf("unimplemented provider %s", config.Name) return nil, fmt.Errorf("unimplemented provider %s", config.Name)
} }
func WatchGroups(ctx context.Context, provider provider.Provider, frequency time.Duration, send chan<- map[string][]string, logger *slog.Logger) { func WatchGroups(ctx context.Context, provider sablier.Provider, frequency time.Duration, send chan<- map[string][]string, logger *slog.Logger) {
ticker := time.NewTicker(frequency) ticker := time.NewTicker(frequency)
for { for {
select { select {

View File

@@ -1,301 +0,0 @@
package sessions
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/store"
"io"
"log/slog"
"maps"
"slices"
"sync"
"time"
"github.com/sablierapp/sablier/app/instance"
)
//go:generate go tool mockgen -package sessionstest -source=sessions_manager.go -destination=sessionstest/mocks_sessions_manager.go *
type Manager interface {
RequestSession(ctx context.Context, names []string, duration time.Duration) (*SessionState, error)
RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*SessionState, error)
RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (*SessionState, error)
LoadSessions(io.ReadCloser) error
SaveSessions(io.WriteCloser) error
RemoveInstance(name string) error
SetGroups(groups map[string][]string)
}
type SessionsManager struct {
store store.Store
provider provider.Provider
groups map[string][]string
l *slog.Logger
}
func NewSessionsManager(logger *slog.Logger, store store.Store, provider provider.Provider) Manager {
sm := &SessionsManager{
store: store,
provider: provider,
groups: map[string][]string{},
l: logger,
}
return sm
}
func (s *SessionsManager) SetGroups(groups map[string][]string) {
if groups == nil {
groups = map[string][]string{}
}
if diff := cmp.Diff(s.groups, groups); diff != "" {
// TODO: Change this log for a friendly logging, groups rarely change, so we can put some effort on displaying what changed
s.l.Info("set groups", slog.Any("old", s.groups), slog.Any("new", groups), slog.Any("diff", diff))
s.groups = groups
}
}
func (s *SessionsManager) RemoveInstance(name string) error {
return s.store.Delete(context.Background(), name)
}
func (s *SessionsManager) LoadSessions(reader io.ReadCloser) error {
unmarshaler, ok := s.store.(json.Unmarshaler)
defer reader.Close()
if ok {
return json.NewDecoder(reader).Decode(unmarshaler)
}
return nil
}
func (s *SessionsManager) SaveSessions(writer io.WriteCloser) error {
marshaler, ok := s.store.(json.Marshaler)
defer writer.Close()
if ok {
encoder := json.NewEncoder(writer)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")
return encoder.Encode(marshaler)
}
return nil
}
type InstanceState struct {
Instance instance.State `json:"instance"`
Error error `json:"error"`
}
type SessionState struct {
Instances map[string]InstanceState `json:"instances"`
}
func (s *SessionState) IsReady() bool {
if s.Instances == nil {
s.Instances = map[string]InstanceState{}
}
for _, v := range s.Instances {
if v.Error != nil || v.Instance.Status != instance.Ready {
return false
}
}
return true
}
func (s *SessionState) Status() string {
if s.IsReady() {
return "ready"
}
return "not-ready"
}
func (s *SessionsManager) RequestSession(ctx context.Context, names []string, duration time.Duration) (sessionState *SessionState, err error) {
if len(names) == 0 {
return nil, fmt.Errorf("names cannot be empty")
}
var wg sync.WaitGroup
mx := sync.Mutex{}
sessionState = &SessionState{
Instances: map[string]InstanceState{},
}
wg.Add(len(names))
for i := 0; i < len(names); i++ {
go func(name string) {
defer wg.Done()
state, err := s.requestInstance(ctx, name, duration)
mx.Lock()
defer mx.Unlock()
sessionState.Instances[name] = InstanceState{
Instance: state,
Error: err,
}
}(names[i])
}
wg.Wait()
return sessionState, nil
}
func (s *SessionsManager) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestSession(ctx, names, duration)
}
func (s *SessionsManager) requestInstance(ctx context.Context, name string, duration time.Duration) (instance.State, error) {
if name == "" {
return instance.State{}, errors.New("instance name cannot be empty")
}
state, err := s.store.Get(ctx, name)
if errors.Is(err, store.ErrKeyNotFound) {
s.l.DebugContext(ctx, "request to start instance received", slog.String("instance", name))
err := s.provider.InstanceStart(ctx, name)
if err != nil {
return instance.State{}, err
}
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return instance.State{}, err
}
s.l.DebugContext(ctx, "request to start instance status completed", slog.String("instance", name), slog.String("status", state.Status))
} else if err != nil {
s.l.ErrorContext(ctx, "request to start instance failed", slog.String("instance", name), slog.Any("error", err))
return instance.State{}, fmt.Errorf("cannot retrieve instance from store: %w", err)
} else if state.Status != instance.Ready {
s.l.DebugContext(ctx, "request to check instance status received", slog.String("instance", name), slog.String("current_status", state.Status))
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return instance.State{}, err
}
s.l.DebugContext(ctx, "request to check instance status completed", slog.String("instance", name), slog.String("new_status", state.Status))
}
s.l.DebugContext(ctx, "set expiration for instance", slog.String("instance", name), slog.Duration("expiration", duration))
// Refresh the duration
s.expiresAfter(ctx, state, duration)
return state, nil
}
func (s *SessionsManager) RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error) {
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
return nil, err
}
if session.IsReady() {
return session, nil
}
ticker := time.NewTicker(5 * time.Second)
readiness := make(chan *SessionState)
errch := make(chan error)
quit := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
errch <- err
return
}
if session.IsReady() {
readiness <- session
}
case <-quit:
ticker.Stop()
return
}
}
}()
select {
case <-ctx.Done():
s.l.DebugContext(ctx, "request cancelled", slog.Any("reason", ctx.Err()))
close(quit)
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled by user: %w", ctx.Err())
}
return nil, fmt.Errorf("request cancelled by user")
case status := <-readiness:
close(quit)
return status, nil
case err := <-errch:
close(quit)
return nil, err
case <-time.After(timeout):
close(quit)
return nil, fmt.Errorf("session was not ready after %s", timeout.String())
}
}
func (s *SessionsManager) RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestReadySession(ctx, names, duration, timeout)
}
func (s *SessionsManager) expiresAfter(ctx context.Context, instance instance.State, duration time.Duration) {
err := s.store.Put(ctx, instance, duration)
if err != nil {
s.l.Error("could not put instance to store, will not expire", slog.Any("error", err), slog.String("instance", instance.Name))
}
}
func (s *SessionState) MarshalJSON() ([]byte, error) {
instances := maps.Values(s.Instances)
return json.Marshal(map[string]any{
"instances": instances,
"status": s.Status(),
})
}

View File

@@ -1,158 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sessions_manager.go
//
// Generated by this command:
//
// mockgen -package sessionstest -source=sessions_manager.go -destination=sessionstest/mocks_sessions_manager.go *
//
// Package sessionstest is a generated GoMock package.
package sessionstest
import (
context "context"
io "io"
reflect "reflect"
time "time"
sessions "github.com/sablierapp/sablier/app/sessions"
gomock "go.uber.org/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
isgomock struct{}
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// LoadSessions mocks base method.
func (m *MockManager) LoadSessions(arg0 io.ReadCloser) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadSessions", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LoadSessions indicates an expected call of LoadSessions.
func (mr *MockManagerMockRecorder) LoadSessions(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadSessions", reflect.TypeOf((*MockManager)(nil).LoadSessions), arg0)
}
// RemoveInstance mocks base method.
func (m *MockManager) RemoveInstance(name string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveInstance", name)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveInstance indicates an expected call of RemoveInstance.
func (mr *MockManagerMockRecorder) RemoveInstance(name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveInstance", reflect.TypeOf((*MockManager)(nil).RemoveInstance), name)
}
// RequestReadySession mocks base method.
func (m *MockManager) RequestReadySession(ctx context.Context, names []string, duration, timeout time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySession", ctx, names, duration, timeout)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySession indicates an expected call of RequestReadySession.
func (mr *MockManagerMockRecorder) RequestReadySession(ctx, names, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySession", reflect.TypeOf((*MockManager)(nil).RequestReadySession), ctx, names, duration, timeout)
}
// RequestReadySessionGroup mocks base method.
func (m *MockManager) RequestReadySessionGroup(ctx context.Context, group string, duration, timeout time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySessionGroup", ctx, group, duration, timeout)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySessionGroup indicates an expected call of RequestReadySessionGroup.
func (mr *MockManagerMockRecorder) RequestReadySessionGroup(ctx, group, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySessionGroup", reflect.TypeOf((*MockManager)(nil).RequestReadySessionGroup), ctx, group, duration, timeout)
}
// RequestSession mocks base method.
func (m *MockManager) RequestSession(ctx context.Context, names []string, duration time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSession", ctx, names, duration)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSession indicates an expected call of RequestSession.
func (mr *MockManagerMockRecorder) RequestSession(ctx, names, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSession", reflect.TypeOf((*MockManager)(nil).RequestSession), ctx, names, duration)
}
// RequestSessionGroup mocks base method.
func (m *MockManager) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSessionGroup", ctx, group, duration)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSessionGroup indicates an expected call of RequestSessionGroup.
func (mr *MockManagerMockRecorder) RequestSessionGroup(ctx, group, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSessionGroup", reflect.TypeOf((*MockManager)(nil).RequestSessionGroup), ctx, group, duration)
}
// SaveSessions mocks base method.
func (m *MockManager) SaveSessions(arg0 io.WriteCloser) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveSessions", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SaveSessions indicates an expected call of SaveSessions.
func (mr *MockManagerMockRecorder) SaveSessions(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveSessions", reflect.TypeOf((*MockManager)(nil).SaveSessions), arg0)
}
// SetGroups mocks base method.
func (m *MockManager) SetGroups(groups map[string][]string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetGroups", groups)
}
// SetGroups indicates an expected call of SetGroups.
func (mr *MockManagerMockRecorder) SetGroups(groups any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockManager)(nil).SetGroups), groups)
}

View File

@@ -1,6 +0,0 @@
package types
type Instance struct {
Name string
Group string
}

View File

@@ -2,14 +2,14 @@ package api
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/sessions" "github.com/sablierapp/sablier/pkg/sablier"
) )
const SablierStatusHeader = "X-Sablier-Session-Status" const SablierStatusHeader = "X-Sablier-Session-Status"
const SablierStatusReady = "ready" const SablierStatusReady = "ready"
const SablierStatusNotReady = "not-ready" const SablierStatusNotReady = "not-ready"
func AddSablierHeader(c *gin.Context, session *sessions.SessionState) { func AddSablierHeader(c *gin.Context, session *sablier.SessionState) {
if session.IsReady() { if session.IsReady() {
c.Header(SablierStatusHeader, SablierStatusReady) c.Header(SablierStatusHeader, SablierStatusReady)
} else { } else {

View File

@@ -4,8 +4,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/sessions/sessionstest"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier/sabliertest"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -14,7 +14,7 @@ import (
"testing" "testing"
) )
func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *routes.ServeStrategy, mock *sessionstest.MockManager) { func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *routes.ServeStrategy, mock *sabliertest.MockSablier) {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@@ -23,7 +23,7 @@ func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strateg
app = gin.New() app = gin.New()
router = app.Group("/api") router = app.Group("/api")
mock = sessionstest.NewMockManager(ctrl) mock = sabliertest.NewMockSablier(ctrl)
strategy = &routes.ServeStrategy{ strategy = &routes.ServeStrategy{
Theme: th, Theme: th,
SessionsManager: mock, SessionsManager: mock,

View File

@@ -1,7 +1,7 @@
package api package api
import ( import (
"github.com/sablierapp/sablier/app/sessions" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
"github.com/tniswong/go.rfcx/rfc7807" "github.com/tniswong/go.rfcx/rfc7807"
"net/http" "net/http"
@@ -25,7 +25,7 @@ func ProblemValidation(e error) rfc7807.Problem {
} }
} }
func ProblemGroupNotFound(e sessions.ErrGroupNotFound) rfc7807.Problem { func ProblemGroupNotFound(e sablier.ErrGroupNotFound) rfc7807.Problem {
pb := rfc7807.Problem{ pb := rfc7807.Problem{
Type: "https://sablierapp.dev/#/errors?id=group-not-found", Type: "https://sablierapp.dev/#/errors?id=group-not-found",
Title: "Group not found", Title: "Group not found",

View File

@@ -5,7 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/http/routes/models" "github.com/sablierapp/sablier/app/http/routes/models"
"github.com/sablierapp/sablier/app/sessions" "github.com/sablierapp/sablier/pkg/sablier"
"net/http" "net/http"
) )
@@ -31,13 +31,13 @@ func StartBlocking(router *gin.RouterGroup, s *routes.ServeStrategy) {
return return
} }
var sessionState *sessions.SessionState var sessionState *sablier.SessionState
var err error var err error
if len(request.Names) > 0 { if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestReadySession(c.Request.Context(), request.Names, request.SessionDuration, request.Timeout) sessionState, err = s.SessionsManager.RequestReadySession(c.Request.Context(), request.Names, request.SessionDuration, request.Timeout)
} else { } else {
sessionState, err = s.SessionsManager.RequestReadySessionGroup(c.Request.Context(), request.Group, request.SessionDuration, request.Timeout) sessionState, err = s.SessionsManager.RequestReadySessionGroup(c.Request.Context(), request.Group, request.SessionDuration, request.Timeout)
var groupNotFoundError sessions.ErrGroupNotFound var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) { if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError)) AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))
return return

View File

@@ -2,7 +2,7 @@ package api
import ( import (
"errors" "errors"
"github.com/sablierapp/sablier/app/sessions" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/tniswong/go.rfcx/rfc7807" "github.com/tniswong/go.rfcx/rfc7807"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -35,7 +35,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingByNames", func(t *testing.T) { t.Run("StartBlockingByNames", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t) app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy) StartBlocking(router, strategy)
m.EXPECT().RequestReadySession(gomock.Any(), []string{"test"}, gomock.Any(), gomock.Any()).Return(&sessions.SessionState{}, nil) m.EXPECT().RequestReadySession(gomock.Any(), []string{"test"}, gomock.Any(), gomock.Any()).Return(&sablier.SessionState{}, nil)
r := PerformRequest(app, "GET", "/api/strategies/blocking?names=test") r := PerformRequest(app, "GET", "/api/strategies/blocking?names=test")
assert.Equal(t, http.StatusOK, r.Code) assert.Equal(t, http.StatusOK, r.Code)
assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader)) assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader))
@@ -43,7 +43,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingByGroup", func(t *testing.T) { t.Run("StartBlockingByGroup", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t) app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy) StartBlocking(router, strategy)
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(&sessions.SessionState{}, nil) m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(&sablier.SessionState{}, nil)
r := PerformRequest(app, "GET", "/api/strategies/blocking?group=test") r := PerformRequest(app, "GET", "/api/strategies/blocking?group=test")
assert.Equal(t, http.StatusOK, r.Code) assert.Equal(t, http.StatusOK, r.Code)
assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader)) assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader))
@@ -51,7 +51,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingErrGroupNotFound", func(t *testing.T) { t.Run("StartBlockingErrGroupNotFound", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t) app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy) StartBlocking(router, strategy)
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(nil, sessions.ErrGroupNotFound{ m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(nil, sablier.ErrGroupNotFound{
Group: "test", Group: "test",
AvailableGroups: []string{"test1", "test2"}, AvailableGroups: []string{"test1", "test2"},
}) })

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"errors" "errors"
"github.com/sablierapp/sablier/pkg/sablier"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -11,8 +12,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/http/routes/models" "github.com/sablierapp/sablier/app/http/routes/models"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/sessions"
theme2 "github.com/sablierapp/sablier/pkg/theme" theme2 "github.com/sablierapp/sablier/pkg/theme"
) )
@@ -40,13 +39,13 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
return return
} }
var sessionState *sessions.SessionState var sessionState *sablier.SessionState
var err error var err error
if len(request.Names) > 0 { if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestSession(c, request.Names, request.SessionDuration) sessionState, err = s.SessionsManager.RequestSession(c, request.Names, request.SessionDuration)
} else { } else {
sessionState, err = s.SessionsManager.RequestSessionGroup(c, request.Group, request.SessionDuration) sessionState, err = s.SessionsManager.RequestSessionGroup(c, request.Group, request.SessionDuration)
var groupNotFoundError sessions.ErrGroupNotFound var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) { if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError)) AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))
return return
@@ -90,7 +89,7 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
}) })
} }
func sessionStateToRenderOptionsInstanceState(sessionState *sessions.SessionState) (instances []theme2.Instance) { func sessionStateToRenderOptionsInstanceState(sessionState *sablier.SessionState) (instances []theme2.Instance) {
if sessionState == nil { if sessionState == nil {
return return
} }
@@ -106,7 +105,7 @@ func sessionStateToRenderOptionsInstanceState(sessionState *sessions.SessionStat
return return
} }
func instanceStateToRenderOptionsRequestState(instanceState instance.State) theme2.Instance { func instanceStateToRenderOptionsRequestState(instanceState sablier.InstanceInfo) theme2.Instance {
var err error var err error
if instanceState.Message == "" { if instanceState.Message == "" {
@@ -117,7 +116,7 @@ func instanceStateToRenderOptionsRequestState(instanceState instance.State) them
return theme2.Instance{ return theme2.Instance{
Name: instanceState.Name, Name: instanceState.Name,
Status: instanceState.Status, Status: string(instanceState.Status),
CurrentReplicas: instanceState.CurrentReplicas, CurrentReplicas: instanceState.CurrentReplicas,
DesiredReplicas: instanceState.DesiredReplicas, DesiredReplicas: instanceState.DesiredReplicas,
Error: err, Error: err,

View File

@@ -2,8 +2,7 @@ package api
import ( import (
"errors" "errors"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/app/sessions"
"github.com/tniswong/go.rfcx/rfc7807" "github.com/tniswong/go.rfcx/rfc7807"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -11,11 +10,11 @@ import (
"testing" "testing"
) )
func session() *sessions.SessionState { func session() *sablier.SessionState {
state := instance.ReadyInstanceState("test", 1) state := sablier.ReadyInstanceState("test", 1)
state2 := instance.ReadyInstanceState("test2", 1) state2 := sablier.ReadyInstanceState("test2", 1)
return &sessions.SessionState{ return &sablier.SessionState{
Instances: map[string]sessions.InstanceState{ Instances: map[string]sablier.InstanceInfoWithError{
"test": { "test": {
Instance: state, Instance: state,
Error: nil, Error: nil,
@@ -77,7 +76,7 @@ func TestStartDynamic(t *testing.T) {
t.Run("StartDynamicErrGroupNotFound", func(t *testing.T) { t.Run("StartDynamicErrGroupNotFound", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t) app, router, strategy, m := NewApiTest(t)
StartDynamic(router, strategy) StartDynamic(router, strategy)
m.EXPECT().RequestSessionGroup(gomock.Any(), "test", gomock.Any()).Return(nil, sessions.ErrGroupNotFound{ m.EXPECT().RequestSessionGroup(gomock.Any(), "test", gomock.Any()).Return(nil, sablier.ErrGroupNotFound{
Group: "test", Group: "test",
AvailableGroups: []string{"test1", "test2"}, AvailableGroups: []string{"test1", "test2"},
}) })

View File

@@ -3,41 +3,41 @@ package docker
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"log/slog" "log/slog"
) )
func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
spec, err := p.Client.ContainerInspect(ctx, name) spec, err := p.Client.ContainerInspect(ctx, name)
if err != nil { if err != nil {
return instance.State{}, fmt.Errorf("cannot inspect container: %w", err) return sablier.InstanceInfo{}, fmt.Errorf("cannot inspect container: %w", err)
} }
// "created", "running", "paused", "restarting", "removing", "exited", or "dead" // "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch spec.State.Status { switch spec.State.Status {
case "created", "paused", "restarting", "removing": case "created", "paused", "restarting", "removing":
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
case "running": case "running":
if spec.State.Health != nil { if spec.State.Health != nil {
// // "starting", "healthy" or "unhealthy" // // "starting", "healthy" or "unhealthy"
if spec.State.Health.Status == "healthy" { if spec.State.Health.Status == "healthy" {
return instance.ReadyInstanceState(name, p.desiredReplicas), nil return sablier.ReadyInstanceState(name, p.desiredReplicas), nil
} else if spec.State.Health.Status == "unhealthy" { } else if spec.State.Health.Status == "unhealthy" {
return instance.UnrecoverableInstanceState(name, "container is unhealthy", p.desiredReplicas), nil return sablier.UnrecoverableInstanceState(name, "container is unhealthy", p.desiredReplicas), nil
} else { } else {
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
} }
} }
p.l.WarnContext(ctx, "container running without healthcheck, you should define a healthcheck on your container so that Sablier properly detects when the container is ready to handle requests.", slog.String("container", name)) p.l.WarnContext(ctx, "container running without healthcheck, you should define a healthcheck on your container so that Sablier properly detects when the container is ready to handle requests.", slog.String("container", name))
return instance.ReadyInstanceState(name, p.desiredReplicas), nil return sablier.ReadyInstanceState(name, p.desiredReplicas), nil
case "exited": case "exited":
if spec.State.ExitCode != 0 { if spec.State.ExitCode != 0 {
return instance.UnrecoverableInstanceState(name, fmt.Sprintf("container exited with code \"%d\"", spec.State.ExitCode), p.desiredReplicas), nil return sablier.UnrecoverableInstanceState(name, fmt.Sprintf("container exited with code \"%d\"", spec.State.ExitCode), p.desiredReplicas), nil
} }
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
case "dead": case "dead":
return instance.UnrecoverableInstanceState(name, "container in \"dead\" state cannot be restarted", p.desiredReplicas), nil return sablier.UnrecoverableInstanceState(name, "container in \"dead\" state cannot be restarted", p.desiredReplicas), nil
default: default:
return instance.UnrecoverableInstanceState(name, fmt.Sprintf("container status \"%s\" not handled", spec.State.Status), p.desiredReplicas), nil return sablier.UnrecoverableInstanceState(name, fmt.Sprintf("container status \"%s\" not handled", spec.State.Status), p.desiredReplicas), nil
} }
} }

View File

@@ -5,8 +5,8 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/docker" "github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
"time" "time"
@@ -24,7 +24,7 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -38,10 +38,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return resp.ID, err return resp.ID, err
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -60,10 +60,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{}) return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{})
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -90,10 +90,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{}) return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{})
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -126,10 +126,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil return c.ID, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Unrecoverable, Status: sablier.InstanceStatusUnrecoverable,
Message: "container is unhealthy", Message: "container is unhealthy",
}, },
wantErr: nil, wantErr: nil,
@@ -163,10 +163,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil return c.ID, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -192,10 +192,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil return c.ID, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -221,10 +221,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil return c.ID, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -250,10 +250,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil return c.ID, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Unrecoverable, Status: sablier.InstanceStatusUnrecoverable,
Message: "container exited with code \"137\"", Message: "container exited with code \"137\"",
}, },
wantErr: nil, wantErr: nil,

View File

@@ -7,12 +7,12 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"strings" "strings"
) )
func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs() args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
@@ -24,7 +24,7 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid
return nil, err return nil, err
} }
instances := make([]types.Instance, 0, len(containers)) instances := make([]sablier.InstanceConfiguration, 0, len(containers))
for _, c := range containers { for _, c := range containers {
instance := containerToInstance(c) instance := containerToInstance(c)
instances = append(instances, instance) instances = append(instances, instance)
@@ -33,7 +33,7 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid
return instances, nil return instances, nil
} }
func containerToInstance(c dockertypes.Container) types.Instance { func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration {
var group string var group string
if _, ok := c.Labels[discovery.LabelEnable]; ok { if _, ok := c.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func containerToInstance(c dockertypes.Container) types.Instance {
} }
} }
return types.Instance{ return sablier.InstanceConfiguration{
Name: strings.TrimPrefix(c.Names[0], "/"), // Containers name are reported with a leading slash Name: strings.TrimPrefix(c.Names[0], "/"), // Containers name are reported with a leading slash
Group: group, Group: group,
} }

View File

@@ -2,9 +2,9 @@ package docker_test
import ( import (
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/docker" "github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"sort" "sort"
"strings" "strings"
@@ -49,7 +49,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
}) })
assert.NilError(t, err) assert.NilError(t, err)
want := []types.Instance{ want := []sablier.InstanceConfiguration{
{ {
Name: strings.TrimPrefix(i1.Name, "/"), Name: strings.TrimPrefix(i1.Name, "/"),
Group: "default", Group: "default",

View File

@@ -4,12 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/sablier"
"log/slog" "log/slog"
) )
// Interface guard // Interface guard
var _ provider.Provider = (*DockerClassicProvider)(nil) var _ sablier.Provider = (*DockerClassicProvider)(nil)
type DockerClassicProvider struct { type DockerClassicProvider struct {
Client client.APIClient Client client.APIClient

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/sablier"
"log/slog" "log/slog"
"strings" "strings"
@@ -14,7 +14,7 @@ import (
) )
// Interface guard // Interface guard
var _ provider.Provider = (*DockerSwarmProvider)(nil) var _ sablier.Provider = (*DockerSwarmProvider)(nil)
type DockerSwarmProvider struct { type DockerSwarmProvider struct {
Client client.APIClient Client client.APIClient

View File

@@ -7,26 +7,26 @@ import (
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm" "github.com/docker/docker/api/types/swarm"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
service, err := p.getServiceByName(name, ctx) service, err := p.getServiceByName(name, ctx)
if err != nil { if err != nil {
return instance.State{}, err return sablier.InstanceInfo{}, err
} }
foundName := p.getInstanceName(name, *service) foundName := p.getInstanceName(name, *service)
if service.Spec.Mode.Replicated == nil { if service.Spec.Mode.Replicated == nil {
return instance.State{}, errors.New("swarm service is not in \"replicated\" mode") return sablier.InstanceInfo{}, errors.New("swarm service is not in \"replicated\" mode")
} }
if service.ServiceStatus.DesiredTasks != service.ServiceStatus.RunningTasks || service.ServiceStatus.DesiredTasks == 0 { if service.ServiceStatus.DesiredTasks != service.ServiceStatus.RunningTasks || service.ServiceStatus.DesiredTasks == 0 {
return instance.NotReadyInstanceState(foundName, 0, p.desiredReplicas), nil return sablier.NotReadyInstanceState(foundName, 0, p.desiredReplicas), nil
} }
return instance.ReadyInstanceState(foundName, p.desiredReplicas), nil return sablier.ReadyInstanceState(foundName, p.desiredReplicas), nil
} }
func (p *DockerSwarmProvider) getServiceByName(name string, ctx context.Context) (*swarm.Service, error) { func (p *DockerSwarmProvider) getServiceByName(name string, ctx context.Context) (*swarm.Service, error) {

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm" "github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
"time" "time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -50,10 +50,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, err return service.Spec.Name, err
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -84,10 +84,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, nil return service.Spec.Name, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -115,10 +115,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, nil return service.Spec.Name, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },

View File

@@ -7,11 +7,11 @@ import (
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm" "github.com/docker/docker/api/types/swarm"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]types.Instance, error) { func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs() args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable)) args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
args.Add("mode", "replicated") args.Add("mode", "replicated")
@@ -24,7 +24,7 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
return nil, err return nil, err
} }
instances := make([]types.Instance, 0, len(services)) instances := make([]sablier.InstanceConfiguration, 0, len(services))
for _, s := range services { for _, s := range services {
instance := p.serviceToInstance(s) instance := p.serviceToInstance(s)
instances = append(instances, instance) instances = append(instances, instance)
@@ -33,7 +33,7 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
return instances, nil return instances, nil
} }
func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i types.Instance) { func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) {
var group string var group string
if _, ok := s.Spec.Labels[discovery.LabelEnable]; ok { if _, ok := s.Spec.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i types.Instan
} }
} }
return types.Instance{ return sablier.InstanceConfiguration{
Name: s.Spec.Name, Name: s.Spec.Name,
Group: group, Group: group,
} }

View File

@@ -2,9 +2,9 @@ package dockerswarm_test
import ( import (
dockertypes "github.com/docker/docker/api/types" dockertypes "github.com/docker/docker/api/types"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm" "github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -49,7 +49,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
}) })
assert.NilError(t, err) assert.NilError(t, err)
want := []types.Instance{ want := []sablier.InstanceConfiguration{
{ {
Name: i1.Spec.Name, Name: i1.Spec.Name,
Group: "default", Group: "default",

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm" "github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
"time" "time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -48,10 +48,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, err return service.Spec.Name, err
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -82,10 +82,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, nil return service.Spec.Name, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -114,10 +114,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, nil return service.Spec.Name, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm" "github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
"time" "time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -48,10 +48,10 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
return service.Spec.Name, err return service.Spec.Name, err
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -82,10 +82,10 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
return service.Spec.Name, nil return service.Spec.Name, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },

View File

@@ -3,20 +3,20 @@ package kubernetes
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (instance.State, error) { func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil { if err != nil {
return instance.State{}, fmt.Errorf("error getting deployment: %w", err) return sablier.InstanceInfo{}, fmt.Errorf("error getting deployment: %w", err)
}
// TODO: Should add option to set ready as soon as one replica is ready
if *d.Spec.Replicas != 0 && *d.Spec.Replicas == d.Status.ReadyReplicas {
return instance.ReadyInstanceState(config.Original, config.Replicas), nil
} }
return instance.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil // TODO: Should add option to set ready as soon as one replica is ready
if *d.Spec.Replicas != 0 && *d.Spec.Replicas == d.Status.ReadyReplicas {
return sablier.ReadyInstanceState(config.Original, config.Replicas), nil
}
return sablier.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil
} }

View File

@@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
autoscalingv1 "k8s.io/api/autoscaling/v1" autoscalingv1 "k8s.io/api/autoscaling/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
@@ -27,7 +27,7 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -49,10 +49,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -71,10 +71,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -106,10 +106,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },

View File

@@ -3,13 +3,13 @@ package kubernetes
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types" "github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1" v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1" core_v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instance, error) { func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
discovery.LabelEnable: "true", discovery.LabelEnable: "true",
@@ -22,7 +22,7 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instan
return nil, err return nil, err
} }
instances := make([]types.Instance, 0, len(deployments.Items)) instances := make([]sablier.InstanceConfiguration, 0, len(deployments.Items))
for _, d := range deployments.Items { for _, d := range deployments.Items {
instance := p.deploymentToInstance(&d) instance := p.deploymentToInstance(&d)
instances = append(instances, instance) instances = append(instances, instance)
@@ -31,7 +31,7 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instan
return instances, nil return instances, nil
} }
func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) types.Instance { func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration {
var group string var group string
if _, ok := d.Labels[discovery.LabelEnable]; ok { if _, ok := d.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) types.Instan
parsed := DeploymentName(d, ParseOptions{Delimiter: p.delimiter}) parsed := DeploymentName(d, ParseOptions{Delimiter: p.delimiter})
return types.Instance{ return sablier.InstanceConfiguration{
Name: parsed.Original, Name: parsed.Original,
Group: group, Group: group,
} }

View File

@@ -3,13 +3,13 @@ package kubernetes
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter})
if err != nil { if err != nil {
return instance.State{}, err return sablier.InstanceInfo{}, err
} }
switch parsed.Kind { switch parsed.Kind {
@@ -18,6 +18,6 @@ func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (
case "statefulset": case "statefulset":
return p.StatefulSetInspect(ctx, parsed) return p.StatefulSetInspect(ctx, parsed)
default: default:
return instance.State{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind) return sablier.InstanceInfo{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind)
} }
} }

View File

@@ -11,6 +11,10 @@ import (
) )
func TestKubernetesProvider_InstanceInspect(t *testing.T) { func TestKubernetesProvider_InstanceInspect(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ctx := context.Background() ctx := context.Background()
type args struct { type args struct {
name string name string

View File

@@ -2,11 +2,11 @@ package kubernetes
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
deployments, err := p.DeploymentList(ctx) deployments, err := p.DeploymentList(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -2,10 +2,10 @@ package kubernetes_test
import ( import (
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"sort" "sort"
"strings" "strings"
@@ -57,7 +57,7 @@ func TestKubernetesProvider_InstanceList(t *testing.T) {
}) })
assert.NilError(t, err) assert.NilError(t, err)
want := []types.Instance{ want := []sablier.InstanceConfiguration{
{ {
Name: kubernetes.DeploymentName(d1, kubernetes.ParseOptions{Delimiter: "_"}).Original, Name: kubernetes.DeploymentName(d1, kubernetes.ParseOptions{Delimiter: "_"}).Original,
Group: "default", Group: "default",

View File

@@ -2,7 +2,7 @@ package kubernetes
import ( import (
"context" "context"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/sablier"
"log/slog" "log/slog"
providerConfig "github.com/sablierapp/sablier/config" providerConfig "github.com/sablierapp/sablier/config"
@@ -10,7 +10,7 @@ import (
) )
// Interface guard // Interface guard
var _ provider.Provider = (*KubernetesProvider)(nil) var _ sablier.Provider = (*KubernetesProvider)(nil)
type KubernetesProvider struct { type KubernetesProvider struct {
Client kubernetes.Interface Client kubernetes.Interface

View File

@@ -2,19 +2,19 @@ package kubernetes
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (instance.State, error) { func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil { if err != nil {
return instance.State{}, err return sablier.InstanceInfo{}, err
} }
if *ss.Spec.Replicas != 0 && *ss.Spec.Replicas == ss.Status.ReadyReplicas { if *ss.Spec.Replicas != 0 && *ss.Spec.Replicas == ss.Status.ReadyReplicas {
return instance.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil return sablier.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil
} }
return instance.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, config.Replicas), nil return sablier.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, config.Replicas), nil
} }

View File

@@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
autoscalingv1 "k8s.io/api/autoscaling/v1" autoscalingv1 "k8s.io/api/autoscaling/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
@@ -27,7 +27,7 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want instance.State want sablier.InstanceInfo
wantErr error wantErr error
}{ }{
{ {
@@ -49,10 +49,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 1, CurrentReplicas: 1,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.Ready, Status: sablier.InstanceStatusReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -71,10 +71,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },
@@ -106,10 +106,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
}, },
}, },
want: instance.State{ want: sablier.InstanceInfo{
CurrentReplicas: 0, CurrentReplicas: 0,
DesiredReplicas: 1, DesiredReplicas: 1,
Status: instance.NotReady, Status: sablier.InstanceStatusNotReady,
}, },
wantErr: nil, wantErr: nil,
}, },

View File

@@ -3,13 +3,13 @@ package kubernetes
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/discovery" "github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types" "github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1" v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1" core_v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Instance, error) { func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
discovery.LabelEnable: "true", discovery.LabelEnable: "true",
@@ -22,7 +22,7 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Insta
return nil, err return nil, err
} }
instances := make([]types.Instance, 0, len(statefulSets.Items)) instances := make([]sablier.InstanceConfiguration, 0, len(statefulSets.Items))
for _, ss := range statefulSets.Items { for _, ss := range statefulSets.Items {
instance := p.statefulSetToInstance(&ss) instance := p.statefulSetToInstance(&ss)
instances = append(instances, instance) instances = append(instances, instance)
@@ -31,7 +31,7 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Insta
return instances, nil return instances, nil
} }
func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) types.Instance { func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration {
var group string var group string
if _, ok := ss.Labels[discovery.LabelEnable]; ok { if _, ok := ss.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) types.Ins
parsed := StatefulSetName(ss, ParseOptions{Delimiter: p.delimiter}) parsed := StatefulSetName(ss, ParseOptions{Delimiter: p.delimiter})
return types.Instance{ return sablier.InstanceConfiguration{
Name: parsed.Original, Name: parsed.Original,
Group: group, Group: group,
} }

View File

@@ -3,7 +3,7 @@
// //
// Generated by this command: // Generated by this command:
// //
// mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go * // mockgen -package providertest -source=provider.go -destination=../provider/providertest/mock_provider.go *
// //
// Package providertest is a generated GoMock package. // Package providertest is a generated GoMock package.
@@ -13,9 +13,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
instance "github.com/sablierapp/sablier/app/instance"
types "github.com/sablierapp/sablier/app/types"
provider "github.com/sablierapp/sablier/pkg/provider" provider "github.com/sablierapp/sablier/pkg/provider"
sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@@ -59,10 +58,10 @@ func (mr *MockProviderMockRecorder) InstanceGroups(ctx any) *gomock.Call {
} }
// InstanceInspect mocks base method. // InstanceInspect mocks base method.
func (m *MockProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) { func (m *MockProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceInspect", ctx, name) ret := m.ctrl.Call(m, "InstanceInspect", ctx, name)
ret0, _ := ret[0].(instance.State) ret0, _ := ret[0].(sablier.InstanceInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -74,10 +73,10 @@ func (mr *MockProviderMockRecorder) InstanceInspect(ctx, name any) *gomock.Call
} }
// InstanceList mocks base method. // InstanceList mocks base method.
func (m *MockProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) { func (m *MockProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceList", ctx, options) ret := m.ctrl.Call(m, "InstanceList", ctx, options)
ret0, _ := ret[0].([]types.Instance) ret0, _ := ret[0].([]sablier.InstanceConfiguration)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View File

@@ -1,4 +1,4 @@
package sessions package sablier
import ( import (
"fmt" "fmt"

54
pkg/sablier/instance.go Normal file
View File

@@ -0,0 +1,54 @@
package sablier
type InstanceStatus string
const (
InstanceStatusReady = "ready"
InstanceStatusNotReady = "not-ready"
InstanceStatusUnrecoverable = "unrecoverable"
)
type InstanceInfo struct {
Name string `json:"name"`
CurrentReplicas int32 `json:"currentReplicas"`
DesiredReplicas int32 `json:"desiredReplicas"`
Status InstanceStatus `json:"status"`
Message string `json:"message,omitempty"`
}
type InstanceConfiguration struct {
Name string
Group string
}
func (instance InstanceInfo) IsReady() bool {
return instance.Status == InstanceStatusReady
}
func UnrecoverableInstanceState(name string, message string, desiredReplicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: InstanceStatusUnrecoverable,
Message: message,
}
}
func ReadyInstanceState(name string, replicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: replicas,
DesiredReplicas: replicas,
Status: InstanceStatusReady,
}
}
func NotReadyInstanceState(name string, currentReplicas int32, desiredReplicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: currentReplicas,
DesiredReplicas: desiredReplicas,
Status: InstanceStatusNotReady,
}
}

View File

@@ -0,0 +1,51 @@
package sablier
import (
"context"
"errors"
"fmt"
"github.com/sablierapp/sablier/pkg/store"
"log/slog"
"time"
)
func (s *sablier) InstanceRequest(ctx context.Context, name string, duration time.Duration) (InstanceInfo, error) {
if name == "" {
return InstanceInfo{}, errors.New("instance name cannot be empty")
}
state, err := s.sessions.Get(ctx, name)
if errors.Is(err, store.ErrKeyNotFound) {
s.l.DebugContext(ctx, "request to start instance received", slog.String("instance", name))
err = s.provider.InstanceStart(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
s.l.DebugContext(ctx, "request to start instance status completed", slog.String("instance", name), slog.String("status", string(state.Status)))
} else if err != nil {
s.l.ErrorContext(ctx, "request to start instance failed", slog.String("instance", name), slog.Any("error", err))
return InstanceInfo{}, fmt.Errorf("cannot retrieve instance from store: %w", err)
} else if state.Status != InstanceStatusReady {
s.l.DebugContext(ctx, "request to check instance status received", slog.String("instance", name), slog.String("current_status", string(state.Status)))
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
s.l.DebugContext(ctx, "request to check instance status completed", slog.String("instance", name), slog.String("new_status", string(state.Status)))
}
s.l.DebugContext(ctx, "set expiration for instance", slog.String("instance", name), slog.Duration("expiration", duration))
err = s.sessions.Put(ctx, state, duration)
if err != nil {
s.l.Error("could not put instance to store, will not expire", slog.Any("error", err), slog.String("instance", state.Name))
return InstanceInfo{}, fmt.Errorf("could not put instance to store: %w", err)
}
return state, nil
}

View File

@@ -1,20 +1,18 @@
package provider package sablier
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/types" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/app/instance"
) )
//go:generate go tool mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go * //go:generate go tool mockgen -package providertest -source=provider.go -destination=../provider/providertest/mock_provider.go *
type Provider interface { type Provider interface {
InstanceStart(ctx context.Context, name string) error InstanceStart(ctx context.Context, name string) error
InstanceStop(ctx context.Context, name string) error InstanceStop(ctx context.Context, name string) error
InstanceInspect(ctx context.Context, name string) (instance.State, error) InstanceInspect(ctx context.Context, name string) (InstanceInfo, error)
InstanceGroups(ctx context.Context) (map[string][]string, error) InstanceGroups(ctx context.Context) (map[string][]string, error)
InstanceList(ctx context.Context, options InstanceListOptions) ([]types.Instance, error) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]InstanceConfiguration, error)
NotifyInstanceStopped(ctx context.Context, instance chan<- string) NotifyInstanceStopped(ctx context.Context, instance chan<- string)
} }

52
pkg/sablier/sablier.go Normal file
View File

@@ -0,0 +1,52 @@
package sablier
import (
"context"
"github.com/google/go-cmp/cmp"
"log/slog"
"time"
)
//go:generate go tool mockgen -package sabliertest -source=sablier.go -destination=sabliertest/mocks_sablier.go *
type Sablier interface {
RequestSession(ctx context.Context, names []string, duration time.Duration) (*SessionState, error)
RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*SessionState, error)
RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RemoveInstance(ctx context.Context, name string) error
SetGroups(groups map[string][]string)
}
type sablier struct {
provider Provider
sessions Store
groups map[string][]string
l *slog.Logger
}
func New(logger *slog.Logger, store Store, provider Provider) Sablier {
return &sablier{
provider: provider,
sessions: store,
groups: map[string][]string{},
l: logger,
}
}
func (s *sablier) SetGroups(groups map[string][]string) {
if groups == nil {
groups = map[string][]string{}
}
if diff := cmp.Diff(s.groups, groups); diff != "" {
// TODO: Change this log for a friendly logging, groups rarely change, so we can put some effort on displaying what changed
s.l.Info("set groups", slog.Any("old", s.groups), slog.Any("new", groups), slog.Any("diff", diff))
s.groups = groups
}
}
func (s *sablier) RemoveInstance(ctx context.Context, name string) error {
return s.sessions.Delete(ctx, name)
}

View File

@@ -0,0 +1,129 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sablier.go
//
// Generated by this command:
//
// mockgen -package sabliertest -source=sablier.go -destination=sabliertest/mocks_sablier.go *
//
// Package sabliertest is a generated GoMock package.
package sabliertest
import (
context "context"
reflect "reflect"
time "time"
sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock"
)
// MockSablier is a mock of Sablier interface.
type MockSablier struct {
ctrl *gomock.Controller
recorder *MockSablierMockRecorder
isgomock struct{}
}
// MockSablierMockRecorder is the mock recorder for MockSablier.
type MockSablierMockRecorder struct {
mock *MockSablier
}
// NewMockSablier creates a new mock instance.
func NewMockSablier(ctrl *gomock.Controller) *MockSablier {
mock := &MockSablier{ctrl: ctrl}
mock.recorder = &MockSablierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSablier) EXPECT() *MockSablierMockRecorder {
return m.recorder
}
// RemoveInstance mocks base method.
func (m *MockSablier) RemoveInstance(ctx context.Context, name string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveInstance", ctx, name)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveInstance indicates an expected call of RemoveInstance.
func (mr *MockSablierMockRecorder) RemoveInstance(ctx, name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveInstance", reflect.TypeOf((*MockSablier)(nil).RemoveInstance), ctx, name)
}
// RequestReadySession mocks base method.
func (m *MockSablier) RequestReadySession(ctx context.Context, names []string, duration, timeout time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySession", ctx, names, duration, timeout)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySession indicates an expected call of RequestReadySession.
func (mr *MockSablierMockRecorder) RequestReadySession(ctx, names, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySession", reflect.TypeOf((*MockSablier)(nil).RequestReadySession), ctx, names, duration, timeout)
}
// RequestReadySessionGroup mocks base method.
func (m *MockSablier) RequestReadySessionGroup(ctx context.Context, group string, duration, timeout time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySessionGroup", ctx, group, duration, timeout)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySessionGroup indicates an expected call of RequestReadySessionGroup.
func (mr *MockSablierMockRecorder) RequestReadySessionGroup(ctx, group, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySessionGroup", reflect.TypeOf((*MockSablier)(nil).RequestReadySessionGroup), ctx, group, duration, timeout)
}
// RequestSession mocks base method.
func (m *MockSablier) RequestSession(ctx context.Context, names []string, duration time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSession", ctx, names, duration)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSession indicates an expected call of RequestSession.
func (mr *MockSablierMockRecorder) RequestSession(ctx, names, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSession", reflect.TypeOf((*MockSablier)(nil).RequestSession), ctx, names, duration)
}
// RequestSessionGroup mocks base method.
func (m *MockSablier) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSessionGroup", ctx, group, duration)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSessionGroup indicates an expected call of RequestSessionGroup.
func (mr *MockSablierMockRecorder) RequestSessionGroup(ctx, group, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSessionGroup", reflect.TypeOf((*MockSablier)(nil).RequestSessionGroup), ctx, group, duration)
}
// SetGroups mocks base method.
func (m *MockSablier) SetGroups(groups map[string][]string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetGroups", groups)
}
// SetGroups indicates an expected call of SetGroups.
func (mr *MockSablierMockRecorder) SetGroups(groups any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockSablier)(nil).SetGroups), groups)
}

41
pkg/sablier/session.go Normal file
View File

@@ -0,0 +1,41 @@
package sablier
import (
"encoding/json"
"maps"
)
type SessionState struct {
Instances map[string]InstanceInfoWithError `json:"instances"`
}
func (s *SessionState) IsReady() bool {
if s.Instances == nil {
s.Instances = map[string]InstanceInfoWithError{}
}
for _, v := range s.Instances {
if v.Error != nil || v.Instance.Status != InstanceStatusReady {
return false
}
}
return true
}
func (s *SessionState) Status() string {
if s.IsReady() {
return "ready"
}
return "not-ready"
}
func (s *SessionState) MarshalJSON() ([]byte, error) {
instances := maps.Values(s.Instances)
return json.Marshal(map[string]any{
"instances": instances,
"status": s.Status(),
})
}

View File

@@ -0,0 +1,143 @@
package sablier
import (
"context"
"fmt"
"log/slog"
"maps"
"slices"
"sync"
"time"
)
type InstanceInfoWithError struct {
Instance InstanceInfo `json:"instance"`
Error error `json:"error"`
}
func (s *sablier) RequestSession(ctx context.Context, names []string, duration time.Duration) (sessionState *SessionState, err error) {
if len(names) == 0 {
return nil, fmt.Errorf("names cannot be empty")
}
var wg sync.WaitGroup
mx := sync.Mutex{}
sessionState = &SessionState{
Instances: map[string]InstanceInfoWithError{},
}
wg.Add(len(names))
for i := 0; i < len(names); i++ {
go func(name string) {
defer wg.Done()
state, err := s.InstanceRequest(ctx, name, duration)
mx.Lock()
defer mx.Unlock()
sessionState.Instances[name] = InstanceInfoWithError{
Instance: state,
Error: err,
}
}(names[i])
}
wg.Wait()
return sessionState, nil
}
func (s *sablier) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestSession(ctx, names, duration)
}
func (s *sablier) RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error) {
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
return nil, err
}
if session.IsReady() {
return session, nil
}
ticker := time.NewTicker(5 * time.Second)
readiness := make(chan *SessionState)
errch := make(chan error)
quit := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
errch <- err
return
}
if session.IsReady() {
readiness <- session
}
case <-quit:
ticker.Stop()
return
}
}
}()
select {
case <-ctx.Done():
s.l.DebugContext(ctx, "request cancelled", slog.Any("reason", ctx.Err()))
close(quit)
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled by user: %w", ctx.Err())
}
return nil, fmt.Errorf("request cancelled by user")
case status := <-readiness:
close(quit)
return status, nil
case err := <-errch:
close(quit)
return nil, err
case <-time.After(timeout):
close(quit)
return nil, fmt.Errorf("session was not ready after %s", timeout.String())
}
}
func (s *sablier) RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestReadySession(ctx, names, duration, timeout)
}

View File

@@ -1,21 +1,21 @@
package sessions package sablier_test
import ( import (
"context" "context"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/pkg/provider/providertest" "github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/storetest" "github.com/sablierapp/sablier/pkg/store/storetest"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"testing" "testing"
"time" "time"
"github.com/sablierapp/sablier/app/instance"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
) )
func TestSessionState_IsReady(t *testing.T) { func TestSessionState_IsReady(t *testing.T) {
type fields struct { type fields struct {
Instances map[string]InstanceState Instances map[string]sablier.InstanceInfoWithError
Error error Error error
} }
tests := []struct { tests := []struct {
@@ -26,9 +26,9 @@ func TestSessionState_IsReady(t *testing.T) {
{ {
name: "all instances are ready", name: "all instances are ready",
fields: fields{ fields: fields{
Instances: createMap([]instance.State{ Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx", Status: instance.Ready}, {Name: "nginx", Status: sablier.InstanceStatusReady},
{Name: "apache", Status: instance.Ready}, {Name: "apache", Status: sablier.InstanceStatusReady},
}), }),
}, },
want: true, want: true,
@@ -36,9 +36,9 @@ func TestSessionState_IsReady(t *testing.T) {
{ {
name: "one instance is not ready", name: "one instance is not ready",
fields: fields{ fields: fields{
Instances: createMap([]instance.State{ Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx", Status: instance.Ready}, {Name: "nginx", Status: sablier.InstanceStatusReady},
{Name: "apache", Status: instance.NotReady}, {Name: "apache", Status: sablier.InstanceStatusNotReady},
}), }),
}, },
want: false, want: false,
@@ -46,16 +46,16 @@ func TestSessionState_IsReady(t *testing.T) {
{ {
name: "no instances specified", name: "no instances specified",
fields: fields{ fields: fields{
Instances: createMap([]instance.State{}), Instances: createMap([]sablier.InstanceInfo{}),
}, },
want: true, want: true,
}, },
{ {
name: "one instance has an error", name: "one instance has an error",
fields: fields{ fields: fields{
Instances: createMap([]instance.State{ Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx-error", Status: instance.Unrecoverable, Message: "connection timeout"}, {Name: "nginx-error", Status: sablier.InstanceStatusUnrecoverable, Message: "connection timeout"},
{Name: "apache", Status: instance.Ready}, {Name: "apache", Status: sablier.InstanceStatusReady},
}), }),
}, },
want: false, want: false,
@@ -63,7 +63,7 @@ func TestSessionState_IsReady(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &SessionState{ s := &sablier.SessionState{
Instances: tt.fields.Instances, Instances: tt.fields.Instances,
} }
if got := s.IsReady(); got != tt.want { if got := s.IsReady(); got != tt.want {
@@ -73,11 +73,11 @@ func TestSessionState_IsReady(t *testing.T) {
} }
} }
func createMap(instances []instance.State) map[string]InstanceState { func createMap(instances []sablier.InstanceInfo) map[string]sablier.InstanceInfoWithError {
states := make(map[string]InstanceState) states := make(map[string]sablier.InstanceInfoWithError)
for _, v := range instances { for _, v := range instances {
states[v.Name] = InstanceState{ states[v.Name] = sablier.InstanceInfoWithError{
Instance: v, Instance: v,
Error: nil, Error: nil,
} }
@@ -86,14 +86,14 @@ func createMap(instances []instance.State) map[string]InstanceState {
return states return states
} }
func setupSessionManager(t *testing.T) (Manager, *storetest.MockStore, *providertest.MockProvider) { func setupSessionManager(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) {
t.Helper() t.Helper()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl) p := providertest.NewMockProvider(ctrl)
s := storetest.NewMockStore(ctrl) s := storetest.NewMockStore(ctrl)
m := NewSessionsManager(slogt.New(t), s, p) m := sablier.New(slogt.New(t), s, p)
return m, s, p return m, s, p
} }
@@ -101,7 +101,7 @@ func TestSessionsManager(t *testing.T) {
t.Run("RemoveInstance", func(t *testing.T) { t.Run("RemoveInstance", func(t *testing.T) {
manager, store, _ := setupSessionManager(t) manager, store, _ := setupSessionManager(t)
store.EXPECT().Delete(gomock.Any(), "test") store.EXPECT().Delete(gomock.Any(), "test")
err := manager.RemoveInstance("test") err := manager.RemoveInstance(t.Context(), "test")
assert.NilError(t, err) assert.NilError(t, err)
}) })
} }
@@ -110,10 +110,10 @@ func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) {
t.Run("request ready session is cancelled by user", func(t *testing.T) { t.Run("request ready session is cancelled by user", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
manager, store, provider := setupSessionManager(t) manager, store, provider := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes() store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
provider.EXPECT().InstanceInspect(ctx, gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) provider.EXPECT().InstanceInspect(ctx, gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil)
errchan := make(chan error) errchan := make(chan error)
go func() { go func() {
@@ -132,10 +132,10 @@ func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) {
t.Run("request ready session is cancelled by timeout", func(t *testing.T) { t.Run("request ready session is cancelled by timeout", func(t *testing.T) {
manager, store, provider := setupSessionManager(t) manager, store, provider := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes() store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
provider.EXPECT().InstanceInspect(t.Context(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil) provider.EXPECT().InstanceInspect(t.Context(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil)
errchan := make(chan error) errchan := make(chan error)
go func() { go func() {
@@ -151,7 +151,7 @@ func TestSessionsManager_RequestReadySession(t *testing.T) {
t.Run("request ready session is ready", func(t *testing.T) { t.Run("request ready session is ready", func(t *testing.T) {
manager, store, _ := setupSessionManager(t) manager, store, _ := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.Ready}, nil).AnyTimes() store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
errchan := make(chan error) errchan := make(chan error)

15
pkg/sablier/store.go Normal file
View File

@@ -0,0 +1,15 @@
package sablier
import (
"context"
"time"
)
//go:generate go tool mockgen -package storetest -source=store.go -destination=../store/storetest/mocks_store.go *
type Store interface {
Get(context.Context, string) (InstanceInfo, error)
Put(context.Context, InstanceInfo, time.Duration) error
Delete(context.Context, string) error
OnExpire(context.Context, func(string)) error
}

View File

@@ -3,24 +3,24 @@ package inmemory
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store" "github.com/sablierapp/sablier/pkg/store"
"github.com/sablierapp/sablier/pkg/tinykv" "github.com/sablierapp/sablier/pkg/tinykv"
"time" "time"
) )
var _ store.Store = (*InMemory)(nil) var _ sablier.Store = (*InMemory)(nil)
var _ json.Marshaler = (*InMemory)(nil) var _ json.Marshaler = (*InMemory)(nil)
var _ json.Unmarshaler = (*InMemory)(nil) var _ json.Unmarshaler = (*InMemory)(nil)
func NewInMemory() store.Store { func NewInMemory() sablier.Store {
return &InMemory{ return &InMemory{
kv: tinykv.New[instance.State](1*time.Second, nil), kv: tinykv.New[sablier.InstanceInfo](1*time.Second, nil),
} }
} }
type InMemory struct { type InMemory struct {
kv tinykv.KV[instance.State] kv tinykv.KV[sablier.InstanceInfo]
} }
func (i InMemory) UnmarshalJSON(bytes []byte) error { func (i InMemory) UnmarshalJSON(bytes []byte) error {
@@ -31,15 +31,15 @@ func (i InMemory) MarshalJSON() ([]byte, error) {
return i.kv.MarshalJSON() return i.kv.MarshalJSON()
} }
func (i InMemory) Get(_ context.Context, s string) (instance.State, error) { func (i InMemory) Get(_ context.Context, s string) (sablier.InstanceInfo, error) {
val, ok := i.kv.Get(s) val, ok := i.kv.Get(s)
if !ok { if !ok {
return instance.State{}, store.ErrKeyNotFound return sablier.InstanceInfo{}, store.ErrKeyNotFound
} }
return val, nil return val, nil
} }
func (i InMemory) Put(_ context.Context, state instance.State, duration time.Duration) error { func (i InMemory) Put(_ context.Context, state sablier.InstanceInfo, duration time.Duration) error {
return i.kv.Put(state.Name, state, duration) return i.kv.Put(state.Name, state, duration)
} }
@@ -49,7 +49,7 @@ func (i InMemory) Delete(_ context.Context, s string) error {
} }
func (i InMemory) OnExpire(_ context.Context, f func(string)) error { func (i InMemory) OnExpire(_ context.Context, f func(string)) error {
i.kv.SetOnExpire(func(k string, _ instance.State) { i.kv.SetOnExpire(func(k string, _ sablier.InstanceInfo) {
f(k) f(k)
}) })
return nil return nil

View File

@@ -2,7 +2,7 @@ package inmemory
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store" "github.com/sablierapp/sablier/pkg/store"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
@@ -24,7 +24,7 @@ func TestInMemory(t *testing.T) {
ctx := context.Background() ctx := context.Background()
vk := NewInMemory() vk := NewInMemory()
err := vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second) err := vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 1*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
i, err := vk.Get(ctx, "test") i, err := vk.Get(ctx, "test")
@@ -40,7 +40,7 @@ func TestInMemory(t *testing.T) {
ctx := context.Background() ctx := context.Background()
vk := NewInMemory() vk := NewInMemory()
err := vk.Put(ctx, instance.State{Name: "test"}, 30*time.Second) err := vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 30*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
i, err := vk.Get(ctx, "test") i, err := vk.Get(ctx, "test")
@@ -66,7 +66,7 @@ func TestInMemory(t *testing.T) {
}) })
assert.NilError(t, err) assert.NilError(t, err)
err = vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second) err = vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 1*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
expired := <-expirations expired := <-expirations
assert.Equal(t, expired, "test") assert.Equal(t, expired, "test")

View File

@@ -1,19 +1,7 @@
package store package store
import ( import (
"context"
"errors" "errors"
"github.com/sablierapp/sablier/app/instance"
"time"
) )
var ErrKeyNotFound = errors.New("key not found") var ErrKeyNotFound = errors.New("key not found")
//go:generate go tool mockgen -package storetest -source=store.go -destination=storetest/mocks_store.go *
type Store interface {
Get(context.Context, string) (instance.State, error)
Put(context.Context, instance.State, time.Duration) error
Delete(context.Context, string) error
OnExpire(context.Context, func(string)) error
}

View File

@@ -3,7 +3,7 @@
// //
// Generated by this command: // Generated by this command:
// //
// mockgen -package storetest -source=store.go -destination=storetest/mocks_store.go * // mockgen -package storetest -source=store.go -destination=../store/storetest/mocks_store.go *
// //
// Package storetest is a generated GoMock package. // Package storetest is a generated GoMock package.
@@ -14,7 +14,7 @@ import (
reflect "reflect" reflect "reflect"
time "time" time "time"
instance "github.com/sablierapp/sablier/app/instance" sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@@ -57,10 +57,10 @@ func (mr *MockStoreMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
} }
// Get mocks base method. // Get mocks base method.
func (m *MockStore) Get(arg0 context.Context, arg1 string) (instance.State, error) { func (m *MockStore) Get(arg0 context.Context, arg1 string) (sablier.InstanceInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1) ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].(instance.State) ret0, _ := ret[0].(sablier.InstanceInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -86,7 +86,7 @@ func (mr *MockStoreMockRecorder) OnExpire(arg0, arg1 any) *gomock.Call {
} }
// Put mocks base method. // Put mocks base method.
func (m *MockStore) Put(arg0 context.Context, arg1 instance.State, arg2 time.Duration) error { func (m *MockStore) Put(arg0 context.Context, arg1 sablier.InstanceInfo, arg2 time.Duration) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Put", arg0, arg1, arg2) ret := m.ctrl.Call(m, "Put", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -3,7 +3,7 @@ package valkey
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store" "github.com/sablierapp/sablier/pkg/store"
"github.com/valkey-io/valkey-go" "github.com/valkey-io/valkey-go"
"log/slog" "log/slog"
@@ -11,13 +11,13 @@ import (
"time" "time"
) )
var _ store.Store = (*ValKey)(nil) var _ sablier.Store = (*ValKey)(nil)
type ValKey struct { type ValKey struct {
Client valkey.Client Client valkey.Client
} }
func New(ctx context.Context, client valkey.Client) (store.Store, error) { func New(ctx context.Context, client valkey.Client) (sablier.Store, error) {
err := client.Do(ctx, client.B().Ping().Build()).Error() err := client.Do(ctx, client.B().Ping().Build()).Error()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -33,25 +33,25 @@ func New(ctx context.Context, client valkey.Client) (store.Store, error) {
return &ValKey{Client: client}, nil return &ValKey{Client: client}, nil
} }
func (v *ValKey) Get(ctx context.Context, s string) (instance.State, error) { func (v *ValKey) Get(ctx context.Context, s string) (sablier.InstanceInfo, error) {
b, err := v.Client.Do(ctx, v.Client.B().Get().Key(s).Build()).AsBytes() b, err := v.Client.Do(ctx, v.Client.B().Get().Key(s).Build()).AsBytes()
if valkey.IsValkeyNil(err) { if valkey.IsValkeyNil(err) {
return instance.State{}, store.ErrKeyNotFound return sablier.InstanceInfo{}, store.ErrKeyNotFound
} }
if err != nil { if err != nil {
return instance.State{}, err return sablier.InstanceInfo{}, err
} }
var i instance.State var i sablier.InstanceInfo
err = json.Unmarshal(b, &i) err = json.Unmarshal(b, &i)
if err != nil { if err != nil {
return instance.State{}, err return sablier.InstanceInfo{}, err
} }
return i, nil return i, nil
} }
func (v *ValKey) Put(ctx context.Context, state instance.State, duration time.Duration) error { func (v *ValKey) Put(ctx context.Context, state sablier.InstanceInfo, duration time.Duration) error {
value, err := json.Marshal(state) value, err := json.Marshal(state)
if err != nil { if err != nil {
return err return err

View File

@@ -2,7 +2,7 @@ package valkey
import ( import (
"context" "context"
"github.com/sablierapp/sablier/app/instance" "github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store" "github.com/sablierapp/sablier/pkg/store"
"github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go"
tcvalkey "github.com/testcontainers/testcontainers-go/modules/valkey" tcvalkey "github.com/testcontainers/testcontainers-go/modules/valkey"
@@ -44,53 +44,51 @@ func setupValKey(t *testing.T) *ValKey {
} }
func TestValKey(t *testing.T) { func TestValKey(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ctx := t.Context()
vk := setupValKey(t)
t.Parallel() t.Parallel()
t.Run("ValKeyErrNotFound", func(t *testing.T) { t.Run("ValKeyErrNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() _, err := vk.Get(ctx, "ValKeyErrNotFound")
vk := setupValKey(t)
_, err := vk.Get(ctx, "test")
assert.ErrorIs(t, err, store.ErrKeyNotFound) assert.ErrorIs(t, err, store.ErrKeyNotFound)
}) })
t.Run("ValKeyPut", func(t *testing.T) { t.Run("ValKeyPut", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
vk := setupValKey(t)
err := vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second) err := vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyPut"}, 1*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
i, err := vk.Get(ctx, "test") i, err := vk.Get(ctx, "ValKeyPut")
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, i.Name, "test") assert.Equal(t, i.Name, "ValKeyPut")
<-time.After(2 * time.Second) <-time.After(2 * time.Second)
_, err = vk.Get(ctx, "test") _, err = vk.Get(ctx, "ValKeyPut")
assert.ErrorIs(t, err, store.ErrKeyNotFound) assert.ErrorIs(t, err, store.ErrKeyNotFound)
}) })
t.Run("ValKeyDelete", func(t *testing.T) { t.Run("ValKeyDelete", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
vk := setupValKey(t)
err := vk.Put(ctx, instance.State{Name: "test"}, 30*time.Second) err := vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyDelete"}, 30*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
i, err := vk.Get(ctx, "test") i, err := vk.Get(ctx, "ValKeyDelete")
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, i.Name, "test") assert.Equal(t, i.Name, "ValKeyDelete")
err = vk.Delete(ctx, "test") err = vk.Delete(ctx, "ValKeyDelete")
assert.NilError(t, err) assert.NilError(t, err)
_, err = vk.Get(ctx, "test") _, err = vk.Get(ctx, "ValKeyDelete")
assert.ErrorIs(t, err, store.ErrKeyNotFound) assert.ErrorIs(t, err, store.ErrKeyNotFound)
}) })
t.Run("ValKeyOnExpire", func(t *testing.T) { t.Run("ValKeyOnExpire", func(t *testing.T) {
t.Parallel() t.Parallel()
vk := setupValKey(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
@@ -100,9 +98,9 @@ func TestValKey(t *testing.T) {
}) })
assert.NilError(t, err) assert.NilError(t, err)
err = vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second) err = vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyOnExpire"}, 1*time.Second)
assert.NilError(t, err) assert.NilError(t, err)
expired := <-expirations expired := <-expirations
assert.Equal(t, expired, "test") assert.Equal(t, expired, "ValKeyOnExpire")
}) })
} }