diff --git a/cmd/system-probe/api/agentrestart_darwin.go b/cmd/system-probe/api/agentrestart_darwin.go new file mode 100644 index 000000000000..deace920499b --- /dev/null +++ b/cmd/system-probe/api/agentrestart_darwin.go @@ -0,0 +1,48 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build darwin + +package api + +import ( + "fmt" + "net/http" + "os/exec" + "time" + + "github.com/DataDog/datadog-agent/pkg/util/log" +) + +var afterFunc = time.AfterFunc + +var kickstart = func(service string) error { + cmd := exec.Command("/bin/launchctl", "kickstart", "-k", service) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%s", string(out)) + } + return nil +} + +func handleAgentRestart(w http.ResponseWriter, r *http.Request) { + // Reply 200 immediately so the client receives the response before launchd + // tears down this process when sysprobe is restarted. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // Restart both services after a short delay so the HTTP response has time + // to be delivered before launchd sends SIGTERM to this process. + afterFunc(100*time.Millisecond, func() { + if err := kickstart("system/com.datadoghq.agent"); err != nil { + log.Errorf("agent-restart: failed to restart com.datadoghq.agent: %v", err) + } + if err := kickstart("system/com.datadoghq.sysprobe"); err != nil { + log.Errorf("agent-restart: failed to restart com.datadoghq.sysprobe: %v", err) + } + }) +} diff --git a/cmd/system-probe/api/agentrestart_darwin_test.go b/cmd/system-probe/api/agentrestart_darwin_test.go new file mode 100644 index 000000000000..5d8f5c84a0bc --- /dev/null +++ b/cmd/system-probe/api/agentrestart_darwin_test.go @@ -0,0 +1,71 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build darwin + +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func withMockKickstart(t *testing.T, mock func(string) error) { + t.Helper() + orig := kickstart + kickstart = mock + t.Cleanup(func() { kickstart = orig }) +} + +// withSyncAfterFunc replaces the timer so the callback runs synchronously inside +// handleAgentRestart, before the function returns. This prevents the real kickstart +// from being restored by t.Cleanup before the timer fires. +func withSyncAfterFunc(t *testing.T) { + t.Helper() + orig := afterFunc + afterFunc = func(_ time.Duration, f func()) *time.Timer { f(); return nil } + t.Cleanup(func() { afterFunc = orig }) +} + +func TestHandleAgentRestart_Returns200Immediately(t *testing.T) { + withSyncAfterFunc(t) + withMockKickstart(t, func(string) error { return nil }) + + req := httptest.NewRequest(http.MethodPost, "/agent-restart", nil) + rr := httptest.NewRecorder() + + handleAgentRestart(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestHandleAgentRestart_ServiceRestartSequence(t *testing.T) { + // expectedServices defines the exact order in which launchd services must be restarted. + // Agent must come before sysprobe because restarting sysprobe sends SIGTERM to this process. + expectedServices := []string{ + "system/com.datadoghq.agent", + "system/com.datadoghq.sysprobe", + } + + withSyncAfterFunc(t) + + var called []string + withMockKickstart(t, func(svc string) error { + called = append(called, svc) + return nil + }) + + req := httptest.NewRequest(http.MethodPost, "/agent-restart", nil) + rr := httptest.NewRecorder() + + handleAgentRestart(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, expectedServices, called) +} diff --git a/cmd/system-probe/api/agentrestart_others.go b/cmd/system-probe/api/agentrestart_others.go new file mode 100644 index 000000000000..0e213f545563 --- /dev/null +++ b/cmd/system-probe/api/agentrestart_others.go @@ -0,0 +1,14 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build !darwin + +package api + +import "net/http" + +func handleAgentRestart(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "not supported on this platform", http.StatusNotImplemented) +} diff --git a/cmd/system-probe/api/agentrestart_others_test.go b/cmd/system-probe/api/agentrestart_others_test.go new file mode 100644 index 000000000000..e691dfb48c89 --- /dev/null +++ b/cmd/system-probe/api/agentrestart_others_test.go @@ -0,0 +1,26 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build !darwin + +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHandleAgentRestart_NotSupportedOnNonDarwin(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/agent-restart", nil) + rr := httptest.NewRecorder() + + handleAgentRestart(rr, req) + + assert.Equal(t, http.StatusNotImplemented, rr.Code) + assert.Contains(t, rr.Body.String(), "not supported on this platform") +} diff --git a/cmd/system-probe/api/server.go b/cmd/system-probe/api/server.go index f9763dbb4f40..ec0eca2026d8 100644 --- a/cmd/system-probe/api/server.go +++ b/cmd/system-probe/api/server.go @@ -64,6 +64,8 @@ func StartServer(cfg *sysconfigtypes.Config, settings settings.Component, rcclie mux.HandleFunc("/debug/selinux_semodule_list", debug.HandleSelinuxSemoduleList) } + mux.Handle("POST /agent-restart", deps.Ipc.HTTPMiddleware(http.HandlerFunc(handleAgentRestart))) + // Register /coverage endpoint for computing code coverage (e2ecoverage build only). coverage.SetupCoverageHandler(mux) diff --git a/comp/core/gui/impl/gui.go b/comp/core/gui/impl/gui.go index 27ca5ece63cc..d0b405e9d0fd 100644 --- a/comp/core/gui/impl/gui.go +++ b/comp/core/gui/impl/gui.go @@ -29,6 +29,7 @@ import ( "github.com/DataDog/datadog-agent/comp/core/flare" guidef "github.com/DataDog/datadog-agent/comp/core/gui/def" "github.com/DataDog/datadog-agent/comp/core/hostname/hostnameinterface/def" + ipc "github.com/DataDog/datadog-agent/comp/core/ipc/def" log "github.com/DataDog/datadog-agent/comp/core/log/def" "github.com/DataDog/datadog-agent/comp/core/status" compdef "github.com/DataDog/datadog-agent/comp/def" @@ -74,6 +75,7 @@ type Requires struct { Status status.Component Lc compdef.Lifecycle Hostname hostnameinterface.Component + Ipc ipc.Component } // Provides defines the output of the gui component. @@ -120,6 +122,12 @@ func NewComponent(deps Requires) Provides { sessionExpiration := deps.Config.GetDuration("GUI_session_expiration") g.auth = newAuthenticator(authToken, sessionExpiration) + setGetAuthToken(deps.Ipc.GetAuthToken) + socketPath := deps.Config.GetString("system_probe_config.sysprobe_socket") + if socketPath == "" { + socketPath = defaultpaths.GetDefaultSystemProbeAddress() + } + setSysprobeSocketPath(socketPath) // register the public routes publicRouter.HandleFunc("GET /{$}", renderIndexPage) @@ -281,6 +289,7 @@ func (g *gui) getAccessToken(w http.ResponseWriter, r *http.Request) { Value: accessToken, Path: "/", HttpOnly: true, + SameSite: http.SameSiteStrictMode, MaxAge: 31536000, // 1 year }) http.Redirect(w, r, "/", http.StatusFound) @@ -292,6 +301,17 @@ func (g *gui) authMiddleware(next http.Handler) http.Handler { // Disable caching w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + // For state-changing requests, reject any cross-origin Origin header to prevent CSRF. + // Same-origin requests from the GUI itself either omit Origin or match the server address. + if r.Method != http.MethodGet && r.Method != http.MethodHead { + if origin := r.Header.Get("Origin"); origin != "" { + if origin != "http://"+g.address { + http.Error(w, "invalid origin", http.StatusForbidden) + return + } + } + } + cookie, _ := r.Cookie("accessToken") if cookie == nil { http.Error(w, "missing accessToken", http.StatusUnauthorized) diff --git a/comp/core/gui/impl/gui_csrf_test.go b/comp/core/gui/impl/gui_csrf_test.go new file mode 100644 index 000000000000..d79de5668824 --- /dev/null +++ b/comp/core/gui/impl/gui_csrf_test.go @@ -0,0 +1,102 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +package guiimpl + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestGUI(t *testing.T) *gui { + t.Helper() + return &gui{ + address: "localhost:5002", + auth: newAuthenticator("test-secret", 5*time.Minute), + intentTokens: make(map[string]bool), + } +} + +func TestGetAccessToken_CookieHasSameSiteStrict(t *testing.T) { + g := newTestGUI(t) + g.intentTokens["test-intent"] = true + + req := httptest.NewRequest(http.MethodGet, "/auth?intent=test-intent", nil) + rr := httptest.NewRecorder() + + g.getAccessToken(rr, req) + + var accessCookie *http.Cookie + for _, c := range rr.Result().Cookies() { + if c.Name == "accessToken" { + accessCookie = c + break + } + } + require.NotNil(t, accessCookie, "accessToken cookie must be set") + assert.Equal(t, http.SameSiteStrictMode, accessCookie.SameSite) + assert.True(t, accessCookie.HttpOnly) +} + +func TestAuthMiddleware_OriginCheck(t *testing.T) { + g := newTestGUI(t) + token := g.auth.GenerateAccessToken() + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + method string + origin string + expectedStatus int + }{ + { + name: "POST without Origin is allowed (same-origin browser request)", + method: http.MethodPost, + origin: "", + expectedStatus: http.StatusOK, + }, + { + name: "POST with matching Origin is allowed", + method: http.MethodPost, + origin: "http://localhost:5002", + expectedStatus: http.StatusOK, + }, + { + name: "POST with cross-origin Origin is rejected", + method: http.MethodPost, + origin: "http://evil.com", + expectedStatus: http.StatusForbidden, + }, + { + name: "GET with cross-origin Origin is allowed (safe method)", + method: http.MethodGet, + origin: "http://evil.com", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/agent/restart", nil) + req.AddCookie(&http.Cookie{Name: "accessToken", Value: token}) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + + rr := httptest.NewRecorder() + g.authMiddleware(okHandler).ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} diff --git a/comp/core/gui/impl/platform_darwin.go b/comp/core/gui/impl/platform_darwin.go index ef6f88be834e..d381a920eefc 100644 --- a/comp/core/gui/impl/platform_darwin.go +++ b/comp/core/gui/impl/platform_darwin.go @@ -6,8 +6,10 @@ package guiimpl import ( - "errors" + "fmt" + "net/http" + sysprobeclient "github.com/DataDog/datadog-agent/pkg/system-probe/api/client" template "github.com/DataDog/datadog-agent/pkg/template/html" ) @@ -25,10 +27,42 @@ const instructionTemplate = `{{define "loginInstruction" }}

Note: If you would like to adjust the GUI session timeout, you can modify the GUI_session_expiration parameter in datadog.yaml {{end}}` +// getAuthToken is a function that fetches the IPC auth token on each call, +// avoiding storage of the credential as a long-lived global. +// sysprobeSocketPath holds the Unix socket path, set once at startup. +var getAuthToken func() string +var sysprobeSocketPath string + +func setGetAuthToken(f func() string) { + getAuthToken = f +} + +func setSysprobeSocketPath(path string) { + sysprobeSocketPath = path +} + func restartEnabled() bool { - return false + return true } func restart() error { - return errors.New("restarting the agent is not implemented on non-windows platforms") + client := sysprobeclient.Get(sysprobeSocketPath) + + url := sysprobeclient.URL("/agent-restart") + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + return fmt.Errorf("could not build restart request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+getAuthToken()) + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("could not reach system-probe: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("system-probe agent restart failed with status %d; see system-probe logs for details", resp.StatusCode) + } + return nil } diff --git a/comp/core/gui/impl/platform_darwin_test.go b/comp/core/gui/impl/platform_darwin_test.go index 83b170ffca96..da57e487daaf 100644 --- a/comp/core/gui/impl/platform_darwin_test.go +++ b/comp/core/gui/impl/platform_darwin_test.go @@ -6,12 +6,16 @@ package guiimpl import ( + "fmt" "io" + "net" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const expectedBody = ` @@ -65,6 +69,11 @@ const expectedBody = `   Flare +

+
@@ -128,6 +137,120 @@ const expectedBody = ` ` +// startUnixServer starts an HTTP server on a temp Unix socket and returns its path. +// Uses os.CreateTemp under /tmp to stay within the 108-char Unix socket path limit on macOS. +func startUnixServer(t *testing.T, handler http.Handler) string { + t.Helper() + f, err := os.CreateTemp("", "gui-test-*.sock") + require.NoError(t, err) + socketPath := f.Name() + f.Close() + os.Remove(socketPath) + t.Cleanup(func() { os.Remove(socketPath) }) + + l, err := net.Listen("unix", socketPath) + require.NoError(t, err) + srv := &http.Server{Handler: handler} + go srv.Serve(l) //nolint:errcheck + t.Cleanup(func() { srv.Close() }) + return socketPath +} + +func TestRestartEnabled(t *testing.T) { + assert.True(t, restartEnabled()) +} + +func TestSetGetAuthToken(t *testing.T) { + orig := getAuthToken + t.Cleanup(func() { getAuthToken = orig }) + + setGetAuthToken(func() string { return "test-token" }) + assert.Equal(t, "test-token", getAuthToken()) +} + +func TestSetSysprobeSocketPath(t *testing.T) { + orig := sysprobeSocketPath + t.Cleanup(func() { sysprobeSocketPath = orig }) + + setSysprobeSocketPath("/tmp/test.sock") + assert.Equal(t, "/tmp/test.sock", sysprobeSocketPath) +} + +func TestRestart_Success(t *testing.T) { + socketPath := startUnixServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/agent-restart", r.URL.Path) + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + })) + + origSocket := sysprobeSocketPath + origToken := getAuthToken + t.Cleanup(func() { + sysprobeSocketPath = origSocket + getAuthToken = origToken + }) + setSysprobeSocketPath(socketPath) + setGetAuthToken(func() string { return "test-token" }) + + err := restart() + assert.NoError(t, err) +} + +func TestRestart_SysprobeUnreachable(t *testing.T) { + origSocket := sysprobeSocketPath + origToken := getAuthToken + t.Cleanup(func() { + sysprobeSocketPath = origSocket + getAuthToken = origToken + }) + setSysprobeSocketPath("/tmp/gui-test-nonexistent.sock") + setGetAuthToken(func() string { return "token" }) + + err := restart() + require.Error(t, err) + assert.Contains(t, err.Error(), "could not reach system-probe") +} + +func TestRestart_SysprobeReturnsError(t *testing.T) { + socketPath := startUnixServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "launchctl failed", http.StatusInternalServerError) + })) + + origSocket := sysprobeSocketPath + origToken := getAuthToken + t.Cleanup(func() { + sysprobeSocketPath = origSocket + getAuthToken = origToken + }) + setSysprobeSocketPath(socketPath) + setGetAuthToken(func() string { return "token" }) + + err := restart() + require.Error(t, err) + assert.Contains(t, err.Error(), "system-probe agent restart failed with status 500") +} + +func TestRestart_SendsAuthorizationHeader(t *testing.T) { + var receivedAuth string + socketPath := startUnixServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + + origSocket := sysprobeSocketPath + origToken := getAuthToken + t.Cleanup(func() { + sysprobeSocketPath = origSocket + getAuthToken = origToken + }) + setSysprobeSocketPath(socketPath) + setGetAuthToken(func() string { return "secret-ipc-token" }) + + require.NoError(t, restart()) + assert.Equal(t, fmt.Sprintf("Bearer %s", "secret-ipc-token"), receivedAuth) +} + func TestRenderIndexPage(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) if err != nil { diff --git a/comp/core/gui/impl/platform_nix.go b/comp/core/gui/impl/platform_nix.go index 0f93fa02b077..c499cf9604d4 100644 --- a/comp/core/gui/impl/platform_nix.go +++ b/comp/core/gui/impl/platform_nix.go @@ -22,6 +22,9 @@ const instructionTemplate = `{{define "loginInstruction" }}

Note: If you would like to adjust the GUI session timeout, you can modify the GUI_session_expiration parameter in datadog.yaml {{end}}` +func setGetAuthToken(_ func() string) {} +func setSysprobeSocketPath(_ string) {} + func restartEnabled() bool { return false } diff --git a/comp/core/gui/impl/platform_windows.go b/comp/core/gui/impl/platform_windows.go index 4c4165c952d5..8c1b44d6f1f2 100644 --- a/comp/core/gui/impl/platform_windows.go +++ b/comp/core/gui/impl/platform_windows.go @@ -33,6 +33,9 @@ const instructionTemplate = `{{define "loginInstruction" }}

Note: If you would like to adjust the GUI session timeout, you can modify the GUI_session_expiration parameter in datadog.yaml {{end}}` +func setGetAuthToken(_ func() string) {} +func setSysprobeSocketPath(_ string) {} + func restartEnabled() bool { return true }