diff --git a/app/http/routes/strategies_test.go b/app/http/routes/strategies_test.go index faa1d09..83f37b3 100644 --- a/app/http/routes/strategies_test.go +++ b/app/http/routes/strategies_test.go @@ -41,6 +41,8 @@ func (s *SessionsManagerMock) SaveSessions(io.WriteCloser) error { return nil } +func (s *SessionsManagerMock) Stop() {} + func TestServeStrategy_ServeDynamic(t *testing.T) { type arg struct { body models.DynamicRequest diff --git a/app/providers/docker_classic.go b/app/providers/docker_classic.go index 5d3ff37..b2e10e2 100644 --- a/app/providers/docker_classic.go +++ b/app/providers/docker_classic.go @@ -107,7 +107,7 @@ func (provider *DockerClassicProvider) GetState(name string) (instance.State, er } } -func (provider *DockerClassicProvider) NotifyInsanceStopped(ctx context.Context, instance chan string) { +func (provider *DockerClassicProvider) NotifyInsanceStopped(ctx context.Context, instance chan<- string) { msgs, errs := provider.Client.Events(ctx, types.EventsOptions{ Filters: filters.NewArgs( filters.Arg("scope", "local"), @@ -125,11 +125,9 @@ func (provider *DockerClassicProvider) NotifyInsanceStopped(ctx context.Context, case err := <-errs: if errors.Is(err, io.EOF) { log.Debug("provider event stream closed") - close(instance) return } case <-ctx.Done(): - close(instance) return } } diff --git a/app/providers/docker_classic_test.go b/app/providers/docker_classic_test.go index b6ea558..c09e4b6 100644 --- a/app/providers/docker_classic_test.go +++ b/app/providers/docker_classic_test.go @@ -427,15 +427,16 @@ func TestDockerClassicProvider_NotifyInsanceStopped(t *testing.T) { desiredReplicas: 1, } - instanceC := make(chan string) + instanceC := make(chan string, 1) - provider.NotifyInsanceStopped(context.Background(), instanceC) + ctx, cancel := context.WithCancel(context.Background()) + provider.NotifyInsanceStopped(ctx, instanceC) var got []string - for i := range instanceC { - got = append(got, i) - } + got = append(got, <-instanceC) + cancel() + close(instanceC) if !reflect.DeepEqual(got, tt.want) { t.Errorf("NotifyInsanceStopped() = %v, want %v", got, tt.want) diff --git a/app/providers/docker_swarm.go b/app/providers/docker_swarm.go index 02e4da1..5aebb3d 100644 --- a/app/providers/docker_swarm.go +++ b/app/providers/docker_swarm.go @@ -124,7 +124,7 @@ func (provider *DockerSwarmProvider) getInstanceName(name string, service swarm. return fmt.Sprintf("%s (%s)", name, service.Spec.Name) } -func (provider *DockerSwarmProvider) NotifyInsanceStopped(ctx context.Context, instance chan string) { +func (provider *DockerSwarmProvider) NotifyInsanceStopped(ctx context.Context, instance chan<- string) { msgs, errs := provider.Client.Events(ctx, types.EventsOptions{ Filters: filters.NewArgs( filters.Arg("scope", "swarm"), @@ -143,11 +143,9 @@ func (provider *DockerSwarmProvider) NotifyInsanceStopped(ctx context.Context, i case err := <-errs: if errors.Is(err, io.EOF) { log.Debug("provider event stream closed") - close(instance) return } case <-ctx.Done(): - close(instance) return } } diff --git a/app/providers/kubernetes.go b/app/providers/kubernetes.go index 05add0e..cf075e3 100644 --- a/app/providers/kubernetes.go +++ b/app/providers/kubernetes.go @@ -166,5 +166,5 @@ func (provider *KubernetesProvider) getStatefulsetState(config *Config) (instanc return instance.NotReadyInstanceState(config.OriginalName, int(ss.Status.ReadyReplicas), int(config.Replicas)) } -func (provider *KubernetesProvider) NotifyInsanceStopped(ctx context.Context, instance chan string) { +func (provider *KubernetesProvider) NotifyInsanceStopped(ctx context.Context, instance chan<- string) { } diff --git a/app/providers/provider.go b/app/providers/provider.go index bb7148a..58e2ee7 100644 --- a/app/providers/provider.go +++ b/app/providers/provider.go @@ -13,7 +13,7 @@ type Provider interface { Stop(name string) (instance.State, error) GetState(name string) (instance.State, error) - NotifyInsanceStopped(ctx context.Context, instance chan string) + NotifyInsanceStopped(ctx context.Context, instance chan<- string) } func NewProvider(config config.Provider) (Provider, error) { diff --git a/app/sablier.go b/app/sablier.go index def49e0..2d46adb 100644 --- a/app/sablier.go +++ b/app/sablier.go @@ -40,6 +40,7 @@ func Start(conf config.Config) error { } sessionsManager := sessions.NewSessionsManager(store, provider) + defer sessionsManager.Stop() if storage.Enabled() { defer saveSessions(storage, sessionsManager) diff --git a/app/sessions/mocks/provider_mock.go b/app/sessions/mocks/provider_mock.go index 2cf3e9e..0c1cf65 100644 --- a/app/sessions/mocks/provider_mock.go +++ b/app/sessions/mocks/provider_mock.go @@ -25,7 +25,7 @@ func NewProviderMock(stoppedInstances []string) *ProviderMock { } } -func (provider *ProviderMock) NotifyInsanceStopped(ctx context.Context, instance chan string) { +func (provider *ProviderMock) NotifyInsanceStopped(ctx context.Context, instance chan<- string) { go func() { defer close(instance) for i := 0; i < len(provider.stoppedInstances); i++ { diff --git a/app/sessions/sessions_manager.go b/app/sessions/sessions_manager.go index 77db733..0887249 100644 --- a/app/sessions/sessions_manager.go +++ b/app/sessions/sessions_manager.go @@ -20,12 +20,17 @@ type Manager interface { LoadSessions(io.ReadCloser) error SaveSessions(io.WriteCloser) error + + Stop() } type SessionsManager struct { - store tinykv.KV[instance.State] - provider providers.Provider - insanceStopped chan string + events context.Context + cancel context.CancelFunc + + store tinykv.KV[instance.State] + provider providers.Provider + instanceStopped chan string } func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Provider) Manager { @@ -41,12 +46,15 @@ func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Prov } }() - provider.NotifyInsanceStopped(context.Background(), instanceStopped) + events, cancel := context.WithCancel(context.Background()) + provider.NotifyInsanceStopped(events, instanceStopped) return &SessionsManager{ - store: store, - provider: provider, - insanceStopped: instanceStopped, + events: events, + cancel: cancel, + store: store, + provider: provider, + instanceStopped: instanceStopped, } } @@ -132,9 +140,6 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura requestState, exists := s.store.Get(name) - // Trust the stored value - // TODO: Provider background check on the store - // Via polling or whatever if !exists { log.Debugf("starting %s...", name) @@ -212,6 +217,17 @@ func (s *SessionsManager) ExpiresAfter(instance *instance.State, duration time.D s.store.Put(instance.Name, *instance, duration) } +func (s *SessionsManager) Stop() { + // Stop event listeners + s.cancel() + + // Stop receiving stopped instance + close(s.instanceStopped) + + // Stop the store + s.store.Stop() +} + func (s *SessionState) MarshalJSON() ([]byte, error) { instances := []InstanceState{}