diff --git a/app/http/pages/render.go b/app/http/pages/render.go index e687b0b..8b031d6 100644 --- a/app/http/pages/render.go +++ b/app/http/pages/render.go @@ -30,7 +30,11 @@ type RenderOptions struct { RefreshFrequency time.Duration Theme string CustomThemes fs.FS - Version string + // If custom theme is loaded through os.DirFS, nothing prevents you + // from escaping the prefix with relative path such as .. + // The `AllowedCustomThemes` are the themes that were scanned during initilization + AllowedCustomThemes map[string]bool + Version string } type TemplateValues struct { @@ -46,13 +50,10 @@ func Render(options RenderOptions, writer io.Writer) error { var err error // Load custom theme if provided - if options.CustomThemes != nil { + if options.CustomThemes != nil && options.AllowedCustomThemes[options.Theme] { tpl, err = template.ParseFS(options.CustomThemes, fmt.Sprintf("%s.html", options.Theme)) - } - - // TODO: Optimize this so we don't have to fallback but instead know if it's a embedded theme or custom theme. - if options.CustomThemes == nil || err != nil { - // Load embedded themes if the custom theme + } else { + // Load from the embedded FS tpl, err = template.ParseFS(themes, fmt.Sprintf("themes/%s.html", options.Theme)) } diff --git a/app/http/pages/render_test.go b/app/http/pages/render_test.go index 505fee4..10c1cfd 100644 --- a/app/http/pages/render_test.go +++ b/app/http/pages/render_test.go @@ -129,6 +129,10 @@ func TestRender(t *testing.T) { "marvel.html": {Data: []byte("{{ .DisplayName }}")}, "dc-comics.html": {Data: []byte("batman")}, }, + AllowedCustomThemes: map[string]bool{ + "marvel": true, + "dc-comics": true, + }, Version: "v0.0.0", }, }, @@ -147,6 +151,10 @@ func TestRender(t *testing.T) { "marvel.html": {Data: []byte("thor")}, "dc-comics.html": {Data: []byte("batman")}, }, + AllowedCustomThemes: map[string]bool{ + "marvel": true, + "dc-comics": true, + }, Version: "v0.0.0", }, }, @@ -165,11 +173,36 @@ func TestRender(t *testing.T) { "marvel.html": {Data: []byte("thor")}, "dc-comics.html": {Data: []byte("batman")}, }, + AllowedCustomThemes: map[string]bool{ + "marvel": true, + "dc-comics": true, + }, Version: "v0.0.0", }, }, wantErr: false, }, + { + name: "Error loading non allowed custom theme", + args: args{ + options: RenderOptions{ + DisplayName: "Test", + InstanceStates: instanceStates, + Theme: "dc-comics", + SessionDuration: 10 * time.Minute, + RefreshFrequency: 5 * time.Second, + CustomThemes: fstest.MapFS{ + "marvel.html": {Data: []byte("thor")}, + "dc-comics.html": {Data: []byte("batman")}, + }, + AllowedCustomThemes: map[string]bool{ + "marvel": true, + }, + Version: "v0.0.0", + }, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/app/http/routes/strategies.go b/app/http/routes/strategies.go index 7f23f37..543745d 100644 --- a/app/http/routes/strategies.go +++ b/app/http/routes/strategies.go @@ -2,7 +2,9 @@ package routes import ( "fmt" + "io/fs" "net/http" + "os" "sort" "strings" "time" @@ -18,13 +20,32 @@ import ( "github.com/gin-gonic/gin" ) +var osDirFS = os.DirFS + type ServeStrategy struct { + customThemesFS fs.FS + customThemes map[string]bool + SessionsManager sessions.Manager StrategyConfig config.Strategy } -// ServeDynamic returns a waiting page displaying the session request if the session is not ready -// If the session is ready, returns a redirect 307 with an arbitrary location +func NewServeStrategy(sessionsManager sessions.Manager, conf config.Strategy) *ServeStrategy { + + serveStrategy := &ServeStrategy{ + SessionsManager: sessionsManager, + StrategyConfig: conf, + } + + if conf.Dynamic.CustomThemesPath != "" { + customThemesFs := osDirFS(conf.Dynamic.CustomThemesPath) + serveStrategy.customThemesFS = customThemesFs + serveStrategy.customThemes = loadAllowedThemes(customThemesFs) + } + + return serveStrategy +} + func (s *ServeStrategy) ServeDynamic(c *gin.Context) { request := models.DynamicRequest{ Theme: s.StrategyConfig.Dynamic.DefaultTheme, @@ -45,12 +66,14 @@ func (s *ServeStrategy) ServeDynamic(c *gin.Context) { } renderOptions := pages.RenderOptions{ - DisplayName: request.DisplayName, - SessionDuration: request.SessionDuration, - Theme: request.Theme, - Version: version.Version, - RefreshFrequency: 5 * time.Second, - InstanceStates: sessionStateToRenderOptionsInstanceState(sessionState), + DisplayName: request.DisplayName, + SessionDuration: request.SessionDuration, + Theme: request.Theme, + CustomThemes: s.customThemesFS, + AllowedCustomThemes: s.customThemes, + Version: version.Version, + RefreshFrequency: 5 * time.Second, + InstanceStates: sessionStateToRenderOptionsInstanceState(sessionState), } c.Header("Content-Type", "text/html") @@ -118,3 +141,24 @@ func instanceStateToRenderOptionsRequestState(instanceState *instance.State) pag Error: err, } } + +func loadAllowedThemes(dir fs.FS) (allowedThemes map[string]bool) { + fs.WalkDir(dir, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + if strings.HasSuffix(d.Name(), ".html") { + log.Debugf("found theme at \"%s\" can be loaded using \"%s\"", path, strings.TrimSuffix(path, ".html")) + allowedThemes[strings.TrimSuffix(path, ".html")] = true + } else { + log.Tracef("ignoring file \"%s\" because it has no .html suffix", path) + } + return nil + }) + return +} diff --git a/app/http/routes/strategies_test.go b/app/http/routes/strategies_test.go index 1e617d0..6c89112 100644 --- a/app/http/routes/strategies_test.go +++ b/app/http/routes/strategies_test.go @@ -4,11 +4,14 @@ import ( "bytes" "encoding/json" "io" + "io/fs" "net/http" "net/http/httptest" "net/url" + "reflect" "sync" "testing" + "testing/fstest" "time" "github.com/acouvreur/sablier/app/http/routes/models" @@ -176,3 +179,102 @@ func createMap(instances []*instance.State) (store *sync.Map) { return } + +func TestNewServeStrategy(t *testing.T) { + type args struct { + sessionsManager sessions.Manager + conf config.Strategy + } + tests := []struct { + name string + args args + osDirFS fs.FS + want map[string]bool + }{ + { + name: "load custom themes", + args: args{ + sessionsManager: &SessionsManagerMock{}, + conf: config.Strategy{ + Dynamic: config.DynamicStrategy{ + CustomThemesPath: "my/path/to/themes", + }, + }, + }, + osDirFS: fstest.MapFS{ + "my/path/to/themes/marvel.html": {Data: []byte("thor")}, + "my/path/to/themes/dc-comics.html": {Data: []byte("batman")}, + }, + want: map[string]bool{ + "marvel": true, + "dc-comics": true, + }, + }, + { + name: "load custom themes recursively", + args: args{ + sessionsManager: &SessionsManagerMock{}, + conf: config.Strategy{ + Dynamic: config.DynamicStrategy{ + CustomThemesPath: "my/path/to/themes", + }, + }, + }, + osDirFS: fstest.MapFS{ + "my/path/to/themes/marvel.html": {Data: []byte("thor")}, + "my/path/to/themes/dc-comics.html": {Data: []byte("batman")}, + "my/path/to/themes/inner/dc-comics.html": {Data: []byte("batman")}, + }, + want: map[string]bool{ + "marvel": true, + "dc-comics": true, + "inner/dc-comics": true, + }, + }, + { + name: "do not load custom themes outside of path", + args: args{ + sessionsManager: &SessionsManagerMock{}, + conf: config.Strategy{ + Dynamic: config.DynamicStrategy{ + CustomThemesPath: "my/path/to/themes", + }, + }, + }, + osDirFS: fstest.MapFS{ + "my/path/to/superman.html": {Data: []byte("superman")}, + "my/path/to/themes/marvel.html": {Data: []byte("thor")}, + "my/path/to/themes/dc-comics.html": {Data: []byte("batman")}, + "my/path/to/themes/inner/dc-comics.html": {Data: []byte("batman")}, + }, + want: map[string]bool{ + "marvel": true, + "dc-comics": true, + "inner/dc-comics": true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + oldosDirFS := osDirFS + defer func() { osDirFS = oldosDirFS }() + + myOsDirFS := func(dir string) fs.FS { + fs, err := fs.Sub(tt.osDirFS, dir) + + if err != nil { + panic(err) + } + + return fs + } + + osDirFS = myOsDirFS + + if got := NewServeStrategy(tt.args.sessionsManager, tt.args.conf); !reflect.DeepEqual(got.customThemes, tt.want) { + t.Errorf("NewServeStrategy() = %v, want %v", got.customThemes, tt.want) + } + }) + } +} diff --git a/app/http/server.go b/app/http/server.go index bf0e5ca..269a220 100644 --- a/app/http/server.go +++ b/app/http/server.go @@ -21,7 +21,7 @@ func Start(serverConf config.Server, strategyConf config.Strategy, sessionManage { api := base.Group("/api") { - strategy := routes.ServeStrategy{SessionsManager: sessionManager, StrategyConfig: strategyConf} + strategy := routes.NewServeStrategy(sessionManager, strategyConf) api.GET("/strategies/dynamic", strategy.ServeDynamic) api.GET("/strategies/blocking", strategy.ServeBlocking) } diff --git a/cmd/root.go b/cmd/root.go index 4be5daa..f8d79e2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -62,6 +62,8 @@ func init() { viper.BindPFlag("logging.level", rootCmd.PersistentFlags().Lookup("logging.level")) // strategy + startCmd.Flags().StringVar(&conf.Strategy.Dynamic.CustomThemesPath, "strategy.dynamic.custom-themes-path", "", "Custom themes folder, will load all .html files recursively") + viper.BindPFlag("strategy.dynamic.custom-themes-path", startCmd.Flags().Lookup("strategy.dynamic.custom-themes-path")) startCmd.Flags().StringVar(&conf.Strategy.Dynamic.DefaultTheme, "strategy.dynamic.default-theme", "hacker-terminal", "Default theme used for dynamic strategy") viper.BindPFlag("strategy.dynamic.default-theme", startCmd.Flags().Lookup("strategy.dynamic.default-theme")) startCmd.Flags().DurationVar(&conf.Strategy.Dynamic.DefaultRefreshFrequency, "strategy.dynamic.default-refresh-frequency", 5*time.Second, "Default refresh frequency in the HTML page for dynamic strategy") diff --git a/config/strategy.go b/config/strategy.go index 38d9fac..9a329e7 100644 --- a/config/strategy.go +++ b/config/strategy.go @@ -3,6 +3,7 @@ package config import "time" type DynamicStrategy struct { + CustomThemesPath string `mapstructure:"CUSTOMTHEMESPATH" yaml:"customThemesPath"` DefaultTheme string `mapstructure:"DEFAULTTHEME" yaml:"defaultTheme" default:"hacker-terminal"` DefaultRefreshFrequency time.Duration `mapstructure:"DEFAULTREFRESHFREQUENCY" yaml:"defaultRefreshFrequency" default:"5s"` }