refactor: remove session manager

The session manager is now simply Sablier
This commit is contained in:
Alexis Couvreur
2025-03-08 12:17:12 -05:00
parent 1298d86c61
commit ce7de13ade
55 changed files with 745 additions and 832 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"golang.org/x/sync/errgroup"
"log/slog"
@@ -13,7 +14,7 @@ import (
// as running instances by Sablier.
// By default, Sablier does not stop all already running instances. Meaning that you need to make an
// initial request in order to trigger the scaling to zero.
func StopAllUnregisteredInstances(ctx context.Context, p provider.Provider, s store.Store, logger *slog.Logger) error {
func StopAllUnregisteredInstances(ctx context.Context, p sablier.Provider, s sablier.Store, logger *slog.Logger) error {
instances, err := p.InstanceList(ctx, provider.InstanceListOptions{
All: false, // Only running containers
Labels: []string{LabelEnable},
@@ -41,7 +42,7 @@ func StopAllUnregisteredInstances(ctx context.Context, p provider.Provider, s st
return waitGroup.Wait()
}
func stopFunc(ctx context.Context, name string, p provider.Provider, logger *slog.Logger) func() error {
func stopFunc(ctx context.Context, name string, p sablier.Provider, logger *slog.Logger) func() error {
return func() error {
err := p.InstanceStop(ctx, name)
if err != nil {

View File

@@ -4,10 +4,9 @@ import (
"errors"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory"
gomock "go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
@@ -22,13 +21,13 @@ func TestStopAllUnregisteredInstances(t *testing.T) {
ctx := t.Context()
// Define instances and registered instances
instances := []types.Instance{
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
{Name: "instance3"},
}
store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute)
err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err)
// Set up expectations for InstanceList
@@ -53,13 +52,13 @@ func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
ctx := t.Context()
// Define instances and registered instances
instances := []types.Instance{
instances := []sablier.InstanceConfiguration{
{Name: "instance1"},
{Name: "instance2"},
{Name: "instance3"},
}
store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute)
err := store.Put(ctx, sablier.InstanceInfo{Name: "instance1"}, time.Minute)
assert.NilError(t, err)
// Set up expectations for InstanceList

View File

@@ -1,15 +1,15 @@
package routes
import (
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme"
)
type ServeStrategy struct {
Theme *theme.Themes
SessionsManager sessions.Manager
SessionsManager sablier.Sablier
StrategyConfig config.Strategy
SessionsConfig config.Sessions
}

View File

@@ -1,55 +0,0 @@
package instance
var Ready = "ready"
var NotReady = "not-ready"
var Unrecoverable = "unrecoverable"
type State struct {
Name string `json:"name"`
CurrentReplicas int32 `json:"currentReplicas"`
DesiredReplicas int32 `json:"desiredReplicas"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
func (instance State) IsReady() bool {
return instance.Status == Ready
}
func ErrorInstanceState(name string, err error, desiredReplicas int32) (State, error) {
return State{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: Unrecoverable,
Message: err.Error(),
}, err
}
func UnrecoverableInstanceState(name string, message string, desiredReplicas int32) State {
return State{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: Unrecoverable,
Message: message,
}
}
func ReadyInstanceState(name string, replicas int32) State {
return State{
Name: name,
CurrentReplicas: replicas,
DesiredReplicas: replicas,
Status: Ready,
}
}
func NotReadyInstanceState(name string, currentReplicas int32, desiredReplicas int32) State {
return State{
Name: name,
CurrentReplicas: currentReplicas,
DesiredReplicas: desiredReplicas,
Status: NotReady,
}
}

View File

@@ -6,10 +6,10 @@ import (
"github.com/docker/docker/client"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory"
"github.com/sablierapp/sablier/pkg/theme"
k8s "k8s.io/client-go/kubernetes"
@@ -21,8 +21,6 @@ import (
"syscall"
"time"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/app/storage"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/internal/server"
"github.com/sablierapp/sablier/version"
@@ -47,30 +45,20 @@ func Start(ctx context.Context, conf config.Config) error {
return err
}
sessionsManager := sessions.NewSessionsManager(logger, store, provider)
if conf.Storage.File != "" {
storage, err := storage.NewFileStorage(conf.Storage, logger)
if err != nil {
return err
}
defer saveSessions(storage, sessionsManager, logger)
loadSessions(storage, sessionsManager, logger)
}
s := sablier.New(logger, store, provider)
groups, err := provider.InstanceGroups(ctx)
if err != nil {
logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err))
} else {
sessionsManager.SetGroups(groups)
s.SetGroups(groups)
}
updateGroups := make(chan map[string][]string)
go WatchGroups(ctx, provider, 2*time.Second, updateGroups, logger)
go func() {
for groups := range updateGroups {
sessionsManager.SetGroups(groups)
s.SetGroups(groups)
}
}()
@@ -78,7 +66,7 @@ func Start(ctx context.Context, conf config.Config) error {
go provider.NotifyInstanceStopped(ctx, instanceStopped)
go func() {
for stopped := range instanceStopped {
err := sessionsManager.RemoveInstance(stopped)
err := s.RemoveInstance(ctx, stopped)
if err != nil {
logger.Warn("could not remove instance", slog.Any("error", err))
}
@@ -111,7 +99,7 @@ func Start(ctx context.Context, conf config.Config) error {
strategy := &routes.ServeStrategy{
Theme: t,
SessionsManager: sessionsManager,
SessionsManager: s,
StrategyConfig: conf.Strategy,
SessionsConfig: conf.Sessions,
}
@@ -132,7 +120,7 @@ func Start(ctx context.Context, conf config.Config) error {
return nil
}
func onSessionExpires(ctx context.Context, provider provider.Provider, logger *slog.Logger) func(key string) {
func onSessionExpires(ctx context.Context, provider sablier.Provider, logger *slog.Logger) func(key string) {
return func(_key string) {
go func(key string) {
logger.InfoContext(ctx, "instance expired", slog.String("instance", key))
@@ -144,32 +132,7 @@ func onSessionExpires(ctx context.Context, provider provider.Provider, logger *s
}
}
func loadSessions(storage storage.Storage, sessions sessions.Manager, logger *slog.Logger) {
logger.Info("loading sessions from storage")
reader, err := storage.Reader()
if err != nil {
logger.Error("error loading sessions from storage", slog.Any("reason", err))
}
err = sessions.LoadSessions(reader)
if err != nil {
logger.Error("error loading sessions into Sablier", slog.Any("reason", err))
}
}
func saveSessions(storage storage.Storage, sessions sessions.Manager, logger *slog.Logger) {
logger.Info("writing sessions to storage")
writer, err := storage.Writer()
if err != nil {
logger.Error("error saving sessions to storage", slog.Any("reason", err))
return
}
err = sessions.SaveSessions(writer)
if err != nil {
logger.Error("error saving sessions from Sablier", slog.Any("reason", err))
}
}
func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (provider.Provider, error) {
func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (sablier.Provider, error) {
if err := config.IsValid(); err != nil {
return nil, err
}
@@ -204,7 +167,7 @@ func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provide
return nil, fmt.Errorf("unimplemented provider %s", config.Name)
}
func WatchGroups(ctx context.Context, provider provider.Provider, frequency time.Duration, send chan<- map[string][]string, logger *slog.Logger) {
func WatchGroups(ctx context.Context, provider sablier.Provider, frequency time.Duration, send chan<- map[string][]string, logger *slog.Logger) {
ticker := time.NewTicker(frequency)
for {
select {

View File

@@ -1,301 +0,0 @@
package sessions
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/store"
"io"
"log/slog"
"maps"
"slices"
"sync"
"time"
"github.com/sablierapp/sablier/app/instance"
)
//go:generate go tool mockgen -package sessionstest -source=sessions_manager.go -destination=sessionstest/mocks_sessions_manager.go *
type Manager interface {
RequestSession(ctx context.Context, names []string, duration time.Duration) (*SessionState, error)
RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*SessionState, error)
RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (*SessionState, error)
LoadSessions(io.ReadCloser) error
SaveSessions(io.WriteCloser) error
RemoveInstance(name string) error
SetGroups(groups map[string][]string)
}
type SessionsManager struct {
store store.Store
provider provider.Provider
groups map[string][]string
l *slog.Logger
}
func NewSessionsManager(logger *slog.Logger, store store.Store, provider provider.Provider) Manager {
sm := &SessionsManager{
store: store,
provider: provider,
groups: map[string][]string{},
l: logger,
}
return sm
}
func (s *SessionsManager) SetGroups(groups map[string][]string) {
if groups == nil {
groups = map[string][]string{}
}
if diff := cmp.Diff(s.groups, groups); diff != "" {
// TODO: Change this log for a friendly logging, groups rarely change, so we can put some effort on displaying what changed
s.l.Info("set groups", slog.Any("old", s.groups), slog.Any("new", groups), slog.Any("diff", diff))
s.groups = groups
}
}
func (s *SessionsManager) RemoveInstance(name string) error {
return s.store.Delete(context.Background(), name)
}
func (s *SessionsManager) LoadSessions(reader io.ReadCloser) error {
unmarshaler, ok := s.store.(json.Unmarshaler)
defer reader.Close()
if ok {
return json.NewDecoder(reader).Decode(unmarshaler)
}
return nil
}
func (s *SessionsManager) SaveSessions(writer io.WriteCloser) error {
marshaler, ok := s.store.(json.Marshaler)
defer writer.Close()
if ok {
encoder := json.NewEncoder(writer)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")
return encoder.Encode(marshaler)
}
return nil
}
type InstanceState struct {
Instance instance.State `json:"instance"`
Error error `json:"error"`
}
type SessionState struct {
Instances map[string]InstanceState `json:"instances"`
}
func (s *SessionState) IsReady() bool {
if s.Instances == nil {
s.Instances = map[string]InstanceState{}
}
for _, v := range s.Instances {
if v.Error != nil || v.Instance.Status != instance.Ready {
return false
}
}
return true
}
func (s *SessionState) Status() string {
if s.IsReady() {
return "ready"
}
return "not-ready"
}
func (s *SessionsManager) RequestSession(ctx context.Context, names []string, duration time.Duration) (sessionState *SessionState, err error) {
if len(names) == 0 {
return nil, fmt.Errorf("names cannot be empty")
}
var wg sync.WaitGroup
mx := sync.Mutex{}
sessionState = &SessionState{
Instances: map[string]InstanceState{},
}
wg.Add(len(names))
for i := 0; i < len(names); i++ {
go func(name string) {
defer wg.Done()
state, err := s.requestInstance(ctx, name, duration)
mx.Lock()
defer mx.Unlock()
sessionState.Instances[name] = InstanceState{
Instance: state,
Error: err,
}
}(names[i])
}
wg.Wait()
return sessionState, nil
}
func (s *SessionsManager) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestSession(ctx, names, duration)
}
func (s *SessionsManager) requestInstance(ctx context.Context, name string, duration time.Duration) (instance.State, error) {
if name == "" {
return instance.State{}, errors.New("instance name cannot be empty")
}
state, err := s.store.Get(ctx, name)
if errors.Is(err, store.ErrKeyNotFound) {
s.l.DebugContext(ctx, "request to start instance received", slog.String("instance", name))
err := s.provider.InstanceStart(ctx, name)
if err != nil {
return instance.State{}, err
}
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return instance.State{}, err
}
s.l.DebugContext(ctx, "request to start instance status completed", slog.String("instance", name), slog.String("status", state.Status))
} else if err != nil {
s.l.ErrorContext(ctx, "request to start instance failed", slog.String("instance", name), slog.Any("error", err))
return instance.State{}, fmt.Errorf("cannot retrieve instance from store: %w", err)
} else if state.Status != instance.Ready {
s.l.DebugContext(ctx, "request to check instance status received", slog.String("instance", name), slog.String("current_status", state.Status))
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return instance.State{}, err
}
s.l.DebugContext(ctx, "request to check instance status completed", slog.String("instance", name), slog.String("new_status", state.Status))
}
s.l.DebugContext(ctx, "set expiration for instance", slog.String("instance", name), slog.Duration("expiration", duration))
// Refresh the duration
s.expiresAfter(ctx, state, duration)
return state, nil
}
func (s *SessionsManager) RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error) {
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
return nil, err
}
if session.IsReady() {
return session, nil
}
ticker := time.NewTicker(5 * time.Second)
readiness := make(chan *SessionState)
errch := make(chan error)
quit := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
errch <- err
return
}
if session.IsReady() {
readiness <- session
}
case <-quit:
ticker.Stop()
return
}
}
}()
select {
case <-ctx.Done():
s.l.DebugContext(ctx, "request cancelled", slog.Any("reason", ctx.Err()))
close(quit)
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled by user: %w", ctx.Err())
}
return nil, fmt.Errorf("request cancelled by user")
case status := <-readiness:
close(quit)
return status, nil
case err := <-errch:
close(quit)
return nil, err
case <-time.After(timeout):
close(quit)
return nil, fmt.Errorf("session was not ready after %s", timeout.String())
}
}
func (s *SessionsManager) RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestReadySession(ctx, names, duration, timeout)
}
func (s *SessionsManager) expiresAfter(ctx context.Context, instance instance.State, duration time.Duration) {
err := s.store.Put(ctx, instance, duration)
if err != nil {
s.l.Error("could not put instance to store, will not expire", slog.Any("error", err), slog.String("instance", instance.Name))
}
}
func (s *SessionState) MarshalJSON() ([]byte, error) {
instances := maps.Values(s.Instances)
return json.Marshal(map[string]any{
"instances": instances,
"status": s.Status(),
})
}

View File

@@ -1,158 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sessions_manager.go
//
// Generated by this command:
//
// mockgen -package sessionstest -source=sessions_manager.go -destination=sessionstest/mocks_sessions_manager.go *
//
// Package sessionstest is a generated GoMock package.
package sessionstest
import (
context "context"
io "io"
reflect "reflect"
time "time"
sessions "github.com/sablierapp/sablier/app/sessions"
gomock "go.uber.org/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
isgomock struct{}
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// LoadSessions mocks base method.
func (m *MockManager) LoadSessions(arg0 io.ReadCloser) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadSessions", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LoadSessions indicates an expected call of LoadSessions.
func (mr *MockManagerMockRecorder) LoadSessions(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadSessions", reflect.TypeOf((*MockManager)(nil).LoadSessions), arg0)
}
// RemoveInstance mocks base method.
func (m *MockManager) RemoveInstance(name string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveInstance", name)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveInstance indicates an expected call of RemoveInstance.
func (mr *MockManagerMockRecorder) RemoveInstance(name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveInstance", reflect.TypeOf((*MockManager)(nil).RemoveInstance), name)
}
// RequestReadySession mocks base method.
func (m *MockManager) RequestReadySession(ctx context.Context, names []string, duration, timeout time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySession", ctx, names, duration, timeout)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySession indicates an expected call of RequestReadySession.
func (mr *MockManagerMockRecorder) RequestReadySession(ctx, names, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySession", reflect.TypeOf((*MockManager)(nil).RequestReadySession), ctx, names, duration, timeout)
}
// RequestReadySessionGroup mocks base method.
func (m *MockManager) RequestReadySessionGroup(ctx context.Context, group string, duration, timeout time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySessionGroup", ctx, group, duration, timeout)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySessionGroup indicates an expected call of RequestReadySessionGroup.
func (mr *MockManagerMockRecorder) RequestReadySessionGroup(ctx, group, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySessionGroup", reflect.TypeOf((*MockManager)(nil).RequestReadySessionGroup), ctx, group, duration, timeout)
}
// RequestSession mocks base method.
func (m *MockManager) RequestSession(ctx context.Context, names []string, duration time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSession", ctx, names, duration)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSession indicates an expected call of RequestSession.
func (mr *MockManagerMockRecorder) RequestSession(ctx, names, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSession", reflect.TypeOf((*MockManager)(nil).RequestSession), ctx, names, duration)
}
// RequestSessionGroup mocks base method.
func (m *MockManager) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*sessions.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSessionGroup", ctx, group, duration)
ret0, _ := ret[0].(*sessions.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSessionGroup indicates an expected call of RequestSessionGroup.
func (mr *MockManagerMockRecorder) RequestSessionGroup(ctx, group, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSessionGroup", reflect.TypeOf((*MockManager)(nil).RequestSessionGroup), ctx, group, duration)
}
// SaveSessions mocks base method.
func (m *MockManager) SaveSessions(arg0 io.WriteCloser) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveSessions", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SaveSessions indicates an expected call of SaveSessions.
func (mr *MockManagerMockRecorder) SaveSessions(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveSessions", reflect.TypeOf((*MockManager)(nil).SaveSessions), arg0)
}
// SetGroups mocks base method.
func (m *MockManager) SetGroups(groups map[string][]string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetGroups", groups)
}
// SetGroups indicates an expected call of SetGroups.
func (mr *MockManagerMockRecorder) SetGroups(groups any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockManager)(nil).SetGroups), groups)
}

View File

@@ -1,6 +0,0 @@
package types
type Instance struct {
Name string
Group string
}

View File

@@ -2,14 +2,14 @@ package api
import (
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/pkg/sablier"
)
const SablierStatusHeader = "X-Sablier-Session-Status"
const SablierStatusReady = "ready"
const SablierStatusNotReady = "not-ready"
func AddSablierHeader(c *gin.Context, session *sessions.SessionState) {
func AddSablierHeader(c *gin.Context, session *sablier.SessionState) {
if session.IsReady() {
c.Header(SablierStatusHeader, SablierStatusReady)
} else {

View File

@@ -4,8 +4,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/sessions/sessionstest"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier/sabliertest"
"github.com/sablierapp/sablier/pkg/theme"
"go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
@@ -14,7 +14,7 @@ import (
"testing"
)
func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *routes.ServeStrategy, mock *sessionstest.MockManager) {
func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *routes.ServeStrategy, mock *sabliertest.MockSablier) {
t.Helper()
gin.SetMode(gin.TestMode)
ctrl := gomock.NewController(t)
@@ -23,7 +23,7 @@ func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strateg
app = gin.New()
router = app.Group("/api")
mock = sessionstest.NewMockManager(ctrl)
mock = sabliertest.NewMockSablier(ctrl)
strategy = &routes.ServeStrategy{
Theme: th,
SessionsManager: mock,

View File

@@ -1,7 +1,7 @@
package api
import (
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme"
"github.com/tniswong/go.rfcx/rfc7807"
"net/http"
@@ -25,7 +25,7 @@ func ProblemValidation(e error) rfc7807.Problem {
}
}
func ProblemGroupNotFound(e sessions.ErrGroupNotFound) rfc7807.Problem {
func ProblemGroupNotFound(e sablier.ErrGroupNotFound) rfc7807.Problem {
pb := rfc7807.Problem{
Type: "https://sablierapp.dev/#/errors?id=group-not-found",
Title: "Group not found",

View File

@@ -5,7 +5,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/http/routes/models"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/pkg/sablier"
"net/http"
)
@@ -31,13 +31,13 @@ func StartBlocking(router *gin.RouterGroup, s *routes.ServeStrategy) {
return
}
var sessionState *sessions.SessionState
var sessionState *sablier.SessionState
var err error
if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestReadySession(c.Request.Context(), request.Names, request.SessionDuration, request.Timeout)
} else {
sessionState, err = s.SessionsManager.RequestReadySessionGroup(c.Request.Context(), request.Group, request.SessionDuration, request.Timeout)
var groupNotFoundError sessions.ErrGroupNotFound
var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))
return

View File

@@ -2,7 +2,7 @@ package api
import (
"errors"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/tniswong/go.rfcx/rfc7807"
"go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
@@ -35,7 +35,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingByNames", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy)
m.EXPECT().RequestReadySession(gomock.Any(), []string{"test"}, gomock.Any(), gomock.Any()).Return(&sessions.SessionState{}, nil)
m.EXPECT().RequestReadySession(gomock.Any(), []string{"test"}, gomock.Any(), gomock.Any()).Return(&sablier.SessionState{}, nil)
r := PerformRequest(app, "GET", "/api/strategies/blocking?names=test")
assert.Equal(t, http.StatusOK, r.Code)
assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader))
@@ -43,7 +43,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingByGroup", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy)
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(&sessions.SessionState{}, nil)
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(&sablier.SessionState{}, nil)
r := PerformRequest(app, "GET", "/api/strategies/blocking?group=test")
assert.Equal(t, http.StatusOK, r.Code)
assert.Equal(t, SablierStatusReady, r.Header().Get(SablierStatusHeader))
@@ -51,7 +51,7 @@ func TestStartBlocking(t *testing.T) {
t.Run("StartBlockingErrGroupNotFound", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t)
StartBlocking(router, strategy)
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(nil, sessions.ErrGroupNotFound{
m.EXPECT().RequestReadySessionGroup(gomock.Any(), "test", gomock.Any(), gomock.Any()).Return(nil, sablier.ErrGroupNotFound{
Group: "test",
AvailableGroups: []string{"test1", "test2"},
})

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"errors"
"github.com/sablierapp/sablier/pkg/sablier"
"sort"
"strconv"
"strings"
@@ -11,8 +12,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/http/routes/models"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/sessions"
theme2 "github.com/sablierapp/sablier/pkg/theme"
)
@@ -40,13 +39,13 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
return
}
var sessionState *sessions.SessionState
var sessionState *sablier.SessionState
var err error
if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestSession(c, request.Names, request.SessionDuration)
} else {
sessionState, err = s.SessionsManager.RequestSessionGroup(c, request.Group, request.SessionDuration)
var groupNotFoundError sessions.ErrGroupNotFound
var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))
return
@@ -90,7 +89,7 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
})
}
func sessionStateToRenderOptionsInstanceState(sessionState *sessions.SessionState) (instances []theme2.Instance) {
func sessionStateToRenderOptionsInstanceState(sessionState *sablier.SessionState) (instances []theme2.Instance) {
if sessionState == nil {
return
}
@@ -106,7 +105,7 @@ func sessionStateToRenderOptionsInstanceState(sessionState *sessions.SessionStat
return
}
func instanceStateToRenderOptionsRequestState(instanceState instance.State) theme2.Instance {
func instanceStateToRenderOptionsRequestState(instanceState sablier.InstanceInfo) theme2.Instance {
var err error
if instanceState.Message == "" {
@@ -117,7 +116,7 @@ func instanceStateToRenderOptionsRequestState(instanceState instance.State) them
return theme2.Instance{
Name: instanceState.Name,
Status: instanceState.Status,
Status: string(instanceState.Status),
CurrentReplicas: instanceState.CurrentReplicas,
DesiredReplicas: instanceState.DesiredReplicas,
Error: err,

View File

@@ -2,8 +2,7 @@ package api
import (
"errors"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/tniswong/go.rfcx/rfc7807"
"go.uber.org/mock/gomock"
"gotest.tools/v3/assert"
@@ -11,11 +10,11 @@ import (
"testing"
)
func session() *sessions.SessionState {
state := instance.ReadyInstanceState("test", 1)
state2 := instance.ReadyInstanceState("test2", 1)
return &sessions.SessionState{
Instances: map[string]sessions.InstanceState{
func session() *sablier.SessionState {
state := sablier.ReadyInstanceState("test", 1)
state2 := sablier.ReadyInstanceState("test2", 1)
return &sablier.SessionState{
Instances: map[string]sablier.InstanceInfoWithError{
"test": {
Instance: state,
Error: nil,
@@ -77,7 +76,7 @@ func TestStartDynamic(t *testing.T) {
t.Run("StartDynamicErrGroupNotFound", func(t *testing.T) {
app, router, strategy, m := NewApiTest(t)
StartDynamic(router, strategy)
m.EXPECT().RequestSessionGroup(gomock.Any(), "test", gomock.Any()).Return(nil, sessions.ErrGroupNotFound{
m.EXPECT().RequestSessionGroup(gomock.Any(), "test", gomock.Any()).Return(nil, sablier.ErrGroupNotFound{
Group: "test",
AvailableGroups: []string{"test1", "test2"},
})

View File

@@ -3,41 +3,41 @@ package docker
import (
"context"
"fmt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
"log/slog"
)
func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) {
func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
spec, err := p.Client.ContainerInspect(ctx, name)
if err != nil {
return instance.State{}, fmt.Errorf("cannot inspect container: %w", err)
return sablier.InstanceInfo{}, fmt.Errorf("cannot inspect container: %w", err)
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch spec.State.Status {
case "created", "paused", "restarting", "removing":
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
case "running":
if spec.State.Health != nil {
// // "starting", "healthy" or "unhealthy"
if spec.State.Health.Status == "healthy" {
return instance.ReadyInstanceState(name, p.desiredReplicas), nil
return sablier.ReadyInstanceState(name, p.desiredReplicas), nil
} else if spec.State.Health.Status == "unhealthy" {
return instance.UnrecoverableInstanceState(name, "container is unhealthy", p.desiredReplicas), nil
return sablier.UnrecoverableInstanceState(name, "container is unhealthy", p.desiredReplicas), nil
} else {
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
}
}
p.l.WarnContext(ctx, "container running without healthcheck, you should define a healthcheck on your container so that Sablier properly detects when the container is ready to handle requests.", slog.String("container", name))
return instance.ReadyInstanceState(name, p.desiredReplicas), nil
return sablier.ReadyInstanceState(name, p.desiredReplicas), nil
case "exited":
if spec.State.ExitCode != 0 {
return instance.UnrecoverableInstanceState(name, fmt.Sprintf("container exited with code \"%d\"", spec.State.ExitCode), p.desiredReplicas), nil
return sablier.UnrecoverableInstanceState(name, fmt.Sprintf("container exited with code \"%d\"", spec.State.ExitCode), p.desiredReplicas), nil
}
return instance.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
return sablier.NotReadyInstanceState(name, 0, p.desiredReplicas), nil
case "dead":
return instance.UnrecoverableInstanceState(name, "container in \"dead\" state cannot be restarted", p.desiredReplicas), nil
return sablier.UnrecoverableInstanceState(name, "container in \"dead\" state cannot be restarted", p.desiredReplicas), nil
default:
return instance.UnrecoverableInstanceState(name, fmt.Sprintf("container status \"%s\" not handled", spec.State.Status), p.desiredReplicas), nil
return sablier.UnrecoverableInstanceState(name, fmt.Sprintf("container status \"%s\" not handled", spec.State.Status), p.desiredReplicas), nil
}
}

View File

@@ -5,8 +5,8 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"testing"
"time"
@@ -24,7 +24,7 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -38,10 +38,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return resp.ID, err
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -60,10 +60,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{})
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -90,10 +90,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, dind.client.ContainerStart(ctx, c.ID, container.StartOptions{})
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -126,10 +126,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.Unrecoverable,
Status: sablier.InstanceStatusUnrecoverable,
Message: "container is unhealthy",
},
wantErr: nil,
@@ -163,10 +163,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -192,10 +192,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -221,10 +221,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -250,10 +250,10 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
return c.ID, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.Unrecoverable,
Status: sablier.InstanceStatusUnrecoverable,
Message: "container exited with code \"137\"",
},
wantErr: nil,

View File

@@ -7,12 +7,12 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"strings"
)
func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) {
func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
@@ -24,7 +24,7 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid
return nil, err
}
instances := make([]types.Instance, 0, len(containers))
instances := make([]sablier.InstanceConfiguration, 0, len(containers))
for _, c := range containers {
instance := containerToInstance(c)
instances = append(instances, instance)
@@ -33,7 +33,7 @@ func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provid
return instances, nil
}
func containerToInstance(c dockertypes.Container) types.Instance {
func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration {
var group string
if _, ok := c.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func containerToInstance(c dockertypes.Container) types.Instance {
}
}
return types.Instance{
return sablier.InstanceConfiguration{
Name: strings.TrimPrefix(c.Names[0], "/"), // Containers name are reported with a leading slash
Group: group,
}

View File

@@ -2,9 +2,9 @@ package docker_test
import (
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"sort"
"strings"
@@ -49,7 +49,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
})
assert.NilError(t, err)
want := []types.Instance{
want := []sablier.InstanceConfiguration{
{
Name: strings.TrimPrefix(i1.Name, "/"),
Group: "default",

View File

@@ -4,12 +4,12 @@ import (
"context"
"fmt"
"github.com/docker/docker/client"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"log/slog"
)
// Interface guard
var _ provider.Provider = (*DockerClassicProvider)(nil)
var _ sablier.Provider = (*DockerClassicProvider)(nil)
type DockerClassicProvider struct {
Client client.APIClient

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"log/slog"
"strings"
@@ -14,7 +14,7 @@ import (
)
// Interface guard
var _ provider.Provider = (*DockerSwarmProvider)(nil)
var _ sablier.Provider = (*DockerSwarmProvider)(nil)
type DockerSwarmProvider struct {
Client client.APIClient

View File

@@ -7,26 +7,26 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
)
func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) {
func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
service, err := p.getServiceByName(name, ctx)
if err != nil {
return instance.State{}, err
return sablier.InstanceInfo{}, err
}
foundName := p.getInstanceName(name, *service)
if service.Spec.Mode.Replicated == nil {
return instance.State{}, errors.New("swarm service is not in \"replicated\" mode")
return sablier.InstanceInfo{}, errors.New("swarm service is not in \"replicated\" mode")
}
if service.ServiceStatus.DesiredTasks != service.ServiceStatus.RunningTasks || service.ServiceStatus.DesiredTasks == 0 {
return instance.NotReadyInstanceState(foundName, 0, p.desiredReplicas), nil
return sablier.NotReadyInstanceState(foundName, 0, p.desiredReplicas), nil
}
return instance.ReadyInstanceState(foundName, p.desiredReplicas), nil
return sablier.ReadyInstanceState(foundName, p.desiredReplicas), nil
}
func (p *DockerSwarmProvider) getServiceByName(name string, ctx context.Context) (*swarm.Service, error) {

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"testing"
"time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -50,10 +50,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, err
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -84,10 +84,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -115,10 +115,10 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
return service.Spec.Name, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},

View File

@@ -7,11 +7,11 @@ import (
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
)
func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]types.Instance, error) {
func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", discovery.LabelEnable))
args.Add("mode", "replicated")
@@ -24,7 +24,7 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
return nil, err
}
instances := make([]types.Instance, 0, len(services))
instances := make([]sablier.InstanceConfiguration, 0, len(services))
for _, s := range services {
instance := p.serviceToInstance(s)
instances = append(instances, instance)
@@ -33,7 +33,7 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
return instances, nil
}
func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i types.Instance) {
func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) {
var group string
if _, ok := s.Spec.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i types.Instan
}
}
return types.Instance{
return sablier.InstanceConfiguration{
Name: s.Spec.Name,
Group: group,
}

View File

@@ -2,9 +2,9 @@ package dockerswarm_test
import (
dockertypes "github.com/docker/docker/api/types"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"gotest.tools/v3/assert"
@@ -49,7 +49,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
})
assert.NilError(t, err)
want := []types.Instance{
want := []sablier.InstanceConfiguration{
{
Name: i1.Spec.Name,
Group: "default",

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"testing"
"time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -48,10 +48,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, err
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -82,10 +82,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -114,10 +114,10 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
return service.Spec.Name, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},

View File

@@ -6,8 +6,8 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"testing"
"time"
@@ -25,7 +25,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -48,10 +48,10 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
return service.Spec.Name, err
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -82,10 +82,10 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
return service.Spec.Name, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},

View File

@@ -3,20 +3,20 @@ package kubernetes
import (
"context"
"fmt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (instance.State, error) {
func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil {
return instance.State{}, fmt.Errorf("error getting deployment: %w", err)
}
// TODO: Should add option to set ready as soon as one replica is ready
if *d.Spec.Replicas != 0 && *d.Spec.Replicas == d.Status.ReadyReplicas {
return instance.ReadyInstanceState(config.Original, config.Replicas), nil
return sablier.InstanceInfo{}, fmt.Errorf("error getting deployment: %w", err)
}
return instance.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil
// TODO: Should add option to set ready as soon as one replica is ready
if *d.Spec.Replicas != 0 && *d.Spec.Replicas == d.Status.ReadyReplicas {
return sablier.ReadyInstanceState(config.Original, config.Replicas), nil
}
return sablier.NotReadyInstanceState(config.Original, d.Status.ReadyReplicas, config.Replicas), nil
}

View File

@@ -5,9 +5,9 @@ import (
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
autoscalingv1 "k8s.io/api/autoscaling/v1"
corev1 "k8s.io/api/core/v1"
@@ -27,7 +27,7 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -49,10 +49,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -71,10 +71,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -106,10 +106,10 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
return kubernetes.DeploymentName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},

View File

@@ -3,13 +3,13 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instance, error) {
func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
@@ -22,7 +22,7 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instan
return nil, err
}
instances := make([]types.Instance, 0, len(deployments.Items))
instances := make([]sablier.InstanceConfiguration, 0, len(deployments.Items))
for _, d := range deployments.Items {
instance := p.deploymentToInstance(&d)
instances = append(instances, instance)
@@ -31,7 +31,7 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]types.Instan
return instances, nil
}
func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) types.Instance {
func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration {
var group string
if _, ok := d.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) types.Instan
parsed := DeploymentName(d, ParseOptions{Delimiter: p.delimiter})
return types.Instance{
return sablier.InstanceConfiguration{
Name: parsed.Original,
Group: group,
}

View File

@@ -3,13 +3,13 @@ package kubernetes
import (
"context"
"fmt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
)
func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) {
func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter})
if err != nil {
return instance.State{}, err
return sablier.InstanceInfo{}, err
}
switch parsed.Kind {
@@ -18,6 +18,6 @@ func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (
case "statefulset":
return p.StatefulSetInspect(ctx, parsed)
default:
return instance.State{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind)
return sablier.InstanceInfo{}, fmt.Errorf("unsupported kind \"%s\" must be one of \"deployment\", \"statefulset\"", parsed.Kind)
}
}

View File

@@ -11,6 +11,10 @@ import (
)
func TestKubernetesProvider_InstanceInspect(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ctx := context.Background()
type args struct {
name string

View File

@@ -2,11 +2,11 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
)
func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) {
func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
deployments, err := p.DeploymentList(ctx)
if err != nil {
return nil, err

View File

@@ -2,10 +2,10 @@ package kubernetes_test
import (
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
"sort"
"strings"
@@ -57,7 +57,7 @@ func TestKubernetesProvider_InstanceList(t *testing.T) {
})
assert.NilError(t, err)
want := []types.Instance{
want := []sablier.InstanceConfiguration{
{
Name: kubernetes.DeploymentName(d1, kubernetes.ParseOptions{Delimiter: "_"}).Original,
Group: "default",

View File

@@ -2,7 +2,7 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/sablier"
"log/slog"
providerConfig "github.com/sablierapp/sablier/config"
@@ -10,7 +10,7 @@ import (
)
// Interface guard
var _ provider.Provider = (*KubernetesProvider)(nil)
var _ sablier.Provider = (*KubernetesProvider)(nil)
type KubernetesProvider struct {
Client kubernetes.Interface

View File

@@ -2,19 +2,19 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (instance.State, error) {
func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil {
return instance.State{}, err
return sablier.InstanceInfo{}, err
}
if *ss.Spec.Replicas != 0 && *ss.Spec.Replicas == ss.Status.ReadyReplicas {
return instance.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil
return sablier.ReadyInstanceState(config.Original, ss.Status.ReadyReplicas), nil
}
return instance.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, config.Replicas), nil
return sablier.NotReadyInstanceState(config.Original, ss.Status.ReadyReplicas, config.Replicas), nil
}

View File

@@ -5,9 +5,9 @@ import (
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert"
autoscalingv1 "k8s.io/api/autoscaling/v1"
corev1 "k8s.io/api/core/v1"
@@ -27,7 +27,7 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
tests := []struct {
name string
args args
want instance.State
want sablier.InstanceInfo
wantErr error
}{
{
@@ -49,10 +49,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 1,
DesiredReplicas: 1,
Status: instance.Ready,
Status: sablier.InstanceStatusReady,
},
wantErr: nil,
},
@@ -71,10 +71,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},
@@ -106,10 +106,10 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
return kubernetes.StatefulSetName(d, kubernetes.ParseOptions{Delimiter: "_"}).Original, nil
},
},
want: instance.State{
want: sablier.InstanceInfo{
CurrentReplicas: 0,
DesiredReplicas: 1,
Status: instance.NotReady,
Status: sablier.InstanceStatusNotReady,
},
wantErr: nil,
},

View File

@@ -3,13 +3,13 @@ package kubernetes
import (
"context"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/sablier"
v1 "k8s.io/api/apps/v1"
core_v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Instance, error) {
func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{
discovery.LabelEnable: "true",
@@ -22,7 +22,7 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Insta
return nil, err
}
instances := make([]types.Instance, 0, len(statefulSets.Items))
instances := make([]sablier.InstanceConfiguration, 0, len(statefulSets.Items))
for _, ss := range statefulSets.Items {
instance := p.statefulSetToInstance(&ss)
instances = append(instances, instance)
@@ -31,7 +31,7 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]types.Insta
return instances, nil
}
func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) types.Instance {
func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration {
var group string
if _, ok := ss.Labels[discovery.LabelEnable]; ok {
@@ -44,7 +44,7 @@ func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) types.Ins
parsed := StatefulSetName(ss, ParseOptions{Delimiter: p.delimiter})
return types.Instance{
return sablier.InstanceConfiguration{
Name: parsed.Original,
Group: group,
}

View File

@@ -3,7 +3,7 @@
//
// Generated by this command:
//
// mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go *
// mockgen -package providertest -source=provider.go -destination=../provider/providertest/mock_provider.go *
//
// Package providertest is a generated GoMock package.
@@ -13,9 +13,8 @@ import (
context "context"
reflect "reflect"
instance "github.com/sablierapp/sablier/app/instance"
types "github.com/sablierapp/sablier/app/types"
provider "github.com/sablierapp/sablier/pkg/provider"
sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock"
)
@@ -59,10 +58,10 @@ func (mr *MockProviderMockRecorder) InstanceGroups(ctx any) *gomock.Call {
}
// InstanceInspect mocks base method.
func (m *MockProvider) InstanceInspect(ctx context.Context, name string) (instance.State, error) {
func (m *MockProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceInspect", ctx, name)
ret0, _ := ret[0].(instance.State)
ret0, _ := ret[0].(sablier.InstanceInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -74,10 +73,10 @@ func (mr *MockProviderMockRecorder) InstanceInspect(ctx, name any) *gomock.Call
}
// InstanceList mocks base method.
func (m *MockProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]types.Instance, error) {
func (m *MockProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceList", ctx, options)
ret0, _ := ret[0].([]types.Instance)
ret0, _ := ret[0].([]sablier.InstanceConfiguration)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View File

@@ -1,4 +1,4 @@
package sessions
package sablier
import (
"fmt"

54
pkg/sablier/instance.go Normal file
View File

@@ -0,0 +1,54 @@
package sablier
type InstanceStatus string
const (
InstanceStatusReady = "ready"
InstanceStatusNotReady = "not-ready"
InstanceStatusUnrecoverable = "unrecoverable"
)
type InstanceInfo struct {
Name string `json:"name"`
CurrentReplicas int32 `json:"currentReplicas"`
DesiredReplicas int32 `json:"desiredReplicas"`
Status InstanceStatus `json:"status"`
Message string `json:"message,omitempty"`
}
type InstanceConfiguration struct {
Name string
Group string
}
func (instance InstanceInfo) IsReady() bool {
return instance.Status == InstanceStatusReady
}
func UnrecoverableInstanceState(name string, message string, desiredReplicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: 0,
DesiredReplicas: desiredReplicas,
Status: InstanceStatusUnrecoverable,
Message: message,
}
}
func ReadyInstanceState(name string, replicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: replicas,
DesiredReplicas: replicas,
Status: InstanceStatusReady,
}
}
func NotReadyInstanceState(name string, currentReplicas int32, desiredReplicas int32) InstanceInfo {
return InstanceInfo{
Name: name,
CurrentReplicas: currentReplicas,
DesiredReplicas: desiredReplicas,
Status: InstanceStatusNotReady,
}
}

View File

@@ -0,0 +1,51 @@
package sablier
import (
"context"
"errors"
"fmt"
"github.com/sablierapp/sablier/pkg/store"
"log/slog"
"time"
)
func (s *sablier) InstanceRequest(ctx context.Context, name string, duration time.Duration) (InstanceInfo, error) {
if name == "" {
return InstanceInfo{}, errors.New("instance name cannot be empty")
}
state, err := s.sessions.Get(ctx, name)
if errors.Is(err, store.ErrKeyNotFound) {
s.l.DebugContext(ctx, "request to start instance received", slog.String("instance", name))
err = s.provider.InstanceStart(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
s.l.DebugContext(ctx, "request to start instance status completed", slog.String("instance", name), slog.String("status", string(state.Status)))
} else if err != nil {
s.l.ErrorContext(ctx, "request to start instance failed", slog.String("instance", name), slog.Any("error", err))
return InstanceInfo{}, fmt.Errorf("cannot retrieve instance from store: %w", err)
} else if state.Status != InstanceStatusReady {
s.l.DebugContext(ctx, "request to check instance status received", slog.String("instance", name), slog.String("current_status", string(state.Status)))
state, err = s.provider.InstanceInspect(ctx, name)
if err != nil {
return InstanceInfo{}, err
}
s.l.DebugContext(ctx, "request to check instance status completed", slog.String("instance", name), slog.String("new_status", string(state.Status)))
}
s.l.DebugContext(ctx, "set expiration for instance", slog.String("instance", name), slog.Duration("expiration", duration))
err = s.sessions.Put(ctx, state, duration)
if err != nil {
s.l.Error("could not put instance to store, will not expire", slog.Any("error", err), slog.String("instance", state.Name))
return InstanceInfo{}, fmt.Errorf("could not put instance to store: %w", err)
}
return state, nil
}

View File

@@ -1,20 +1,18 @@
package provider
package sablier
import (
"context"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/provider"
)
//go:generate go tool mockgen -package providertest -source=provider.go -destination=providertest/mock_provider.go *
//go:generate go tool mockgen -package providertest -source=provider.go -destination=../provider/providertest/mock_provider.go *
type Provider interface {
InstanceStart(ctx context.Context, name string) error
InstanceStop(ctx context.Context, name string) error
InstanceInspect(ctx context.Context, name string) (instance.State, error)
InstanceInspect(ctx context.Context, name string) (InstanceInfo, error)
InstanceGroups(ctx context.Context) (map[string][]string, error)
InstanceList(ctx context.Context, options InstanceListOptions) ([]types.Instance, error)
InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]InstanceConfiguration, error)
NotifyInstanceStopped(ctx context.Context, instance chan<- string)
}

52
pkg/sablier/sablier.go Normal file
View File

@@ -0,0 +1,52 @@
package sablier
import (
"context"
"github.com/google/go-cmp/cmp"
"log/slog"
"time"
)
//go:generate go tool mockgen -package sabliertest -source=sablier.go -destination=sabliertest/mocks_sablier.go *
type Sablier interface {
RequestSession(ctx context.Context, names []string, duration time.Duration) (*SessionState, error)
RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*SessionState, error)
RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (*SessionState, error)
RemoveInstance(ctx context.Context, name string) error
SetGroups(groups map[string][]string)
}
type sablier struct {
provider Provider
sessions Store
groups map[string][]string
l *slog.Logger
}
func New(logger *slog.Logger, store Store, provider Provider) Sablier {
return &sablier{
provider: provider,
sessions: store,
groups: map[string][]string{},
l: logger,
}
}
func (s *sablier) SetGroups(groups map[string][]string) {
if groups == nil {
groups = map[string][]string{}
}
if diff := cmp.Diff(s.groups, groups); diff != "" {
// TODO: Change this log for a friendly logging, groups rarely change, so we can put some effort on displaying what changed
s.l.Info("set groups", slog.Any("old", s.groups), slog.Any("new", groups), slog.Any("diff", diff))
s.groups = groups
}
}
func (s *sablier) RemoveInstance(ctx context.Context, name string) error {
return s.sessions.Delete(ctx, name)
}

View File

@@ -0,0 +1,129 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sablier.go
//
// Generated by this command:
//
// mockgen -package sabliertest -source=sablier.go -destination=sabliertest/mocks_sablier.go *
//
// Package sabliertest is a generated GoMock package.
package sabliertest
import (
context "context"
reflect "reflect"
time "time"
sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock"
)
// MockSablier is a mock of Sablier interface.
type MockSablier struct {
ctrl *gomock.Controller
recorder *MockSablierMockRecorder
isgomock struct{}
}
// MockSablierMockRecorder is the mock recorder for MockSablier.
type MockSablierMockRecorder struct {
mock *MockSablier
}
// NewMockSablier creates a new mock instance.
func NewMockSablier(ctrl *gomock.Controller) *MockSablier {
mock := &MockSablier{ctrl: ctrl}
mock.recorder = &MockSablierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSablier) EXPECT() *MockSablierMockRecorder {
return m.recorder
}
// RemoveInstance mocks base method.
func (m *MockSablier) RemoveInstance(ctx context.Context, name string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveInstance", ctx, name)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveInstance indicates an expected call of RemoveInstance.
func (mr *MockSablierMockRecorder) RemoveInstance(ctx, name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveInstance", reflect.TypeOf((*MockSablier)(nil).RemoveInstance), ctx, name)
}
// RequestReadySession mocks base method.
func (m *MockSablier) RequestReadySession(ctx context.Context, names []string, duration, timeout time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySession", ctx, names, duration, timeout)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySession indicates an expected call of RequestReadySession.
func (mr *MockSablierMockRecorder) RequestReadySession(ctx, names, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySession", reflect.TypeOf((*MockSablier)(nil).RequestReadySession), ctx, names, duration, timeout)
}
// RequestReadySessionGroup mocks base method.
func (m *MockSablier) RequestReadySessionGroup(ctx context.Context, group string, duration, timeout time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestReadySessionGroup", ctx, group, duration, timeout)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestReadySessionGroup indicates an expected call of RequestReadySessionGroup.
func (mr *MockSablierMockRecorder) RequestReadySessionGroup(ctx, group, duration, timeout any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestReadySessionGroup", reflect.TypeOf((*MockSablier)(nil).RequestReadySessionGroup), ctx, group, duration, timeout)
}
// RequestSession mocks base method.
func (m *MockSablier) RequestSession(ctx context.Context, names []string, duration time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSession", ctx, names, duration)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSession indicates an expected call of RequestSession.
func (mr *MockSablierMockRecorder) RequestSession(ctx, names, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSession", reflect.TypeOf((*MockSablier)(nil).RequestSession), ctx, names, duration)
}
// RequestSessionGroup mocks base method.
func (m *MockSablier) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (*sablier.SessionState, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestSessionGroup", ctx, group, duration)
ret0, _ := ret[0].(*sablier.SessionState)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RequestSessionGroup indicates an expected call of RequestSessionGroup.
func (mr *MockSablierMockRecorder) RequestSessionGroup(ctx, group, duration any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestSessionGroup", reflect.TypeOf((*MockSablier)(nil).RequestSessionGroup), ctx, group, duration)
}
// SetGroups mocks base method.
func (m *MockSablier) SetGroups(groups map[string][]string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetGroups", groups)
}
// SetGroups indicates an expected call of SetGroups.
func (mr *MockSablierMockRecorder) SetGroups(groups any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroups", reflect.TypeOf((*MockSablier)(nil).SetGroups), groups)
}

41
pkg/sablier/session.go Normal file
View File

@@ -0,0 +1,41 @@
package sablier
import (
"encoding/json"
"maps"
)
type SessionState struct {
Instances map[string]InstanceInfoWithError `json:"instances"`
}
func (s *SessionState) IsReady() bool {
if s.Instances == nil {
s.Instances = map[string]InstanceInfoWithError{}
}
for _, v := range s.Instances {
if v.Error != nil || v.Instance.Status != InstanceStatusReady {
return false
}
}
return true
}
func (s *SessionState) Status() string {
if s.IsReady() {
return "ready"
}
return "not-ready"
}
func (s *SessionState) MarshalJSON() ([]byte, error) {
instances := maps.Values(s.Instances)
return json.Marshal(map[string]any{
"instances": instances,
"status": s.Status(),
})
}

View File

@@ -0,0 +1,143 @@
package sablier
import (
"context"
"fmt"
"log/slog"
"maps"
"slices"
"sync"
"time"
)
type InstanceInfoWithError struct {
Instance InstanceInfo `json:"instance"`
Error error `json:"error"`
}
func (s *sablier) RequestSession(ctx context.Context, names []string, duration time.Duration) (sessionState *SessionState, err error) {
if len(names) == 0 {
return nil, fmt.Errorf("names cannot be empty")
}
var wg sync.WaitGroup
mx := sync.Mutex{}
sessionState = &SessionState{
Instances: map[string]InstanceInfoWithError{},
}
wg.Add(len(names))
for i := 0; i < len(names); i++ {
go func(name string) {
defer wg.Done()
state, err := s.InstanceRequest(ctx, name, duration)
mx.Lock()
defer mx.Unlock()
sessionState.Instances[name] = InstanceInfoWithError{
Instance: state,
Error: err,
}
}(names[i])
}
wg.Wait()
return sessionState, nil
}
func (s *sablier) RequestSessionGroup(ctx context.Context, group string, duration time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestSession(ctx, names, duration)
}
func (s *sablier) RequestReadySession(ctx context.Context, names []string, duration time.Duration, timeout time.Duration) (*SessionState, error) {
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
return nil, err
}
if session.IsReady() {
return session, nil
}
ticker := time.NewTicker(5 * time.Second)
readiness := make(chan *SessionState)
errch := make(chan error)
quit := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
session, err := s.RequestSession(ctx, names, duration)
if err != nil {
errch <- err
return
}
if session.IsReady() {
readiness <- session
}
case <-quit:
ticker.Stop()
return
}
}
}()
select {
case <-ctx.Done():
s.l.DebugContext(ctx, "request cancelled", slog.Any("reason", ctx.Err()))
close(quit)
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled by user: %w", ctx.Err())
}
return nil, fmt.Errorf("request cancelled by user")
case status := <-readiness:
close(quit)
return status, nil
case err := <-errch:
close(quit)
return nil, err
case <-time.After(timeout):
close(quit)
return nil, fmt.Errorf("session was not ready after %s", timeout.String())
}
}
func (s *sablier) RequestReadySessionGroup(ctx context.Context, group string, duration time.Duration, timeout time.Duration) (sessionState *SessionState, err error) {
if len(group) == 0 {
return nil, fmt.Errorf("group is mandatory")
}
names, ok := s.groups[group]
if !ok {
return nil, ErrGroupNotFound{
Group: group,
AvailableGroups: slices.Collect(maps.Keys(s.groups)),
}
}
if len(names) == 0 {
return nil, fmt.Errorf("group has no member")
}
return s.RequestReadySession(ctx, names, duration, timeout)
}

View File

@@ -1,21 +1,21 @@
package sessions
package sablier_test
import (
"context"
"github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/pkg/provider/providertest"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/storetest"
"go.uber.org/mock/gomock"
"testing"
"time"
"github.com/sablierapp/sablier/app/instance"
"gotest.tools/v3/assert"
)
func TestSessionState_IsReady(t *testing.T) {
type fields struct {
Instances map[string]InstanceState
Instances map[string]sablier.InstanceInfoWithError
Error error
}
tests := []struct {
@@ -26,9 +26,9 @@ func TestSessionState_IsReady(t *testing.T) {
{
name: "all instances are ready",
fields: fields{
Instances: createMap([]instance.State{
{Name: "nginx", Status: instance.Ready},
{Name: "apache", Status: instance.Ready},
Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx", Status: sablier.InstanceStatusReady},
{Name: "apache", Status: sablier.InstanceStatusReady},
}),
},
want: true,
@@ -36,9 +36,9 @@ func TestSessionState_IsReady(t *testing.T) {
{
name: "one instance is not ready",
fields: fields{
Instances: createMap([]instance.State{
{Name: "nginx", Status: instance.Ready},
{Name: "apache", Status: instance.NotReady},
Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx", Status: sablier.InstanceStatusReady},
{Name: "apache", Status: sablier.InstanceStatusNotReady},
}),
},
want: false,
@@ -46,16 +46,16 @@ func TestSessionState_IsReady(t *testing.T) {
{
name: "no instances specified",
fields: fields{
Instances: createMap([]instance.State{}),
Instances: createMap([]sablier.InstanceInfo{}),
},
want: true,
},
{
name: "one instance has an error",
fields: fields{
Instances: createMap([]instance.State{
{Name: "nginx-error", Status: instance.Unrecoverable, Message: "connection timeout"},
{Name: "apache", Status: instance.Ready},
Instances: createMap([]sablier.InstanceInfo{
{Name: "nginx-error", Status: sablier.InstanceStatusUnrecoverable, Message: "connection timeout"},
{Name: "apache", Status: sablier.InstanceStatusReady},
}),
},
want: false,
@@ -63,7 +63,7 @@ func TestSessionState_IsReady(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SessionState{
s := &sablier.SessionState{
Instances: tt.fields.Instances,
}
if got := s.IsReady(); got != tt.want {
@@ -73,11 +73,11 @@ func TestSessionState_IsReady(t *testing.T) {
}
}
func createMap(instances []instance.State) map[string]InstanceState {
states := make(map[string]InstanceState)
func createMap(instances []sablier.InstanceInfo) map[string]sablier.InstanceInfoWithError {
states := make(map[string]sablier.InstanceInfoWithError)
for _, v := range instances {
states[v.Name] = InstanceState{
states[v.Name] = sablier.InstanceInfoWithError{
Instance: v,
Error: nil,
}
@@ -86,14 +86,14 @@ func createMap(instances []instance.State) map[string]InstanceState {
return states
}
func setupSessionManager(t *testing.T) (Manager, *storetest.MockStore, *providertest.MockProvider) {
func setupSessionManager(t *testing.T) (sablier.Sablier, *storetest.MockStore, *providertest.MockProvider) {
t.Helper()
ctrl := gomock.NewController(t)
p := providertest.NewMockProvider(ctrl)
s := storetest.NewMockStore(ctrl)
m := NewSessionsManager(slogt.New(t), s, p)
m := sablier.New(slogt.New(t), s, p)
return m, s, p
}
@@ -101,7 +101,7 @@ func TestSessionsManager(t *testing.T) {
t.Run("RemoveInstance", func(t *testing.T) {
manager, store, _ := setupSessionManager(t)
store.EXPECT().Delete(gomock.Any(), "test")
err := manager.RemoveInstance("test")
err := manager.RemoveInstance(t.Context(), "test")
assert.NilError(t, err)
})
}
@@ -110,10 +110,10 @@ func TestSessionsManager_RequestReadySessionCancelledByUser(t *testing.T) {
t.Run("request ready session is cancelled by user", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
manager, store, provider := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes()
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
provider.EXPECT().InstanceInspect(ctx, gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil)
provider.EXPECT().InstanceInspect(ctx, gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil)
errchan := make(chan error)
go func() {
@@ -132,10 +132,10 @@ func TestSessionsManager_RequestReadySessionCancelledByTimeout(t *testing.T) {
t.Run("request ready session is cancelled by timeout", func(t *testing.T) {
manager, store, provider := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil).AnyTimes()
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
provider.EXPECT().InstanceInspect(t.Context(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.NotReady}, nil)
provider.EXPECT().InstanceInspect(t.Context(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusNotReady}, nil)
errchan := make(chan error)
go func() {
@@ -151,7 +151,7 @@ func TestSessionsManager_RequestReadySession(t *testing.T) {
t.Run("request ready session is ready", func(t *testing.T) {
manager, store, _ := setupSessionManager(t)
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(instance.State{Name: "apache", Status: instance.Ready}, nil).AnyTimes()
store.EXPECT().Get(gomock.Any(), gomock.Any()).Return(sablier.InstanceInfo{Name: "apache", Status: sablier.InstanceStatusReady}, nil).AnyTimes()
store.EXPECT().Put(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
errchan := make(chan error)

15
pkg/sablier/store.go Normal file
View File

@@ -0,0 +1,15 @@
package sablier
import (
"context"
"time"
)
//go:generate go tool mockgen -package storetest -source=store.go -destination=../store/storetest/mocks_store.go *
type Store interface {
Get(context.Context, string) (InstanceInfo, error)
Put(context.Context, InstanceInfo, time.Duration) error
Delete(context.Context, string) error
OnExpire(context.Context, func(string)) error
}

View File

@@ -3,24 +3,24 @@ package inmemory
import (
"context"
"encoding/json"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"github.com/sablierapp/sablier/pkg/tinykv"
"time"
)
var _ store.Store = (*InMemory)(nil)
var _ sablier.Store = (*InMemory)(nil)
var _ json.Marshaler = (*InMemory)(nil)
var _ json.Unmarshaler = (*InMemory)(nil)
func NewInMemory() store.Store {
func NewInMemory() sablier.Store {
return &InMemory{
kv: tinykv.New[instance.State](1*time.Second, nil),
kv: tinykv.New[sablier.InstanceInfo](1*time.Second, nil),
}
}
type InMemory struct {
kv tinykv.KV[instance.State]
kv tinykv.KV[sablier.InstanceInfo]
}
func (i InMemory) UnmarshalJSON(bytes []byte) error {
@@ -31,15 +31,15 @@ func (i InMemory) MarshalJSON() ([]byte, error) {
return i.kv.MarshalJSON()
}
func (i InMemory) Get(_ context.Context, s string) (instance.State, error) {
func (i InMemory) Get(_ context.Context, s string) (sablier.InstanceInfo, error) {
val, ok := i.kv.Get(s)
if !ok {
return instance.State{}, store.ErrKeyNotFound
return sablier.InstanceInfo{}, store.ErrKeyNotFound
}
return val, nil
}
func (i InMemory) Put(_ context.Context, state instance.State, duration time.Duration) error {
func (i InMemory) Put(_ context.Context, state sablier.InstanceInfo, duration time.Duration) error {
return i.kv.Put(state.Name, state, duration)
}
@@ -49,7 +49,7 @@ func (i InMemory) Delete(_ context.Context, s string) error {
}
func (i InMemory) OnExpire(_ context.Context, f func(string)) error {
i.kv.SetOnExpire(func(k string, _ instance.State) {
i.kv.SetOnExpire(func(k string, _ sablier.InstanceInfo) {
f(k)
})
return nil

View File

@@ -2,7 +2,7 @@ package inmemory
import (
"context"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"gotest.tools/v3/assert"
"testing"
@@ -24,7 +24,7 @@ func TestInMemory(t *testing.T) {
ctx := context.Background()
vk := NewInMemory()
err := vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second)
err := vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 1*time.Second)
assert.NilError(t, err)
i, err := vk.Get(ctx, "test")
@@ -40,7 +40,7 @@ func TestInMemory(t *testing.T) {
ctx := context.Background()
vk := NewInMemory()
err := vk.Put(ctx, instance.State{Name: "test"}, 30*time.Second)
err := vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 30*time.Second)
assert.NilError(t, err)
i, err := vk.Get(ctx, "test")
@@ -66,7 +66,7 @@ func TestInMemory(t *testing.T) {
})
assert.NilError(t, err)
err = vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second)
err = vk.Put(ctx, sablier.InstanceInfo{Name: "test"}, 1*time.Second)
assert.NilError(t, err)
expired := <-expirations
assert.Equal(t, expired, "test")

View File

@@ -1,19 +1,7 @@
package store
import (
"context"
"errors"
"github.com/sablierapp/sablier/app/instance"
"time"
)
var ErrKeyNotFound = errors.New("key not found")
//go:generate go tool mockgen -package storetest -source=store.go -destination=storetest/mocks_store.go *
type Store interface {
Get(context.Context, string) (instance.State, error)
Put(context.Context, instance.State, time.Duration) error
Delete(context.Context, string) error
OnExpire(context.Context, func(string)) error
}

View File

@@ -3,7 +3,7 @@
//
// Generated by this command:
//
// mockgen -package storetest -source=store.go -destination=storetest/mocks_store.go *
// mockgen -package storetest -source=store.go -destination=../store/storetest/mocks_store.go *
//
// Package storetest is a generated GoMock package.
@@ -14,7 +14,7 @@ import (
reflect "reflect"
time "time"
instance "github.com/sablierapp/sablier/app/instance"
sablier "github.com/sablierapp/sablier/pkg/sablier"
gomock "go.uber.org/mock/gomock"
)
@@ -57,10 +57,10 @@ func (mr *MockStoreMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
}
// Get mocks base method.
func (m *MockStore) Get(arg0 context.Context, arg1 string) (instance.State, error) {
func (m *MockStore) Get(arg0 context.Context, arg1 string) (sablier.InstanceInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].(instance.State)
ret0, _ := ret[0].(sablier.InstanceInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -86,7 +86,7 @@ func (mr *MockStoreMockRecorder) OnExpire(arg0, arg1 any) *gomock.Call {
}
// Put mocks base method.
func (m *MockStore) Put(arg0 context.Context, arg1 instance.State, arg2 time.Duration) error {
func (m *MockStore) Put(arg0 context.Context, arg1 sablier.InstanceInfo, arg2 time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Put", arg0, arg1, arg2)
ret0, _ := ret[0].(error)

View File

@@ -3,7 +3,7 @@ package valkey
import (
"context"
"encoding/json"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"github.com/valkey-io/valkey-go"
"log/slog"
@@ -11,13 +11,13 @@ import (
"time"
)
var _ store.Store = (*ValKey)(nil)
var _ sablier.Store = (*ValKey)(nil)
type ValKey struct {
Client valkey.Client
}
func New(ctx context.Context, client valkey.Client) (store.Store, error) {
func New(ctx context.Context, client valkey.Client) (sablier.Store, error) {
err := client.Do(ctx, client.B().Ping().Build()).Error()
if err != nil {
return nil, err
@@ -33,25 +33,25 @@ func New(ctx context.Context, client valkey.Client) (store.Store, error) {
return &ValKey{Client: client}, nil
}
func (v *ValKey) Get(ctx context.Context, s string) (instance.State, error) {
func (v *ValKey) Get(ctx context.Context, s string) (sablier.InstanceInfo, error) {
b, err := v.Client.Do(ctx, v.Client.B().Get().Key(s).Build()).AsBytes()
if valkey.IsValkeyNil(err) {
return instance.State{}, store.ErrKeyNotFound
return sablier.InstanceInfo{}, store.ErrKeyNotFound
}
if err != nil {
return instance.State{}, err
return sablier.InstanceInfo{}, err
}
var i instance.State
var i sablier.InstanceInfo
err = json.Unmarshal(b, &i)
if err != nil {
return instance.State{}, err
return sablier.InstanceInfo{}, err
}
return i, nil
}
func (v *ValKey) Put(ctx context.Context, state instance.State, duration time.Duration) error {
func (v *ValKey) Put(ctx context.Context, state sablier.InstanceInfo, duration time.Duration) error {
value, err := json.Marshal(state)
if err != nil {
return err

View File

@@ -2,7 +2,7 @@ package valkey
import (
"context"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store"
"github.com/testcontainers/testcontainers-go"
tcvalkey "github.com/testcontainers/testcontainers-go/modules/valkey"
@@ -44,53 +44,51 @@ func setupValKey(t *testing.T) *ValKey {
}
func TestValKey(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
ctx := t.Context()
vk := setupValKey(t)
t.Parallel()
t.Run("ValKeyErrNotFound", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
vk := setupValKey(t)
_, err := vk.Get(ctx, "test")
_, err := vk.Get(ctx, "ValKeyErrNotFound")
assert.ErrorIs(t, err, store.ErrKeyNotFound)
})
t.Run("ValKeyPut", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
vk := setupValKey(t)
err := vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second)
err := vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyPut"}, 1*time.Second)
assert.NilError(t, err)
i, err := vk.Get(ctx, "test")
i, err := vk.Get(ctx, "ValKeyPut")
assert.NilError(t, err)
assert.Equal(t, i.Name, "test")
assert.Equal(t, i.Name, "ValKeyPut")
<-time.After(2 * time.Second)
_, err = vk.Get(ctx, "test")
_, err = vk.Get(ctx, "ValKeyPut")
assert.ErrorIs(t, err, store.ErrKeyNotFound)
})
t.Run("ValKeyDelete", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
vk := setupValKey(t)
err := vk.Put(ctx, instance.State{Name: "test"}, 30*time.Second)
err := vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyDelete"}, 30*time.Second)
assert.NilError(t, err)
i, err := vk.Get(ctx, "test")
i, err := vk.Get(ctx, "ValKeyDelete")
assert.NilError(t, err)
assert.Equal(t, i.Name, "test")
assert.Equal(t, i.Name, "ValKeyDelete")
err = vk.Delete(ctx, "test")
err = vk.Delete(ctx, "ValKeyDelete")
assert.NilError(t, err)
_, err = vk.Get(ctx, "test")
_, err = vk.Get(ctx, "ValKeyDelete")
assert.ErrorIs(t, err, store.ErrKeyNotFound)
})
t.Run("ValKeyOnExpire", func(t *testing.T) {
t.Parallel()
vk := setupValKey(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -100,9 +98,9 @@ func TestValKey(t *testing.T) {
})
assert.NilError(t, err)
err = vk.Put(ctx, instance.State{Name: "test"}, 1*time.Second)
err = vk.Put(ctx, sablier.InstanceInfo{Name: "ValKeyOnExpire"}, 1*time.Second)
assert.NilError(t, err)
expired := <-expirations
assert.Equal(t, expired, "test")
assert.Equal(t, expired, "ValKeyOnExpire")
})
}