diff --git a/.goreleaser.yml b/.goreleaser.yml index 221bfd9..631cdf2 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -40,6 +40,6 @@ changelog: - "^test:" release: - draft: true + draft: false prerelease: auto name_template: "Flashduty CLI {{.Version}}" diff --git a/internal/cli/root.go b/internal/cli/root.go index d851be9..99173ff 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -9,9 +9,11 @@ import ( flashduty "github.com/flashcatcloud/flashduty-sdk" "github.com/spf13/cobra" + "golang.org/x/term" "github.com/flashcatcloud/flashduty-cli/internal/config" "github.com/flashcatcloud/flashduty-cli/internal/output" + "github.com/flashcatcloud/flashduty-cli/internal/update" ) // flashdutyClient defines the SDK operations used by CLI commands. @@ -86,12 +88,48 @@ var ( flagBaseURL string ) +var updateResultCh chan *update.CheckResult + var rootCmd = &cobra.Command{ Use: "flashduty", Short: "Flashduty CLI - incident management from your terminal", Long: "Flashduty CLI - incident management from your terminal.\n\nGet started by running 'flashduty login' to authenticate.", SilenceUsage: true, SilenceErrors: true, + PersistentPreRun: func(cmd *cobra.Command, _ []string) { + path := cmd.CommandPath() + if path == "flashduty update" || path == "flashduty version" { + return + } + if !update.ShouldCheck(versionStr) { + return + } + if !term.IsTerminal(int(os.Stderr.Fd())) { + return + } + updateResultCh = make(chan *update.CheckResult, 1) + go func() { + result, err := update.CheckForUpdate(versionStr) + if err != nil { + return + } + updateResultCh <- result + }() + }, + PersistentPostRun: func(_ *cobra.Command, _ []string) { + if updateResultCh == nil { + return + } + select { + case result := <-updateResultCh: + if result != nil && result.UpdateAvailable { + fmt.Fprintf(os.Stderr, "\nA new version of flashduty is available: v%s -> %s\n", + update.StripV(result.CurrentVersion), result.LatestVersion) + fmt.Fprintf(os.Stderr, "To update, run: flashduty update\n") + } + default: + } + }, } func init() { @@ -125,6 +163,8 @@ func init() { // Phase 3 rootCmd.AddCommand(newInsightCmd()) rootCmd.AddCommand(newAuditCmd()) + + rootCmd.AddCommand(newUpdateCmd()) } // Execute runs the root command. diff --git a/internal/cli/update.go b/internal/cli/update.go new file mode 100644 index 0000000..7b738af --- /dev/null +++ b/internal/cli/update.go @@ -0,0 +1,72 @@ +package cli + +import ( + "fmt" + "os" + "os/exec" + "runtime" + + "github.com/spf13/cobra" + + "github.com/flashcatcloud/flashduty-cli/internal/update" +) + +func newUpdateCmd() *cobra.Command { + var flagCheck bool + + cmd := &cobra.Command{ + Use: "update", + Short: "Update flashduty to the latest version", + RunE: func(cmd *cobra.Command, _ []string) error { + w := cmd.OutOrStdout() + _, _ = fmt.Fprintf(w, "Current version: %s\n", versionStr) + _, _ = fmt.Fprintf(w, "Checking for updates...\n") + + result, err := update.CheckForUpdate(versionStr) + if err != nil { + return fmt.Errorf("failed to check for updates: %w", err) + } + + if !result.UpdateAvailable { + _, _ = fmt.Fprintf(w, "Already up to date (%s).\n", versionStr) + return nil + } + + _, _ = fmt.Fprintf(w, "A new version is available: v%s -> %s\n", + update.StripV(versionStr), result.LatestVersion) + _, _ = fmt.Fprintf(w, "Release: %s\n", result.LatestURL) + + if flagCheck { + return nil + } + + _, _ = fmt.Fprintf(w, "\nUpdating...\n") + return runInstaller(cmd) + }, + } + + cmd.Flags().BoolVar(&flagCheck, "check", false, "Only check for updates, do not install") + return cmd +} + +func runInstaller(cmd *cobra.Command) error { + var c *exec.Cmd + if runtime.GOOS == "windows" { + c = exec.Command("powershell", "-Command", + fmt.Sprintf("irm %s | iex", update.InstallPowerShellURL())) + } else { + c = exec.Command("sh", "-c", + fmt.Sprintf("curl -fsSL %s | sh", update.InstallShellURL())) + } + + c.Stdout = cmd.OutOrStdout() + c.Stderr = cmd.ErrOrStderr() + c.Stdin = os.Stdin + + if err := c.Run(); err != nil { + return fmt.Errorf("update failed: %w", err) + } + + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nUpdate complete. Run 'flashduty version' to verify.\n") + return nil +} diff --git a/internal/update/check.go b/internal/update/check.go new file mode 100644 index 0000000..1b570da --- /dev/null +++ b/internal/update/check.go @@ -0,0 +1,199 @@ +package update + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +const ( + repoOwner = "flashcatcloud" + repoName = "flashduty-cli" + checkInterval = 24 * time.Hour + httpTimeout = 5 * time.Second + stateFileName = "state.yaml" + installShURL = "https://raw.githubusercontent.com/" + repoOwner + "/" + repoName + "/main/install.sh" + installPs1URL = "https://raw.githubusercontent.com/" + repoOwner + "/" + repoName + "/main/install.ps1" + maxResponseBytes = 1 << 20 // 1MB +) + +var apiURL = "https://api.github.com/repos/" + repoOwner + "/" + repoName + "/releases/latest" + +type State struct { + CheckedAt time.Time `yaml:"checked_at"` + LatestVersion string `yaml:"latest_version"` + LatestURL string `yaml:"latest_url"` +} + +type CheckResult struct { + CurrentVersion string + LatestVersion string + LatestURL string + UpdateAvailable bool +} + +type githubRelease struct { + TagName string `json:"tag_name"` + HTMLURL string `json:"html_url"` +} + +func InstallShellURL() string { return installShURL } +func InstallPowerShellURL() string { return installPs1URL } + +func stateDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to determine home directory: %w", err) + } + return filepath.Join(home, ".flashduty"), nil +} + +func statePath() (string, error) { + dir, err := stateDir() + if err != nil { + return "", err + } + return filepath.Join(dir, stateFileName), nil +} + +func loadState() *State { + path, err := statePath() + if err != nil { + return &State{} + } + data, err := os.ReadFile(path) + if err != nil { + return &State{} + } + var s State + if err := yaml.Unmarshal(data, &s); err != nil { + return &State{} + } + return &s +} + +func saveState(s *State) error { + dir, err := stateDir() + if err != nil { + return err + } + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create state directory: %w", err) + } + data, err := yaml.Marshal(s) + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + path, err := statePath() + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func fetchLatestVersion() (string, string, error) { + client := &http.Client{Timeout: httpTimeout} + resp, err := client.Get(apiURL) + if err != nil { + return "", "", fmt.Errorf("failed to fetch latest release: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("GitHub API returned %d", resp.StatusCode) + } + + var rel githubRelease + if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseBytes)).Decode(&rel); err != nil { + return "", "", fmt.Errorf("failed to parse release response: %w", err) + } + if rel.TagName == "" { + return "", "", fmt.Errorf("empty tag_name in response") + } + return rel.TagName, rel.HTMLURL, nil +} + +func StripV(v string) string { + return strings.TrimPrefix(v, "v") +} + +// stripPreRelease removes pre-release suffix (e.g. "1.0.0-rc1" -> "1.0.0"). +func stripPreRelease(v string) string { + if base, _, ok := strings.Cut(v, "-"); ok { + return base + } + return v +} + +func compareSemver(a, b string) int { + a = stripPreRelease(a) + b = stripPreRelease(b) + aParts := strings.Split(a, ".") + bParts := strings.Split(b, ".") + maxLen := max(len(aParts), len(bParts)) + for i := range maxLen { + var ai, bi int + if i < len(aParts) { + ai, _ = strconv.Atoi(aParts[i]) + } + if i < len(bParts) { + bi, _ = strconv.Atoi(bParts[i]) + } + if ai != bi { + return ai - bi + } + } + return 0 +} + +func IsNewer(latestTag, currentVersion string) bool { + latest := StripV(latestTag) + current := StripV(currentVersion) + if latest == current { + return false + } + return compareSemver(latest, current) > 0 +} + +func ShouldCheck(currentVersion string) bool { + if currentVersion == "dev" || currentVersion == "(devel)" { + return false + } + if os.Getenv("FLASHDUTY_NO_UPDATE_CHECK") == "1" { + return false + } + if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" || + os.Getenv("JENKINS_URL") != "" || os.Getenv("GITLAB_CI") != "" { + return false + } + state := loadState() + return time.Since(state.CheckedAt) >= checkInterval +} + +func CheckForUpdate(currentVersion string) (*CheckResult, error) { + tag, url, err := fetchLatestVersion() + if err != nil { + return nil, err + } + + _ = saveState(&State{ + CheckedAt: time.Now(), + LatestVersion: tag, + LatestURL: url, + }) + + return &CheckResult{ + CurrentVersion: currentVersion, + LatestVersion: tag, + LatestURL: url, + UpdateAvailable: IsNewer(tag, currentVersion), + }, nil +} diff --git a/internal/update/check_test.go b/internal/update/check_test.go new file mode 100644 index 0000000..673c29d --- /dev/null +++ b/internal/update/check_test.go @@ -0,0 +1,308 @@ +package update + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "gopkg.in/yaml.v3" +) + +// setTestHome overrides the home directory for testing across all platforms. +func setTestHome(t *testing.T, dir string) { + t.Helper() + t.Setenv("HOME", dir) + if runtime.GOOS == "windows" { + t.Setenv("USERPROFILE", dir) + } +} + +// clearCIEnv clears CI-related env vars so ShouldCheck doesn't short-circuit. +func clearCIEnv(t *testing.T) { + t.Helper() + t.Setenv("CI", "") + t.Setenv("GITHUB_ACTIONS", "") + t.Setenv("JENKINS_URL", "") + t.Setenv("GITLAB_CI", "") + t.Setenv("FLASHDUTY_NO_UPDATE_CHECK", "") +} + +func TestStripV(t *testing.T) { + tests := []struct { + in, want string + }{ + {"v0.6.0", "0.6.0"}, + {"0.6.0", "0.6.0"}, + {"v1.0.0", "1.0.0"}, + {"", ""}, + } + for _, tt := range tests { + if got := StripV(tt.in); got != tt.want { + t.Errorf("StripV(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestCompareSemver(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"0.7.0", "0.6.0", 1}, + {"0.6.1", "0.6.0", 1}, + {"1.0.0", "0.9.9", 1}, + {"0.10.0", "0.9.0", 1}, + {"0.6.0", "0.7.0", -1}, + {"0.6.0", "0.6.0", 0}, + {"1.0.0", "1.0.0", 0}, + {"1.0.0-rc1", "1.0.0-rc2", 0}, + {"1.0.0-beta", "1.0.0", 0}, + } + for _, tt := range tests { + got := compareSemver(tt.a, tt.b) + switch { + case tt.want > 0 && got <= 0: + t.Errorf("compareSemver(%q, %q) = %d, want >0", tt.a, tt.b, got) + case tt.want < 0 && got >= 0: + t.Errorf("compareSemver(%q, %q) = %d, want <0", tt.a, tt.b, got) + case tt.want == 0 && got != 0: + t.Errorf("compareSemver(%q, %q) = %d, want 0", tt.a, tt.b, got) + } + } +} + +func TestIsNewer(t *testing.T) { + tests := []struct { + latest, current string + want bool + }{ + {"v0.7.0", "0.6.0", true}, + {"v0.7.0", "v0.6.0", true}, + {"0.7.0", "0.6.0", true}, + {"v0.6.0", "0.6.0", false}, + {"v0.5.0", "0.6.0", false}, + {"v0.10.0", "0.9.0", true}, + {"v1.0.0-rc1", "0.9.0", true}, + } + for _, tt := range tests { + if got := IsNewer(tt.latest, tt.current); got != tt.want { + t.Errorf("IsNewer(%q, %q) = %v, want %v", tt.latest, tt.current, got, tt.want) + } + } +} + +func TestShouldCheck_DevVersion(t *testing.T) { + if ShouldCheck("dev") { + t.Error("ShouldCheck(\"dev\") = true, want false") + } + if ShouldCheck("(devel)") { + t.Error("ShouldCheck(\"(devel)\") = true, want false") + } +} + +func TestShouldCheck_EnvDisabled(t *testing.T) { + t.Setenv("FLASHDUTY_NO_UPDATE_CHECK", "1") + if ShouldCheck("0.6.0") { + t.Error("ShouldCheck should return false when FLASHDUTY_NO_UPDATE_CHECK=1") + } +} + +func TestShouldCheck_CI(t *testing.T) { + t.Setenv("CI", "true") + if ShouldCheck("0.6.0") { + t.Error("ShouldCheck should return false in CI") + } +} + +func TestShouldCheck_RecentCheck(t *testing.T) { + tmp := t.TempDir() + setTestHome(t, tmp) + clearCIEnv(t) + + dir := filepath.Join(tmp, ".flashduty") + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatal(err) + } + + s := &State{CheckedAt: time.Now()} + data, err := yaml.Marshal(s) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, stateFileName), data, 0600); err != nil { + t.Fatal(err) + } + + if ShouldCheck("0.6.0") { + t.Error("ShouldCheck should return false when checked recently") + } +} + +func TestShouldCheck_StaleCheck(t *testing.T) { + tmp := t.TempDir() + setTestHome(t, tmp) + clearCIEnv(t) + + dir := filepath.Join(tmp, ".flashduty") + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatal(err) + } + + s := &State{CheckedAt: time.Now().Add(-25 * time.Hour)} + data, err := yaml.Marshal(s) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, stateFileName), data, 0600); err != nil { + t.Fatal(err) + } + + if !ShouldCheck("0.6.0") { + t.Error("ShouldCheck should return true when check is stale") + } +} + +func TestLoadSaveState(t *testing.T) { + tmp := t.TempDir() + setTestHome(t, tmp) + + now := time.Now().Truncate(time.Second) + want := &State{ + CheckedAt: now, + LatestVersion: "v0.7.0", + LatestURL: "https://example.com/release", + } + + if err := saveState(want); err != nil { + t.Fatalf("saveState: %v", err) + } + + got := loadState() + if got.LatestVersion != want.LatestVersion { + t.Errorf("LatestVersion = %q, want %q", got.LatestVersion, want.LatestVersion) + } + if got.LatestURL != want.LatestURL { + t.Errorf("LatestURL = %q, want %q", got.LatestURL, want.LatestURL) + } + if got.CheckedAt.Unix() != want.CheckedAt.Unix() { + t.Errorf("CheckedAt = %v, want %v", got.CheckedAt, want.CheckedAt) + } +} + +func TestLoadState_CorruptFile(t *testing.T) { + tmp := t.TempDir() + setTestHome(t, tmp) + + dir := filepath.Join(tmp, ".flashduty") + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, stateFileName), []byte("{[invalid yaml"), 0600); err != nil { + t.Fatal(err) + } + + got := loadState() + if !got.CheckedAt.IsZero() { + t.Error("corrupt state file should return zero state") + } +} + +func TestFetchLatestVersion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + rel := githubRelease{ + TagName: "v0.7.0", + HTMLURL: "https://github.com/flashcatcloud/flashduty-cli/releases/tag/v0.7.0", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(rel) + })) + defer srv.Close() + + origURL := apiURL + apiURL = srv.URL + defer func() { apiURL = origURL }() + + tag, url, err := fetchLatestVersion() + if err != nil { + t.Fatalf("fetchLatestVersion: %v", err) + } + if tag != "v0.7.0" { + t.Errorf("tag = %q, want %q", tag, "v0.7.0") + } + if url != "https://github.com/flashcatcloud/flashduty-cli/releases/tag/v0.7.0" { + t.Errorf("url = %q", url) + } +} + +func TestFetchLatestVersion_Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + origURL := apiURL + apiURL = srv.URL + defer func() { apiURL = origURL }() + + _, _, err := fetchLatestVersion() + if err == nil { + t.Error("expected error for 404 response") + } +} + +func TestFetchLatestVersion_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{not valid json`)) + })) + defer srv.Close() + + origURL := apiURL + apiURL = srv.URL + defer func() { apiURL = origURL }() + + _, _, err := fetchLatestVersion() + if err == nil { + t.Error("expected error for invalid JSON response") + } +} + +func TestCheckForUpdate(t *testing.T) { + tmp := t.TempDir() + setTestHome(t, tmp) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + rel := githubRelease{ + TagName: "v0.7.0", + HTMLURL: "https://github.com/flashcatcloud/flashduty-cli/releases/tag/v0.7.0", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(rel) + })) + defer srv.Close() + + origURL := apiURL + apiURL = srv.URL + defer func() { apiURL = origURL }() + + result, err := CheckForUpdate("0.6.0") + if err != nil { + t.Fatalf("CheckForUpdate: %v", err) + } + if !result.UpdateAvailable { + t.Error("UpdateAvailable = false, want true") + } + if result.LatestVersion != "v0.7.0" { + t.Errorf("LatestVersion = %q, want %q", result.LatestVersion, "v0.7.0") + } + + state := loadState() + if state.LatestVersion != "v0.7.0" { + t.Errorf("state.LatestVersion = %q, want %q", state.LatestVersion, "v0.7.0") + } +}