refactor: reorganize code structure (#556)

* refactor: rename providers to Provider

* refactor folders

* fix build cmd

* fix build cmd

* fix build cmd

* fix cmd start
This commit is contained in:
Alexis Couvreur
2025-03-10 14:11:40 -04:00
committed by GitHub
parent 8122a888b1
commit fca9c79289
83 changed files with 474 additions and 698 deletions

View File

@@ -30,7 +30,7 @@ jobs:
cache-dependency-path: go.sum cache-dependency-path: go.sum
- name: Build - name: Build
run: go build -v . run: go build -v ./cmd/sablier
- name: Test - name: Test
run: go test -v -json -race -covermode atomic -coverprofile coverage.txt ./... 2>&1 | go tool go-junit-report -parser gojson > junit.xml run: go test -v -json -race -covermode atomic -coverprofile coverage.txt ./... 2>&1 | go tool go-junit-report -parser gojson > junit.xml

View File

@@ -11,20 +11,20 @@ GIT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD)
BUILDTIME := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") BUILDTIME := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
BUILDUSER := $(shell whoami)@$(shell hostname) BUILDUSER := $(shell whoami)@$(shell hostname)
VPREFIX := github.com/sablierapp/sablier/version VPREFIX := github.com/sablierapp/sablier/pkg/version
GO_LDFLAGS := -s -w -X $(VPREFIX).Branch=$(GIT_BRANCH) -X $(VPREFIX).Version=$(VERSION) -X $(VPREFIX).Revision=$(GIT_REVISION) -X $(VPREFIX).BuildUser=$(BUILDUSER) -X $(VPREFIX).BuildDate=$(BUILDTIME) GO_LDFLAGS := -s -w -X $(VPREFIX).Branch=$(GIT_BRANCH) -X $(VPREFIX).Version=$(VERSION) -X $(VPREFIX).Revision=$(GIT_REVISION) -X $(VPREFIX).BuildUser=$(BUILDUSER) -X $(VPREFIX).BuildDate=$(BUILDTIME)
$(PLATFORMS): $(PLATFORMS):
CGO_ENABLED=0 GOOS=$(os) GOARCH=$(arch) go build -trimpath -tags=nomsgpack -v -ldflags="${GO_LDFLAGS}" -o 'sablier_$(VERSION)_$(os)-$(arch)' . CGO_ENABLED=0 GOOS=$(os) GOARCH=$(arch) go build -trimpath -tags=nomsgpack -v -ldflags="${GO_LDFLAGS}" -o 'sablier_$(VERSION)_$(os)-$(arch)' ./cmd/sablier
run: run:
go run main.go start --storage.file=state.json --logging.level=debug go run ./cmd/sablier start --storage.file=state.json --logging.level=debug
gen: gen:
go generate -v ./... go generate -v ./...
build: build:
go build -v . go build -v ./cmd/sablier
test: test:
go test -v ./... go test -v ./...

View File

@@ -1,31 +0,0 @@
package healthcheck
import (
"io"
"net/http"
)
const (
healthy = true
unhealthy = false
)
func Health(url string) (string, bool) {
resp, err := http.Get(url)
if err != nil {
return err.Error(), unhealthy
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return err.Error(), unhealthy
}
if resp.StatusCode >= 400 {
return string(body), unhealthy
}
return string(body), healthy
}

View File

@@ -1,10 +0,0 @@
package models
import "time"
type BlockingRequest struct {
Names []string `form:"names"`
Group string `form:"group"`
SessionDuration time.Duration `form:"session_duration"`
Timeout time.Duration `form:"timeout"`
}

View File

@@ -1,15 +0,0 @@
package models
import (
"time"
)
type DynamicRequest struct {
Group string `form:"group"`
Names []string `form:"names"`
ShowDetails bool `form:"show_details"`
DisplayName string `form:"display_name"`
Theme string `form:"theme"`
SessionDuration time.Duration `form:"session_duration"`
RefreshFrequency time.Duration `form:"refresh_frequency"`
}

View File

@@ -1,15 +0,0 @@
package routes
import (
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme"
)
type ServeStrategy struct {
Theme *theme.Themes
SessionsManager sablier.Sablier
StrategyConfig config.Strategy
SessionsConfig config.Sessions
}

View File

@@ -1,13 +0,0 @@
package routes
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/version"
)
func GetVersion(c *gin.Context) {
c.JSON(http.StatusOK, version.Map())
}

View File

@@ -1,32 +0,0 @@
package routes
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/version"
"gotest.tools/v3/assert"
)
func TestGetVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
version.Branch = "testing"
version.Revision = "8ffebca"
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
expected, _ := json.Marshal(version.Map())
GetVersion(c)
res := recorder.Result()
defer res.Body.Close()
data, _ := io.ReadAll(res.Body)
assert.Equal(t, res.StatusCode, http.StatusOK)
assert.Equal(t, string(data), string(expected))
}

View File

@@ -1,184 +0,0 @@
package app
import (
"context"
"fmt"
"github.com/docker/docker/client"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory"
"github.com/sablierapp/sablier/pkg/theme"
k8s "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"log/slog"
"os"
"os/signal"
"syscall"
"time"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/internal/server"
"github.com/sablierapp/sablier/version"
)
func Start(ctx context.Context, conf config.Config) error {
// Create context that listens for the interrupt signal from the OS.
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer stop()
logger := setupLogger(conf.Logging)
logger.Info("running Sablier version " + version.Info())
provider, err := NewProvider(ctx, logger, conf.Provider)
if err != nil {
return err
}
store := inmemory.NewInMemory()
err = store.OnExpire(ctx, onSessionExpires(ctx, provider, logger))
if err != nil {
return err
}
s := sablier.New(logger, store, provider)
groups, err := provider.InstanceGroups(ctx)
if err != nil {
logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err))
} else {
s.SetGroups(groups)
}
updateGroups := make(chan map[string][]string)
go WatchGroups(ctx, provider, 2*time.Second, updateGroups, logger)
go func() {
for groups := range updateGroups {
s.SetGroups(groups)
}
}()
instanceStopped := make(chan string)
go provider.NotifyInstanceStopped(ctx, instanceStopped)
go func() {
for stopped := range instanceStopped {
err := s.RemoveInstance(ctx, stopped)
if err != nil {
logger.Warn("could not remove instance", slog.Any("error", err))
}
}
}()
if conf.Provider.AutoStopOnStartup {
err := s.StopAllUnregisteredInstances(ctx)
if err != nil {
logger.ErrorContext(ctx, "unable to stop unregistered instances", slog.Any("reason", err))
}
}
var t *theme.Themes
if conf.Strategy.Dynamic.CustomThemesPath != "" {
logger.DebugContext(ctx, "loading themes from custom theme path", slog.String("path", conf.Strategy.Dynamic.CustomThemesPath))
custom := os.DirFS(conf.Strategy.Dynamic.CustomThemesPath)
t, err = theme.NewWithCustomThemes(custom, logger)
if err != nil {
return err
}
} else {
logger.DebugContext(ctx, "loading themes without custom theme path", slog.String("reason", "--strategy.dynamic.custom-themes-path is empty"))
t, err = theme.New(logger)
if err != nil {
return err
}
}
strategy := &routes.ServeStrategy{
Theme: t,
SessionsManager: s,
StrategyConfig: conf.Strategy,
SessionsConfig: conf.Sessions,
}
go server.Start(ctx, logger, conf.Server, strategy)
// Listen for the interrupt signal.
<-ctx.Done()
stop()
logger.InfoContext(ctx, "shutting down gracefully, press Ctrl+C again to force")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
logger.InfoContext(ctx, "Server exiting")
return nil
}
func onSessionExpires(ctx context.Context, provider sablier.Provider, logger *slog.Logger) func(key string) {
return func(_key string) {
go func(key string) {
logger.InfoContext(ctx, "instance expired", slog.String("instance", key))
err := provider.InstanceStop(ctx, key)
if err != nil {
logger.ErrorContext(ctx, "instance expired could not be stopped from provider", slog.String("instance", key), slog.Any("error", err))
}
}(_key)
}
}
func NewProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (sablier.Provider, error) {
if err := config.IsValid(); err != nil {
return nil, err
}
switch config.Name {
case "swarm", "docker_swarm":
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, fmt.Errorf("cannot create docker swarm client: %v", err)
}
return dockerswarm.NewDockerSwarmProvider(ctx, cli, logger)
case "docker":
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, fmt.Errorf("cannot create docker client: %v", err)
}
return docker.NewDockerClassicProvider(ctx, cli, logger)
case "kubernetes":
kubeclientConfig, err := rest.InClusterConfig()
if err != nil {
return nil, err
}
kubeclientConfig.QPS = config.Kubernetes.QPS
kubeclientConfig.Burst = config.Kubernetes.Burst
cli, err := k8s.NewForConfig(kubeclientConfig)
if err != nil {
return nil, err
}
return kubernetes.NewKubernetesProvider(ctx, cli, logger, config.Kubernetes)
}
return nil, fmt.Errorf("unimplemented provider %s", config.Name)
}
func WatchGroups(ctx context.Context, provider sablier.Provider, frequency time.Duration, send chan<- map[string][]string, logger *slog.Logger) {
ticker := time.NewTicker(frequency)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
groups, err := provider.InstanceGroups(ctx)
if err != nil {
logger.Error("cannot retrieve group from provider", slog.Any("reason", err))
} else if groups != nil {
send <- groups
}
}
}
}

View File

@@ -1,27 +0,0 @@
package cmd
import (
"fmt"
"os"
"github.com/sablierapp/sablier/app/http/healthcheck"
"github.com/spf13/cobra"
)
var newHealthCommand = func() *cobra.Command {
return &cobra.Command{
Use: "health",
Short: "Calls the health endpoint of a Sablier instance",
Run: func(cmd *cobra.Command, args []string) {
details, healthy := healthcheck.Health(cmd.Flag("url").Value.String())
if healthy {
fmt.Fprintf(os.Stderr, "healthy: %v\n", details)
os.Exit(0)
} else {
fmt.Fprintf(os.Stderr, "unhealthy: %v\n", details)
os.Exit(1)
}
},
}
}

View File

@@ -0,0 +1,52 @@
package healthcheck
import (
"fmt"
"github.com/spf13/cobra"
"io"
"net/http"
"os"
)
const (
healthy = true
unhealthy = false
)
func Health(url string) (string, bool) {
resp, err := http.Get(url)
if err != nil {
return err.Error(), unhealthy
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return err.Error(), unhealthy
}
if resp.StatusCode >= 400 {
return string(body), unhealthy
}
return string(body), healthy
}
func NewCmd() *cobra.Command {
return &cobra.Command{
Use: "health",
Short: "Calls the health endpoint of a Sablier instance",
Run: func(cmd *cobra.Command, args []string) {
details, healthy := Health(cmd.Flag("url").Value.String())
if healthy {
fmt.Fprintf(os.Stderr, "healthy: %v\n", details)
os.Exit(0)
} else {
fmt.Fprintf(os.Stderr, "unhealthy: %v\n", details)
os.Exit(1)
}
},
}
}

View File

@@ -1,194 +0,0 @@
package cmd
import (
"bufio"
"bytes"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sablierapp/sablier/config"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"gotest.tools/v3/assert"
)
func TestDefault(t *testing.T) {
testDir, err := os.Getwd()
require.NoError(t, err, "error getting the current working directory")
wantConfig, err := os.ReadFile(filepath.Join(testDir, "testdata", "config_default.json"))
require.NoError(t, err, "error reading test config file")
// CHANGE `startCmd` behavior to only print the config, this is for testing purposes only
newStartCommand = mockStartCommand
t.Run("config file", func(t *testing.T) {
conf = config.NewConfig()
cmd := NewRootCommand()
output := &bytes.Buffer{}
cmd.SetOut(output)
cmd.SetArgs([]string{
"start",
})
cmd.Execute()
gotOutput := output.String()
assert.Equal(t, string(wantConfig), gotOutput)
})
}
func TestPrecedence(t *testing.T) {
testDir, err := os.Getwd()
require.NoError(t, err, "error getting the current working directory")
// CHANGE `startCmd` behavior to only print the config, this is for testing purposes only
newStartCommand = mockStartCommand
t.Run("config file", func(t *testing.T) {
wantConfig, err := os.ReadFile(filepath.Join(testDir, "testdata", "config_yaml_wanted.json"))
require.NoError(t, err, "error reading test config file")
conf = config.NewConfig()
cmd := NewRootCommand()
output := &bytes.Buffer{}
cmd.SetOut(output)
cmd.SetArgs([]string{
"--configFile", filepath.Join(testDir, "testdata", "config.yml"),
"start",
})
cmd.Execute()
gotOutput := output.String()
assert.Equal(t, string(wantConfig), gotOutput)
})
t.Run("env var", func(t *testing.T) {
setEnvsFromFile(filepath.Join(testDir, "testdata", "config.env"))
defer unsetEnvsFromFile(filepath.Join(testDir, "testdata", "config.env"))
wantConfig, err := os.ReadFile(filepath.Join(testDir, "testdata", "config_env_wanted.json"))
require.NoError(t, err, "error reading test config file")
conf = config.NewConfig()
cmd := NewRootCommand()
output := &bytes.Buffer{}
cmd.SetOut(output)
cmd.SetArgs([]string{
"--configFile", filepath.Join(testDir, "testdata", "config.yml"),
"start",
})
cmd.Execute()
gotOutput := output.String()
assert.Equal(t, string(wantConfig), gotOutput)
})
t.Run("flag", func(t *testing.T) {
setEnvsFromFile(filepath.Join(testDir, "testdata", "config.env"))
defer unsetEnvsFromFile(filepath.Join(testDir, "testdata", "config.env"))
wantConfig, err := os.ReadFile(filepath.Join(testDir, "testdata", "config_cli_wanted.json"))
require.NoError(t, err, "error reading test config file")
cmd := NewRootCommand()
output := &bytes.Buffer{}
conf = config.NewConfig()
cmd.SetOut(output)
cmd.SetArgs([]string{
"--configFile", filepath.Join(testDir, "testdata", "config.yml"),
"start",
"--provider.name", "cli",
"--provider.kubernetes.qps", "256",
"--provider.kubernetes.burst", "512",
"--provider.kubernetes.delimiter", "_",
"--server.port", "3333",
"--server.base-path", "/cli/",
"--storage.file", "/tmp/cli.json",
"--sessions.default-duration", "3h",
"--sessions.expiration-interval", "3h",
"--logging.level", "info",
"--strategy.dynamic.custom-themes-path", "/tmp/cli/themes",
// Must use `=` see https://github.com/spf13/cobra/issues/613
"--strategy.dynamic.show-details-by-default=false",
"--strategy.dynamic.default-theme", "cli",
"--strategy.dynamic.default-refresh-frequency", "3h",
"--strategy.blocking.default-timeout", "3h",
})
cmd.Execute()
gotOutput := output.String()
assert.Equal(t, string(wantConfig), gotOutput)
})
}
func setEnvsFromFile(path string) {
readFile, err := os.Open(path)
if err != nil {
panic(err)
}
defer readFile.Close()
if err != nil {
panic(err)
}
fileScanner := bufio.NewScanner(readFile)
fileScanner.Split(bufio.ScanLines)
for fileScanner.Scan() {
splitted := strings.Split(fileScanner.Text(), "=")
os.Setenv(splitted[0], splitted[1])
}
}
func unsetEnvsFromFile(path string) {
readFile, err := os.Open(path)
if err != nil {
panic(err)
}
defer readFile.Close()
if err != nil {
panic(err)
}
fileScanner := bufio.NewScanner(readFile)
fileScanner.Split(bufio.ScanLines)
for fileScanner.Scan() {
splitted := strings.Split(fileScanner.Text(), "=")
os.Unsetenv(splitted[0])
}
}
func mockStartCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "start",
Short: "InstanceStart the Sablier server",
Run: func(cmd *cobra.Command, args []string) {
viper.Unmarshal(&conf)
out := cmd.OutOrStdout()
encoder := json.NewEncoder(out)
encoder.SetIndent("", " ")
encoder.Encode(conf)
},
}
return cmd
}

View File

@@ -1,16 +1,17 @@
package cmd package main
import ( import (
"fmt" "fmt"
"github.com/sablierapp/sablier/cmd/healthcheck"
"github.com/sablierapp/sablier/cmd/version"
"github.com/sablierapp/sablier/pkg/config"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"log/slog" "log/slog"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/sablierapp/sablier/config"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
) )
const ( const (
@@ -21,7 +22,7 @@ const (
var conf = config.NewConfig() var conf = config.NewConfig()
var cfgFile string var cfgFile string
func Execute() { func main() {
cmd := NewRootCommand() cmd := NewRootCommand()
if err := cmd.Execute(); err != nil { if err := cmd.Execute(); err != nil {
os.Exit(1) os.Exit(1)
@@ -42,7 +43,7 @@ It provides an integrations with multiple reverse proxies and different loading
rootCmd.PersistentFlags().StringVar(&cfgFile, "configFile", "", "Config file path. If not defined, looks for sablier.(yml|yaml|toml) in /etc/sablier/ > $XDG_CONFIG_HOME > $HOME/.config/ and current directory") rootCmd.PersistentFlags().StringVar(&cfgFile, "configFile", "", "Config file path. If not defined, looks for sablier.(yml|yaml|toml) in /etc/sablier/ > $XDG_CONFIG_HOME > $HOME/.config/ and current directory")
startCmd := newStartCommand() startCmd := NewCmd()
// Provider flags // Provider flags
startCmd.Flags().StringVar(&conf.Provider.Name, "provider.name", "docker", fmt.Sprintf("Provider to use to manage containers %v", config.GetProviders())) startCmd.Flags().StringVar(&conf.Provider.Name, "provider.name", "docker", fmt.Sprintf("Provider to use to manage containers %v", config.GetProviders()))
viper.BindPFlag("provider.name", startCmd.Flags().Lookup("provider.name")) viper.BindPFlag("provider.name", startCmd.Flags().Lookup("provider.name"))
@@ -69,7 +70,7 @@ It provides an integrations with multiple reverse proxies and different loading
viper.BindPFlag("sessions.expiration-interval", startCmd.Flags().Lookup("sessions.expiration-interval")) viper.BindPFlag("sessions.expiration-interval", startCmd.Flags().Lookup("sessions.expiration-interval"))
// logging level // logging level
rootCmd.PersistentFlags().StringVar(&conf.Logging.Level, "logging.level", strings.ToLower(slog.LevelInfo.String()), "The logging level. Can be one of [panic, fatal, error, warn, info, debug, trace]") rootCmd.PersistentFlags().StringVar(&conf.Logging.Level, "logging.level", strings.ToLower(slog.LevelInfo.String()), "The logging level. Can be one of [error, warn, info, debug]")
viper.BindPFlag("logging.level", rootCmd.PersistentFlags().Lookup("logging.level")) viper.BindPFlag("logging.level", rootCmd.PersistentFlags().Lookup("logging.level"))
// strategy // strategy
@@ -85,9 +86,9 @@ It provides an integrations with multiple reverse proxies and different loading
viper.BindPFlag("strategy.blocking.default-timeout", startCmd.Flags().Lookup("strategy.blocking.default-timeout")) viper.BindPFlag("strategy.blocking.default-timeout", startCmd.Flags().Lookup("strategy.blocking.default-timeout"))
rootCmd.AddCommand(startCmd) rootCmd.AddCommand(startCmd)
rootCmd.AddCommand(newVersionCommand()) rootCmd.AddCommand(version.NewCmd())
healthCmd := newHealthCommand() healthCmd := healthcheck.NewCmd()
healthCmd.Flags().String("url", "http://localhost:10000/health", "Sablier health endpoint") healthCmd.Flags().String("url", "http://localhost:10000/health", "Sablier health endpoint")
rootCmd.AddCommand(healthCmd) rootCmd.AddCommand(healthCmd)
@@ -147,3 +148,21 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) {
} }
}) })
} }
func NewCmd() *cobra.Command {
return &cobra.Command{
Use: "start",
Short: "Start the Sablier server",
Run: func(cmd *cobra.Command, args []string) {
err := viper.Unmarshal(&conf)
if err != nil {
panic(err)
}
err = Start(cmd.Context(), conf)
if err != nil {
panic(err)
}
},
}
}

View File

@@ -1,8 +1,8 @@
package app package main
import ( import (
"github.com/lmittmann/tint" "github.com/lmittmann/tint"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"log/slog" "log/slog"
"os" "os"
"strings" "strings"

50
cmd/sablier/provider.go Normal file
View File

@@ -0,0 +1,50 @@
package main
import (
"context"
"fmt"
"github.com/docker/docker/client"
"github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/docker"
"github.com/sablierapp/sablier/pkg/provider/dockerswarm"
"github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier"
k8s "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"log/slog"
)
func setupProvider(ctx context.Context, logger *slog.Logger, config config.Provider) (sablier.Provider, error) {
if err := config.IsValid(); err != nil {
return nil, err
}
switch config.Name {
case "swarm", "docker_swarm":
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, fmt.Errorf("cannot create docker swarm client: %v", err)
}
return dockerswarm.New(ctx, cli, logger)
case "docker":
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, fmt.Errorf("cannot create docker client: %v", err)
}
return docker.New(ctx, cli, logger)
case "kubernetes":
kubeclientConfig, err := rest.InClusterConfig()
if err != nil {
return nil, err
}
kubeclientConfig.QPS = config.Kubernetes.QPS
kubeclientConfig.Burst = config.Kubernetes.Burst
cli, err := k8s.NewForConfig(kubeclientConfig)
if err != nil {
return nil, err
}
return kubernetes.New(ctx, cli, logger, config.Kubernetes)
}
return nil, fmt.Errorf("unimplemented provider %s", config.Name)
}

92
cmd/sablier/sablier.go Normal file
View File

@@ -0,0 +1,92 @@
package main
import (
"context"
"fmt"
"github.com/sablierapp/sablier/internal/api"
"github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/store/inmemory"
"github.com/sablierapp/sablier/pkg/version"
"log/slog"
"os/signal"
"syscall"
"time"
"github.com/sablierapp/sablier/internal/server"
)
func Start(ctx context.Context, conf config.Config) error {
// Create context that listens for the interrupt signal from the OS.
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer stop()
logger := setupLogger(conf.Logging)
logger.Info("running Sablier version " + version.Info())
provider, err := setupProvider(ctx, logger, conf.Provider)
if err != nil {
return fmt.Errorf("cannot setup provider: %w", err)
}
store := inmemory.NewInMemory()
err = store.OnExpire(ctx, sablier.OnInstanceExpired(ctx, provider, logger))
if err != nil {
return err
}
s := sablier.New(logger, store, provider)
groups, err := provider.InstanceGroups(ctx)
if err != nil {
logger.WarnContext(ctx, "initial group scan failed", slog.Any("reason", err))
} else {
s.SetGroups(groups)
}
go s.GroupWatch(ctx)
instanceStopped := make(chan string)
go provider.NotifyInstanceStopped(ctx, instanceStopped)
go func() {
for stopped := range instanceStopped {
err := s.RemoveInstance(ctx, stopped)
if err != nil {
logger.Warn("could not remove instance", slog.Any("error", err))
}
}
}()
if conf.Provider.AutoStopOnStartup {
err := s.StopAllUnregisteredInstances(ctx)
if err != nil {
logger.ErrorContext(ctx, "unable to stop unregistered instances", slog.Any("reason", err))
}
}
t, err := setupTheme(ctx, conf, logger)
if err != nil {
return fmt.Errorf("cannot setup theme: %w", err)
}
strategy := &api.ServeStrategy{
Theme: t,
Sablier: s,
StrategyConfig: conf.Strategy,
SessionsConfig: conf.Sessions,
}
go server.Start(ctx, logger, conf.Server, strategy)
// Listen for the interrupt signal.
<-ctx.Done()
stop()
logger.InfoContext(ctx, "shutting down gracefully, press Ctrl+C again to force")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
logger.InfoContext(ctx, "Server exiting")
return nil
}

28
cmd/sablier/theme.go Normal file
View File

@@ -0,0 +1,28 @@
package main
import (
"context"
"github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/theme"
"log/slog"
"os"
)
func setupTheme(ctx context.Context, conf config.Config, logger *slog.Logger) (*theme.Themes, error) {
if conf.Strategy.Dynamic.CustomThemesPath != "" {
logger.DebugContext(ctx, "loading themes from custom theme path", slog.String("path", conf.Strategy.Dynamic.CustomThemesPath))
custom := os.DirFS(conf.Strategy.Dynamic.CustomThemesPath)
t, err := theme.NewWithCustomThemes(custom, logger)
if err != nil {
return nil, err
}
return t, nil
}
logger.DebugContext(ctx, "loading themes without custom theme path", slog.String("reason", "--strategy.dynamic.custom-themes-path is empty"))
t, err := theme.New(logger)
if err != nil {
return nil, err
}
return t, nil
}

View File

@@ -1,22 +0,0 @@
package cmd
import (
"github.com/sablierapp/sablier/app"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var newStartCommand = func() *cobra.Command {
return &cobra.Command{
Use: "start",
Short: "InstanceStart the Sablier server",
Run: func(cmd *cobra.Command, args []string) {
viper.Unmarshal(&conf)
err := app.Start(cmd.Context(), conf)
if err != nil {
panic(err)
}
},
}
}

View File

@@ -1,13 +1,13 @@
package cmd package version
import ( import (
"fmt" "fmt"
"github.com/sablierapp/sablier/pkg/version"
"github.com/sablierapp/sablier/version"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var newVersionCommand = func() *cobra.Command { func NewCmd() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "version", Use: "version",
Short: "Print the version Sablier", Short: "Print the version Sablier",

15
internal/api/api.go Normal file
View File

@@ -0,0 +1,15 @@
package api
import (
config2 "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/sablier"
"github.com/sablierapp/sablier/pkg/theme"
)
type ServeStrategy struct {
Theme *theme.Themes
Sablier sablier.Sablier
StrategyConfig config2.Strategy
SessionsConfig config2.Sessions
}

View File

@@ -3,8 +3,7 @@ package api
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/app/http/routes" config2 "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/pkg/sablier/sabliertest" "github.com/sablierapp/sablier/pkg/sablier/sabliertest"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
@@ -14,7 +13,7 @@ import (
"testing" "testing"
) )
func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *routes.ServeStrategy, mock *sabliertest.MockSablier) { func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strategy *ServeStrategy, mock *sabliertest.MockSablier) {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@@ -24,11 +23,11 @@ func NewApiTest(t *testing.T) (app *gin.Engine, router *gin.RouterGroup, strateg
app = gin.New() app = gin.New()
router = app.Group("/api") router = app.Group("/api")
mock = sabliertest.NewMockSablier(ctrl) mock = sabliertest.NewMockSablier(ctrl)
strategy = &routes.ServeStrategy{ strategy = &ServeStrategy{
Theme: th, Theme: th,
SessionsManager: mock, Sablier: mock,
StrategyConfig: config.NewStrategyConfig(), StrategyConfig: config2.NewStrategyConfig(),
SessionsConfig: config.NewSessionsConfig(), SessionsConfig: config2.NewSessionsConfig(),
} }
return app, router, strategy, mock return app, router, strategy, mock

View File

@@ -3,15 +3,21 @@ package api
import ( import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/app/http/routes/models"
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
"net/http" "net/http"
"time"
) )
func StartBlocking(router *gin.RouterGroup, s *routes.ServeStrategy) { type BlockingRequest struct {
Names []string `form:"names"`
Group string `form:"group"`
SessionDuration time.Duration `form:"session_duration"`
Timeout time.Duration `form:"timeout"`
}
func StartBlocking(router *gin.RouterGroup, s *ServeStrategy) {
router.GET("/strategies/blocking", func(c *gin.Context) { router.GET("/strategies/blocking", func(c *gin.Context) {
request := models.BlockingRequest{ request := BlockingRequest{
SessionDuration: s.SessionsConfig.DefaultDuration, SessionDuration: s.SessionsConfig.DefaultDuration,
Timeout: s.StrategyConfig.Blocking.DefaultTimeout, Timeout: s.StrategyConfig.Blocking.DefaultTimeout,
} }
@@ -34,9 +40,9 @@ func StartBlocking(router *gin.RouterGroup, s *routes.ServeStrategy) {
var sessionState *sablier.SessionState var sessionState *sablier.SessionState
var err error var err error
if len(request.Names) > 0 { if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestReadySession(c.Request.Context(), request.Names, request.SessionDuration, request.Timeout) sessionState, err = s.Sablier.RequestReadySession(c.Request.Context(), request.Names, request.SessionDuration, request.Timeout)
} else { } else {
sessionState, err = s.SessionsManager.RequestReadySessionGroup(c.Request.Context(), request.Group, request.SessionDuration, request.Timeout) sessionState, err = s.Sablier.RequestReadySessionGroup(c.Request.Context(), request.Group, request.SessionDuration, request.Timeout)
var groupNotFoundError sablier.ErrGroupNotFound var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) { if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError)) AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))

View File

@@ -8,16 +8,25 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/pkg/theme"
"github.com/sablierapp/sablier/app/http/routes/models"
theme2 "github.com/sablierapp/sablier/pkg/theme"
) )
func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) { type DynamicRequest struct {
Group string `form:"group"`
Names []string `form:"names"`
ShowDetails bool `form:"show_details"`
DisplayName string `form:"display_name"`
Theme string `form:"theme"`
SessionDuration time.Duration `form:"session_duration"`
RefreshFrequency time.Duration `form:"refresh_frequency"`
}
func StartDynamic(router *gin.RouterGroup, s *ServeStrategy) {
router.GET("/strategies/dynamic", func(c *gin.Context) { router.GET("/strategies/dynamic", func(c *gin.Context) {
request := models.DynamicRequest{ request := DynamicRequest{
Theme: s.StrategyConfig.Dynamic.DefaultTheme, Theme: s.StrategyConfig.Dynamic.DefaultTheme,
ShowDetails: s.StrategyConfig.Dynamic.ShowDetailsByDefault, ShowDetails: s.StrategyConfig.Dynamic.ShowDetailsByDefault,
RefreshFrequency: s.StrategyConfig.Dynamic.DefaultRefreshFrequency, RefreshFrequency: s.StrategyConfig.Dynamic.DefaultRefreshFrequency,
@@ -42,9 +51,9 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
var sessionState *sablier.SessionState var sessionState *sablier.SessionState
var err error var err error
if len(request.Names) > 0 { if len(request.Names) > 0 {
sessionState, err = s.SessionsManager.RequestSession(c, request.Names, request.SessionDuration) sessionState, err = s.Sablier.RequestSession(c, request.Names, request.SessionDuration)
} else { } else {
sessionState, err = s.SessionsManager.RequestSessionGroup(c, request.Group, request.SessionDuration) sessionState, err = s.Sablier.RequestSessionGroup(c, request.Group, request.SessionDuration)
var groupNotFoundError sablier.ErrGroupNotFound var groupNotFoundError sablier.ErrGroupNotFound
if errors.As(err, &groupNotFoundError) { if errors.As(err, &groupNotFoundError) {
AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError)) AbortWithProblemDetail(c, ProblemGroupNotFound(groupNotFoundError))
@@ -64,7 +73,7 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
AddSablierHeader(c, sessionState) AddSablierHeader(c, sessionState)
renderOptions := theme2.Options{ renderOptions := theme.Options{
DisplayName: request.DisplayName, DisplayName: request.DisplayName,
ShowDetails: request.ShowDetails, ShowDetails: request.ShowDetails,
SessionDuration: request.SessionDuration, SessionDuration: request.SessionDuration,
@@ -75,7 +84,7 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
writer := bufio.NewWriter(buf) writer := bufio.NewWriter(buf)
err = s.Theme.Render(request.Theme, renderOptions, writer) err = s.Theme.Render(request.Theme, renderOptions, writer)
var themeNotFound theme2.ErrThemeNotFound var themeNotFound theme.ErrThemeNotFound
if errors.As(err, &themeNotFound) { if errors.As(err, &themeNotFound) {
AbortWithProblemDetail(c, ProblemThemeNotFound(themeNotFound)) AbortWithProblemDetail(c, ProblemThemeNotFound(themeNotFound))
return return
@@ -89,7 +98,7 @@ func StartDynamic(router *gin.RouterGroup, s *routes.ServeStrategy) {
}) })
} }
func sessionStateToRenderOptionsInstanceState(sessionState *sablier.SessionState) (instances []theme2.Instance) { func sessionStateToRenderOptionsInstanceState(sessionState *sablier.SessionState) (instances []theme.Instance) {
if sessionState == nil { if sessionState == nil {
return return
} }
@@ -105,7 +114,7 @@ func sessionStateToRenderOptionsInstanceState(sessionState *sablier.SessionState
return return
} }
func instanceStateToRenderOptionsRequestState(instanceState sablier.InstanceInfo) theme2.Instance { func instanceStateToRenderOptionsRequestState(instanceState sablier.InstanceInfo) theme.Instance {
var err error var err error
if instanceState.Message == "" { if instanceState.Message == "" {
@@ -114,7 +123,7 @@ func instanceStateToRenderOptionsRequestState(instanceState sablier.InstanceInfo
err = errors.New(instanceState.Message) err = errors.New(instanceState.Message)
} }
return theme2.Instance{ return theme.Instance{
Name: instanceState.Name, Name: instanceState.Name,
Status: string(instanceState.Status), Status: string(instanceState.Status),
CurrentReplicas: instanceState.CurrentReplicas, CurrentReplicas: instanceState.CurrentReplicas,

View File

@@ -2,11 +2,10 @@ package api
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes"
"net/http" "net/http"
) )
func ListThemes(router *gin.RouterGroup, s *routes.ServeStrategy) { func ListThemes(router *gin.RouterGroup, s *ServeStrategy) {
handler := func(c *gin.Context) { handler := func(c *gin.Context) {
c.JSON(http.StatusOK, map[string]interface{}{ c.JSON(http.StatusOK, map[string]interface{}{
"themes": s.Theme.List(), "themes": s.Theme.List(),

View File

@@ -3,12 +3,11 @@ package server
import ( import (
"context" "context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/internal/api" "github.com/sablierapp/sablier/internal/api"
"github.com/sablierapp/sablier/pkg/config"
) )
func registerRoutes(ctx context.Context, router *gin.Engine, serverConf config.Server, s *routes.ServeStrategy) { func registerRoutes(ctx context.Context, router *gin.Engine, serverConf config.Server, s *api.ServeStrategy) {
// Enables automatic redirection if the current route cannot be matched but a // Enables automatic redirection if the current route cannot be matched but a
// handler for the path with (without) the trailing slash exists. // handler for the path with (without) the trailing slash exists.
router.RedirectTrailingSlash = true router.RedirectTrailingSlash = true

View File

@@ -5,14 +5,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/app/http/routes" "github.com/sablierapp/sablier/internal/api"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"log/slog" "log/slog"
"net/http" "net/http"
"time" "time"
) )
func setupRouter(ctx context.Context, logger *slog.Logger, serverConf config.Server, s *routes.ServeStrategy) *gin.Engine { func setupRouter(ctx context.Context, logger *slog.Logger, serverConf config.Server, s *api.ServeStrategy) *gin.Engine {
r := gin.New() r := gin.New()
r.Use(StructuredLogger(logger)) r.Use(StructuredLogger(logger))
@@ -23,7 +23,7 @@ func setupRouter(ctx context.Context, logger *slog.Logger, serverConf config.Ser
return r return r
} }
func Start(ctx context.Context, logger *slog.Logger, serverConf config.Server, s *routes.ServeStrategy) { func Start(ctx context.Context, logger *slog.Logger, serverConf config.Server, s *api.ServeStrategy) {
start := time.Now() start := time.Now()
if logger.Enabled(ctx, slog.LevelDebug) { if logger.Enabled(ctx, slog.LevelDebug) {

11
main.go
View File

@@ -1,11 +0,0 @@
package main
import (
"github.com/gin-gonic/gin"
"github.com/sablierapp/sablier/cmd"
)
func main() {
gin.SetMode(gin.ReleaseMode)
cmd.Execute()
}

View File

@@ -14,11 +14,11 @@ type Provider struct {
} }
type Kubernetes struct { type Kubernetes struct {
//QPS limit for K8S API access client-side throttle // QPS limit for K8S API access client-side throttle
QPS float32 `mapstructure:"QPS" yaml:"QPS" default:"5"` QPS float32 `mapstructure:"QPS" yaml:"QPS" default:"5"`
//Maximum burst for client-side throttle // Maximum burst for client-side throttle
Burst int `mapstructure:"BURST" yaml:"Burst" default:"10"` Burst int `mapstructure:"BURST" yaml:"Burst" default:"10"`
//Delimiter used for namespace/resource type/name resolution. Defaults to "_" for backward compatibility. But you should use "/" or ".". // Delimiter used for namespace/resource type/name resolution. Defaults to "_" for backward compatibility. But you should use "/" or ".".
Delimiter string `mapstructure:"DELIMITER" yaml:"Delimiter" default:"_"` Delimiter string `mapstructure:"DELIMITER" yaml:"Delimiter" default:"_"`
} }
@@ -31,7 +31,7 @@ func NewProviderConfig() Provider {
Kubernetes: Kubernetes{ Kubernetes: Kubernetes{
QPS: 5, QPS: 5,
Burst: 10, Burst: 10,
Delimiter: "_", //Delimiter used for namespace/resource type/name resolution. Defaults to "_" for backward compatibility. But you should use "/" or ".". Delimiter: "_",
}, },
} }
} }

View File

@@ -7,7 +7,7 @@ import (
"log/slog" "log/slog"
) )
func (p *DockerClassicProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) { func (p *Provider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
spec, err := p.Client.ContainerInspect(ctx, name) spec, err := p.Client.ContainerInspect(ctx, name)
if err != nil { if err != nil {
return sablier.InstanceInfo{}, fmt.Errorf("cannot inspect container: %w", err) return sablier.InstanceInfo{}, fmt.Errorf("cannot inspect container: %w", err)

View File

@@ -263,7 +263,7 @@ func TestDockerClassicProvider_GetState(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := docker.NewDockerClassicProvider(ctx, c.client, slogt.New(t)) p, err := docker.New(ctx, c.client, slogt.New(t))
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)

View File

@@ -11,7 +11,7 @@ import (
"strings" "strings"
) )
func (p *DockerClassicProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) { func (p *Provider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs() args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
@@ -49,7 +49,7 @@ func containerToInstance(c dockertypes.Container) sablier.InstanceConfiguration
} }
} }
func (p *DockerClassicProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { func (p *Provider) InstanceGroups(ctx context.Context) (map[string][]string, error) {
args := filters.NewArgs() args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))

View File

@@ -18,7 +18,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
ctx := t.Context() ctx := t.Context()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := docker.NewDockerClassicProvider(ctx, dind.client, slogt.New(t)) p, err := docker.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
c1, err := dind.CreateMimic(ctx, MimicOptions{ c1, err := dind.CreateMimic(ctx, MimicOptions{
@@ -77,7 +77,7 @@ func TestDockerClassicProvider_GetGroups(t *testing.T) {
ctx := t.Context() ctx := t.Context()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := docker.NewDockerClassicProvider(ctx, dind.client, slogt.New(t)) p, err := docker.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
c1, err := dind.CreateMimic(ctx, MimicOptions{ c1, err := dind.CreateMimic(ctx, MimicOptions{

View File

@@ -6,7 +6,7 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
) )
func (p *DockerClassicProvider) InstanceStart(ctx context.Context, name string) error { func (p *Provider) InstanceStart(ctx context.Context, name string) error {
// TODO: InstanceStart should block until the container is ready. // TODO: InstanceStart should block until the container is ready.
err := p.Client.ContainerStart(ctx, name, container.StartOptions{}) err := p.Client.ContainerStart(ctx, name, container.StartOptions{})
if err != nil { if err != nil {

View File

@@ -47,7 +47,7 @@ func TestDockerClassicProvider_Start(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := docker.NewDockerClassicProvider(ctx, c.client, slogt.New(t)) p, err := docker.New(ctx, c.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
name, err := tt.args.do(c) name, err := tt.args.do(c)

View File

@@ -7,7 +7,7 @@ import (
"log/slog" "log/slog"
) )
func (p *DockerClassicProvider) InstanceStop(ctx context.Context, name string) error { func (p *Provider) InstanceStop(ctx context.Context, name string) error {
p.l.DebugContext(ctx, "stopping container", slog.String("name", name)) p.l.DebugContext(ctx, "stopping container", slog.String("name", name))
err := p.Client.ContainerStop(ctx, name, container.StopOptions{}) err := p.Client.ContainerStop(ctx, name, container.StopOptions{})
if err != nil { if err != nil {

View File

@@ -57,7 +57,7 @@ func TestDockerClassicProvider_Stop(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := docker.NewDockerClassicProvider(ctx, c.client, slogt.New(t)) p, err := docker.New(ctx, c.client, slogt.New(t))
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)

View File

@@ -9,15 +9,15 @@ import (
) )
// Interface guard // Interface guard
var _ sablier.Provider = (*DockerClassicProvider)(nil) var _ sablier.Provider = (*Provider)(nil)
type DockerClassicProvider struct { type Provider struct {
Client client.APIClient Client client.APIClient
desiredReplicas int32 desiredReplicas int32
l *slog.Logger l *slog.Logger
} }
func NewDockerClassicProvider(ctx context.Context, cli *client.Client, logger *slog.Logger) (*DockerClassicProvider, error) { func New(ctx context.Context, cli *client.Client, logger *slog.Logger) (*Provider, error) {
logger = logger.With(slog.String("provider", "docker")) logger = logger.With(slog.String("provider", "docker"))
serverVersion, err := cli.ServerVersion(ctx) serverVersion, err := cli.ServerVersion(ctx)
@@ -29,7 +29,7 @@ func NewDockerClassicProvider(ctx context.Context, cli *client.Client, logger *s
slog.String("version", serverVersion.Version), slog.String("version", serverVersion.Version),
slog.String("api_version", serverVersion.APIVersion), slog.String("api_version", serverVersion.APIVersion),
) )
return &DockerClassicProvider{ return &Provider{
Client: cli, Client: cli,
desiredReplicas: 1, desiredReplicas: 1,
l: logger, l: logger,

View File

@@ -10,7 +10,7 @@ import (
"strings" "strings"
) )
func (p *DockerClassicProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { func (p *Provider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) {
msgs, errs := p.Client.Events(ctx, events.ListOptions{ msgs, errs := p.Client.Events(ctx, events.ListOptions{
Filters: filters.NewArgs( Filters: filters.NewArgs(
filters.Arg("scope", "local"), filters.Arg("scope", "local"),

View File

@@ -18,7 +18,7 @@ func TestDockerClassicProvider_NotifyInstanceStopped(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
defer cancel() defer cancel()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := docker.NewDockerClassicProvider(ctx, dind.client, slogt.New(t)) p, err := docker.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
c, err := dind.CreateMimic(ctx, MimicOptions{}) c, err := dind.CreateMimic(ctx, MimicOptions{})

View File

@@ -14,16 +14,16 @@ import (
) )
// Interface guard // Interface guard
var _ sablier.Provider = (*DockerSwarmProvider)(nil) var _ sablier.Provider = (*Provider)(nil)
type DockerSwarmProvider struct { type Provider struct {
Client client.APIClient Client client.APIClient
desiredReplicas int32 desiredReplicas int32
l *slog.Logger l *slog.Logger
} }
func NewDockerSwarmProvider(ctx context.Context, cli *client.Client, logger *slog.Logger) (*DockerSwarmProvider, error) { func New(ctx context.Context, cli *client.Client, logger *slog.Logger) (*Provider, error) {
logger = logger.With(slog.String("provider", "swarm")) logger = logger.With(slog.String("provider", "swarm"))
serverVersion, err := cli.ServerVersion(ctx) serverVersion, err := cli.ServerVersion(ctx)
@@ -36,7 +36,7 @@ func NewDockerSwarmProvider(ctx context.Context, cli *client.Client, logger *slo
slog.String("api_version", serverVersion.APIVersion), slog.String("api_version", serverVersion.APIVersion),
) )
return &DockerSwarmProvider{ return &Provider{
Client: cli, Client: cli,
desiredReplicas: 1, desiredReplicas: 1,
l: logger, l: logger,
@@ -44,7 +44,7 @@ func NewDockerSwarmProvider(ctx context.Context, cli *client.Client, logger *slo
} }
func (p *DockerSwarmProvider) ServiceUpdateReplicas(ctx context.Context, name string, replicas uint64) error { func (p *Provider) ServiceUpdateReplicas(ctx context.Context, name string, replicas uint64) error {
service, err := p.getServiceByName(name, ctx) service, err := p.getServiceByName(name, ctx)
if err != nil { if err != nil {
return err return err
@@ -69,7 +69,7 @@ func (p *DockerSwarmProvider) ServiceUpdateReplicas(ctx context.Context, name st
return nil return nil
} }
func (p *DockerSwarmProvider) getInstanceName(name string, service swarm.Service) string { func (p *Provider) getInstanceName(name string, service swarm.Service) string {
if name == service.Spec.Name { if name == service.Spec.Name {
return name return name
} }

View File

@@ -9,7 +9,7 @@ import (
"log/slog" "log/slog"
) )
func (p *DockerSwarmProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { func (p *Provider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) {
msgs, errs := p.Client.Events(ctx, events.ListOptions{ msgs, errs := p.Client.Events(ctx, events.ListOptions{
Filters: filters.NewArgs( Filters: filters.NewArgs(
filters.Arg("scope", "swarm"), filters.Arg("scope", "swarm"),

View File

@@ -18,7 +18,7 @@ func TestDockerSwarmProvider_NotifyInstanceStopped(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
defer cancel() defer cancel()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := dockerswarm.NewDockerSwarmProvider(ctx, dind.client, slogt.New(t)) p, err := dockerswarm.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
c, err := dind.CreateMimic(ctx, MimicOptions{}) c, err := dind.CreateMimic(ctx, MimicOptions{})

View File

@@ -10,7 +10,7 @@ import (
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) { func (p *Provider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
service, err := p.getServiceByName(name, ctx) service, err := p.getServiceByName(name, ctx)
if err != nil { if err != nil {
return sablier.InstanceInfo{}, err return sablier.InstanceInfo{}, err
@@ -29,7 +29,7 @@ func (p *DockerSwarmProvider) InstanceInspect(ctx context.Context, name string)
return sablier.ReadyInstanceState(foundName, p.desiredReplicas), nil return sablier.ReadyInstanceState(foundName, p.desiredReplicas), nil
} }
func (p *DockerSwarmProvider) getServiceByName(name string, ctx context.Context) (*swarm.Service, error) { func (p *Provider) getServiceByName(name string, ctx context.Context) (*swarm.Service, error) {
opts := types.ServiceListOptions{ opts := types.ServiceListOptions{
Filters: filters.NewArgs(), Filters: filters.NewArgs(),
Status: true, Status: true,

View File

@@ -127,7 +127,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := dockerswarm.NewDockerSwarmProvider(ctx, c.client, slogt.New(t)) p, err := dockerswarm.New(ctx, c.client, slogt.New(t))
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)
@@ -135,7 +135,7 @@ func TestDockerSwarmProvider_GetState(t *testing.T) {
tt.want.Name = name tt.want.Name = name
got, err := p.InstanceInspect(ctx, name) got, err := p.InstanceInspect(ctx, name)
if !cmp.Equal(err, tt.wantErr) { if !cmp.Equal(err, tt.wantErr) {
t.Errorf("DockerSwarmProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Provider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
assert.DeepEqual(t, got, tt.want) assert.DeepEqual(t, got, tt.want)

View File

@@ -10,7 +10,7 @@ import (
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) { func (p *Provider) InstanceList(ctx context.Context, _ provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
args := filters.NewArgs() args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) args.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))
args.Add("mode", "replicated") args.Add("mode", "replicated")
@@ -32,7 +32,7 @@ func (p *DockerSwarmProvider) InstanceList(ctx context.Context, _ provider.Insta
return instances, nil return instances, nil
} }
func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) { func (p *Provider) serviceToInstance(s swarm.Service) (i sablier.InstanceConfiguration) {
var group string var group string
if _, ok := s.Spec.Labels["sablier.enable"]; ok { if _, ok := s.Spec.Labels["sablier.enable"]; ok {
@@ -49,7 +49,7 @@ func (p *DockerSwarmProvider) serviceToInstance(s swarm.Service) (i sablier.Inst
} }
} }
func (p *DockerSwarmProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { func (p *Provider) InstanceGroups(ctx context.Context) (map[string][]string, error) {
f := filters.NewArgs() f := filters.NewArgs()
f.Add("label", fmt.Sprintf("%s=true", "sablier.enable")) f.Add("label", fmt.Sprintf("%s=true", "sablier.enable"))

View File

@@ -20,7 +20,7 @@ func TestDockerClassicProvider_InstanceList(t *testing.T) {
ctx := t.Context() ctx := t.Context()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := dockerswarm.NewDockerSwarmProvider(ctx, dind.client, slogt.New(t)) p, err := dockerswarm.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
s1, err := dind.CreateMimic(ctx, MimicOptions{ s1, err := dind.CreateMimic(ctx, MimicOptions{
@@ -77,7 +77,7 @@ func TestDockerClassicProvider_GetGroups(t *testing.T) {
ctx := t.Context() ctx := t.Context()
dind := setupDinD(t, ctx) dind := setupDinD(t, ctx)
p, err := dockerswarm.NewDockerSwarmProvider(ctx, dind.client, slogt.New(t)) p, err := dockerswarm.New(ctx, dind.client, slogt.New(t))
assert.NilError(t, err) assert.NilError(t, err)
s1, err := dind.CreateMimic(ctx, MimicOptions{ s1, err := dind.CreateMimic(ctx, MimicOptions{

View File

@@ -2,6 +2,6 @@ package dockerswarm
import "context" import "context"
func (p *DockerSwarmProvider) InstanceStart(ctx context.Context, name string) error { func (p *Provider) InstanceStart(ctx context.Context, name string) error {
return p.ServiceUpdateReplicas(ctx, name, uint64(p.desiredReplicas)) return p.ServiceUpdateReplicas(ctx, name, uint64(p.desiredReplicas))
} }

View File

@@ -126,7 +126,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := dockerswarm.NewDockerSwarmProvider(ctx, c.client, slogt.New(t)) p, err := dockerswarm.New(ctx, c.client, slogt.New(t))
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)
@@ -134,7 +134,7 @@ func TestDockerSwarmProvider_Start(t *testing.T) {
tt.want.Name = name tt.want.Name = name
err = p.InstanceStart(ctx, name) err = p.InstanceStart(ctx, name)
if !cmp.Equal(err, tt.wantErr) { if !cmp.Equal(err, tt.wantErr) {
t.Errorf("DockerSwarmProvider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Provider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr)
return return
} }

View File

@@ -2,6 +2,6 @@ package dockerswarm
import "context" import "context"
func (p *DockerSwarmProvider) InstanceStop(ctx context.Context, name string) error { func (p *Provider) InstanceStop(ctx context.Context, name string) error {
return p.ServiceUpdateReplicas(ctx, name, 0) return p.ServiceUpdateReplicas(ctx, name, 0)
} }

View File

@@ -94,7 +94,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := dockerswarm.NewDockerSwarmProvider(ctx, c.client, slogt.New(t)) p, err := dockerswarm.New(ctx, c.client, slogt.New(t))
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)
@@ -102,7 +102,7 @@ func TestDockerSwarmProvider_Stop(t *testing.T) {
tt.want.Name = name tt.want.Name = name
err = p.InstanceStop(ctx, name) err = p.InstanceStop(ctx, name)
if !cmp.Equal(err, tt.wantErr) { if !cmp.Equal(err, tt.wantErr) {
t.Errorf("DockerSwarmProvider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Provider.InstanceStop() error = %v, wantErr %v", err, tt.wantErr)
return return
} }

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
) )
func (p *KubernetesProvider) watchDeployents(instance chan<- string) cache.SharedIndexInformer { func (p *Provider) watchDeployents(instance chan<- string) cache.SharedIndexInformer {
handler := cache.ResourceEventHandlerFuncs{ handler := cache.ResourceEventHandlerFuncs{
UpdateFunc: func(old, new interface{}) { UpdateFunc: func(old, new interface{}) {
newDeployment := new.(*appsv1.Deployment) newDeployment := new.(*appsv1.Deployment)

View File

@@ -7,7 +7,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) DeploymentInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) { func (p *Provider) DeploymentInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) d, err := p.Client.AppsV1().Deployments(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil { if err != nil {
return sablier.InstanceInfo{}, fmt.Errorf("error getting deployment: %w", err) return sablier.InstanceInfo{}, fmt.Errorf("error getting deployment: %w", err)

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -118,7 +118,7 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := kubernetes.NewKubernetesProvider(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)
@@ -126,7 +126,7 @@ func TestKubernetesProvider_DeploymentInspect(t *testing.T) {
tt.want.Name = name tt.want.Name = name
got, err := p.InstanceInspect(ctx, name) got, err := p.InstanceInspect(ctx, name)
if !cmp.Equal(err, tt.wantErr) { if !cmp.Equal(err, tt.wantErr) {
t.Errorf("KubernetesProvider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Provider.InstanceInspect() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
assert.DeepEqual(t, got, tt.want) assert.DeepEqual(t, got, tt.want)

View File

@@ -8,7 +8,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) { func (p *Provider) DeploymentList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
"sablier.enable": "true", "sablier.enable": "true",
@@ -30,7 +30,7 @@ func (p *KubernetesProvider) DeploymentList(ctx context.Context) ([]sablier.Inst
return instances, nil return instances, nil
} }
func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration { func (p *Provider) deploymentToInstance(d *v1.Deployment) sablier.InstanceConfiguration {
var group string var group string
if _, ok := d.Labels["sablier.enable"]; ok { if _, ok := d.Labels["sablier.enable"]; ok {
@@ -49,7 +49,7 @@ func (p *KubernetesProvider) deploymentToInstance(d *v1.Deployment) sablier.Inst
} }
} }
func (p *KubernetesProvider) DeploymentGroups(ctx context.Context) (map[string][]string, error) { func (p *Provider) DeploymentGroups(ctx context.Context) (map[string][]string, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
"sablier.enable": "true", "sablier.enable": "true",

View File

@@ -2,7 +2,7 @@ package kubernetes
import "context" import "context"
func (p *KubernetesProvider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) { func (p *Provider) NotifyInstanceStopped(ctx context.Context, instance chan<- string) {
informer := p.watchDeployents(instance) informer := p.watchDeployents(instance)
go informer.Run(ctx.Done()) go informer.Run(ctx.Done())
informer = p.watchStatefulSets(instance) informer = p.watchStatefulSets(instance)

View File

@@ -3,7 +3,7 @@ package kubernetes_test
import ( import (
"context" "context"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -19,7 +19,7 @@ func TestKubernetesProvider_NotifyInstanceStopped(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
defer cancel() defer cancel()
kind := setupKinD(t, ctx) kind := setupKinD(t, ctx)
p, err := kubernetes.NewKubernetesProvider(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
assert.NilError(t, err) assert.NilError(t, err)
waitC := make(chan string) waitC := make(chan string)

View File

@@ -6,7 +6,7 @@ import (
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *KubernetesProvider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) { func (p *Provider) InstanceInspect(ctx context.Context, name string) (sablier.InstanceInfo, error) {
parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter})
if err != nil { if err != nil {
return sablier.InstanceInfo{}, err return sablier.InstanceInfo{}, err

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
@@ -43,7 +43,7 @@ func TestKubernetesProvider_InstanceInspect(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := kubernetes.NewKubernetesProvider(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
_, err = p.InstanceInspect(ctx, tt.args.name) _, err = p.InstanceInspect(ctx, tt.args.name)
assert.Error(t, err, tt.want.Error()) assert.Error(t, err, tt.want.Error())

View File

@@ -6,7 +6,7 @@ import (
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
) )
func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) { func (p *Provider) InstanceList(ctx context.Context, options provider.InstanceListOptions) ([]sablier.InstanceConfiguration, error) {
deployments, err := p.DeploymentList(ctx) deployments, err := p.DeploymentList(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -20,7 +20,7 @@ func (p *KubernetesProvider) InstanceList(ctx context.Context, options provider.
return append(deployments, statefulSets...), nil return append(deployments, statefulSets...), nil
} }
func (p *KubernetesProvider) InstanceGroups(ctx context.Context) (map[string][]string, error) { func (p *Provider) InstanceGroups(ctx context.Context) (map[string][]string, error) {
deployments, err := p.DeploymentGroups(ctx) deployments, err := p.DeploymentGroups(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -2,7 +2,7 @@ package kubernetes_test
import ( import (
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider" "github.com/sablierapp/sablier/pkg/provider"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
@@ -19,7 +19,7 @@ func TestKubernetesProvider_InstanceList(t *testing.T) {
ctx := t.Context() ctx := t.Context()
kind := setupKinD(t, ctx) kind := setupKinD(t, ctx)
p, err := kubernetes.NewKubernetesProvider(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
assert.NilError(t, err) assert.NilError(t, err)
d1, err := kind.CreateMimicDeployment(ctx, MimicOptions{ d1, err := kind.CreateMimicDeployment(ctx, MimicOptions{
@@ -93,7 +93,7 @@ func TestKubernetesProvider_InstanceGroups(t *testing.T) {
ctx := t.Context() ctx := t.Context()
kind := setupKinD(t, ctx) kind := setupKinD(t, ctx)
p, err := kubernetes.NewKubernetesProvider(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
assert.NilError(t, err) assert.NilError(t, err)
d1, err := kind.CreateMimicDeployment(ctx, MimicOptions{ d1, err := kind.CreateMimicDeployment(ctx, MimicOptions{

View File

@@ -2,7 +2,7 @@ package kubernetes
import "context" import "context"
func (p *KubernetesProvider) InstanceStart(ctx context.Context, name string) error { func (p *Provider) InstanceStart(ctx context.Context, name string) error {
parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter})
if err != nil { if err != nil {
return err return err

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
@@ -92,7 +92,7 @@ func TestKubernetesProvider_InstanceStart(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := kubernetes.NewKubernetesProvider(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
assert.NilError(t, err) assert.NilError(t, err)
name, err := tt.args.do(kind) name, err := tt.args.do(kind)

View File

@@ -2,7 +2,7 @@ package kubernetes
import "context" import "context"
func (p *KubernetesProvider) InstanceStop(ctx context.Context, name string) error { func (p *Provider) InstanceStop(ctx context.Context, name string) error {
parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter}) parsed, err := ParseName(name, ParseOptions{Delimiter: p.delimiter})
if err != nil { if err != nil {
return err return err

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"testing" "testing"
@@ -92,7 +92,7 @@ func TestKubernetesProvider_InstanceStop(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := kubernetes.NewKubernetesProvider(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, kind.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
assert.NilError(t, err) assert.NilError(t, err)
name, err := tt.args.do(kind) name, err := tt.args.do(kind)

View File

@@ -2,23 +2,23 @@ package kubernetes
import ( import (
"context" "context"
providerConfig "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
"log/slog" "log/slog"
providerConfig "github.com/sablierapp/sablier/config"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
) )
// Interface guard // Interface guard
var _ sablier.Provider = (*KubernetesProvider)(nil) var _ sablier.Provider = (*Provider)(nil)
type KubernetesProvider struct { type Provider struct {
Client kubernetes.Interface Client kubernetes.Interface
delimiter string delimiter string
l *slog.Logger l *slog.Logger
} }
func NewKubernetesProvider(ctx context.Context, client *kubernetes.Clientset, logger *slog.Logger, kubeclientConfig providerConfig.Kubernetes) (*KubernetesProvider, error) { func New(ctx context.Context, client *kubernetes.Clientset, logger *slog.Logger, config providerConfig.Kubernetes) (*Provider, error) {
logger = logger.With(slog.String("provider", "kubernetes")) logger = logger.With(slog.String("provider", "kubernetes"))
info, err := client.ServerVersion() info, err := client.ServerVersion()
@@ -28,13 +28,13 @@ func NewKubernetesProvider(ctx context.Context, client *kubernetes.Clientset, lo
logger.InfoContext(ctx, "connection established with kubernetes", logger.InfoContext(ctx, "connection established with kubernetes",
slog.String("version", info.String()), slog.String("version", info.String()),
slog.Float64("config.qps", float64(kubeclientConfig.QPS)), slog.Float64("config.qps", float64(config.QPS)),
slog.Int("config.burst", kubeclientConfig.Burst), slog.Int("config.burst", config.Burst),
) )
return &KubernetesProvider{ return &Provider{
Client: client, Client: client,
delimiter: kubeclientConfig.Delimiter, delimiter: config.Delimiter,
l: logger, l: logger,
}, nil }, nil

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
) )
func (p *KubernetesProvider) watchStatefulSets(instance chan<- string) cache.SharedIndexInformer { func (p *Provider) watchStatefulSets(instance chan<- string) cache.SharedIndexInformer {
handler := cache.ResourceEventHandlerFuncs{ handler := cache.ResourceEventHandlerFuncs{
UpdateFunc: func(old, new interface{}) { UpdateFunc: func(old, new interface{}) {
newStatefulSet := new.(*appsv1.StatefulSet) newStatefulSet := new.(*appsv1.StatefulSet)

View File

@@ -6,7 +6,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) StatefulSetInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) { func (p *Provider) StatefulSetInspect(ctx context.Context, config ParsedName) (sablier.InstanceInfo, error) {
ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{}) ss, err := p.Client.AppsV1().StatefulSets(config.Namespace).Get(ctx, config.Name, metav1.GetOptions{})
if err != nil { if err != nil {
return sablier.InstanceInfo{}, err return sablier.InstanceInfo{}, err

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/config" "github.com/sablierapp/sablier/pkg/config"
"github.com/sablierapp/sablier/pkg/provider/kubernetes" "github.com/sablierapp/sablier/pkg/provider/kubernetes"
"github.com/sablierapp/sablier/pkg/sablier" "github.com/sablierapp/sablier/pkg/sablier"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@@ -118,7 +118,7 @@ func TestKubernetesProvider_InspectStatefulSet(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
p, err := kubernetes.NewKubernetesProvider(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes) p, err := kubernetes.New(ctx, c.client, slogt.New(t), config.NewProviderConfig().Kubernetes)
name, err := tt.args.do(c) name, err := tt.args.do(c)
assert.NilError(t, err) assert.NilError(t, err)

View File

@@ -8,7 +8,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
) )
func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) { func (p *Provider) StatefulSetList(ctx context.Context) ([]sablier.InstanceConfiguration, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
"sablier.enable": "true", "sablier.enable": "true",
@@ -30,7 +30,7 @@ func (p *KubernetesProvider) StatefulSetList(ctx context.Context) ([]sablier.Ins
return instances, nil return instances, nil
} }
func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration { func (p *Provider) statefulSetToInstance(ss *v1.StatefulSet) sablier.InstanceConfiguration {
var group string var group string
if _, ok := ss.Labels["sablier.enable"]; ok { if _, ok := ss.Labels["sablier.enable"]; ok {
@@ -49,7 +49,7 @@ func (p *KubernetesProvider) statefulSetToInstance(ss *v1.StatefulSet) sablier.I
} }
} }
func (p *KubernetesProvider) StatefulSetGroups(ctx context.Context) (map[string][]string, error) { func (p *Provider) StatefulSetGroups(ctx context.Context) (map[string][]string, error) {
labelSelector := metav1.LabelSelector{ labelSelector := metav1.LabelSelector{
MatchLabels: map[string]string{ MatchLabels: map[string]string{
"sablier.enable": "true", "sablier.enable": "true",

View File

@@ -12,7 +12,7 @@ type Workload interface {
UpdateScale(ctx context.Context, workloadName string, scale *autoscalingv1.Scale, opts metav1.UpdateOptions) (*autoscalingv1.Scale, error) UpdateScale(ctx context.Context, workloadName string, scale *autoscalingv1.Scale, opts metav1.UpdateOptions) (*autoscalingv1.Scale, error)
} }
func (p *KubernetesProvider) scale(ctx context.Context, config ParsedName, replicas int32) error { func (p *Provider) scale(ctx context.Context, config ParsedName, replicas int32) error {
var workload Workload var workload Workload
switch config.Kind { switch config.Kind {

View File

@@ -0,0 +1,26 @@
package sablier
import (
"context"
"log/slog"
"time"
)
func (s *sablier) GroupWatch(ctx context.Context) {
// This should be changed to event based instead of polling.
ticker := time.NewTicker(2 * time.Second)
for {
select {
case <-ctx.Done():
s.l.InfoContext(ctx, "stop watching groups", slog.Any("reason", ctx.Err()))
return
case <-ticker.C:
groups, err := s.provider.InstanceGroups(ctx)
if err != nil {
s.l.ErrorContext(ctx, "cannot retrieve group from provider", slog.Any("reason", err))
} else if groups != nil {
s.SetGroups(groups)
}
}
}
}

View File

@@ -0,0 +1,18 @@
package sablier
import (
"context"
"log/slog"
)
func OnInstanceExpired(ctx context.Context, provider Provider, logger *slog.Logger) func(string) {
return func(_key string) {
go func(key string) {
logger.InfoContext(ctx, "instance expired", slog.String("instance", key))
err := provider.InstanceStop(ctx, key)
if err != nil {
logger.ErrorContext(ctx, "instance expired could not be stopped from provider", slog.String("instance", key), slog.Any("error", err))
}
}(_key)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"log/slog" "log/slog"
"sync"
"time" "time"
) )
@@ -18,11 +19,14 @@ type Sablier interface {
RemoveInstance(ctx context.Context, name string) error RemoveInstance(ctx context.Context, name string) error
SetGroups(groups map[string][]string) SetGroups(groups map[string][]string)
StopAllUnregisteredInstances(ctx context.Context) error StopAllUnregisteredInstances(ctx context.Context) error
GroupWatch(ctx context.Context)
} }
type sablier struct { type sablier struct {
provider Provider provider Provider
sessions Store sessions Store
groupsMu sync.RWMutex
groups map[string][]string groups map[string][]string
l *slog.Logger l *slog.Logger
@@ -32,12 +36,15 @@ func New(logger *slog.Logger, store Store, provider Provider) Sablier {
return &sablier{ return &sablier{
provider: provider, provider: provider,
sessions: store, sessions: store,
groupsMu: sync.RWMutex{},
groups: map[string][]string{}, groups: map[string][]string{},
l: logger, l: logger,
} }
} }
func (s *sablier) SetGroups(groups map[string][]string) { func (s *sablier) SetGroups(groups map[string][]string) {
s.groupsMu.Lock()
defer s.groupsMu.Unlock()
if groups == nil { if groups == nil {
groups = map[string][]string{} groups = map[string][]string{}
} }

View File

@@ -42,6 +42,18 @@ func (m *MockSablier) EXPECT() *MockSablierMockRecorder {
return m.recorder return m.recorder
} }
// GroupWatch mocks base method.
func (m *MockSablier) GroupWatch(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "GroupWatch", ctx)
}
// GroupWatch indicates an expected call of GroupWatch.
func (mr *MockSablierMockRecorder) GroupWatch(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupWatch", reflect.TypeOf((*MockSablier)(nil).GroupWatch), ctx)
}
// RemoveInstance mocks base method. // RemoveInstance mocks base method.
func (m *MockSablier) RemoveInstance(ctx context.Context, name string) error { func (m *MockSablier) RemoveInstance(ctx context.Context, name string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -2,10 +2,10 @@ package theme
import ( import (
"fmt" "fmt"
"github.com/sablierapp/sablier/pkg/version"
"io" "io"
"github.com/sablierapp/sablier/pkg/durations" "github.com/sablierapp/sablier/pkg/durations"
"github.com/sablierapp/sablier/version"
) )
func (t *Themes) Render(name string, opts Options, writer io.Writer) error { func (t *Themes) Render(name string, opts Options, writer io.Writer) error {

View File

@@ -5,13 +5,12 @@ import (
"fmt" "fmt"
"github.com/neilotoole/slogt" "github.com/neilotoole/slogt"
"github.com/sablierapp/sablier/pkg/theme" "github.com/sablierapp/sablier/pkg/theme"
"github.com/sablierapp/sablier/pkg/version"
"log/slog" "log/slog"
"os" "os"
"testing" "testing"
"testing/fstest" "testing/fstest"
"time" "time"
"github.com/sablierapp/sablier/version"
) )
var ( var (