From 31063fa1b460d01c8d9ad1dcfcb1746d85db21ce Mon Sep 17 00:00:00 2001 From: Amir Raminfar Date: Tue, 20 Jun 2023 10:43:47 -0700 Subject: [PATCH] feat: removes localhost as a required client. fixes #2259 (#2263) * feat: removes localhost as a required connection * refactors code * fixes tests * adds more tests * adds more tests * refactors * cleans up logs --- docker/client.go | 14 ++-- main.go | 180 ++++++++++++++++++++++++---------------- main_test.go | 113 +++++++++++++++++++++++++ web/routes.go | 8 +- web/routes_auth_test.go | 6 +- web/routes_logs_test.go | 16 ++++ 6 files changed, 255 insertions(+), 82 deletions(-) create mode 100644 main_test.go diff --git a/docker/client.go b/docker/client.go index ddcdc622..832e6bd1 100644 --- a/docker/client.go +++ b/docker/client.go @@ -71,7 +71,7 @@ type Client interface { } // NewClientWithFilters creates a new instance of Client with docker filters -func NewClientWithFilters(f map[string][]string) Client { +func NewClientWithFilters(f map[string][]string) (Client, error) { filterArgs := filters.NewArgs() for key, values := range f { for _, value := range values { @@ -84,13 +84,13 @@ func NewClientWithFilters(f map[string][]string) Client { cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { - log.Fatal(err) + return nil, err } - return &dockerClient{cli, filterArgs} + return &dockerClient{cli, filterArgs}, nil } -func NewClientWithTlsAndFilter(f map[string][]string, connection string) Client { +func NewClientWithTlsAndFilter(f map[string][]string, connection string) (Client, error) { filterArgs := filters.NewArgs() for key, values := range f { for _, value := range values { @@ -102,7 +102,7 @@ func NewClientWithTlsAndFilter(f map[string][]string, connection string) Client remoteUrl, err := url.Parse(connection) if err != nil { - log.Fatal(err) + return nil, err } if remoteUrl.Scheme != "tcp" { @@ -136,10 +136,10 @@ func NewClientWithTlsAndFilter(f map[string][]string, connection string) Client cli, err := client.NewClientWithOpts(opts...) if err != nil { - log.Fatal(err) + return nil, err } - return &dockerClient{cli, filterArgs} + return &dockerClient{cli, filterArgs}, nil } func (d *dockerClient) FindContainer(id string) (Container, error) { diff --git a/main.go b/main.go index f53361fd..4bdb298b 100644 --- a/main.go +++ b/main.go @@ -62,20 +62,7 @@ func (args) Version() string { var content embed.FS func main() { - var args args - var err error - parser := arg.MustParse(&args) - args.Filter = make(map[string][]string) - - for _, filter := range args.FilterStrings { - pos := strings.Index(filter, "=") - if pos == -1 { - parser.Fail("each filter should be of the form key=value") - } - key := filter[:pos] - val := filter[pos+1:] - args.Filter[key] = append(args.Filter[key], val) - } + args := parseArgs() level, _ := log.ParseLevel(args.Level) log.SetLevel(level) @@ -93,64 +80,15 @@ func main() { log.Infof("Dozzle version %s", version) - dockerClient := docker.NewClientWithFilters(args.Filter) - for i := 1; ; i++ { - _, err := dockerClient.ListContainers() - if err == nil { - break - } else if args.WaitForDockerSeconds <= 0 { - log.Fatalf("Could not connect to Docker Engine: %v", err) - } else { - log.Infof("Waiting for Docker Engine (attempt %d): %s", i, err) - time.Sleep(5 * time.Second) - args.WaitForDockerSeconds -= 5 - } + clients := createClients(args, docker.NewClientWithFilters, docker.NewClientWithTlsAndFilter) + + if len(clients) == 0 { + log.Fatal("Could not connect to any Docker Engines") + } else { + log.Infof("Connected to %d Docker Engine(s)", len(clients)) } - clients := make(map[string]docker.Client) - clients["localhost"] = dockerClient - - for _, host := range args.RemoteHost { - log.Infof("Creating client for %s", host) - client := docker.NewClientWithTlsAndFilter(args.Filter, host) - clients[host] = client - } - - if args.Username == "" && args.UsernameFile != nil { - args.Username = args.UsernameFile.Value - } - - if args.Password == "" && args.PasswordFile != nil { - args.Password = args.PasswordFile.Value - } - - if args.Username != "" || args.Password != "" { - if args.Username == "" || args.Password == "" { - log.Fatalf("Username AND password are required for authentication") - } - } - - config := web.Config{ - Addr: args.Addr, - Base: args.Base, - Version: version, - Username: args.Username, - Password: args.Password, - Hostname: args.Hostname, - NoAnalytics: args.NoAnalytics, - } - - assets, err := fs.Sub(content, "dist") - if err != nil { - log.Fatalf("Could not open embedded dist folder: %v", err) - } - - if _, ok := os.LookupEnv("LIVE_FS"); ok { - log.Info("Using live filesystem at ./dist") - assets = os.DirFS("./dist") - } - - srv := web.CreateServer(clients, assets, config) + srv := createServer(args, clients) go doStartEvent(args) go func() { log.Infof("Accepting connections on %s", srv.Addr) @@ -169,7 +107,7 @@ func main() { if err := srv.Shutdown(ctx); err != nil { log.Fatal(err) } - log.Debug("shut down complete") + log.Debug("shutdown complete") } func doStartEvent(arg args) { @@ -198,3 +136,103 @@ func doStartEvent(arg args) { log.Debug(err) } } + +func createClients(args args, localClientFactory func(map[string][]string) (docker.Client, error), remoteClientFactory func(map[string][]string, string) (docker.Client, error)) map[string]docker.Client { + clients := make(map[string]docker.Client) + + if localClient := createLocalClient(args, localClientFactory); localClient != nil { + clients["localhost"] = localClient + } + + for _, host := range args.RemoteHost { + log.Infof("Creating client for %s", host) + client, err := remoteClientFactory(args.Filter, host) + if err == nil { + clients[host] = client + } else { + log.Warnf("Could not create client for %s: %s", host, err) + } + } + + return clients +} + +func createServer(args args, clients map[string]docker.Client) *http.Server { + config := web.Config{ + Addr: args.Addr, + Base: args.Base, + Version: version, + Username: args.Username, + Password: args.Password, + Hostname: args.Hostname, + NoAnalytics: args.NoAnalytics, + } + + assets, err := fs.Sub(content, "dist") + if err != nil { + log.Fatalf("Could not open embedded dist folder: %v", err) + } + + if _, ok := os.LookupEnv("LIVE_FS"); ok { + log.Info("Using live filesystem at ./dist") + assets = os.DirFS("./dist") + } + + return web.CreateServer(clients, assets, config) +} + +func createLocalClient(args args, localClientFactory func(map[string][]string) (docker.Client, error)) docker.Client { + for i := 1; ; i++ { + dockerClient, err := localClientFactory(args.Filter) + + if err == nil { + _, err := dockerClient.ListContainers() + + if err == nil { + log.Debugf("Connected to local Docker Engine") + return dockerClient + + } + } + if args.WaitForDockerSeconds > 0 { + log.Infof("Waiting for Docker Engine (attempt %d): %s", i, err) + time.Sleep(5 * time.Second) + args.WaitForDockerSeconds -= 5 + } else { + log.Debugf("Local Docker Engine not found") + break + } + } + return nil +} + +func parseArgs() args { + var args args + parser := arg.MustParse(&args) + args.Filter = make(map[string][]string) + + for _, filter := range args.FilterStrings { + pos := strings.Index(filter, "=") + if pos == -1 { + parser.Fail("each filter should be of the form key=value") + } + key := filter[:pos] + val := filter[pos+1:] + args.Filter[key] = append(args.Filter[key], val) + } + + if args.Username == "" && args.UsernameFile != nil { + args.Username = args.UsernameFile.Value + } + + if args.Password == "" && args.PasswordFile != nil { + args.Password = args.PasswordFile.Value + } + + if args.Username != "" || args.Password != "" { + if args.Username == "" || args.Password == "" { + log.Fatalf("Username AND password are required for authentication") + } + } + return args +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..85858e4d --- /dev/null +++ b/main_test.go @@ -0,0 +1,113 @@ +package main + +import ( + "errors" + "testing" + + "github.com/amir20/dozzle/docker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type fakeClient struct { + docker.Client + mock.Mock +} + +func (f *fakeClient) ListContainers() ([]docker.Container, error) { + args := f.Called() + return args.Get(0).([]docker.Container), args.Error(1) +} + +func Test_valid_localhost(t *testing.T) { + fakeClientFactory := func(filter map[string][]string) (docker.Client, error) { + client := new(fakeClient) + client.On("ListContainers").Return([]docker.Container{}, nil) + return client, nil + } + + args := args{} + + actualClient := createLocalClient(args, fakeClientFactory) + + assert.NotNil(t, actualClient) +} + +func Test_invalid_localhost(t *testing.T) { + fakeClientFactory := func(filter map[string][]string) (docker.Client, error) { + client := new(fakeClient) + client.On("ListContainers").Return([]docker.Container{}, errors.New("error")) + return client, nil + } + + args := args{} + + actualClient := createLocalClient(args, fakeClientFactory) + + assert.Nil(t, actualClient) +} + +func Test_valid_remote(t *testing.T) { + fakeLocalClientFactory := func(filter map[string][]string) (docker.Client, error) { + client := new(fakeClient) + client.On("ListContainers").Return([]docker.Container{}, errors.New("error")) + return client, nil + } + + fakeRemoteClientFactory := func(filter map[string][]string, host string) (docker.Client, error) { + client := new(fakeClient) + return client, nil + } + + args := args{ + RemoteHost: []string{"tcp://localhost:2375"}, + } + + clients := createClients(args, fakeLocalClientFactory, fakeRemoteClientFactory) + + assert.Equal(t, 1, len(clients)) + assert.Contains(t, clients, "tcp://localhost:2375") + assert.NotContains(t, clients, "localhost") +} + +func Test_valid_remote_and_local(t *testing.T) { + fakeLocalClientFactory := func(filter map[string][]string) (docker.Client, error) { + client := new(fakeClient) + client.On("ListContainers").Return([]docker.Container{}, nil) + return client, nil + } + + fakeRemoteClientFactory := func(filter map[string][]string, host string) (docker.Client, error) { + client := new(fakeClient) + return client, nil + } + + args := args{ + RemoteHost: []string{"tcp://localhost:2375"}, + } + + clients := createClients(args, fakeLocalClientFactory, fakeRemoteClientFactory) + + assert.Equal(t, 2, len(clients)) + assert.Contains(t, clients, "tcp://localhost:2375") + assert.Contains(t, clients, "localhost") +} + +func Test_no_clients(t *testing.T) { + fakeLocalClientFactory := func(filter map[string][]string) (docker.Client, error) { + client := new(fakeClient) + client.On("ListContainers").Return([]docker.Container{}, errors.New("error")) + return client, nil + } + + fakeRemoteClientFactory := func(filter map[string][]string, host string) (docker.Client, error) { + client := new(fakeClient) + return client, nil + } + + args := args{} + + clients := createClients(args, fakeLocalClientFactory, fakeRemoteClientFactory) + + assert.Equal(t, 0, len(clients)) +} diff --git a/web/routes.go b/web/routes.go index 493fe04f..2a0d41e4 100644 --- a/web/routes.go +++ b/web/routes.go @@ -184,9 +184,15 @@ func (h *handler) healthcheck(w http.ResponseWriter, r *http.Request) { } func (h *handler) clientFromRequest(r *http.Request) docker.Client { + if !r.URL.Query().Has("host") { + log.Fatalf("No host parameter found in request %v", r.URL) + } + host := r.URL.Query().Get("host") if client, ok := h.clients[host]; ok { return client } - return h.clients["localhost"] + + log.Fatalf("No client found for host %v and url %v", host, r.URL) + return nil } diff --git a/web/routes_auth_test.go b/web/routes_auth_test.go index adf29f35..aa7fc84c 100644 --- a/web/routes_auth_test.go +++ b/web/routes_auth_test.go @@ -182,7 +182,7 @@ func Test_createRoutes_username_password_valid_session(t *testing.T) { handler := createHandler(mockedClient, nil, Config{Base: "/", Username: "amir", Password: "password"}) // Get cookie first - req, err := http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1", nil) + req, err := http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1&host=localhost", nil) require.NoError(t, err, "NewRequest should not return an error.") session, _ := store.Get(req, sessionName) session.Values[authorityKey] = time.Now().Unix() @@ -191,7 +191,7 @@ func Test_createRoutes_username_password_valid_session(t *testing.T) { cookies := recorder.Result().Cookies() // Test with cookie - req, err = http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1", nil) + req, err = http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1&host=localhost", nil) require.NoError(t, err, "NewRequest should not return an error.") req.AddCookie(cookies[0]) rr := httptest.NewRecorder() @@ -204,7 +204,7 @@ func Test_createRoutes_username_password_invalid_session(t *testing.T) { mockedClient.On("FindContainer", "123").Return(docker.Container{ID: "123"}, nil) mockedClient.On("ContainerLogs", mock.Anything, "since", docker.STDALL).Return(io.NopCloser(strings.NewReader("test data")), io.EOF) handler := createHandler(mockedClient, nil, Config{Base: "/", Username: "amir", Password: "password"}) - req, err := http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1", nil) + req, err := http.NewRequest("GET", "/api/logs/stream?id=123&stdout=1&stderr=1&host=localhost", nil) require.NoError(t, err, "NewRequest should not return an error.") req.AddCookie(&http.Cookie{Name: "session", Value: "baddata"}) rr := httptest.NewRecorder() diff --git a/web/routes_logs_test.go b/web/routes_logs_test.go index 0fd15cfc..05137437 100644 --- a/web/routes_logs_test.go +++ b/web/routes_logs_test.go @@ -24,6 +24,7 @@ func Test_handler_streamLogs_happy(t *testing.T) { q.Add("id", id) q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -50,6 +51,7 @@ func Test_handler_streamLogs_happy_with_id(t *testing.T) { q.Add("id", id) q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -76,6 +78,7 @@ func Test_handler_streamLogs_happy_container_stopped(t *testing.T) { q.Add("id", id) q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -101,6 +104,7 @@ func Test_handler_streamLogs_error_finding_container(t *testing.T) { q.Add("id", id) q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -125,6 +129,7 @@ func Test_handler_streamLogs_error_reading(t *testing.T) { q.Add("id", id) q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -148,6 +153,7 @@ func Test_handler_streamLogs_error_std(t *testing.T) { req, err := http.NewRequest("GET", "/api/logs/stream", nil) q := req.URL.Query() q.Add("id", id) + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() require.NoError(t, err, "NewRequest should not return an error.") @@ -167,6 +173,9 @@ func Test_handler_streamLogs_error_std(t *testing.T) { func Test_handler_streamEvents_happy(t *testing.T) { req, err := http.NewRequest("GET", "/api/events/stream", nil) require.NoError(t, err, "NewRequest should not return an error.") + q := req.URL.Query() + q.Add("host", "localhost") + req.URL.RawQuery = q.Encode() mockedClient := new(MockedClient) messages := make(chan docker.ContainerEvent) errChannel := make(chan error) @@ -199,6 +208,9 @@ func Test_handler_streamEvents_happy(t *testing.T) { func Test_handler_streamEvents_error(t *testing.T) { req, err := http.NewRequest("GET", "/api/events/stream", nil) require.NoError(t, err, "NewRequest should not return an error.") + q := req.URL.Query() + q.Add("host", "localhost") + req.URL.RawQuery = q.Encode() mockedClient := new(MockedClient) messages := make(chan docker.ContainerEvent) errChannel := make(chan error) @@ -224,6 +236,9 @@ func Test_handler_streamEvents_error(t *testing.T) { func Test_handler_streamEvents_error_request(t *testing.T) { req, err := http.NewRequest("GET", "/api/events/stream", nil) require.NoError(t, err, "NewRequest should not return an error.") + q := req.URL.Query() + q.Add("host", "localhost") + req.URL.RawQuery = q.Encode() mockedClient := new(MockedClient) @@ -264,6 +279,7 @@ func Test_handler_between_dates(t *testing.T) { q.Add("id", "123456") q.Add("stdout", "true") q.Add("stderr", "true") + q.Add("host", "localhost") req.URL.RawQuery = q.Encode() mockedClient := new(MockedClient)