diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0cda4c..2bd0546 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,16 +14,13 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24' - - name: Download dependencies - run: go mod download + - name: Tidy, test, and build + run: make tidy test build - - name: Build - run: go build -v ./... - - - name: Test - run: go test -v -race -coverprofile=coverage.out ./... + - name: Coverage gate + run: make test-cover-check lint: runs-on: ubuntu-latest @@ -32,7 +29,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24' - name: golangci-lint uses: golangci/golangci-lint-action@v7 diff --git a/.golangci.yml b/.golangci.yml index 5667a5f..86ef99d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,11 +3,15 @@ version: "2" linters: enable: - errcheck + - errorlint + - exhaustive - govet + - gosec - ineffassign + - misspell + - revive - staticcheck - unused - - misspell settings: errcheck: exclude-functions: @@ -19,9 +23,18 @@ linters: - path: _test\.go linters: - errcheck + - gosec - linters: - errcheck source: "defer.*\\.Close\\(\\)" + # Stuttering names (e.g., mail.MailClient) are accepted for clarity at call sites + - linters: + - revive + text: "stutters" + # Standard library package name conflicts are intentional (e.g., calendar, contacts) + - linters: + - revive + text: "avoid package names that conflict" formatters: enable: diff --git a/CLAUDE.md b/CLAUDE.md index c70331f..9a54028 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -9,282 +9,66 @@ gro is a **read-only** command-line interface for Google services written in Go. **Binary name:** `gro` **Module:** `github.com/open-cli-collective/google-readonly` -### Current Features +### Features - Gmail: Search, read, thread viewing, labels, attachments - Google Calendar: List calendars, view events, today/week shortcuts - Google Contacts: List contacts, search, view details, list groups - -### Planned Features -- Google Drive: List files, download content +- Google Drive: List files, search, get details, download, tree view, shared drives ## Quick Commands ```bash -# Build -make build - -# Run tests -make test - -# Run tests with coverage -make test-cover - -# Lint -make lint - -# Format code -make fmt - -# All checks (format, lint, test) -make verify - -# Install locally -make install - -# Clean build artifacts -make clean +make build # Build binary +make test # Run tests with race detection +make test-cover # Tests with HTML coverage report +make lint # Run golangci-lint +make fmt # Format code +make check # CI gate: tidy, lint, test, build +make install # Install to /usr/local/bin ``` -## Architecture - -``` -google-readonly/ -├── main.go # Entry point -├── cmd/gro/ # Main package -│ └── main.go -├── internal/ -│ ├── cmd/ -│ │ ├── root/ # Root command, version -│ │ │ └── root.go -│ │ ├── initcmd/ # OAuth setup (gro init) -│ │ │ ├── init.go -│ │ │ └── init_test.go -│ │ ├── config/ # gro config {show,test,clear} -│ │ │ ├── config.go -│ │ │ └── config_test.go -│ │ ├── mail/ # gro mail {search,read,thread,labels,attachments} -│ │ │ ├── mail.go # Parent command -│ │ │ ├── search.go -│ │ │ ├── read.go -│ │ │ ├── thread.go -│ │ │ ├── labels.go -│ │ │ ├── attachments.go -│ │ │ ├── attachments_list.go -│ │ │ ├── attachments_download.go -│ │ │ ├── output.go # Shared output helpers -│ │ │ └── *_test.go -│ │ │ -│ │ ├── calendar/ # gro calendar {list,events,get,today,week} -│ │ │ ├── calendar.go # Parent command with 'cal' alias -│ │ │ ├── list.go -│ │ │ ├── events.go -│ │ │ ├── get.go -│ │ │ ├── today.go -│ │ │ ├── week.go -│ │ │ ├── dates.go # Date parsing/formatting helpers -│ │ │ ├── output.go # Shared output helpers -│ │ │ └── *_test.go -│ │ │ -│ │ └── contacts/ # gro contacts {list,search,get,groups} -│ │ ├── contacts.go # Parent command with 'ppl' alias -│ │ ├── list.go -│ │ ├── search.go -│ │ ├── get.go -│ │ ├── groups.go -│ │ ├── output.go # Shared output helpers -│ │ └── *_test.go -│ │ -│ ├── gmail/ # Gmail API client -│ │ ├── client.go -│ │ ├── messages.go -│ │ ├── attachments.go -│ │ └── *_test.go -│ │ -│ ├── calendar/ # Google Calendar API client -│ │ ├── client.go -│ │ ├── events.go -│ │ └── *_test.go -│ │ -│ ├── contacts/ # Google People API client (Contacts) -│ │ ├── client.go -│ │ ├── contacts.go -│ │ └── *_test.go -│ │ -│ ├── keychain/ # Secure credential storage -│ │ ├── keychain.go -│ │ ├── keychain_darwin.go # macOS Keychain support -│ │ ├── keychain_linux.go # Linux secret-tool support -│ │ ├── keychain_windows.go # Windows file fallback -│ │ ├── token_source.go # Persistent token source wrapper -│ │ └── keychain_test.go -│ │ -│ ├── zip/ # Secure zip extraction -│ │ ├── extract.go -│ │ └── extract_test.go -│ │ -│ └── version/ # Build-time version injection -│ └── version.go -│ -├── .github/workflows/ -│ ├── ci.yml # Lint and test on PR/push -│ ├── auto-release.yml # Create tags on main push -│ └── release.yml # Build and release binaries -│ -├── packaging/ -│ ├── chocolatey/ # Windows Chocolatey package -│ └── winget/ # Windows Winget manifests -│ -├── Makefile # Build, test, lint targets -├── .goreleaser.yml # Cross-platform builds -└── .golangci.yml # Linter config (v2 format) -``` - -## Key Patterns - -### Read-Only by Design - -This CLI intentionally only supports read operations: -- Uses `gmail.GmailReadonlyScope` exclusively -- Only calls `.List()` and `.Get()` Gmail API methods -- No `.Send()`, `.Delete()`, `.Modify()`, or `.Trash()` operations - -### OAuth2 Configuration +## Documentation -Credentials are stored in `~/.config/google-readonly/`: -- `credentials.json` - OAuth client credentials (from Google Cloud Console) +| Document | Contents | +|----------|----------| +| `docs/architecture.md` | Dependency graph, package responsibilities, file naming conventions | +| `docs/golden-principles.md` | Mechanical rules enforced by structural tests | +| `docs/adding-a-domain.md` | Step-by-step checklist for adding a new Google API | -OAuth tokens are stored securely based on platform: -- **macOS**: System Keychain (via `security` CLI) -- **Linux**: libsecret (via `secret-tool`) if available, otherwise config file -- **Fallback**: `~/.config/google-readonly/token.json` with 0600 permissions - -### Command Patterns - -All commands use the factory pattern with `NewCommand()`: - -```go -func NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "search ", - Short: "Search for messages", - Args: cobra.ExactArgs(1), - RunE: runSearch, - } - cmd.Flags().Int64VarP(&searchMaxResults, "max", "m", 10, "Maximum results") - return cmd -} - -func runSearch(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() - if err != nil { - return err - } - // ... use client -} -``` +## Key Constraints -### Output Formats +- **Read-only by design**: Only `*ReadonlyScope` in `auth.AllScopes`. No write API methods. +- **Interface-at-consumer**: Each `internal/cmd/{domain}/output.go` defines its client interface. +- **ClientFactory DI**: Swappable factory for test mock injection. +- **--json on all leaf commands**: Every leaf subcommand supports `--json/-j`. +- **Structural enforcement**: `internal/architecture/architecture_test.go` enforces all patterns at CI time. -Commands support two output modes: -- **Text** (default): Human-readable formatted output -- **JSON** (`--json`): Machine-readable JSON for scripting - -```go -if jsonOutput { - return printJSON(messages) -} -// ... text output -``` +See `docs/golden-principles.md` for the full set of enforced rules. ## Testing -Tests use `testify` for assertions and table-driven test patterns: - -```go -func TestParseMessage(t *testing.T) { - tests := []struct { - name string - input *gmail.Message - expected *Message - }{ - {"basic message", ...}, - {"multipart message", ...}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := parseMessage(tt.input, true) - assert.Equal(t, tt.expected.Subject, result.Subject) - }) - } -} -``` - Run tests: `make test` -Coverage report: `make test-cover && open coverage.html` +Coverage: `make test-cover && open coverage.html` -## Adding a New Command +Tests use `internal/testutil/` for assertions (`testutil.Equal`, `testutil.NoError`, etc.) and fixtures (`testutil.SampleMessage()`, `testutil.SampleEvent()`, etc.). See `docs/golden-principles.md` for mock and test helper patterns. -1. Create new file in appropriate `internal/cmd/` directory -2. Define the command with `NewCommand()` factory function -3. Register in parent command's `NewCommand()` with `AddCommand()` -4. Add flags if needed -5. Write tests in `*_test.go` +## OAuth2 Configuration -Example: +Credentials: `~/.config/google-readonly/credentials.json` (from Google Cloud Console) -```go -func NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "labels", - Short: "List Gmail labels", - RunE: runLabels, - } - cmd.Flags().BoolVarP(&labelsJSON, "json", "j", false, "Output as JSON") - return cmd -} -``` - -## Dependencies - -Key dependencies: -- `github.com/spf13/cobra` - CLI framework -- `golang.org/x/oauth2` - OAuth2 client -- `google.golang.org/api/gmail/v1` - Gmail API client -- `google.golang.org/api/calendar/v3` - Calendar API client -- `google.golang.org/api/people/v1` - People API client (Contacts) -- `github.com/stretchr/testify` - Testing assertions (dev) - -## Error Message Conventions - -Follow [Go Code Review Comments](https://github.com/go/wiki/wiki/CodeReviewComments#error-strings): - -- Start with lowercase -- Don't end with punctuation -- Be descriptive but concise +Tokens stored securely per platform: +- **macOS**: System Keychain (via `security` CLI) +- **Linux**: libsecret (via `secret-tool`) if available, otherwise config file +- **Fallback**: `~/.config/google-readonly/token.json` with 0600 permissions -```go -// Good -return fmt.Errorf("failed to get message: %w", err) -return fmt.Errorf("attachment not found: %s", filename) +## Error Conventions -// Bad -return fmt.Errorf("Failed to get message: %w", err) // capitalized -return fmt.Errorf("attachment not found.") // ends with punctuation -``` +Follow [Go conventions](https://github.com/go/wiki/wiki/CodeReviewComments#error-strings): lowercase, no trailing punctuation, use `%w` for wrapping. ## Commit Conventions -Use conventional commits: - -``` -type(scope): description - -feat(mail): add attachment download command -fix(keychain): handle missing secret-tool -docs(readme): add installation instructions -``` +Use conventional commits: `type(scope): description` | Prefix | Purpose | Triggers Release? | |--------|---------|-------------------| @@ -292,46 +76,18 @@ docs(readme): add installation instructions | `fix:` | Bug fixes | Yes | | `docs:` | Documentation only | No | | `test:` | Adding/updating tests | No | -| `refactor:` | Code changes that don't fix bugs or add features | No | +| `refactor:` | Code changes (no bug fix or feature) | No | | `chore:` | Maintenance tasks | No | | `ci:` | CI/CD changes | No | -## CI & Release Workflow - -Releases are automated with a dual-gate system: - -**Gate 1 - Path filter:** Only triggers when Go code changes (`**.go`, `go.mod`, `go.sum`) -**Gate 2 - Commit prefix:** Only `feat:` and `fix:` commits create releases +## Dependencies -This means: -- `feat: add command` + Go files changed → release -- `fix: handle edge case` + Go files changed → release -- `docs:`, `ci:`, `test:`, `refactor:` → no release -- Changes only to docs, packaging, workflows → no release +- `github.com/spf13/cobra` - CLI framework +- `golang.org/x/oauth2` - OAuth2 client +- `google.golang.org/api/*` - Google API clients (Gmail, Calendar, People, Drive) ## Common Issues -### "Unable to read credentials file" - -Ensure OAuth credentials are set up: -```bash -mkdir -p ~/.config/google-readonly -# Download credentials.json from Google Cloud Console -mv ~/Downloads/client_secret_*.json ~/.config/google-readonly/credentials.json -``` - -### "Token has been expired or revoked" - -Clear the token and re-authenticate: -```bash -gro config clear -gro init -``` - -## Security +**"Unable to read credentials file"**: Run `gro init` and follow the OAuth setup wizard. -- **Read-only scope**: Cannot modify, send, or delete data -- **Secure token storage**: OAuth tokens stored in system keychain when available -- **File fallback**: When secure storage is unavailable, tokens stored with 0600 permissions -- **Token refresh persistence**: Refreshed tokens are automatically saved -- **No credential exposure**: Credentials never logged or transmitted +**"Token has been expired or revoked"**: Run `gro config clear && gro init`. diff --git a/Makefile b/Makefile index dcde8b0..68bc92d 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ LDFLAGS := -ldflags "-s -w \ DIST_DIR = dist -.PHONY: all build test test-cover test-short lint fmt deps verify clean release checksums install uninstall +.PHONY: all build test test-cover test-cover-check test-short lint fmt tidy deps verify check clean release checksums install uninstall all: build @@ -23,6 +23,15 @@ test-cover: go test -v -race -coverprofile=coverage.out ./... go tool cover -html=coverage.out -o coverage.html +test-cover-check: + @go test -race -coverprofile=coverage.out ./... > /dev/null 2>&1 + @total=$$(go tool cover -func=coverage.out | grep '^total:' | awk '{print $$3}' | tr -d '%'); \ + threshold=60; \ + echo "Total coverage: $${total}% (threshold: $${threshold}%)"; \ + if [ $$(echo "$$total < $$threshold" | bc) -eq 1 ]; then \ + echo "FAIL: coverage below threshold"; exit 1; \ + fi + test-short: go test -v -short ./... @@ -33,6 +42,10 @@ fmt: go fmt ./... goimports -local github.com/open-cli-collective/google-readonly -w . +tidy: + go mod tidy + git diff --exit-code go.mod go.sum + deps: go mod download go mod tidy @@ -40,6 +53,9 @@ deps: verify: go mod verify +# CI gate: everything that must pass before merge +check: tidy lint test build + clean: rm -rf bin/ $(DIST_DIR)/ coverage.out coverage.html $(BINARY) diff --git a/STANDARDS.md b/STANDARDS.md new file mode 100644 index 0000000..7bedb89 --- /dev/null +++ b/STANDARDS.md @@ -0,0 +1,1529 @@ +# Go CLI Style Guide + +This document catalogs coding conventions for Go CLI tools. It is intended for use as an operationalized code review prompt for AI-assisted review, but is also useful as a human reference. + +When reviewing code, flag deviations from these patterns. Be pragmatic: the goal is consistency within a codebase, not pedantic enforcement. If a deviation improves readability or correctness, note it as an intentional departure rather than a defect. + +### Guiding Philosophy + +Prefer clarity, composability, and maintainability over cleverness. Go's strength is boringly readable code — lean into that. Use the standard library aggressively. Resist the urge to abstract prematurely or import a dependency for something you can write in 20 lines. Use judgement, not dogma. + +--- + +## 1. Project Configuration + +### Module Layout + +Every tool gets its own `go.mod`. For a monorepo with shared libraries, use a top-level module with internal packages: + +``` +tools/ +├── go.mod +├── go.sum +├── cmd/ +│ ├── ingest/ +│ │ └── main.go +│ ├── reconcile/ +│ │ └── main.go +│ └── sync/ +│ └── main.go +├── internal/ +│ ├── config/ +│ ├── logging/ +│ └── aws/ +└── pkg/ # only if genuinely intended for external consumption +``` + +`internal/` is the default for shared code. `pkg/` is only for packages explicitly designed as public API for other modules. When in doubt, use `internal/`. + +### Build Configuration + +Pin the Go version in `go.mod` and use a `.tool-versions` or `go.env` for the team: + +``` +go 1.24 +``` + +Use `go.sum` for reproducible builds. Run `go mod tidy` before every commit — CI should fail if `go.mod` and `go.sum` are dirty. + +### Linting + +All projects use `golangci-lint` with a shared `.golangci.yml`. At minimum, enable: + +```yaml +linters: + enable: + - errcheck + - govet + - staticcheck + - unused + - ineffassign + - misspell + - revive + - gosec + - errorlint # enforce error wrapping best practices + - exhaustive # enforce exhaustive switch/select on enums +``` + +`go vet` and `staticcheck` findings are non-negotiable. Treat them as errors in CI. + +### Dependency Hygiene + +Order imports in three groups separated by blank lines: standard library, external dependencies, internal packages: + +```go +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" + "go.uber.org/zap" + + "github.com/yourorg/tools/internal/config" +) +``` + +`goimports` enforces this automatically. Run it on save. + +### Makefile + +Every repo has a `Makefile` at the root. This is the answer to "I just cloned this repo, now what." CI runs the same targets developers run locally. + +```makefile +.PHONY: build lint test tidy check + +# Build all binaries into bin/ +build: + go build -o bin/ ./cmd/... + +# Lint with golangci-lint (config in .golangci.yml) +lint: + golangci-lint run ./... + +# Run tests with race detector +test: + go test -race ./... + +# Tidy and verify modules are clean +tidy: + go mod tidy + git diff --exit-code go.mod go.sum + +# CI gate: everything that must pass before merge +check: tidy lint test build +``` + +**Rules:** + +- `make check` is the CI gate. It must pass before merge. Run it locally before pushing. +- `make build` outputs all binaries to `bin/`. Add `bin/` to `.gitignore`. +- `make tidy` fails if `go.mod` or `go.sum` are dirty — this catches forgotten `go mod tidy` runs. +- Add tool-specific targets as needed (`make migrate`, `make generate`, `make integration-test`), but the four core targets (`build`, `lint`, `test`, `tidy`) are non-negotiable. +- Keep targets simple. If a target exceeds ~5 lines of shell, it belongs in a script in `scripts/` that the Makefile calls. + +--- + +## 2. Type Design + +### Structs for Data, Methods for Behavior + +Go doesn't have records, but the same instinct applies: separate data-carrying types from service types. Data structs should be plain, exported fields. Service types hold dependencies and attach methods: + +```go +// Data: plain struct, no methods beyond serialization +type SyncResult struct { + CompanyID string + Success bool + FailedKeys []string + Duration time.Duration +} + +// Service: holds dependencies, has methods +type Reconciler struct { + db *sql.DB + logger *slog.Logger + clock func() time.Time // injectable for testing +} +``` + +### Prefer Value Semantics for Small Types + +Small data types (< ~128 bytes, no mutability needs) should be passed and returned by value, not pointer. This is Go's equivalent of preferring value types: + +```go +// Good: small, immutable-ish, pass by value +type Tenant struct { + ID string + Name string +} + +type DateRange struct { + Start time.Time + End time.Time +} + +// Pointer receiver appropriate: large struct or needs mutation +type IngestionState struct { + // ... many fields, mutated over lifetime +} + +func (s *IngestionState) MarkComplete(key string) { ... } +``` + +### Strongly-Typed Identifiers + +Wrap primitive identifiers in named types to prevent parameter confusion: + +```go +type TenantID string +type CompanyID string +type BusinessID string + +func GetConnection(tenant TenantID, company CompanyID) (*Connection, error) { ... } +``` + +This makes `GetConnection(companyID, tenantID)` a compile error instead of a subtle bug. Use sparingly — only where misorderings are a real risk (multiple string parameters of the same shape). + +### Constructor Functions + +Use `NewX` functions when a type requires initialization, validation, or has unexported fields. Return the concrete type, not an interface: + +```go +func NewReconciler(db *sql.DB, logger *slog.Logger, opts ...Option) *Reconciler { + r := &Reconciler{ + db: db, + logger: logger, + clock: time.Now, + } + for _, opt := range opts { + opt(r) + } + return r +} +``` + +For simple structs with all-exported fields, struct literals are fine — no constructor needed. + +### Enums via Constants + +Go lacks sum types. Use typed constants with `iota`, and always handle the zero value explicitly: + +```go +type PlatformType int + +const ( + PlatformUnknown PlatformType = iota // zero value = unknown + PlatformAccounting + PlatformBanking + PlatformCommerce +) + +func (p PlatformType) String() string { + switch p { + case PlatformAccounting: + return "Accounting" + case PlatformBanking: + return "Banking" + case PlatformCommerce: + return "Commerce" + default: + return fmt.Sprintf("PlatformType(%d)", p) + } +} +``` + +For cases where you need exhaustiveness checking, `exhaustive` lint catches missing switch arms. + +--- + +## 3. Interface Design + +### Accept Interfaces, Return Structs + +This is the single most important Go design principle. Define interfaces at the call site (consumer), not at the implementation: + +```go +// Good: interface defined where it's consumed, not where it's implemented +// In reconciler.go: +type AccountFetcher interface { + GetAccounts(ctx context.Context, tenant TenantID) ([]Account, error) +} + +type Reconciler struct { + accounts AccountFetcher + // ... +} + +// In accounts.go — no interface declared here, just a concrete type +type AccountStore struct { + db *sql.DB +} + +func (s *AccountStore) GetAccounts(ctx context.Context, tenant TenantID) ([]Account, error) { ... } +``` + +### Keep Interfaces Small + +One to three methods is ideal. If an interface has more than five methods, it's probably doing too much. The standard library's `io.Reader` (one method) is the gold standard. + +```go +// Good: focused interface +type TokenRefresher interface { + RefreshToken(ctx context.Context, tenant TenantID) (Token, error) +} + +// Suspicious: interface is a service dump +type BusinessManager interface { + GetBusiness(ctx context.Context, id string) (*Business, error) + CreateBusiness(ctx context.Context, b Business) error + UpdateBusiness(ctx context.Context, b Business) error + DeleteBusiness(ctx context.Context, id string) error + ListBusinesses(ctx context.Context, tenant TenantID) ([]Business, error) + GetBusinessConnections(ctx context.Context, id string) ([]Connection, error) + // ... this is just a struct with extra steps +} +``` + +### The Empty Interface Smell + +`any` (`interface{}`) in function signatures is almost always a design smell. It means "I gave up on types." Acceptable uses: logging arguments, JSON marshaling boundaries, generic containers. Unacceptable: core business logic parameters. + +--- + +## 4. Error Handling + +### Errors Are Values, Not Exceptions + +Every function that can fail returns an `error`. Check it immediately. Never discard errors silently: + +```go +// Good: check immediately, handle or propagate +result, err := store.GetAccounts(ctx, tenant) +if err != nil { + return nil, fmt.Errorf("fetching accounts for %s: %w", tenant, err) +} +``` + +### Wrapping With Context + +Always wrap errors with `fmt.Errorf("context: %w", err)` to build a trace. The message should describe what *this* function was trying to do, not repeat what the callee said: + +```go +// Good: each layer adds its own context +func (r *Reconciler) Run(ctx context.Context, tenant TenantID) error { + accounts, err := r.accounts.GetAccounts(ctx, tenant) + if err != nil { + return fmt.Errorf("reconciling tenant %s: %w", tenant, err) + } + // ... +} + +// Bad: restating the callee's error +if err != nil { + return fmt.Errorf("failed to get accounts: %w", err) // "failed to" is noise +} +``` + +### Sentinel Errors and Error Types + +Define sentinel errors for conditions callers need to match on. Use custom error types when callers need structured data: + +```go +var ( + ErrNotFound = errors.New("not found") + ErrAlreadyExists = errors.New("already exists") + ErrRateLimited = errors.New("rate limited") +) + +// Callers check with errors.Is: +if errors.Is(err, ErrNotFound) { + // handle missing resource +} + +// Custom error type when callers need details: +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation: %s: %s", e.Field, e.Message) +} + +// Callers check with errors.As: +var ve *ValidationError +if errors.As(err, &ve) { + fmt.Printf("bad field: %s\n", ve.Field) +} +``` + +### Don't Panic + +`panic` is for programmer bugs (impossible states, violated invariants in init), never for runtime errors. A CLI tool that panics on bad input is broken. Recover from panics only at the outermost boundary (e.g., a top-level middleware in a server, or the root command's `RunE`). + +### Eliminate `else` After Error Returns + +Go's error handling naturally produces guard clauses. Embrace them — never nest the happy path inside `else`: + +```go +// Good: guard clause, happy path is un-indented +token, err := auth.GetToken(ctx, tenant) +if err != nil { + return fmt.Errorf("getting token: %w", err) +} +// continue with token... + +// Bad: unnecessary nesting +token, err := auth.GetToken(ctx, tenant) +if err == nil { + // happy path buried in a branch +} else { + return err +} +``` + +--- + +## 5. Context Propagation + +### Context Is Always the First Parameter + +Every function that does I/O, calls other services, or could be cancelled takes `context.Context` as its first parameter. Named `ctx`: + +```go +func (s *SyncService) Sync(ctx context.Context, tenant TenantID, companyID CompanyID) error +``` + +### Never Store Context in Structs + +Context is request-scoped. Storing it in a struct means you're holding onto a cancelled context or sharing one across requests: + +```go +// Bad: context outlives the request +type Worker struct { + ctx context.Context // don't do this +} + +// Good: pass per-call +func (w *Worker) Process(ctx context.Context, job Job) error { ... } +``` + +### Respect Cancellation + +Check `ctx.Err()` or use `select` on `ctx.Done()` in loops and before expensive operations: + +```go +for _, batch := range batches { + if err := ctx.Err(); err != nil { + return fmt.Errorf("cancelled during batch processing: %w", err) + } + if err := processBatch(ctx, batch); err != nil { + return err + } +} +``` + +--- + +## 6. CLI Patterns + +### Cobra for All Tools + +Use Cobra for every CLI tool, even single-command ones. The cognitive cost of "which framework did this tool use" is worse than the tiny overhead of Cobra on a simple tool. Cobra gives you consistent `--help`, flag parsing, subcommand structure, and shell completion for free. Standardize on it and stop thinking about it. + +```go +func main() { + root := &cobra.Command{ + Use: "mytool", + Short: "Does the thing", + // No Run on root — forces subcommand usage + } + + root.AddCommand( + newSyncCmd(), + newReconcileCmd(), + newReportCmd(), + ) + + if err := root.Execute(); err != nil { + os.Exit(1) + } +} +``` + +### Command Factory Functions + +Each subcommand lives in its own file and returns a `*cobra.Command`. Wire dependencies in the `RunE` closure: + +```go +func newSyncCmd() *cobra.Command { + var ( + tenant string + dryRun bool + workers int + ) + + cmd := &cobra.Command{ + Use: "sync [company-id]", + Short: "Sync data for a company", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + companyID := args[0] + + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + logger := logging.New(cfg.LogLevel) + db, err := openDB(ctx, cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("connecting to database: %w", err) + } + defer db.Close() + + svc := NewSyncService(db, logger) + return svc.Run(ctx, TenantID(tenant), CompanyID(companyID), dryRun) + }, + } + + cmd.Flags().StringVar(&tenant, "tenant", "", "tenant identifier (required)") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "preview changes without writing") + cmd.Flags().IntVar(&workers, "workers", 4, "number of parallel workers") + _ = cmd.MarkFlagRequired("tenant") + + return cmd +} +``` + +### Exit Codes + +Use distinct exit codes for different failure modes. Define them as constants: + +```go +const ( + ExitOK = 0 + ExitUsageError = 1 + ExitRuntimeError = 2 + ExitConfigError = 3 + ExitPartialFailure = 4 +) +``` + +Cobra handles exit code 1 for usage errors by default. For other cases, handle in `main`: + +```go +func main() { + if err := root.Execute(); err != nil { + var cfgErr *config.Error + if errors.As(err, &cfgErr) { + os.Exit(ExitConfigError) + } + os.Exit(ExitRuntimeError) + } +} +``` + +### Stdin/Stdout/Stderr Discipline + +Standard output is for *data* (pipeable results). Standard error is for *diagnostics* (logs, progress, errors). Never mix them: + +```go +// Data goes to stdout — can be piped to jq, another tool, etc. +enc := json.NewEncoder(os.Stdout) +enc.Encode(result) + +// Diagnostics go to stderr +logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) +``` + +If the tool's primary output is human-readable (not piped), stdout is fine for both, but design for the pipeable case first. + +### Signal Handling + +CLI tools should handle SIGINT/SIGTERM gracefully. Use `signal.NotifyContext` for cancellation: + +```go +func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := run(ctx); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} +``` + +--- + +## 7. Configuration + +### Layered Config: Env > Flags > File > Defaults + +Configuration sources, in precedence order: environment variables override flags, flags override file values, file values override defaults. Use a single config struct: + +```go +type Config struct { + DatabaseURL string `env:"DATABASE_URL" json:"database_url"` + LogLevel string `env:"LOG_LEVEL" json:"log_level"` + Workers int `env:"WORKERS" json:"workers"` + Timeout time.Duration `env:"TIMEOUT" json:"timeout"` + DryRun bool // flag-only, no file/env +} + +func Load() (*Config, error) { + cfg := Config{ + LogLevel: "info", + Workers: 4, + Timeout: 30 * time.Second, + } + + // Load from file if present, then overlay env vars + // ... + + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + return &cfg, nil +} +``` + +### Validate Early, Fail Fast + +Validate all configuration at startup before doing any work. A config error 30 minutes into a batch job is a waste: + +```go +func (c *Config) Validate() error { + if c.DatabaseURL == "" { + return errors.New("DATABASE_URL is required") + } + if c.Workers < 1 || c.Workers > 64 { + return fmt.Errorf("workers must be 1-64, got %d", c.Workers) + } + return nil +} +``` + +### Testable Time + +Inject a clock function instead of calling `time.Now()` directly. This is the same principle as C#'s `TimeProvider`: + +```go +// In production +svc := &Service{clock: time.Now} + +// In tests +svc := &Service{clock: func() time.Time { return fixedTime }} +``` + +For more complex time needs, define a small interface: + +```go +type Clock interface { + Now() time.Time +} +``` + +--- + +## 8. Concurrency + +### Start Goroutines, Manage Lifetimes + +Every goroutine must have a clear shutdown path. Use `context.Context` for cancellation and `sync.WaitGroup` or `errgroup.Group` for completion: + +```go +g, ctx := errgroup.WithContext(ctx) + +for _, job := range jobs { + g.Go(func() error { + return processJob(ctx, job) + }) +} + +if err := g.Wait(); err != nil { + return fmt.Errorf("processing jobs: %w", err) +} +``` + +### errgroup for Parallel Tasks + +`errgroup` is the default for parallel work in CLI tools. It handles cancellation on first error and waitgroup semantics in one package: + +```go +g, ctx := errgroup.WithContext(ctx) +g.SetLimit(workers) // bounded parallelism + +for _, item := range items { + g.Go(func() error { + return process(ctx, item) + }) +} +return g.Wait() +``` + +### Channels for Pipelines, Not for Synchronization + +Use channels when data flows between stages. Use `sync.WaitGroup`, `errgroup`, or `sync.Mutex` for synchronization. Don't use a `chan struct{}` when a `WaitGroup` is clearer: + +```go +// Good: channel as a pipeline stage +func produce(ctx context.Context, items []Item) <-chan Item { + ch := make(chan Item) + go func() { + defer close(ch) + for _, item := range items { + select { + case ch <- item: + case <-ctx.Done(): + return + } + } + }() + return ch +} +``` + +### Never Launch Unbounded Goroutines + +Always limit concurrency for I/O-bound work. A CLI tool that launches 10,000 goroutines to hit an API will get rate-limited or OOM. Use `errgroup.SetLimit`, a semaphore channel, or a worker pool: + +```go +// Semaphore pattern +sem := make(chan struct{}, maxConcurrency) +for _, item := range items { + sem <- struct{}{} + go func() { + defer func() { <-sem }() + process(ctx, item) + }() +} +``` + +--- + +## 9. Data Access + +### database/sql with pgx or lib/pq + +Use `database/sql` as the interface layer. `pgx` is preferred as the driver for PostgreSQL (better performance, native types). Dapper-style explicit SQL applies equally here — write your own queries, don't hide behind an ORM: + +```go +const accountsQuery = ` + WITH target AS ( + SELECT b.id AS business_id + FROM business b + JOIN financial_institution fi ON b.tenant_id = fi.tenant_id + WHERE fi.fi_identifier = $1 AND b.company_id = $2 + ) + SELECT a.id, a.name, a.type + FROM account a + JOIN target t ON a.business_id = t.business_id + ORDER BY a.name +` + +func (s *AccountStore) GetAccounts(ctx context.Context, tenant TenantID, company CompanyID) ([]Account, error) { + rows, err := s.db.QueryContext(ctx, accountsQuery, string(tenant), string(company)) + if err != nil { + return nil, fmt.Errorf("querying accounts: %w", err) + } + defer rows.Close() + + var accounts []Account + for rows.Next() { + var a Account + if err := rows.Scan(&a.ID, &a.Name, &a.Type); err != nil { + return nil, fmt.Errorf("scanning account row: %w", err) + } + accounts = append(accounts, a) + } + return accounts, rows.Err() +} +``` + +### SQL Best Practices + +These carry over directly from the C# guide: + +- Always specify columns — avoid `SELECT *` +- Always use parameterized queries (`$1`, `$2`), never `fmt.Sprintf` into SQL +- Use CTEs over subqueries for readability +- Paginate large result sets; prefer cursor-based pagination over `OFFSET`/`LIMIT` +- Batch large `IN` clauses (100+ items) with `ANY($1::text[])` or temp tables + +### Transaction Management + +```go +tx, err := db.BeginTx(ctx, nil) +if err != nil { + return fmt.Errorf("beginning transaction: %w", err) +} +defer tx.Rollback() // no-op if committed + +// ... operations using tx ... + +if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) +} +``` + +The `defer tx.Rollback()` pattern is idiomatic — it's a no-op after a successful commit and ensures cleanup on any error path. + +--- + +## 10. Serialization + +### encoding/json from the Standard Library + +Use `encoding/json` by default. For performance-sensitive paths, `json/v2` (when stable) or `github.com/goccy/go-json` are acceptable drop-in replacements. + +### Struct Tags Are the Schema + +```go +type SyncRequest struct { + TenantID string `json:"tenant_id"` + CompanyID string `json:"company_id"` + Platform PlatformType `json:"platform"` + Priority int `json:"priority,omitempty"` +} +``` + +Use `omitempty` deliberately — it means "omit when zero value," which may or may not be what you want. An `int` field with `omitempty` drops `0`, which may be meaningful. + +### Custom Marshaling for Enums + +```go +func (p PlatformType) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + +func (p *PlatformType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch s { + case "Accounting": + *p = PlatformAccounting + case "Banking": + *p = PlatformBanking + default: + return fmt.Errorf("unknown platform type: %q", s) + } + return nil +} +``` + +--- + +## 11. Logging + +### slog for CLIs, zap for Services + +Use `log/slog` from the standard library for CLI tools. It's zero-dependency, has the right level of abstraction for console programs, and writes to stderr by default (which is what you want for CLIs — see Section 6 on stdout/stderr discipline). + +For long-running web services where logging is on the hot path, `go.uber.org/zap` is acceptable — it's measurably faster due to pre-allocation and zero-reflection design. But for CLI tools, logging throughput is never the bottleneck, and slog's simplicity wins. + +```go +// CLI: slog with text handler for human-readable output +handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, +}) +logger := slog.New(handler) + +// CLI: JSON handler when output will be ingested by log aggregation +handler := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, +}) +logger := slog.New(handler) + +// Good: structured key-value pairs +logger.Info("calculating insights", + "tenant", tenant, + "company_id", companyID, + "platform", platform, +) + +logger.Error("sync failed", + "tenant", tenant, + "company_id", companyID, + "error", err, + "elapsed_ms", elapsed.Milliseconds(), +) +``` + +### Named Fields, Not Interpolation + +Same principle as the C# guide — each value must be a discrete, queryable field: + +```go +// Good: each field independently queryable +logger.Info("processing transaction", + "tenant", tenant, + "company_id", companyID, + "txn_id", txnID, +) + +// Bad: opaque string defeats structured logging +logger.Info(fmt.Sprintf("%s::%s - processing transaction %s", tenant, companyID, txnID)) +``` + +### Logger Propagation + +Pass loggers as dependencies, not globals. Use `slog.With` to add context that applies to all messages in a scope: + +```go +func (s *SyncService) Run(ctx context.Context, tenant TenantID, company CompanyID) error { + log := s.logger.With("tenant", tenant, "company_id", company) + log.Info("starting sync") + // all subsequent log calls in this scope include tenant and company_id +} +``` + +### Log Levels + +| Level | Usage | +|-------|-------| +| Info | Start/completion of major operations, business events | +| Warn | Retry attempts, degraded scenarios, non-critical issues | +| Error | Failures (always include the error value) | +| Debug | Detailed operational info, only enabled in dev/troubleshooting | + +### Performance Timing + +```go +start := time.Now() +// ... work +logger.Info("operation complete", + "tenant", tenant, + "elapsed_ms", time.Since(start).Milliseconds(), +) +``` + +### Log Security + +Never log sensitive information: passwords, tokens, PII, full credit card numbers, SSNs. Be cautious with user attributes — only log what's necessary for debugging. + +--- + +## 12. Error Handling & Result Patterns + +### Guard Clauses and Early Returns + +Same as C#: reject invalid states early, keep the happy path at the lowest indentation level: + +```go +func (s *Service) Process(ctx context.Context, req Request) (*Result, error) { + if req.TenantID == "" { + return nil, &ValidationError{Field: "tenant_id", Message: "required"} + } + + conn, err := s.getConnection(ctx, req.TenantID) + if err != nil { + return nil, fmt.Errorf("getting connection: %w", err) + } + + // happy path continues un-indented... +} +``` + +### Multi-Value Returns for Outcome Disambiguation + +Go's multiple return values serve the same role as C#'s tuple returns: + +```go +// Found vs not-found vs error are three different outcomes +func (s *Store) GetAccount(ctx context.Context, id string) (account Account, found bool, err error) { + row := s.db.QueryRowContext(ctx, query, id) + if err := row.Scan(&account.ID, &account.Name); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return Account{}, false, nil + } + return Account{}, false, fmt.Errorf("scanning account: %w", err) + } + return account, true, nil +} +``` + +### ok-Pattern for Optional Results + +For lookups that may miss, use the comma-ok pattern familiar from map access: + +```go +val, ok := cache[key] +if !ok { + // handle miss +} +``` + +### Collecting Errors in Batch Operations + +For operations that process multiple items where you want partial results, collect errors rather than failing on the first one: + +```go +var errs []error +for _, item := range items { + if err := process(ctx, item); err != nil { + errs = append(errs, fmt.Errorf("item %s: %w", item.ID, err)) + continue + } +} +if len(errs) > 0 { + return fmt.Errorf("partial failure (%d/%d): %w", len(errs), len(items), errors.Join(errs...)) +} +``` + +--- + +## 13. Collection Patterns + +### Nil Slices Over Empty Slices + +In Go, a nil slice and an empty slice behave identically for `len`, `cap`, `range`, and `append`. Prefer nil (the zero value) — don't allocate when there's nothing to hold: + +```go +// Good: zero value is fine +var accounts []Account +// len(accounts) == 0, range works, append works + +// Unnecessary: allocating for no reason +accounts := make([]Account, 0) +accounts := []Account{} +``` + +Exception: JSON serialization. `json.Marshal(nil)` produces `null`, while `json.Marshal([]Account{})` produces `[]`. If the distinction matters to consumers, initialize explicitly. + +### Pre-Allocate When Size Is Known + +```go +results := make([]Result, 0, len(items)) +for _, item := range items { + results = append(results, transform(item)) +} +``` + +### maps Package for Common Operations + +Use `maps.Keys`, `maps.Values`, `maps.Clone` from the standard library instead of hand-rolling: + +```go +import "maps" + +keys := slices.Sorted(maps.Keys(accountsByID)) +``` + +### slices Package for Transformations + +Use `slices.SortFunc`, `slices.Contains`, `slices.Compact`, etc.: + +```go +import "slices" + +slices.SortFunc(accounts, func(a, b Account) int { + return strings.Compare(a.Name, b.Name) +}) + +hasAdmin := slices.ContainsFunc(roles, func(r Role) bool { + return r.Name == "admin" +}) +``` + +### Chunking for Batch Operations + +Same concept as C#'s `.Chunk()` — batch items for APIs with size limits: + +```go +func Chunk[T any](items []T, size int) [][]T { + var chunks [][]T + for size < len(items) { + items, chunks = items[size:], append(chunks, items[:size]) + } + return append(chunks, items) +} + +// Usage: DynamoDB BatchWriteItem supports max 25 items +for _, batch := range Chunk(writeRequests, 25) { + if err := writeBatch(ctx, batch); err != nil { + return err + } +} +``` + +Or use `slices.Chunk` if available on your Go version. + +--- + +## 14. Testing + +### Framework: Standard Library Only + +Use `testing` from the standard library. No testify, no gomega, no ginkgo. Table-driven tests and `t.Helper()` cover nearly everything. If you need mocks, write them by hand or use a small code generator — a mock framework dependency is almost never worth it. + +### Table-Driven Tests + +The default test structure. Each case is a named struct in a slice: + +```go +func TestGetPrimaryKeyName(t *testing.T) { + tests := []struct { + name string + platform PlatformType + want string + wantErr bool + }{ + {name: "accounting", platform: PlatformAccounting, want: "companyAndInsight"}, + {name: "banking", platform: PlatformBanking, want: "extCompanyId"}, + {name: "unknown panics", platform: PlatformUnknown, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetPrimaryKeyName(tt.platform) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("GetPrimaryKeyName(%v) = %q, want %q", tt.platform, got, tt.want) + } + }) + } +} +``` + +### Test Helpers + +Use `t.Helper()` for functions that report failures on behalf of the caller: + +```go +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func assertEqual[T comparable](t *testing.T, got, want T) { + t.Helper() + if got != want { + t.Errorf("got %v, want %v", got, want) + } +} +``` + +### Test Fixtures and Golden Files + +For complex test data, use `testdata/` directories. For output comparison, use golden files: + +```go +func TestRenderReport(t *testing.T) { + got := renderReport(testInput) + + golden := filepath.Join("testdata", t.Name()+".golden") + if *update { + os.WriteFile(golden, []byte(got), 0644) + } + + want, _ := os.ReadFile(golden) + if got != string(want) { + t.Errorf("output mismatch; run with -update to refresh golden files") + } +} +``` + +### Fake Implementations Over Mocks + +Write simple fake structs that satisfy interfaces. They're more readable and more maintainable than mock framework magic: + +```go +type fakeAccountStore struct { + accounts []Account + err error +} + +func (f *fakeAccountStore) GetAccounts(ctx context.Context, tenant TenantID) ([]Account, error) { + return f.accounts, f.err +} + +// In test: +store := &fakeAccountStore{ + accounts: []Account{{ID: "1", Name: "Test"}}, +} +svc := NewReconciler(store, slog.Default()) +``` + +### Test Naming + +Pattern: `TestFunctionName_Scenario` using sub-tests for cases: + +```go +func TestReconciler_Run(t *testing.T) { + t.Run("empty accounts returns early", func(t *testing.T) { ... }) + t.Run("mismatched totals returns error", func(t *testing.T) { ... }) + t.Run("successful reconciliation", func(t *testing.T) { ... }) +} +``` + +### Parallel Tests + +Mark tests as parallel when they don't share mutable state: + +```go +func TestExpensiveComputation(t *testing.T) { + t.Parallel() + // ... +} +``` + +For table-driven tests, capture the loop variable and run subtests in parallel: + +```go +for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // ... + }) +} +``` + +--- + +## 15. Formatting & Layout + +### gofmt Is Non-Negotiable + +All code is formatted with `gofmt`. No exceptions, no arguments, no custom settings. Use `goimports` as a superset (handles import ordering too). Configure your editor to run it on save. + +### Line Length + +Go has no official line limit, but target ~100-120 characters for readability. Wrap function signatures and long expressions: + +```go +func (s *SyncService) ProcessBatch( + ctx context.Context, + tenant TenantID, + companyID CompanyID, + items []SyncItem, + opts ProcessOptions, +) (*BatchResult, error) { +``` + +### File Organization + +Within a file, order declarations: + +1. Package-level constants and variables +2. Types (structs, interfaces) +3. Constructor functions (`NewX`) +4. Methods grouped by receiver type +5. Package-level functions (helpers, utilities) + +### One Type Per File (Usually) + +Major types get their own file. Small related types (a struct and its constructor, an interface and a helper) can share a file. If a file exceeds ~400 lines, consider splitting. + +### Comments: Focus on "Why" + +Same as C#: comments explain *why*, not *what*. If the what/how isn't clear, improve the name: + +```go +// Bad: restating the code +// Check if rate is greater than zero +if rate > 0 { ... } + +// Good: explaining domain knowledge +// Fed data uses VEB (bolivar fuerte) instead of the current ISO code VES (bolivar soberano). +if strings.EqualFold(code, "VES") { + return "VEB" +} +``` + +### Package Comments + +Every package should have a doc comment in a `doc.go` or at the top of the primary file: + +```go +// Package reconcile provides tools for reconciling account data +// between external platforms and the internal ledger. +package reconcile +``` + +### Exported Function Documentation + +All exported functions, types, and methods have doc comments. Start with the name of the thing: + +```go +// GetAccounts returns all accounts for the given tenant and company. +// It returns an empty slice (not nil) if no accounts are found. +func (s *Store) GetAccounts(ctx context.Context, tenant TenantID, company CompanyID) ([]Account, error) { +``` + +--- + +## 16. Naming + +### Go Naming Conventions + +These are non-negotiable — they're enforced by the compiler and community norms: + +- **Exported** identifiers are `PascalCase`: `GetAccounts`, `SyncResult`, `ErrNotFound` +- **Unexported** identifiers are `camelCase`: `processItem`, `accountStore`, `defaultTimeout` +- **Acronyms** are all-caps: `ID`, `URL`, `HTTP`, `API` — not `Id`, `Url`, `Http` +- **Package names** are lowercase, single word when possible: `config`, `sync`, `ledger` — not `ledgerUtils`, `sync_helpers` +- **Interface names**: single-method interfaces use the `-er` suffix: `Reader`, `Writer`, `Closer`, `Fetcher`. Multi-method interfaces describe the role: `AccountStore`, `TokenProvider` + +### Receiver Names + +Use one or two letter abbreviations, consistent across all methods on a type. Never `self` or `this`: + +```go +func (r *Reconciler) Run(ctx context.Context) error { ... } +func (r *Reconciler) validate() error { ... } + +func (s *Store) GetAccounts(ctx context.Context) ([]Account, error) { ... } +``` + +### Variable Names + +Short names for short scopes, descriptive names for long scopes: + +```go +// Good: short scope, short name +for i, a := range accounts { ... } + +// Good: longer scope, descriptive name +var activeAccounts []Account +for _, account := range allAccounts { + if account.IsActive { + activeAccounts = append(activeAccounts, account) + } +} +``` + +### Don't Stutter + +Package names qualify their exports. Don't repeat the package name in the type name: + +```go +// Bad: config.ConfigOptions stutters +package config +type ConfigOptions struct { ... } + +// Good: config.Options reads naturally +package config +type Options struct { ... } +``` + +--- + +## 17. Dependency Management + +### Standard Library First + +Before reaching for a third-party package, check if the standard library covers it. Go's stdlib is unusually comprehensive. Common cases where teams reach for dependencies unnecessarily: + +- HTTP clients: `net/http` is excellent. You rarely need a wrapper. +- JSON: `encoding/json` covers most cases. Only reach for alternatives on hot paths with benchmarks. +- Logging: `log/slog` for CLI tools (see Section 11). `zap` is acceptable for web services. +- Testing: `testing` + table-driven tests covers 95% of needs. + +### Acceptable Common Dependencies + +These are fine to pull in without justification: + +- `github.com/spf13/cobra` — CLI framework (mandatory for all tools, see Section 6) +- `github.com/aws/aws-sdk-go-v2` — AWS API access +- `github.com/jackc/pgx/v5` — PostgreSQL driver +- `golang.org/x/sync/errgroup` — parallel goroutine management +- `go.uber.org/zap` — structured logging for web services (not CLIs) + +Everything else needs a reason. "It's popular" is not a reason. + +### Keeping Dependencies Updated + +Run `go get -u ./...` and `go mod tidy` regularly. Pin major versions in `go.mod`. Review changelogs for security patches. + +--- + +## 18. Zero Values and Nullability + +### Embrace the Zero Value + +Go's zero values (`""`, `0`, `false`, `nil` for pointers/slices/maps) are part of the type system. Design types so that the zero value is useful: + +```go +// Good: zero value is a valid, empty state +type BatchResult struct { + Processed int + Failed int + Errors []error // nil = no errors +} + +// r := BatchResult{} is already valid, means "nothing processed, no errors" +``` + +### Pointer Fields Mean "Optional" + +Use pointer fields when the zero value is meaningful and you need to distinguish "unset" from "zero": + +```go +type UpdateRequest struct { + Name *string // nil = don't update, "" = set to empty + Workers *int // nil = don't update, 0 = valid value + Active *bool // nil = don't update, false = valid value +} +``` + +### Validate at Boundaries + +Same philosophy as the C# guide: eliminate nil/zero concerns at the edges. Inside the domain, types should carry only valid state: + +```go +// Boundary: validate and reject +func (h *Handler) HandleSync(req *http.Request) error { + var input SyncRequest + if err := json.NewDecoder(req.Body).Decode(&input); err != nil { + return fmt.Errorf("invalid request body: %w", err) + } + if input.TenantID == "" { + return &ValidationError{Field: "tenant_id", Message: "required"} + } + // domain functions receive validated, non-zero data + return h.svc.Sync(req.Context(), TenantID(input.TenantID), CompanyID(input.CompanyID)) +} +``` + +### Empty Collections Over Nil in JSON + +When serializing for external consumers, initialize slices if `null` vs `[]` matters: + +```go +type Response struct { + Items []Item `json:"items"` +} + +// If Items might be nil, initialize before marshaling: +if resp.Items == nil { + resp.Items = []Item{} +} +``` + +--- + +## 19. AWS SDK Patterns + +### Use SDK v2 + +All new code uses `aws-sdk-go-v2`. Do not use v1 (`aws-sdk-go`). + +### SQS Long Polling + +```go +out, err := client.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: &queueURL, + WaitTimeSeconds: 20, + MaxNumberOfMessages: 10, +}) +``` + +### S3 Pagination + +Use the paginator helpers from SDK v2: + +```go +paginator := s3.NewListObjectsV2Paginator(client, &s3.ListObjectsV2Input{ + Bucket: &bucket, + Prefix: &prefix, +}) + +for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return fmt.Errorf("listing objects: %w", err) + } + for _, obj := range page.Contents { + // process + } +} +``` + +### DynamoDB Batch Operations + +```go +// BatchWriteItem supports max 25 items +for _, batch := range Chunk(writeRequests, 25) { + _, err := client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{ + RequestItems: map[string][]types.WriteRequest{ + tableName: batch, + }, + }) + if err != nil { + return fmt.Errorf("batch write: %w", err) + } +} +``` + +--- + +## 20. Patterns to Maintain + +### Prefer Standard Library Over Hand-Rolled + +Same principle as the C# guide: check whether Go's stdlib already provides the functionality before writing a utility. In particular: + +- `slices` and `maps` packages replace many hand-rolled loops (Go 1.21+) +- `slog` replaces custom logging for CLI tools (Go 1.21+); `zap` remains appropriate for services +- `errors.Join` replaces custom multi-error types (Go 1.20+) +- `sync.OnceValue` replaces lazy initialization patterns (Go 1.21+) +- `http.NewServeMux` pattern matching replaces many router libraries (Go 1.22+) + +### decimal for Money + +Go's `float64` has the same problems as C#'s `double`. Use `shopspring/decimal` or a similar arbitrary-precision library for monetary calculations. Never `float64` for money: + +```go +import "github.com/shopspring/decimal" + +rate := decimal.NewFromString("0.0425") +monthly := rate.Div(decimal.NewFromInt(12)) +``` + +### time.Time, Not int64 + +Represent timestamps as `time.Time`, not Unix epoch integers. Convert at boundaries (JSON serialization, database storage), not in domain logic. + +### Performance Behind Good Names + +Same as C#: a function named `GetExchangeRate` can use `unsafe.Pointer` arithmetic internally if profiling demands it. The caller sees a clean API. Optimize hot paths, not cold paths. Profile before optimizing. + +### Composition Over Inheritance + +Go doesn't have inheritance. Use embedding for shared structure, interfaces for shared behavior: + +```go +// Embedding: shared fields +type BaseJob struct { + TenantID TenantID + CompanyID CompanyID + CreatedAt time.Time +} + +type SyncJob struct { + BaseJob + Platform PlatformType + Priority int +} + +// Interface: shared behavior +type Processor interface { + Process(ctx context.Context, job Job) error +} +``` diff --git a/cmd/gro/main.go b/cmd/gro/main.go index 6ff498d..5e3ce7b 100644 --- a/cmd/gro/main.go +++ b/cmd/gro/main.go @@ -1,9 +1,17 @@ +// Package main is the entry point for the gro CLI. package main import ( + "context" + "os" + "os/signal" + "syscall" + "github.com/open-cli-collective/google-readonly/internal/cmd/root" ) func main() { - root.Execute() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + root.ExecuteContext(ctx) } diff --git a/docs/adding-a-domain.md b/docs/adding-a-domain.md new file mode 100644 index 0000000..518b5f6 --- /dev/null +++ b/docs/adding-a-domain.md @@ -0,0 +1,112 @@ +# Adding a New Google API Domain + +This checklist covers adding a new Google API (e.g., Google Tasks, Google Sheets) to gro. The structural tests in `internal/architecture/architecture_test.go` automatically enforce steps marked with [enforced]. + +## Checklist + +### 1. Add the OAuth scope + +In `internal/auth/auth.go`, add the readonly scope to `AllScopes`: +```go +var AllScopes = []string{ + gmail.GmailReadonlyScope, + calendar.CalendarReadonlyScope, + people.ContactsReadonlyScope, + drive.DriveReadonlyScope, + tasks.TasksReadonlyScope, // new +} +``` + +[enforced] Only `*ReadonlyScope` constants are permitted. + +### 2. Create the API client package + +Create `internal/{domain}/` with: +- `client.go` — `Client` struct, `NewClient(ctx context.Context) (*Client, error)`, methods +- Data model files as needed +- `*_test.go` — Unit tests for parsing and data models + +The constructor must follow the established pattern: +```go +func NewClient(ctx context.Context) (*Client, error) { + client, err := auth.GetHTTPClient(ctx) + if err != nil { + return nil, fmt.Errorf("loading OAuth client: %w", err) + } + srv, err := tasks.NewService(ctx, option.WithHTTPClient(client)) + if err != nil { + return nil, fmt.Errorf("creating Tasks service: %w", err) + } + return &Client{service: srv}, nil +} +``` + +[enforced] This package must NOT import any `internal/cmd/` package. + +### 3. Create the command package + +Create `internal/cmd/{domain}/` with these files: + +**`output.go`** — [enforced] Must contain: +- An exported interface ending in `Client` (e.g., `TasksClient`) +- A `ClientFactory` variable +- A `newXClient()` wrapper function +- A `printJSON()` function + +**`{domain}.go`** — [enforced] Must contain: +- An exported `NewCommand()` function returning `*cobra.Command` +- `AddCommand()` calls for all subcommands + +**Each subcommand file** — [enforced] Each leaf command must have `--json/-j` flag: +- Unexported `new{Sub}Command()` factory +- `--json/-j` flag for JSON output (exempt for binary download commands) + +### 4. Create test infrastructure + +**`mock_test.go`** — Function-field mock with compile-time interface check: +```go +type MockTasksClient struct { + ListTasksFunc func(ctx context.Context, ...) (...) +} + +var _ TasksClient = (*MockTasksClient)(nil) +``` + +**`handlers_test.go`** — Test helpers using centralized utilities: +```go +func withMockClient(mock TasksClient, f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (TasksClient, error) { + return mock, nil + }, f) +} + +func withFailingClientFactory(f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (TasksClient, error) { + return nil, errors.New("connection failed") + }, f) +} +``` + +Use `testutil.CaptureStdout(t, func() { ... })` for output capture. + +### 5. Add test fixtures + +In `internal/testutil/fixtures.go`, add `SampleX()` functions for the new API types. + +### 6. Register the domain command + +In `internal/cmd/root/root.go`, add: +```go +cmd.AddCommand(tasks.NewCommand()) +``` + +### 7. Update structural test registration + +In `internal/architecture/architecture_test.go`, add the new domain to: +- `domainPackages` slice (e.g., `"tasks"`) +- `apiClientPackages` slice (e.g., `"tasks"`) +- `domainCommands()` map + +### 8. Verify + +Run `make check`. The structural tests will catch any missing patterns. diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..5923591 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,79 @@ +# Architecture + +## Dependency Graph + +``` +cmd/gro/main.go + -> internal/cmd/root/ + -> internal/cmd/mail/ (MailClient interface + ClientFactory) + -> internal/cmd/calendar/ (CalendarClient interface + ClientFactory) + -> internal/cmd/contacts/ (ContactsClient interface + ClientFactory) + -> internal/cmd/drive/ (DriveClient interface + ClientFactory) + -> internal/cmd/initcmd/ (OAuth setup wizard) + -> internal/cmd/config/ (Credential management) + +Each cmd/ package depends on its API client: + internal/cmd/mail/ -> internal/gmail/ + internal/cmd/calendar/ -> internal/calendar/ + internal/cmd/contacts/ -> internal/contacts/ + internal/cmd/drive/ -> internal/drive/ + +All API clients depend on: + internal/auth/ -> internal/keychain/, internal/config/ + +Shared utilities (no internal deps): + internal/testutil/ Test fixtures and assertion helpers + internal/output/ JSON output encoding + internal/format/ Human-readable formatting + internal/errors/ Error types + internal/log/ Logging + internal/cache/ Response caching + internal/zip/ Secure zip extraction + internal/version/ Build-time version injection +``` + +## Data Flow + +``` +User -> cobra command -> ClientFactory(ctx) -> API Client -> auth.GetHTTPClient -> Google API + | + internal/{gmail,calendar,contacts,drive}/ +``` + +## Package Responsibilities + +| Package | Responsibility | +|---------|---------------| +| `cmd/gro/` | Entry point, calls `root.NewCommand()` | +| `internal/cmd/root/` | Root cobra command, registers all domain commands | +| `internal/cmd/{domain}/` | Command handlers, client interface, output formatting | +| `internal/{gmail,calendar,contacts,drive}/` | API client, data models, response parsing | +| `internal/auth/` | OAuth2 config loading, HTTP client creation | +| `internal/keychain/` | Platform-specific secure token storage | +| `internal/testutil/` | Test assertions, fixtures, helpers | +| `internal/architecture/` | Structural tests enforcing codebase conventions | + +## File Naming Conventions + +Each domain command package (`internal/cmd/{domain}/`) contains: + +| File | Purpose | +|------|---------| +| `{domain}.go` | Parent command with `NewCommand()` and `AddCommand()` calls | +| `output.go` | Client interface, `ClientFactory`, `printJSON()`, text formatters | +| `{subcommand}.go` | One file per subcommand with `new{Sub}Command()` factory | +| `mock_test.go` | Mock client with function fields + compile-time interface check | +| `handlers_test.go` | `withMockClient()`, `withFailingClientFactory()`, integration tests | +| `*_test.go` | Additional unit tests | + +Each API client package (`internal/{domain}/`) contains: + +| File | Purpose | +|------|---------| +| `client.go` | `Client` struct, `NewClient(ctx)`, client methods | +| Additional `.go` | Data models, parsing helpers | +| `*_test.go` | Unit tests | + +## Structural Enforcement + +Architectural invariants are enforced by tests in `internal/architecture/architecture_test.go`. These run as part of `make check` and CI. See `docs/golden-principles.md` for the rules being enforced. diff --git a/docs/golden-principles.md b/docs/golden-principles.md new file mode 100644 index 0000000..837d2ef --- /dev/null +++ b/docs/golden-principles.md @@ -0,0 +1,82 @@ +# Golden Principles + +These are the mechanical rules that keep the codebase consistent. Each rule is enforced by structural tests in `internal/architecture/architecture_test.go` and runs automatically in CI via `make check`. + +## 1. Interface-at-consumer + +Every domain command package (`internal/cmd/{domain}/`) defines its own client interface in `output.go`. The API client package (`internal/{domain}/`) does NOT define an interface — it returns a concrete `*Client` struct. + +**Enforced by:** `TestDomainPackagesDefineClientInterface` + +## 2. ClientFactory for dependency injection + +Every domain command package declares a package-level `ClientFactory` variable. Production code calls `ClientFactory(ctx)`. Tests override it to inject mocks. + +```go +var ClientFactory = func(ctx context.Context) (XClient, error) { + return x.NewClient(ctx) +} +``` + +**Enforced by:** `TestDomainPackagesHaveClientFactory` + +## 3. NewCommand() factory + +Parent commands export `NewCommand()` returning `*cobra.Command`. Subcommands use unexported `new{Sub}Command()`. Parent commands register subcommands via `cmd.AddCommand()`. + +**Enforced by:** `TestDomainPackagesExportNewCommand` + +## 4. --json on every leaf command + +All leaf subcommands (commands with no children) support `--json/-j` for machine-readable output. Download commands that output binary file data are exempt. + +**Enforced by:** `TestAllLeafCommandsHaveJSONFlag` + +## 5. Read-only only + +Only `*ReadonlyScope` constants may appear in `auth.AllScopes`. No write API methods (`.Send()`, `.Trash()`, `.BatchModify()`, etc.) in production code. + +**Enforced by:** `TestAllScopesAreReadOnly`, `TestNoWriteAPIMethodsInProductionCode` + +## 6. Dependency direction + +- API client packages must NOT import `internal/cmd/` (clients don't know about commands) +- `internal/auth/` must NOT import API client packages (auth is lower-level) + +**Enforced by:** `TestAPIClientPackagesDoNotImportCmd`, `TestAuthPackageDoesNotImportAPIClients` + +## 7. context.Context on all I/O methods + +Every public method that performs I/O takes `context.Context` as its first parameter. The only exceptions are pure getter methods that return cached data (e.g., `GetLabelName`, `GetLabels`). + +## 8. Error wrapping + +Use `fmt.Errorf("doing X: %w", err)` at every level. Error messages are lowercase and have no trailing punctuation, following [Go conventions](https://github.com/go/wiki/wiki/CodeReviewComments#error-strings). + +## 9. Mock pattern + +Mocks use function fields in `mock_test.go` with a compile-time interface check: + +```go +type MockXClient struct { + MethodFunc func(...) (...) +} + +var _ XClient = (*MockXClient)(nil) + +func (m *MockXClient) Method(...) (...) { + if m.MethodFunc != nil { + return m.MethodFunc(...) + } + return zero, nil +} +``` + +Test helpers `withMockClient` and `withFailingClientFactory` use `testutil.WithFactory` to swap the `ClientFactory`. + +## 10. Centralized test helpers + +- `testutil.CaptureStdout(t, func())` — captures stdout during command execution +- `testutil.WithFactory(&factory, replacement, func())` — generic factory swap +- `testutil.SampleX()` functions — fixture data for all API types +- `testutil.Equal`, `testutil.NoError`, etc. — assertion helpers diff --git a/go.mod b/go.mod index 183f17c..cec1e5c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.24.0 require ( github.com/spf13/cobra v1.8.0 - github.com/stretchr/testify v1.11.1 golang.org/x/oauth2 v0.34.0 google.golang.org/api v0.262.0 ) @@ -14,7 +13,6 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -23,7 +21,6 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect github.com/googleapis/gax-go/v2 v2.16.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect @@ -37,5 +34,4 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20260120174246-409b4a993575 // indirect google.golang.org/grpc v1.78.0 // indirect google.golang.org/protobuf v1.36.11 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9b5fa78..9e1eeba 100644 --- a/go.sum +++ b/go.sum @@ -30,14 +30,8 @@ github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5 github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= @@ -86,7 +80,5 @@ google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpW google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/architecture/architecture_test.go b/internal/architecture/architecture_test.go new file mode 100644 index 0000000..dca606a --- /dev/null +++ b/internal/architecture/architecture_test.go @@ -0,0 +1,375 @@ +package architecture + +import ( + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/google-readonly/internal/auth" + calcmd "github.com/open-cli-collective/google-readonly/internal/cmd/calendar" + contactscmd "github.com/open-cli-collective/google-readonly/internal/cmd/contacts" + drivecmd "github.com/open-cli-collective/google-readonly/internal/cmd/drive" + mailcmd "github.com/open-cli-collective/google-readonly/internal/cmd/mail" +) + +// domainPackages lists the command packages that must follow structural conventions. +var domainPackages = []string{"mail", "calendar", "contacts", "drive"} + +// apiClientPackages lists the internal API client package directory names. +var apiClientPackages = []string{"gmail", "calendar", "contacts", "drive"} + +// jsonExemptCommands lists leaf commands exempt from the --json flag requirement. +// Key format: "parent subcommand" (e.g., "mail attachments download"). +// Only add exemptions for commands that output binary file data where JSON is inapplicable. +var jsonExemptCommands = map[string]bool{ + "mail attachments download": true, // writes binary attachment files to disk + "drive download": true, // writes binary file data to disk +} + +// domainCommands returns the top-level cobra.Command for each domain package. +func domainCommands() map[string]*cobra.Command { + return map[string]*cobra.Command{ + "mail": mailcmd.NewCommand(), + "calendar": calcmd.NewCommand(), + "contacts": contactscmd.NewCommand(), + "drive": drivecmd.NewCommand(), + } +} + +// findModuleRoot walks up from the working directory to locate go.mod. +func findModuleRoot(t *testing.T) string { + t.Helper() + dir, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("could not find module root (go.mod)") + } + dir = parent + } +} + +// parseNonTestFiles parses all non-test .go files in a directory. +func parseNonTestFiles(t *testing.T, dir string) []*ast.File { + t.Helper() + fset := token.NewFileSet() + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("reading directory %s: %v", dir, err) + } + var files []*ast.File + for _, entry := range entries { + name := entry.Name() + if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { + continue + } + f, err := parser.ParseFile(fset, filepath.Join(dir, name), nil, parser.ParseComments) + if err != nil { + t.Fatalf("parsing %s: %v", name, err) + } + files = append(files, f) + } + return files +} + +// collectImports returns all import paths from a set of parsed files. +func collectImports(files []*ast.File) []string { + var imports []string + for _, f := range files { + for _, imp := range f.Imports { + path := strings.Trim(imp.Path.Value, `"`) + imports = append(imports, path) + } + } + return imports +} + +type leafInfo struct { + path string + cmd *cobra.Command +} + +// leafCommands recursively collects all leaf commands (commands with no subcommands). +func leafCommands(cmd *cobra.Command, parentPath string) []leafInfo { + subs := cmd.Commands() + if len(subs) == 0 { + return []leafInfo{{path: parentPath, cmd: cmd}} + } + var leaves []leafInfo + for _, sub := range subs { + subPath := parentPath + " " + sub.Name() + leaves = append(leaves, leafCommands(sub, subPath)...) + } + return leaves +} + +// --------------------------------------------------------------------------- +// Structural tests +// --------------------------------------------------------------------------- + +// TestDomainPackagesDefineClientInterface verifies that every domain command package +// declares an exported interface type whose name ends in "Client". +func TestDomainPackagesDefineClientInterface(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + for _, pkg := range domainPackages { + t.Run(pkg, func(t *testing.T) { + t.Parallel() + dir := filepath.Join(root, "internal", "cmd", pkg) + files := parseNonTestFiles(t, dir) + + var found bool + for _, f := range files { + for _, decl := range f.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + _, isInterface := typeSpec.Type.(*ast.InterfaceType) + if isInterface && strings.HasSuffix(typeSpec.Name.Name, "Client") { + found = true + if !typeSpec.Name.IsExported() { + t.Errorf("client interface %s must be exported", typeSpec.Name.Name) + } + } + } + } + } + + if !found { + t.Errorf("package internal/cmd/%s must define an exported interface ending in 'Client' (see docs/golden-principles.md)", pkg) + } + }) + } +} + +// TestDomainPackagesHaveClientFactory verifies that every domain command package +// declares a package-level ClientFactory variable for dependency injection. +func TestDomainPackagesHaveClientFactory(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + for _, pkg := range domainPackages { + t.Run(pkg, func(t *testing.T) { + t.Parallel() + dir := filepath.Join(root, "internal", "cmd", pkg) + files := parseNonTestFiles(t, dir) + + var found bool + for _, f := range files { + for _, decl := range f.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.VAR { + continue + } + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for _, name := range valueSpec.Names { + if name.Name == "ClientFactory" { + found = true + } + } + } + } + } + + if !found { + t.Errorf("package internal/cmd/%s must define a ClientFactory variable for dependency injection (see docs/golden-principles.md)", pkg) + } + }) + } +} + +// TestDomainPackagesExportNewCommand verifies that every domain command package +// exports a NewCommand() function (top-level, not a method). +func TestDomainPackagesExportNewCommand(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + for _, pkg := range domainPackages { + t.Run(pkg, func(t *testing.T) { + t.Parallel() + dir := filepath.Join(root, "internal", "cmd", pkg) + files := parseNonTestFiles(t, dir) + + var found bool + for _, f := range files { + for _, decl := range f.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + // Must be a package-level function (no receiver), named NewCommand + if funcDecl.Name.Name == "NewCommand" && funcDecl.Recv == nil { + found = true + } + } + } + + if !found { + t.Errorf("package internal/cmd/%s must export a NewCommand() function (see docs/golden-principles.md)", pkg) + } + }) + } +} + +// TestAllLeafCommandsHaveJSONFlag verifies that every leaf subcommand +// (commands with no children) declares a --json/-j flag. +func TestAllLeafCommandsHaveJSONFlag(t *testing.T) { + t.Parallel() + + for name, cmd := range domainCommands() { + for _, leaf := range leafCommands(cmd, name) { + t.Run(strings.TrimSpace(leaf.path), func(t *testing.T) { + t.Parallel() + key := strings.TrimSpace(leaf.path) + if jsonExemptCommands[key] { + t.Skipf("exempt from --json requirement") + } + flag := leaf.cmd.Flags().Lookup("json") + if flag == nil { + t.Errorf("leaf command %q must have a --json flag (see docs/golden-principles.md)", key) + return + } + if flag.Shorthand != "j" { + t.Errorf("leaf command %q --json flag must have shorthand 'j', got %q", key, flag.Shorthand) + } + }) + } + } +} + +// TestAPIClientPackagesDoNotImportCmd verifies that API client packages +// (internal/gmail, internal/calendar, etc.) never import command packages. +// Dependency direction must be: cmd -> api client, never the reverse. +func TestAPIClientPackagesDoNotImportCmd(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + for _, pkg := range apiClientPackages { + t.Run(pkg, func(t *testing.T) { + t.Parallel() + dir := filepath.Join(root, "internal", pkg) + files := parseNonTestFiles(t, dir) + imports := collectImports(files) + + for _, imp := range imports { + if strings.Contains(imp, "internal/cmd") { + t.Errorf("API client package internal/%s must not import cmd packages, but imports %q", pkg, imp) + } + } + }) + } +} + +// TestAuthPackageDoesNotImportAPIClients verifies that the auth package +// does not depend on any internal API client packages. +// Dependency direction must be: api client -> auth, never the reverse. +func TestAuthPackageDoesNotImportAPIClients(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + dir := filepath.Join(root, "internal", "auth") + files := parseNonTestFiles(t, dir) + imports := collectImports(files) + + for _, imp := range imports { + for _, apiPkg := range apiClientPackages { + if strings.HasSuffix(imp, "/internal/"+apiPkg) { + t.Errorf("auth package must not import API client package internal/%s", apiPkg) + } + } + } +} + +// TestAllScopesAreReadOnly verifies that every OAuth scope in auth.AllScopes +// is a read-only scope. This is the primary mechanical enforcement of the +// read-only-by-design principle. +func TestAllScopesAreReadOnly(t *testing.T) { + t.Parallel() + + if len(auth.AllScopes) == 0 { + t.Fatal("auth.AllScopes must not be empty") + } + + for _, scope := range auth.AllScopes { + if !strings.Contains(scope, "readonly") { + t.Errorf("scope %q is not a readonly scope; all scopes in AllScopes must be read-only", scope) + } + } +} + +// TestNoWriteAPIMethodsInProductionCode scans all non-test Go source files +// for Google API write method calls. This is defense-in-depth on top of the +// scope check — even with readonly scopes, we don't want write method calls +// in the codebase since they indicate incorrect intent. +func TestNoWriteAPIMethodsInProductionCode(t *testing.T) { + t.Parallel() + root := findModuleRoot(t) + + // These patterns are specific to Google API client libraries and unlikely + // to appear in other contexts. Generic method names like .Delete() or + // .Insert() are intentionally excluded to avoid false positives. + forbiddenPatterns := []string{ + ".Send(", + ".Trash(", + ".Untrash(", + ".BatchModify(", + ".BatchDelete(", + } + + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == "vendor" || name == ".git" || name == "dist" || name == "bin" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return nil + } + + data, readErr := os.ReadFile(path) + if readErr != nil { + t.Errorf("reading %s: %v", path, readErr) + return nil + } + content := string(data) + rel, _ := filepath.Rel(root, path) + + for _, pattern := range forbiddenPatterns { + if strings.Contains(content, pattern) { + t.Errorf("file %s contains forbidden write API method %q — this CLI is read-only by design", rel, pattern) + } + } + return nil + }) + if err != nil { + t.Fatalf("walking source tree: %v", err) + } +} diff --git a/internal/architecture/doc.go b/internal/architecture/doc.go new file mode 100644 index 0000000..fbc1d9d --- /dev/null +++ b/internal/architecture/doc.go @@ -0,0 +1,5 @@ +// Package architecture contains structural tests that enforce codebase conventions. +// These tests verify that all domain packages follow established patterns: +// client interfaces, factory functions, command structure, and dependency direction. +// See docs/golden-principles.md for the rules these tests enforce. +package architecture diff --git a/internal/auth/auth.go b/internal/auth/auth.go index ca02d3f..cf0e336 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,3 +1,4 @@ +// Package auth provides OAuth2 authentication and credential management for Google APIs. package auth import ( @@ -31,7 +32,7 @@ func GetOAuthConfig() (*oauth2.Config, error) { if err != nil { return nil, err } - b, err := os.ReadFile(credPath) + b, err := os.ReadFile(credPath) //nolint:gosec // Path from user config directory if err != nil { return nil, fmt.Errorf("unable to read credentials file: %w", err) } @@ -62,7 +63,7 @@ func GetHTTPClient(ctx context.Context) (*http.Client, error) { } // Create persistent token source that saves refreshed tokens - tokenSource := keychain.NewPersistentTokenSource(config, tok) + tokenSource := keychain.NewPersistentTokenSource(ctx, config, tok) return oauth2.NewClient(ctx, tokenSource), nil } @@ -77,7 +78,7 @@ func ExchangeAuthCode(ctx context.Context, config *oauth2.Config, code string) ( } func tokenFromFile(file string) (*oauth2.Token, error) { - f, err := os.Open(file) + f, err := os.Open(file) //nolint:gosec // Path from user config directory if err != nil { return nil, err } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 85ae931..c39722b 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -3,11 +3,9 @@ package auth import ( "os" "path/filepath" + "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/open-cli-collective/google-readonly/internal/config" ) @@ -18,12 +16,18 @@ func TestDeprecatedWrappers(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) authDir, err := GetConfigDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } configDir, err := config.GetConfigDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - assert.Equal(t, configDir, authDir) + if authDir != configDir { + t.Errorf("got %v, want %v", authDir, configDir) + } }) t.Run("GetCredentialsPath delegates to config package", func(t *testing.T) { @@ -31,12 +35,18 @@ func TestDeprecatedWrappers(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) authPath, err := GetCredentialsPath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } configPath, err := config.GetCredentialsPath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - assert.Equal(t, configPath, authPath) + if authPath != configPath { + t.Errorf("got %v, want %v", authPath, configPath) + } }) t.Run("GetTokenPath delegates to config package", func(t *testing.T) { @@ -44,43 +54,75 @@ func TestDeprecatedWrappers(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) authPath, err := GetTokenPath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } configPath, err := config.GetTokenPath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - assert.Equal(t, configPath, authPath) + if authPath != configPath { + t.Errorf("got %v, want %v", authPath, configPath) + } }) t.Run("ShortenPath delegates to config package", func(t *testing.T) { + t.Parallel() home, err := os.UserHomeDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } testPath := filepath.Join(home, ".config", "test") authResult := ShortenPath(testPath) configResult := config.ShortenPath(testPath) - assert.Equal(t, configResult, authResult) + if authResult != configResult { + t.Errorf("got %v, want %v", authResult, configResult) + } }) t.Run("Constants match config package", func(t *testing.T) { - assert.Equal(t, config.DirName, ConfigDirName) - assert.Equal(t, config.CredentialsFile, CredentialsFile) - assert.Equal(t, config.TokenFile, TokenFile) + t.Parallel() + if ConfigDirName != config.DirName { + t.Errorf("got %v, want %v", ConfigDirName, config.DirName) + } + if CredentialsFile != config.CredentialsFile { + t.Errorf("got %v, want %v", CredentialsFile, config.CredentialsFile) + } + if TokenFile != config.TokenFile { + t.Errorf("got %v, want %v", TokenFile, config.TokenFile) + } }) } func TestAllScopes(t *testing.T) { - assert.Len(t, AllScopes, 4) - assert.Contains(t, AllScopes, "https://www.googleapis.com/auth/gmail.readonly") - assert.Contains(t, AllScopes, "https://www.googleapis.com/auth/calendar.readonly") - assert.Contains(t, AllScopes, "https://www.googleapis.com/auth/contacts.readonly") - assert.Contains(t, AllScopes, "https://www.googleapis.com/auth/drive.readonly") + t.Parallel() + if len(AllScopes) != 4 { + t.Errorf("got length %d, want %d", len(AllScopes), 4) + } + scopeSet := strings.Join(AllScopes, " ") + if !strings.Contains(scopeSet, "https://www.googleapis.com/auth/gmail.readonly") { + t.Errorf("expected AllScopes to contain %q", "https://www.googleapis.com/auth/gmail.readonly") + } + if !strings.Contains(scopeSet, "https://www.googleapis.com/auth/calendar.readonly") { + t.Errorf("expected AllScopes to contain %q", "https://www.googleapis.com/auth/calendar.readonly") + } + if !strings.Contains(scopeSet, "https://www.googleapis.com/auth/contacts.readonly") { + t.Errorf("expected AllScopes to contain %q", "https://www.googleapis.com/auth/contacts.readonly") + } + if !strings.Contains(scopeSet, "https://www.googleapis.com/auth/drive.readonly") { + t.Errorf("expected AllScopes to contain %q", "https://www.googleapis.com/auth/drive.readonly") + } } func TestTokenFromFile(t *testing.T) { + t.Parallel() t.Run("reads valid token file", func(t *testing.T) { + t.Parallel() tmpDir := t.TempDir() tokenPath := filepath.Join(tmpDir, "token.json") @@ -91,28 +133,46 @@ func TestTokenFromFile(t *testing.T) { "expiry": "2024-01-01T00:00:00Z" }` err := os.WriteFile(tokenPath, []byte(tokenData), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } token, err := tokenFromFile(tokenPath) - require.NoError(t, err) - assert.Equal(t, "test-access-token", token.AccessToken) - assert.Equal(t, "Bearer", token.TokenType) - assert.Equal(t, "test-refresh-token", token.RefreshToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "test-access-token" { + t.Errorf("got %v, want %v", token.AccessToken, "test-access-token") + } + if token.TokenType != "Bearer" { + t.Errorf("got %v, want %v", token.TokenType, "Bearer") + } + if token.RefreshToken != "test-refresh-token" { + t.Errorf("got %v, want %v", token.RefreshToken, "test-refresh-token") + } }) t.Run("returns error for non-existent file", func(t *testing.T) { + t.Parallel() _, err := tokenFromFile("/nonexistent/token.json") - assert.Error(t, err) + if err == nil { + t.Fatal("expected error, got nil") + } }) t.Run("returns error for invalid JSON", func(t *testing.T) { + t.Parallel() tmpDir := t.TempDir() tokenPath := filepath.Join(tmpDir, "token.json") err := os.WriteFile(tokenPath, []byte("not valid json"), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } _, err = tokenFromFile(tokenPath) - assert.Error(t, err) + if err == nil { + t.Fatal("expected error, got nil") + } }) } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 430cc36..0b07c0a 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -64,7 +64,7 @@ func New(ttlHours int) (*Cache, error) { func (c *Cache) GetDrives() ([]*CachedDrive, error) { path := filepath.Join(c.dir, DrivesFile) - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // Path constructed from known config directory if err != nil { if os.IsNotExist(err) { return nil, nil // Cache miss, not an error @@ -134,7 +134,7 @@ func (c *Cache) GetStatus() (*Status, error) { // Check drives cache drivesPath := filepath.Join(c.dir, DrivesFile) - data, err := os.ReadFile(drivesPath) + data, err := os.ReadFile(drivesPath) //nolint:gosec // Path constructed from known config directory if err == nil { var cache DriveCache if json.Unmarshal(data, &cache) == nil { diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 96212e3..ee442b7 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -7,45 +7,44 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestNew(t *testing.T) { t.Run("creates cache with default TTL", func(t *testing.T) { c, err := New(0) - require.NoError(t, err) - assert.NotNil(t, c) - assert.Equal(t, DefaultTTLHours, c.ttlHours) + testutil.NoError(t, err) + testutil.NotNil(t, c) + testutil.Equal(t, c.ttlHours, DefaultTTLHours) defer c.Clear() }) t.Run("creates cache with custom TTL", func(t *testing.T) { c, err := New(12) - require.NoError(t, err) - assert.Equal(t, 12, c.ttlHours) + testutil.NoError(t, err) + testutil.Equal(t, c.ttlHours, 12) defer c.Clear() }) t.Run("creates cache directory", func(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() _, err = os.Stat(c.dir) - assert.NoError(t, err) + testutil.NoError(t, err) }) } func TestCache_GetSetDrives(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() t.Run("returns nil for missing cache", func(t *testing.T) { drives, err := c.GetDrives() - assert.NoError(t, err) - assert.Nil(t, drives) + testutil.NoError(t, err) + testutil.Nil(t, drives) }) t.Run("stores and retrieves drives", func(t *testing.T) { @@ -55,21 +54,21 @@ func TestCache_GetSetDrives(t *testing.T) { } err := c.SetDrives(input) - require.NoError(t, err) + testutil.NoError(t, err) drives, err := c.GetDrives() - require.NoError(t, err) - require.Len(t, drives, 2) - assert.Equal(t, "drive1", drives[0].ID) - assert.Equal(t, "Engineering", drives[0].Name) - assert.Equal(t, "drive2", drives[1].ID) - assert.Equal(t, "Marketing", drives[1].Name) + testutil.NoError(t, err) + testutil.Len(t, drives, 2) + testutil.Equal(t, drives[0].ID, "drive1") + testutil.Equal(t, drives[0].Name, "Engineering") + testutil.Equal(t, drives[1].ID, "drive2") + testutil.Equal(t, drives[1].Name, "Marketing") }) } func TestCache_Expiration(t *testing.T) { c, err := New(1) // 1 hour TTL - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() t.Run("returns nil for expired cache", func(t *testing.T) { @@ -83,15 +82,15 @@ func TestCache_Expiration(t *testing.T) { } data, err := json.Marshal(expiredCache) - require.NoError(t, err) + testutil.NoError(t, err) path := filepath.Join(c.dir, DrivesFile) err = os.WriteFile(path, data, 0600) - require.NoError(t, err) + testutil.NoError(t, err) drives, err := c.GetDrives() - assert.NoError(t, err) - assert.Nil(t, drives, "expired cache should return nil") + testutil.NoError(t, err) + testutil.Nil(t, drives) }) t.Run("returns drives for valid cache", func(t *testing.T) { @@ -105,68 +104,68 @@ func TestCache_Expiration(t *testing.T) { } data, err := json.Marshal(freshCache) - require.NoError(t, err) + testutil.NoError(t, err) path := filepath.Join(c.dir, DrivesFile) err = os.WriteFile(path, data, 0600) - require.NoError(t, err) + testutil.NoError(t, err) drives, err := c.GetDrives() - assert.NoError(t, err) - require.Len(t, drives, 1) - assert.Equal(t, "drive1", drives[0].ID) + testutil.NoError(t, err) + testutil.Len(t, drives, 1) + testutil.Equal(t, drives[0].ID, "drive1") }) } func TestCache_CorruptedCache(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() t.Run("returns nil for corrupted JSON", func(t *testing.T) { path := filepath.Join(c.dir, DrivesFile) err := os.WriteFile(path, []byte("not valid json"), 0600) - require.NoError(t, err) + testutil.NoError(t, err) drives, err := c.GetDrives() - assert.NoError(t, err) - assert.Nil(t, drives, "corrupted cache should return nil") + testutil.NoError(t, err) + testutil.Nil(t, drives) }) } func TestCache_Clear(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) // Add some data err = c.SetDrives([]*CachedDrive{{ID: "test", Name: "Test"}}) - require.NoError(t, err) + testutil.NoError(t, err) // Verify file exists path := filepath.Join(c.dir, DrivesFile) _, err = os.Stat(path) - require.NoError(t, err) + testutil.NoError(t, err) // Clear cache err = c.Clear() - require.NoError(t, err) + testutil.NoError(t, err) // Verify directory is gone _, err = os.Stat(c.dir) - assert.True(t, os.IsNotExist(err)) + testutil.True(t, os.IsNotExist(err)) } func TestCache_GetStatus(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() t.Run("returns status with no cache", func(t *testing.T) { status, err := c.GetStatus() - require.NoError(t, err) - assert.Equal(t, c.dir, status.Dir) - assert.Equal(t, 24, status.TTLHours) - assert.Nil(t, status.DrivesCache) + testutil.NoError(t, err) + testutil.Equal(t, status.Dir, c.dir) + testutil.Equal(t, status.TTLHours, 24) + testutil.Nil(t, status.DrivesCache) }) t.Run("returns status with drives cache", func(t *testing.T) { @@ -174,14 +173,14 @@ func TestCache_GetStatus(t *testing.T) { {ID: "drive1", Name: "Test1"}, {ID: "drive2", Name: "Test2"}, }) - require.NoError(t, err) + testutil.NoError(t, err) status, err := c.GetStatus() - require.NoError(t, err) - require.NotNil(t, status.DrivesCache) - assert.Equal(t, 2, status.DrivesCache.Count) - assert.False(t, status.DrivesCache.IsStale) - assert.True(t, status.DrivesCache.ExpiresAt.After(time.Now())) + testutil.NoError(t, err) + testutil.NotNil(t, status.DrivesCache) + testutil.Equal(t, status.DrivesCache.Count, 2) + testutil.False(t, status.DrivesCache.IsStale) + testutil.True(t, status.DrivesCache.ExpiresAt.After(time.Now())) }) t.Run("marks stale cache as stale", func(t *testing.T) { @@ -196,17 +195,17 @@ func TestCache_GetStatus(t *testing.T) { os.WriteFile(path, data, 0600) status, err := c.GetStatus() - require.NoError(t, err) - require.NotNil(t, status.DrivesCache) - assert.True(t, status.DrivesCache.IsStale) + testutil.NoError(t, err) + testutil.NotNil(t, status.DrivesCache) + testutil.True(t, status.DrivesCache.IsStale) }) } func TestCache_GetDir(t *testing.T) { c, err := New(24) - require.NoError(t, err) + testutil.NoError(t, err) defer c.Clear() - assert.NotEmpty(t, c.GetDir()) - assert.Contains(t, c.GetDir(), "cache") + testutil.NotEmpty(t, c.GetDir()) + testutil.Contains(t, c.GetDir(), "cache") } diff --git a/internal/calendar/client.go b/internal/calendar/client.go index 05d9583..1c4e077 100644 --- a/internal/calendar/client.go +++ b/internal/calendar/client.go @@ -1,3 +1,4 @@ +// Package calendar provides a client for the Google Calendar API. package calendar import ( @@ -19,12 +20,12 @@ type Client struct { func NewClient(ctx context.Context) (*Client, error) { client, err := auth.GetHTTPClient(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("loading OAuth client: %w", err) } srv, err := calendar.NewService(ctx, option.WithHTTPClient(client)) if err != nil { - return nil, fmt.Errorf("unable to create Calendar service: %w", err) + return nil, fmt.Errorf("creating Calendar service: %w", err) } return &Client{ @@ -33,16 +34,16 @@ func NewClient(ctx context.Context) (*Client, error) { } // ListCalendars returns all calendars the user has access to -func (c *Client) ListCalendars() ([]*calendar.CalendarListEntry, error) { - resp, err := c.service.CalendarList.List().Do() +func (c *Client) ListCalendars(ctx context.Context) ([]*calendar.CalendarListEntry, error) { + resp, err := c.service.CalendarList.List().Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list calendars: %w", err) + return nil, fmt.Errorf("listing calendars: %w", err) } return resp.Items, nil } // ListEvents returns events from the specified calendar within the given time range -func (c *Client) ListEvents(calendarID string, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { +func (c *Client) ListEvents(ctx context.Context, calendarID string, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { call := c.service.Events.List(calendarID). SingleEvents(true). OrderBy("startTime") @@ -57,18 +58,18 @@ func (c *Client) ListEvents(calendarID string, timeMin, timeMax string, maxResul call = call.MaxResults(maxResults) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list events: %w", err) + return nil, fmt.Errorf("listing events: %w", err) } return resp.Items, nil } // GetEvent retrieves a single event by ID -func (c *Client) GetEvent(calendarID, eventID string) (*calendar.Event, error) { - event, err := c.service.Events.Get(calendarID, eventID).Do() +func (c *Client) GetEvent(ctx context.Context, calendarID, eventID string) (*calendar.Event, error) { + event, err := c.service.Events.Get(calendarID, eventID).Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get event: %w", err) + return nil, fmt.Errorf("getting event: %w", err) } return event, nil } diff --git a/internal/calendar/client_test.go b/internal/calendar/client_test.go index e056b5d..ea4a97f 100644 --- a/internal/calendar/client_test.go +++ b/internal/calendar/client_test.go @@ -2,13 +2,15 @@ package calendar import ( "testing" - - "github.com/stretchr/testify/assert" ) func TestClientStructure(t *testing.T) { + t.Parallel() t.Run("Client has private service field", func(t *testing.T) { + t.Parallel() client := &Client{} - assert.Nil(t, client.service) + if client.service != nil { + t.Errorf("got %v, want nil", client.service) + } }) } diff --git a/internal/calendar/events_test.go b/internal/calendar/events_test.go index 0b41c0b..ce27647 100644 --- a/internal/calendar/events_test.go +++ b/internal/calendar/events_test.go @@ -1,14 +1,16 @@ package calendar import ( + "strings" "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/calendar/v3" ) func TestParseEvent(t *testing.T) { + t.Parallel() t.Run("parses basic event", func(t *testing.T) { + t.Parallel() apiEvent := &calendar.Event{ Id: "event123", Summary: "Team Meeting", @@ -26,15 +28,28 @@ func TestParseEvent(t *testing.T) { event := ParseEvent(apiEvent) - assert.Equal(t, "event123", event.ID) - assert.Equal(t, "Team Meeting", event.Summary) - assert.Equal(t, "Weekly sync", event.Description) - assert.Equal(t, "Conference Room A", event.Location) - assert.Equal(t, "confirmed", event.Status) - assert.False(t, event.AllDay) + if got := event.ID; got != "event123" { + t.Errorf("got %v, want %v", got, "event123") + } + if got := event.Summary; got != "Team Meeting" { + t.Errorf("got %v, want %v", got, "Team Meeting") + } + if got := event.Description; got != "Weekly sync" { + t.Errorf("got %v, want %v", got, "Weekly sync") + } + if got := event.Location; got != "Conference Room A" { + t.Errorf("got %v, want %v", got, "Conference Room A") + } + if got := event.Status; got != "confirmed" { + t.Errorf("got %v, want %v", got, "confirmed") + } + if event.AllDay { + t.Error("got true, want false") + } }) t.Run("parses all-day event", func(t *testing.T) { + t.Parallel() apiEvent := &calendar.Event{ Id: "allday123", Summary: "Company Holiday", @@ -48,12 +63,19 @@ func TestParseEvent(t *testing.T) { event := ParseEvent(apiEvent) - assert.Equal(t, "allday123", event.ID) - assert.True(t, event.AllDay) - assert.Equal(t, "2026-01-01", event.Start.Date) + if got := event.ID; got != "allday123" { + t.Errorf("got %v, want %v", got, "allday123") + } + if !event.AllDay { + t.Error("got false, want true") + } + if got := event.Start.Date; got != "2026-01-01" { + t.Errorf("got %v, want %v", got, "2026-01-01") + } }) t.Run("parses event with organizer", func(t *testing.T) { + t.Parallel() apiEvent := &calendar.Event{ Id: "org123", Summary: "Project Review", @@ -72,12 +94,19 @@ func TestParseEvent(t *testing.T) { event := ParseEvent(apiEvent) - assert.NotNil(t, event.Organizer) - assert.Equal(t, "boss@example.com", event.Organizer.Email) - assert.Equal(t, "The Boss", event.Organizer.DisplayName) + if event.Organizer == nil { + t.Fatal("expected non-nil, got nil") + } + if got := event.Organizer.Email; got != "boss@example.com" { + t.Errorf("got %v, want %v", got, "boss@example.com") + } + if got := event.Organizer.DisplayName; got != "The Boss" { + t.Errorf("got %v, want %v", got, "The Boss") + } }) t.Run("parses event with attendees", func(t *testing.T) { + t.Parallel() apiEvent := &calendar.Event{ Id: "att123", Summary: "Team Standup", @@ -104,14 +133,25 @@ func TestParseEvent(t *testing.T) { event := ParseEvent(apiEvent) - assert.Len(t, event.Attendees, 2) - assert.Equal(t, "alice@example.com", event.Attendees[0].Email) - assert.Equal(t, "accepted", event.Attendees[0].Status) - assert.Equal(t, "bob@example.com", event.Attendees[1].Email) - assert.True(t, event.Attendees[1].Optional) + if len(event.Attendees) != 2 { + t.Errorf("got length %d, want %d", len(event.Attendees), 2) + } + if got := event.Attendees[0].Email; got != "alice@example.com" { + t.Errorf("got %v, want %v", got, "alice@example.com") + } + if got := event.Attendees[0].Status; got != "accepted" { + t.Errorf("got %v, want %v", got, "accepted") + } + if got := event.Attendees[1].Email; got != "bob@example.com" { + t.Errorf("got %v, want %v", got, "bob@example.com") + } + if !event.Attendees[1].Optional { + t.Error("got false, want true") + } }) t.Run("handles event with hangout link", func(t *testing.T) { + t.Parallel() apiEvent := &calendar.Event{ Id: "meet123", Summary: "Video Call", @@ -126,12 +166,16 @@ func TestParseEvent(t *testing.T) { event := ParseEvent(apiEvent) - assert.Equal(t, "https://meet.google.com/abc-defg-hij", event.HangoutLink) + if got := event.HangoutLink; got != "https://meet.google.com/abc-defg-hij" { + t.Errorf("got %v, want %v", got, "https://meet.google.com/abc-defg-hij") + } }) } func TestParseCalendar(t *testing.T) { + t.Parallel() t.Run("parses calendar entry", func(t *testing.T) { + t.Parallel() apiCal := &calendar.CalendarListEntry{ Id: "primary", Summary: "My Calendar", @@ -143,15 +187,28 @@ func TestParseCalendar(t *testing.T) { cal := ParseCalendar(apiCal) - assert.Equal(t, "primary", cal.ID) - assert.Equal(t, "My Calendar", cal.Summary) - assert.Equal(t, "Personal calendar", cal.Description) - assert.True(t, cal.Primary) - assert.Equal(t, "owner", cal.AccessRole) - assert.Equal(t, "America/New_York", cal.TimeZone) + if got := cal.ID; got != "primary" { + t.Errorf("got %v, want %v", got, "primary") + } + if got := cal.Summary; got != "My Calendar" { + t.Errorf("got %v, want %v", got, "My Calendar") + } + if got := cal.Description; got != "Personal calendar" { + t.Errorf("got %v, want %v", got, "Personal calendar") + } + if !cal.Primary { + t.Error("got false, want true") + } + if got := cal.AccessRole; got != "owner" { + t.Errorf("got %v, want %v", got, "owner") + } + if got := cal.TimeZone; got != "America/New_York" { + t.Errorf("got %v, want %v", got, "America/New_York") + } }) t.Run("parses shared calendar", func(t *testing.T) { + t.Parallel() apiCal := &calendar.CalendarListEntry{ Id: "shared@group.calendar.google.com", Summary: "Team Calendar", @@ -161,13 +218,19 @@ func TestParseCalendar(t *testing.T) { cal := ParseCalendar(apiCal) - assert.False(t, cal.Primary) - assert.Equal(t, "reader", cal.AccessRole) + if cal.Primary { + t.Error("got true, want false") + } + if got := cal.AccessRole; got != "reader" { + t.Errorf("got %v, want %v", got, "reader") + } }) } func TestEventGetStartTime(t *testing.T) { + t.Parallel() t.Run("parses datetime", func(t *testing.T) { + t.Parallel() event := &Event{ Start: &EventTime{ DateTime: "2026-01-24T10:00:00-05:00", @@ -175,13 +238,22 @@ func TestEventGetStartTime(t *testing.T) { } start, err := event.GetStartTime() - assert.NoError(t, err) - assert.Equal(t, 2026, start.Year()) - assert.Equal(t, 1, int(start.Month())) - assert.Equal(t, 24, start.Day()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := start.Year(); got != 2026 { + t.Errorf("got %v, want %v", got, 2026) + } + if got := int(start.Month()); got != 1 { + t.Errorf("got %v, want %v", got, 1) + } + if got := start.Day(); got != 24 { + t.Errorf("got %v, want %v", got, 24) + } }) t.Run("parses date for all-day event", func(t *testing.T) { + t.Parallel() event := &Event{ AllDay: true, Start: &EventTime{ @@ -190,21 +262,32 @@ func TestEventGetStartTime(t *testing.T) { } start, err := event.GetStartTime() - assert.NoError(t, err) - assert.Equal(t, 24, start.Day()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := start.Day(); got != 24 { + t.Errorf("got %v, want %v", got, 24) + } }) t.Run("handles nil start", func(t *testing.T) { + t.Parallel() event := &Event{} start, err := event.GetStartTime() - assert.NoError(t, err) - assert.True(t, start.IsZero()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !start.IsZero() { + t.Error("got false, want true") + } }) } func TestEventFormatTimeRange(t *testing.T) { + t.Parallel() t.Run("formats same-day event", func(t *testing.T) { + t.Parallel() event := &Event{ Start: &EventTime{ DateTime: "2026-01-24T10:00:00-05:00", @@ -215,12 +298,19 @@ func TestEventFormatTimeRange(t *testing.T) { } result := event.FormatTimeRange() - assert.Contains(t, result, "Jan 24, 2026") - assert.Contains(t, result, "10:00") - assert.Contains(t, result, "11:00") + if !strings.Contains(result, "Jan 24, 2026") { + t.Errorf("expected %q to contain %q", result, "Jan 24, 2026") + } + if !strings.Contains(result, "10:00") { + t.Errorf("expected %q to contain %q", result, "10:00") + } + if !strings.Contains(result, "11:00") { + t.Errorf("expected %q to contain %q", result, "11:00") + } }) t.Run("formats all-day event", func(t *testing.T) { + t.Parallel() event := &Event{ AllDay: true, Start: &EventTime{ @@ -232,7 +322,11 @@ func TestEventFormatTimeRange(t *testing.T) { } result := event.FormatTimeRange() - assert.Contains(t, result, "Jan 24, 2026") - assert.Contains(t, result, "all day") + if !strings.Contains(result, "Jan 24, 2026") { + t.Errorf("expected %q to contain %q", result, "Jan 24, 2026") + } + if !strings.Contains(result, "all day") { + t.Errorf("expected %q to contain %q", result, "all day") + } }) } diff --git a/internal/calendar/interfaces.go b/internal/calendar/interfaces.go deleted file mode 100644 index f57d02e..0000000 --- a/internal/calendar/interfaces.go +++ /dev/null @@ -1,21 +0,0 @@ -package calendar - -import ( - "google.golang.org/api/calendar/v3" -) - -// CalendarClientInterface defines the interface for Calendar client operations. -// This enables unit testing through mock implementations. -type CalendarClientInterface interface { - // ListCalendars returns all calendars the user has access to - ListCalendars() ([]*calendar.CalendarListEntry, error) - - // ListEvents returns events from the specified calendar within the given time range - ListEvents(calendarID string, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) - - // GetEvent retrieves a single event by ID - GetEvent(calendarID, eventID string) (*calendar.Event, error) -} - -// Verify that Client implements CalendarClientInterface -var _ CalendarClientInterface = (*Client)(nil) diff --git a/internal/cmd/calendar/calendar.go b/internal/cmd/calendar/calendar.go index 525ddc7..3b45873 100644 --- a/internal/cmd/calendar/calendar.go +++ b/internal/cmd/calendar/calendar.go @@ -1,3 +1,4 @@ +// Package calendar implements the gro calendar command and subcommands. package calendar import ( diff --git a/internal/cmd/calendar/calendar_test.go b/internal/cmd/calendar/calendar_test.go index 3a39182..c2b21d4 100644 --- a/internal/cmd/calendar/calendar_test.go +++ b/internal/cmd/calendar/calendar_test.go @@ -3,43 +3,43 @@ package calendar import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestCalendarCommand(t *testing.T) { cmd := NewCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "calendar", cmd.Use) + testutil.Equal(t, cmd.Use, "calendar") }) t.Run("has cal alias", func(t *testing.T) { - assert.Contains(t, cmd.Aliases, "cal") + testutil.SliceContains(t, cmd.Aliases, "cal") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "Calendar") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "Calendar") }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "events") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "events") }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 5) + testutil.GreaterOrEqual(t, len(subcommands), 5) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "list") - assert.Contains(t, names, "events") - assert.Contains(t, names, "get") - assert.Contains(t, names, "today") - assert.Contains(t, names, "week") + testutil.SliceContains(t, names, "list") + testutil.SliceContains(t, names, "events") + testutil.SliceContains(t, names, "get") + testutil.SliceContains(t, names, "today") + testutil.SliceContains(t, names, "week") }) } @@ -47,27 +47,27 @@ func TestListCommand(t *testing.T) { cmd := newListCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "list", cmd.Use) + testutil.Equal(t, cmd.Use, "list") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "calendar") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "calendar") }) } @@ -75,48 +75,48 @@ func TestEventsCommand(t *testing.T) { cmd := newEventsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "events [calendar-id]", cmd.Use) + testutil.Equal(t, cmd.Use, "events [calendar-id]") }) t.Run("accepts optional calendar id argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"calendar-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"calendar-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "10", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "10") }) t.Run("has from flag", func(t *testing.T) { flag := cmd.Flags().Lookup("from") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has to flag", func(t *testing.T) { flag := cmd.Flags().Lookup("to") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has calendar flag", func(t *testing.T) { flag := cmd.Flags().Lookup("calendar") - assert.NotNil(t, flag) - assert.Equal(t, "c", flag.Shorthand) - assert.Equal(t, "primary", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "c") + testutil.Equal(t, flag.DefValue, "primary") }) } @@ -124,31 +124,31 @@ func TestGetCommand(t *testing.T) { cmd := newGetCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "get ", cmd.Use) + testutil.Equal(t, cmd.Use, "get ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"event-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"event-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) t.Run("has calendar flag", func(t *testing.T) { flag := cmd.Flags().Lookup("calendar") - assert.NotNil(t, flag) - assert.Equal(t, "c", flag.Shorthand) - assert.Equal(t, "primary", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "c") + testutil.Equal(t, flag.DefValue, "primary") }) } @@ -156,31 +156,31 @@ func TestTodayCommand(t *testing.T) { cmd := newTodayCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "today", cmd.Use) + testutil.Equal(t, cmd.Use, "today") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) t.Run("has calendar flag", func(t *testing.T) { flag := cmd.Flags().Lookup("calendar") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "today") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "today") }) } @@ -188,30 +188,30 @@ func TestWeekCommand(t *testing.T) { cmd := newWeekCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "week", cmd.Use) + testutil.Equal(t, cmd.Use, "week") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) t.Run("has calendar flag", func(t *testing.T) { flag := cmd.Flags().Lookup("calendar") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "week") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "week") }) } diff --git a/internal/cmd/calendar/dates_test.go b/internal/cmd/calendar/dates_test.go index 286aa1b..d23890d 100644 --- a/internal/cmd/calendar/dates_test.go +++ b/internal/cmd/calendar/dates_test.go @@ -4,11 +4,11 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestParseDate(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -78,22 +78,24 @@ func TestParseDate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result, err := parseDate(tt.input) if tt.wantErr { - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid date format") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "invalid date format") } else { - require.NoError(t, err) - assert.Equal(t, tt.want.Year(), result.Year()) - assert.Equal(t, tt.want.Month(), result.Month()) - assert.Equal(t, tt.want.Day(), result.Day()) + testutil.NoError(t, err) + testutil.Equal(t, result.Year(), tt.want.Year()) + testutil.Equal(t, result.Month(), tt.want.Month()) + testutil.Equal(t, result.Day(), tt.want.Day()) } }) } } func TestEndOfDay(t *testing.T) { + t.Parallel() tests := []struct { name string input time.Time @@ -123,13 +125,15 @@ func TestEndOfDay(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := endOfDay(tt.input) - assert.Equal(t, tt.want, result) + testutil.Equal(t, result, tt.want) }) } } func TestWeekBounds(t *testing.T) { + t.Parallel() loc := time.UTC tests := []struct { @@ -202,31 +206,33 @@ func TestWeekBounds(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() start, end := weekBounds(tt.input) - assert.Equal(t, tt.wantStart, start, "start mismatch") - assert.Equal(t, tt.wantEnd, end, "end mismatch") + testutil.Equal(t, start, tt.wantStart) + testutil.Equal(t, end, tt.wantEnd) // Verify start is Monday - assert.Equal(t, time.Monday, start.Weekday(), "start should be Monday") + testutil.Equal(t, start.Weekday(), time.Monday) // Verify end is Sunday - assert.Equal(t, time.Sunday, end.Weekday(), "end should be Sunday") + testutil.Equal(t, end.Weekday(), time.Sunday) // Verify start is at 00:00:00 - assert.Equal(t, 0, start.Hour()) - assert.Equal(t, 0, start.Minute()) - assert.Equal(t, 0, start.Second()) + testutil.Equal(t, start.Hour(), 0) + testutil.Equal(t, start.Minute(), 0) + testutil.Equal(t, start.Second(), 0) // Verify end is at 23:59:59 - assert.Equal(t, 23, end.Hour()) - assert.Equal(t, 59, end.Minute()) - assert.Equal(t, 59, end.Second()) + testutil.Equal(t, end.Hour(), 23) + testutil.Equal(t, end.Minute(), 59) + testutil.Equal(t, end.Second(), 59) }) } } func TestWeekBoundsSundayEdgeCase(t *testing.T) { + t.Parallel() // Specific test for the Sunday edge case which requires special handling loc := time.UTC @@ -240,23 +246,25 @@ func TestWeekBoundsSundayEdgeCase(t *testing.T) { for _, sunday := range sundays { t.Run(sunday.Format("2006-01-02"), func(t *testing.T) { + t.Parallel() start, end := weekBounds(sunday) // The Sunday should be included in the week - assert.Equal(t, sunday.Year(), end.Year()) - assert.Equal(t, sunday.Month(), end.Month()) - assert.Equal(t, sunday.Day(), end.Day()) + testutil.Equal(t, end.Year(), sunday.Year()) + testutil.Equal(t, end.Month(), sunday.Month()) + testutil.Equal(t, end.Day(), sunday.Day()) // The Monday should be 6 days before the Sunday expectedMonday := sunday.AddDate(0, 0, -6) - assert.Equal(t, expectedMonday.Year(), start.Year()) - assert.Equal(t, expectedMonday.Month(), start.Month()) - assert.Equal(t, expectedMonday.Day(), start.Day()) + testutil.Equal(t, start.Year(), expectedMonday.Year()) + testutil.Equal(t, start.Month(), expectedMonday.Month()) + testutil.Equal(t, start.Day(), expectedMonday.Day()) }) } } func TestTodayBounds(t *testing.T) { + t.Parallel() loc := time.UTC tests := []struct { @@ -305,19 +313,20 @@ func TestTodayBounds(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() start, end := todayBounds(tt.input) - assert.Equal(t, tt.wantStart, start) - assert.Equal(t, tt.wantEnd, end) + testutil.Equal(t, start, tt.wantStart) + testutil.Equal(t, end, tt.wantEnd) // Verify same day - assert.Equal(t, tt.input.Year(), start.Year()) - assert.Equal(t, tt.input.Month(), start.Month()) - assert.Equal(t, tt.input.Day(), start.Day()) + testutil.Equal(t, start.Year(), tt.input.Year()) + testutil.Equal(t, start.Month(), tt.input.Month()) + testutil.Equal(t, start.Day(), tt.input.Day()) - assert.Equal(t, tt.input.Year(), end.Year()) - assert.Equal(t, tt.input.Month(), end.Month()) - assert.Equal(t, tt.input.Day(), end.Day()) + testutil.Equal(t, end.Year(), tt.input.Year()) + testutil.Equal(t, end.Month(), tt.input.Month()) + testutil.Equal(t, end.Day(), tt.input.Day()) }) } } diff --git a/internal/cmd/calendar/events.go b/internal/cmd/calendar/events.go index 718fcb9..9128860 100644 --- a/internal/cmd/calendar/events.go +++ b/internal/cmd/calendar/events.go @@ -38,9 +38,9 @@ Examples: calID = args[0] } - client, err := newCalendarClient() + client, err := newCalendarClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Calendar client: %w", err) + return fmt.Errorf("creating Calendar client: %w", err) } // Parse date range @@ -66,7 +66,7 @@ Examples: timeMax = endOfDay(t).Format(time.RFC3339) } - return listAndPrintEvents(client, EventListOptions{ + return listAndPrintEvents(cmd.Context(), client, EventListOptions{ CalendarID: calID, TimeMin: timeMin, TimeMax: timeMax, diff --git a/internal/cmd/calendar/events_helper.go b/internal/cmd/calendar/events_helper.go index 84e710d..2002d12 100644 --- a/internal/cmd/calendar/events_helper.go +++ b/internal/cmd/calendar/events_helper.go @@ -1,6 +1,7 @@ package calendar import ( + "context" "fmt" "github.com/open-cli-collective/google-readonly/internal/calendar" @@ -19,13 +20,17 @@ type EventListOptions struct { // listAndPrintEvents fetches events and prints them according to the options. // This is a shared helper used by today, week, and events commands. -func listAndPrintEvents(client calendar.CalendarClientInterface, opts EventListOptions) error { - events, err := client.ListEvents(opts.CalendarID, opts.TimeMin, opts.TimeMax, opts.MaxResults) +func listAndPrintEvents(ctx context.Context, client CalendarClient, opts EventListOptions) error { + events, err := client.ListEvents(ctx, opts.CalendarID, opts.TimeMin, opts.TimeMax, opts.MaxResults) if err != nil { return err } if len(events) == 0 { + if opts.JSONOutput { + fmt.Println("[]") + return nil + } if opts.EmptyMessage != "" { fmt.Println(opts.EmptyMessage) } else { diff --git a/internal/cmd/calendar/get.go b/internal/cmd/calendar/get.go index 2963f88..9c21b22 100644 --- a/internal/cmd/calendar/get.go +++ b/internal/cmd/calendar/get.go @@ -29,14 +29,14 @@ Examples: RunE: func(cmd *cobra.Command, args []string) error { eventID := args[0] - client, err := newCalendarClient() + client, err := newCalendarClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Calendar client: %w", err) + return fmt.Errorf("creating Calendar client: %w", err) } - event, err := client.GetEvent(calendarID, eventID) + event, err := client.GetEvent(cmd.Context(), calendarID, eventID) if err != nil { - return fmt.Errorf("failed to get event: %w", err) + return fmt.Errorf("getting event: %w", err) } parsedEvent := calendar.ParseEvent(event) diff --git a/internal/cmd/calendar/handlers_test.go b/internal/cmd/calendar/handlers_test.go index 5dedf90..6ed15dd 100644 --- a/internal/cmd/calendar/handlers_test.go +++ b/internal/cmd/calendar/handlers_test.go @@ -1,61 +1,34 @@ package calendar import ( - "bytes" + "context" "encoding/json" "errors" - "io" - "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "google.golang.org/api/calendar/v3" calendarapi "github.com/open-cli-collective/google-readonly/internal/calendar" "github.com/open-cli-collective/google-readonly/internal/testutil" ) -// captureOutput captures stdout during test execution -func captureOutput(t *testing.T, f func()) string { - t.Helper() - old := os.Stdout - r, w, err := os.Pipe() - require.NoError(t, err) - os.Stdout = w - - f() - - w.Close() - os.Stdout = old - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() -} - // withMockClient sets up a mock client factory for tests -func withMockClient(mock calendarapi.CalendarClientInterface, f func()) { - originalFactory := ClientFactory - ClientFactory = func() (calendarapi.CalendarClientInterface, error) { +func withMockClient(mock CalendarClient, f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (CalendarClient, error) { return mock, nil - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } // withFailingClientFactory sets up a factory that returns an error func withFailingClientFactory(f func()) { - originalFactory := ClientFactory - ClientFactory = func() (calendarapi.CalendarClientInterface, error) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (CalendarClient, error) { return nil, errors.New("connection failed") - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } func TestListCommand_Success(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListCalendarsFunc: func() ([]*calendar.CalendarListEntry, error) { + mock := &MockCalendarClient{ + ListCalendarsFunc: func(_ context.Context) ([]*calendar.CalendarListEntry, error) { return testutil.SampleCalendars(), nil }, } @@ -63,20 +36,20 @@ func TestListCommand_Success(t *testing.T) { cmd := newListCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "primary@example.com") - assert.Contains(t, output, "(primary)") - assert.Contains(t, output, "work@example.com") + testutil.Contains(t, output, "primary@example.com") + testutil.Contains(t, output, "(primary)") + testutil.Contains(t, output, "work@example.com") }) } func TestListCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListCalendarsFunc: func() ([]*calendar.CalendarListEntry, error) { + mock := &MockCalendarClient{ + ListCalendarsFunc: func(_ context.Context) ([]*calendar.CalendarListEntry, error) { return testutil.SampleCalendars(), nil }, } @@ -85,21 +58,21 @@ func TestListCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var calendars []*calendarapi.CalendarInfo err := json.Unmarshal([]byte(output), &calendars) - assert.NoError(t, err) - assert.Len(t, calendars, 2) + testutil.NoError(t, err) + testutil.Len(t, calendars, 2) }) } func TestListCommand_Empty(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListCalendarsFunc: func() ([]*calendar.CalendarListEntry, error) { + mock := &MockCalendarClient{ + ListCalendarsFunc: func(_ context.Context) ([]*calendar.CalendarListEntry, error) { return []*calendar.CalendarListEntry{}, nil }, } @@ -107,18 +80,41 @@ func TestListCommand_Empty(t *testing.T) { cmd := newListCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No calendars found") + testutil.Contains(t, output, "No calendars found") + }) +} + +func TestListCommand_Empty_JSON(t *testing.T) { + mock := &MockCalendarClient{ + ListCalendarsFunc: func(_ context.Context) ([]*calendar.CalendarListEntry, error) { + return []*calendar.CalendarListEntry{}, nil + }, + } + + cmd := newListCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var calendars []any + err := json.Unmarshal([]byte(output), &calendars) + testutil.NoError(t, err) + testutil.Len(t, calendars, 0) }) } func TestListCommand_APIError(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListCalendarsFunc: func() ([]*calendar.CalendarListEntry, error) { + mock := &MockCalendarClient{ + ListCalendarsFunc: func(_ context.Context) ([]*calendar.CalendarListEntry, error) { return nil, errors.New("API error") }, } @@ -127,8 +123,8 @@ func TestListCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to list calendars") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "listing calendars") }) } @@ -137,15 +133,15 @@ func TestListCommand_ClientCreationError(t *testing.T) { withFailingClientFactory(func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create Calendar client") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating Calendar client") }) } func TestEventsCommand_Success(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListEventsFunc: func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { - assert.Equal(t, "primary", calendarID) + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, calendarID, _, _ string, _ int64) ([]*calendar.Event, error) { + testutil.Equal(t, calendarID, "primary") return []*calendar.Event{testutil.SampleEvent("event1")}, nil }, } @@ -154,19 +150,19 @@ func TestEventsCommand_Success(t *testing.T) { cmd.SetArgs([]string{}) // Uses default "primary" calendar withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Test Meeting") + testutil.Contains(t, output, "Test Meeting") }) } func TestEventsCommand_WithDateRange(t *testing.T) { var capturedTimeMin, capturedTimeMax string - mock := &testutil.MockCalendarClient{ - ListEventsFunc: func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, timeMin, timeMax string, _ int64) ([]*calendar.Event, error) { capturedTimeMin = timeMin capturedTimeMax = timeMax return []*calendar.Event{}, nil @@ -177,21 +173,21 @@ func TestEventsCommand_WithDateRange(t *testing.T) { cmd.SetArgs([]string{"--from", "2024-01-01", "--to", "2024-01-31"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) // Verify dates were parsed and passed - assert.Contains(t, capturedTimeMin, "2024-01-01") - assert.Contains(t, capturedTimeMax, "2024-01-31") - assert.Contains(t, output, "No events") + testutil.Contains(t, capturedTimeMin, "2024-01-01") + testutil.Contains(t, capturedTimeMax, "2024-01-31") + testutil.Contains(t, output, "No events") }) } func TestEventsCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListEventsFunc: func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, _, _ string, _ int64) ([]*calendar.Event, error) { return []*calendar.Event{testutil.SampleEvent("event1")}, nil }, } @@ -200,15 +196,38 @@ func TestEventsCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var events []*calendarapi.Event err := json.Unmarshal([]byte(output), &events) - assert.NoError(t, err) - assert.Len(t, events, 1) + testutil.NoError(t, err) + testutil.Len(t, events, 1) + }) +} + +func TestEventsCommand_Empty_JSON(t *testing.T) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, _, _ string, _ int64) ([]*calendar.Event, error) { + return []*calendar.Event{}, nil + }, + } + + cmd := newEventsCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var events []any + err := json.Unmarshal([]byte(output), &events) + testutil.NoError(t, err) + testutil.Len(t, events, 0) }) } @@ -216,10 +235,10 @@ func TestEventsCommand_InvalidFromDate(t *testing.T) { cmd := newEventsCommand() cmd.SetArgs([]string{"--from", "invalid-date"}) - withMockClient(&testutil.MockCalendarClient{}, func() { + withMockClient(&MockCalendarClient{}, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid --from date") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "invalid --from date") }) } @@ -227,18 +246,18 @@ func TestEventsCommand_InvalidToDate(t *testing.T) { cmd := newEventsCommand() cmd.SetArgs([]string{"--to", "invalid-date"}) - withMockClient(&testutil.MockCalendarClient{}, func() { + withMockClient(&MockCalendarClient{}, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid --to date") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "invalid --to date") }) } func TestGetCommand_Success(t *testing.T) { - mock := &testutil.MockCalendarClient{ - GetEventFunc: func(calendarID, eventID string) (*calendar.Event, error) { - assert.Equal(t, "primary", calendarID) - assert.Equal(t, "event123", eventID) + mock := &MockCalendarClient{ + GetEventFunc: func(_ context.Context, calendarID, eventID string) (*calendar.Event, error) { + testutil.Equal(t, calendarID, "primary") + testutil.Equal(t, eventID, "event123") return testutil.SampleEvent("event123"), nil }, } @@ -247,20 +266,20 @@ func TestGetCommand_Success(t *testing.T) { cmd.SetArgs([]string{"event123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "event123") - assert.Contains(t, output, "Test Meeting") - assert.Contains(t, output, "Conference Room A") + testutil.Contains(t, output, "event123") + testutil.Contains(t, output, "Test Meeting") + testutil.Contains(t, output, "Conference Room A") }) } func TestGetCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockCalendarClient{ - GetEventFunc: func(calendarID, eventID string) (*calendar.Event, error) { + mock := &MockCalendarClient{ + GetEventFunc: func(_ context.Context, _, _ string) (*calendar.Event, error) { return testutil.SampleEvent("event123"), nil }, } @@ -269,21 +288,21 @@ func TestGetCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"event123", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var event calendarapi.Event err := json.Unmarshal([]byte(output), &event) - assert.NoError(t, err) - assert.Equal(t, "event123", event.ID) + testutil.NoError(t, err) + testutil.Equal(t, event.ID, "event123") }) } func TestGetCommand_NotFound(t *testing.T) { - mock := &testutil.MockCalendarClient{ - GetEventFunc: func(calendarID, eventID string) (*calendar.Event, error) { + mock := &MockCalendarClient{ + GetEventFunc: func(_ context.Context, _, _ string) (*calendar.Event, error) { return nil, errors.New("event not found") }, } @@ -293,14 +312,14 @@ func TestGetCommand_NotFound(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get event") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "getting event") }) } func TestTodayCommand_Success(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListEventsFunc: func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, _, _ string, _ int64) ([]*calendar.Event, error) { return []*calendar.Event{testutil.SampleEvent("today_event")}, nil }, } @@ -308,18 +327,41 @@ func TestTodayCommand_Success(t *testing.T) { cmd := newTodayCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + testutil.Contains(t, output, "Test Meeting") + }) +} + +func TestTodayCommand_Empty_JSON(t *testing.T) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, _, _ string, _ int64) ([]*calendar.Event, error) { + return []*calendar.Event{}, nil + }, + } + + cmd := newTodayCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Test Meeting") + var events []any + err := json.Unmarshal([]byte(output), &events) + testutil.NoError(t, err) + testutil.Len(t, events, 0) }) } func TestWeekCommand_Success(t *testing.T) { - mock := &testutil.MockCalendarClient{ - ListEventsFunc: func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { + mock := &MockCalendarClient{ + ListEventsFunc: func(_ context.Context, _, _, _ string, _ int64) ([]*calendar.Event, error) { return []*calendar.Event{ testutil.SampleEvent("week_event1"), testutil.SampleEvent("week_event2"), @@ -330,12 +372,12 @@ func TestWeekCommand_Success(t *testing.T) { cmd := newWeekCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) // Should show events - assert.Contains(t, output, "Test Meeting") + testutil.Contains(t, output, "Test Meeting") }) } diff --git a/internal/cmd/calendar/list.go b/internal/cmd/calendar/list.go index f9459c4..d607916 100644 --- a/internal/cmd/calendar/list.go +++ b/internal/cmd/calendar/list.go @@ -22,18 +22,22 @@ Examples: gro calendar list gro cal list --json`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newCalendarClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newCalendarClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Calendar client: %w", err) + return fmt.Errorf("creating Calendar client: %w", err) } - calendars, err := client.ListCalendars() + calendars, err := client.ListCalendars(cmd.Context()) if err != nil { - return fmt.Errorf("failed to list calendars: %w", err) + return fmt.Errorf("listing calendars: %w", err) } if len(calendars) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No calendars found.") return nil } diff --git a/internal/cmd/calendar/mock_test.go b/internal/cmd/calendar/mock_test.go new file mode 100644 index 0000000..46e49cb --- /dev/null +++ b/internal/cmd/calendar/mock_test.go @@ -0,0 +1,38 @@ +package calendar + +import ( + "context" + + "google.golang.org/api/calendar/v3" +) + +// MockCalendarClient is a configurable mock for CalendarClient. +type MockCalendarClient struct { + ListCalendarsFunc func(ctx context.Context) ([]*calendar.CalendarListEntry, error) + ListEventsFunc func(ctx context.Context, calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) + GetEventFunc func(ctx context.Context, calendarID, eventID string) (*calendar.Event, error) +} + +// Verify MockCalendarClient implements CalendarClient +var _ CalendarClient = (*MockCalendarClient)(nil) + +func (m *MockCalendarClient) ListCalendars(ctx context.Context) ([]*calendar.CalendarListEntry, error) { + if m.ListCalendarsFunc != nil { + return m.ListCalendarsFunc(ctx) + } + return nil, nil +} + +func (m *MockCalendarClient) ListEvents(ctx context.Context, calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { + if m.ListEventsFunc != nil { + return m.ListEventsFunc(ctx, calendarID, timeMin, timeMax, maxResults) + } + return nil, nil +} + +func (m *MockCalendarClient) GetEvent(ctx context.Context, calendarID, eventID string) (*calendar.Event, error) { + if m.GetEventFunc != nil { + return m.GetEventFunc(ctx, calendarID, eventID) + } + return nil, nil +} diff --git a/internal/cmd/calendar/output.go b/internal/cmd/calendar/output.go index 2c494e6..2d7ce7b 100644 --- a/internal/cmd/calendar/output.go +++ b/internal/cmd/calendar/output.go @@ -4,19 +4,28 @@ import ( "context" "fmt" + calendarv3 "google.golang.org/api/calendar/v3" + "github.com/open-cli-collective/google-readonly/internal/calendar" "github.com/open-cli-collective/google-readonly/internal/output" ) +// CalendarClient defines the interface for Calendar client operations used by calendar commands. +type CalendarClient interface { + ListCalendars(ctx context.Context) ([]*calendarv3.CalendarListEntry, error) + ListEvents(ctx context.Context, calendarID string, timeMin, timeMax string, maxResults int64) ([]*calendarv3.Event, error) + GetEvent(ctx context.Context, calendarID, eventID string) (*calendarv3.Event, error) +} + // ClientFactory is the function used to create Calendar clients. // Override in tests to inject mocks. -var ClientFactory = func() (calendar.CalendarClientInterface, error) { - return calendar.NewClient(context.Background()) +var ClientFactory = func(ctx context.Context) (CalendarClient, error) { + return calendar.NewClient(ctx) } // newCalendarClient creates a new calendar client -func newCalendarClient() (calendar.CalendarClientInterface, error) { - return ClientFactory() +func newCalendarClient(ctx context.Context) (CalendarClient, error) { + return ClientFactory(ctx) } // printJSON outputs data as indented JSON diff --git a/internal/cmd/calendar/output_test.go b/internal/cmd/calendar/output_test.go index da83c03..417f3ce 100644 --- a/internal/cmd/calendar/output_test.go +++ b/internal/cmd/calendar/output_test.go @@ -8,10 +8,8 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/open-cli-collective/google-readonly/internal/calendar" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestPrintJSON(t *testing.T) { @@ -61,7 +59,7 @@ func TestPrintJSON(t *testing.T) { os.Stdout = w err := printJSON(tt.data) - require.NoError(t, err) + testutil.NoError(t, err) w.Close() os.Stdout = oldStdout @@ -70,15 +68,15 @@ func TestPrintJSON(t *testing.T) { io.Copy(&buf, r) output := buf.String() - assert.NotEmpty(t, output) + testutil.NotEmpty(t, output) // Verify it's valid JSON var parsed any err = json.Unmarshal([]byte(output), &parsed) - assert.NoError(t, err, "output should be valid JSON") + testutil.NoError(t, err) if tt.wantJSON != "" { - assert.Equal(t, tt.wantJSON, output) + testutil.Equal(t, output, tt.wantJSON) } }) } @@ -241,10 +239,10 @@ func TestPrintEvent(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } for _, notWant := range tt.wantNotContains { - assert.NotContains(t, output, notWant) + testutil.NotContains(t, output, notWant) } }) } @@ -317,7 +315,7 @@ func TestPrintEventSummary(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } }) } @@ -397,22 +395,22 @@ func TestPrintCalendar(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } // Check that non-primary calendars don't have "(primary)" if !tt.cal.Primary { - assert.NotContains(t, output, "(primary)") + testutil.NotContains(t, output, "(primary)") } // Check that empty descriptions aren't printed if tt.cal.Description == "" { - assert.NotContains(t, output, "Description:") + testutil.NotContains(t, output, "Description:") } // Check that empty timezones aren't printed if tt.cal.TimeZone == "" { - assert.NotContains(t, output, "Timezone:") + testutil.NotContains(t, output, "Timezone:") } }) } @@ -440,8 +438,8 @@ func TestPrintCalendarNoPrimary(t *testing.T) { output := buf.String() // Should have the ID without "(primary)" - assert.Contains(t, output, "ID: other@google.com") - assert.NotContains(t, output, "(primary)") + testutil.Contains(t, output, "ID: other@google.com") + testutil.NotContains(t, output, "(primary)") } func TestPrintAttendeeWithoutStatus(t *testing.T) { @@ -472,8 +470,8 @@ func TestPrintAttendeeWithoutStatus(t *testing.T) { lines := strings.Split(output, "\n") for _, line := range lines { if strings.Contains(line, "Alice") { - assert.NotContains(t, line, "()") - assert.Contains(t, line, "Alice ") + testutil.NotContains(t, line, "()") + testutil.Contains(t, line, "Alice ") } } } diff --git a/internal/cmd/calendar/today.go b/internal/cmd/calendar/today.go index 6fd6390..e2930c9 100644 --- a/internal/cmd/calendar/today.go +++ b/internal/cmd/calendar/today.go @@ -25,16 +25,16 @@ Examples: gro cal today --json gro cal today --calendar work@group.calendar.google.com`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newCalendarClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newCalendarClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Calendar client: %w", err) + return fmt.Errorf("creating Calendar client: %w", err) } now := time.Now() startOfDay, endOfDayTime := todayBounds(now) - return listAndPrintEvents(client, EventListOptions{ + return listAndPrintEvents(cmd.Context(), client, EventListOptions{ CalendarID: calendarID, TimeMin: startOfDay.Format(time.RFC3339), TimeMax: endOfDayTime.Format(time.RFC3339), diff --git a/internal/cmd/calendar/week.go b/internal/cmd/calendar/week.go index 20d93fb..ee1b848 100644 --- a/internal/cmd/calendar/week.go +++ b/internal/cmd/calendar/week.go @@ -25,16 +25,16 @@ Examples: gro cal week --json gro cal week --calendar work@group.calendar.google.com`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newCalendarClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newCalendarClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Calendar client: %w", err) + return fmt.Errorf("creating Calendar client: %w", err) } now := time.Now() startOfWeek, endOfWeek := weekBounds(now) - return listAndPrintEvents(client, EventListOptions{ + return listAndPrintEvents(cmd.Context(), client, EventListOptions{ CalendarID: calendarID, TimeMin: startOfWeek.Format(time.RFC3339), TimeMax: endOfWeek.Format(time.RFC3339), diff --git a/internal/cmd/config/cache.go b/internal/cmd/config/cache.go index 114e5b4..716a35c 100644 --- a/internal/cmd/config/cache.go +++ b/internal/cmd/config/cache.go @@ -39,20 +39,20 @@ func newCacheShowCommand() *cobra.Command { - Configured TTL - Cached data status (when last updated, expiration)`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { cfg, err := configpkg.LoadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return fmt.Errorf("loading config: %w", err) } c, err := cache.New(cfg.CacheTTLHours) if err != nil { - return fmt.Errorf("failed to initialize cache: %w", err) + return fmt.Errorf("initializing cache: %w", err) } status, err := c.GetStatus() if err != nil { - return fmt.Errorf("failed to get cache status: %w", err) + return fmt.Errorf("getting cache status: %w", err) } if jsonOutput { @@ -92,19 +92,19 @@ func newCacheClearCommand() *cobra.Command { Short: "Clear all cached data", Long: `Remove all cached data. Cache will be repopulated on next use.`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { cfg, err := configpkg.LoadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return fmt.Errorf("loading config: %w", err) } c, err := cache.New(cfg.CacheTTLHours) if err != nil { - return fmt.Errorf("failed to initialize cache: %w", err) + return fmt.Errorf("initializing cache: %w", err) } if err := c.Clear(); err != nil { - return fmt.Errorf("failed to clear cache: %w", err) + return fmt.Errorf("clearing cache: %w", err) } fmt.Println("Cache cleared.") @@ -126,7 +126,7 @@ Examples: gro config cache ttl 12 # Set TTL to 12 hours gro config cache ttl 48 # Set TTL to 48 hours`, Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, args []string) error { ttl, err := strconv.Atoi(args[0]) if err != nil || ttl <= 0 { return fmt.Errorf("invalid TTL value: must be a positive integer (hours)") @@ -134,13 +134,13 @@ Examples: cfg, err := configpkg.LoadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return fmt.Errorf("loading config: %w", err) } cfg.CacheTTLHours = ttl if err := configpkg.SaveConfig(cfg); err != nil { - return fmt.Errorf("failed to save config: %w", err) + return fmt.Errorf("saving config: %w", err) } fmt.Printf("Cache TTL set to %d hours.\n", ttl) diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index cacee14..898da2f 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -1,13 +1,14 @@ +// Package config implements the gro config command and subcommands. package config import ( - "context" "fmt" "os" "time" "github.com/spf13/cobra" + "github.com/open-cli-collective/google-readonly/internal/auth" "github.com/open-cli-collective/google-readonly/internal/gmail" "github.com/open-cli-collective/google-readonly/internal/keychain" ) @@ -65,11 +66,11 @@ The credentials.json file (OAuth client config) is not removed.`, } } -func runShow(cmd *cobra.Command, args []string) error { +func runShow(cmd *cobra.Command, _ []string) error { // Check credentials file - credPath, err := gmail.GetCredentialsPath() + credPath, err := auth.GetCredentialsPath() if err != nil { - return fmt.Errorf("failed to get credentials path: %w", err) + return fmt.Errorf("getting credentials path: %w", err) } credStatus := "OK" @@ -111,8 +112,8 @@ func runShow(cmd *cobra.Command, args []string) error { // Show email if we can get it without triggering auth if keychain.HasStoredToken() && credStatus == "OK" { - if client, err := gmail.NewClient(context.Background()); err == nil { - if profile, err := client.GetProfile(); err == nil { + if client, err := gmail.NewClient(cmd.Context()); err == nil { + if profile, err := client.GetProfile(cmd.Context()); err == nil { fmt.Printf("Email: %s\n", profile.EmailAddress) } } @@ -127,7 +128,7 @@ func runShow(cmd *cobra.Command, args []string) error { return nil } -func runTest(cmd *cobra.Command, args []string) error { +func runTest(cmd *cobra.Command, _ []string) error { fmt.Println("Testing Gmail API connection...") fmt.Println() @@ -141,21 +142,21 @@ func runTest(cmd *cobra.Command, args []string) error { fmt.Println(" OAuth token: Found") // Try to create client (tests token validity) - client, err := gmail.NewClient(context.Background()) + client, err := gmail.NewClient(cmd.Context()) if err != nil { fmt.Println(" Token valid: FAILED") fmt.Println() fmt.Println("Token may be expired or revoked.") fmt.Println("Run 'gro config clear' then 'gro init' to re-authenticate.") - return fmt.Errorf("failed to create client: %w", err) + return fmt.Errorf("creating client: %w", err) } fmt.Println(" Token valid: OK") // Test API access - profile, err := client.GetProfile() + profile, err := client.GetProfile(cmd.Context()) if err != nil { fmt.Println(" Gmail API: FAILED") - return fmt.Errorf("failed to access Gmail API: %w", err) + return fmt.Errorf("accessing Gmail API: %w", err) } fmt.Println(" Gmail API: OK") fmt.Printf(" Messages: %d total\n", profile.MessagesTotal) @@ -166,7 +167,7 @@ func runTest(cmd *cobra.Command, args []string) error { return nil } -func runClear(cmd *cobra.Command, args []string) error { +func runClear(_ *cobra.Command, _ []string) error { if !keychain.HasStoredToken() { fmt.Println("No OAuth token found to clear.") return nil @@ -175,7 +176,7 @@ func runClear(cmd *cobra.Command, args []string) error { backend := keychain.GetStorageBackend() if err := keychain.DeleteToken(); err != nil { - return fmt.Errorf("failed to clear token: %w", err) + return fmt.Errorf("clearing token: %w", err) } fmt.Printf("Cleared OAuth token from %s.\n", backend) diff --git a/internal/cmd/config/config_test.go b/internal/cmd/config/config_test.go index b9e9503..76a2670 100644 --- a/internal/cmd/config/config_test.go +++ b/internal/cmd/config/config_test.go @@ -3,32 +3,32 @@ package config import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestConfigCommand(t *testing.T) { cmd := NewCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "config", cmd.Use) + testutil.Equal(t, cmd.Use, "config") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 4) + testutil.GreaterOrEqual(t, len(subcommands), 4) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "show") - assert.Contains(t, names, "test") - assert.Contains(t, names, "clear") - assert.Contains(t, names, "cache") + testutil.SliceContains(t, names, "show") + testutil.SliceContains(t, names, "test") + testutil.SliceContains(t, names, "clear") + testutil.SliceContains(t, names, "cache") }) } @@ -36,23 +36,23 @@ func TestConfigShowCommand(t *testing.T) { cmd := newShowCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "show", cmd.Use) + testutil.Equal(t, cmd.Use, "show") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) + testutil.NotEmpty(t, cmd.Long) }) } @@ -60,23 +60,23 @@ func TestConfigTestCommand(t *testing.T) { cmd := newTestCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "test", cmd.Use) + testutil.Equal(t, cmd.Use, "test") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) + testutil.NotEmpty(t, cmd.Long) }) } @@ -84,24 +84,24 @@ func TestConfigClearCommand(t *testing.T) { cmd := newClearCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "clear", cmd.Use) + testutil.Equal(t, cmd.Use, "clear") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "token") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "token") }) } @@ -109,29 +109,29 @@ func TestCacheCommand(t *testing.T) { cmd := newCacheCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "cache", cmd.Use) + testutil.Equal(t, cmd.Use, "cache") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "cache") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "cache") }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.Equal(t, 3, len(subcommands)) + testutil.Equal(t, len(subcommands), 3) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "show") - assert.Contains(t, names, "clear") - assert.Contains(t, names, "ttl") + testutil.SliceContains(t, names, "show") + testutil.SliceContains(t, names, "clear") + testutil.SliceContains(t, names, "ttl") }) } @@ -139,31 +139,31 @@ func TestCacheShowCommand(t *testing.T) { cmd := newCacheShowCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "show", cmd.Use) + testutil.Equal(t, cmd.Use, "show") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "cache") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "cache") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) } @@ -171,23 +171,23 @@ func TestCacheClearCommand(t *testing.T) { cmd := newCacheClearCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "clear", cmd.Use) + testutil.Equal(t, cmd.Use, "clear") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) + testutil.NotEmpty(t, cmd.Long) }) } @@ -195,26 +195,26 @@ func TestCacheTTLCommand(t *testing.T) { cmd := newCacheTTLCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "ttl ", cmd.Use) + testutil.Equal(t, cmd.Use, "ttl ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"24"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"24", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "TTL") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "TTL") }) } diff --git a/internal/cmd/contacts/contacts.go b/internal/cmd/contacts/contacts.go index da06c00..0708aa0 100644 --- a/internal/cmd/contacts/contacts.go +++ b/internal/cmd/contacts/contacts.go @@ -1,3 +1,4 @@ +// Package contacts implements the gro contacts command and subcommands. package contacts import ( diff --git a/internal/cmd/contacts/contacts_test.go b/internal/cmd/contacts/contacts_test.go index 0047fcf..40de2e2 100644 --- a/internal/cmd/contacts/contacts_test.go +++ b/internal/cmd/contacts/contacts_test.go @@ -3,41 +3,41 @@ package contacts import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestContactsCommand(t *testing.T) { cmd := NewCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "contacts", cmd.Use) + testutil.Equal(t, cmd.Use, "contacts") }) t.Run("has ppl alias", func(t *testing.T) { - assert.Contains(t, cmd.Aliases, "ppl") + testutil.SliceContains(t, cmd.Aliases, "ppl") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "read-only") + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "read-only") }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 4) + testutil.GreaterOrEqual(t, len(subcommands), 4) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "list") - assert.Contains(t, names, "search") - assert.Contains(t, names, "get") - assert.Contains(t, names, "groups") + testutil.SliceContains(t, names, "list") + testutil.SliceContains(t, names, "search") + testutil.SliceContains(t, names, "get") + testutil.SliceContains(t, names, "groups") }) } @@ -45,35 +45,35 @@ func TestListCommand(t *testing.T) { cmd := newListCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "list", cmd.Use) + testutil.Equal(t, cmd.Use, "list") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("rejects arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "10", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "10") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) } @@ -81,38 +81,38 @@ func TestSearchCommand(t *testing.T) { cmd := newSearchCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "search ", cmd.Use) + testutil.Equal(t, cmd.Use, "search ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{"query"}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("rejects no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("rejects multiple arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{"query1", "query2"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) } @@ -120,27 +120,27 @@ func TestGetCommand(t *testing.T) { cmd := newGetCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "get ", cmd.Use) + testutil.Equal(t, cmd.Use, "get ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{"people/c123"}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("rejects no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) } @@ -148,33 +148,33 @@ func TestGroupsCommand(t *testing.T) { cmd := newGroupsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "groups", cmd.Use) + testutil.Equal(t, cmd.Use, "groups") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("rejects arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "30", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "30") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) } diff --git a/internal/cmd/contacts/get.go b/internal/cmd/contacts/get.go index cef9ad2..cc24cd4 100644 --- a/internal/cmd/contacts/get.go +++ b/internal/cmd/contacts/get.go @@ -26,14 +26,14 @@ Examples: RunE: func(cmd *cobra.Command, args []string) error { resourceName := args[0] - client, err := newContactsClient() + client, err := newContactsClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Contacts client: %w", err) + return fmt.Errorf("creating Contacts client: %w", err) } - person, err := client.GetContact(resourceName) + person, err := client.GetContact(cmd.Context(), resourceName) if err != nil { - return fmt.Errorf("failed to get contact: %w", err) + return fmt.Errorf("getting contact: %w", err) } contact := contacts.ParseContact(person) diff --git a/internal/cmd/contacts/groups.go b/internal/cmd/contacts/groups.go index 1cb73d6..d9de898 100644 --- a/internal/cmd/contacts/groups.go +++ b/internal/cmd/contacts/groups.go @@ -26,18 +26,22 @@ Examples: gro contacts groups --max 50 gro ppl groups --json`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newContactsClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newContactsClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Contacts client: %w", err) + return fmt.Errorf("creating Contacts client: %w", err) } - resp, err := client.ListContactGroups("", maxResults) + resp, err := client.ListContactGroups(cmd.Context(), "", maxResults) if err != nil { - return fmt.Errorf("failed to list contact groups: %w", err) + return fmt.Errorf("listing contact groups: %w", err) } if len(resp.ContactGroups) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No contact groups found.") return nil } diff --git a/internal/cmd/contacts/handlers_test.go b/internal/cmd/contacts/handlers_test.go index 60be8cf..7a96385 100644 --- a/internal/cmd/contacts/handlers_test.go +++ b/internal/cmd/contacts/handlers_test.go @@ -1,61 +1,34 @@ package contacts import ( - "bytes" + "context" "encoding/json" "errors" - "io" - "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "google.golang.org/api/people/v1" contactsapi "github.com/open-cli-collective/google-readonly/internal/contacts" "github.com/open-cli-collective/google-readonly/internal/testutil" ) -// captureOutput captures stdout during test execution -func captureOutput(t *testing.T, f func()) string { - t.Helper() - old := os.Stdout - r, w, err := os.Pipe() - require.NoError(t, err) - os.Stdout = w - - f() - - w.Close() - os.Stdout = old - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() -} - // withMockClient sets up a mock client factory for tests -func withMockClient(mock contactsapi.ContactsClientInterface, f func()) { - originalFactory := ClientFactory - ClientFactory = func() (contactsapi.ContactsClientInterface, error) { +func withMockClient(mock ContactsClient, f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (ContactsClient, error) { return mock, nil - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } // withFailingClientFactory sets up a factory that returns an error func withFailingClientFactory(f func()) { - originalFactory := ClientFactory - ClientFactory = func() (contactsapi.ContactsClientInterface, error) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (ContactsClient, error) { return nil, errors.New("connection failed") - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } func TestListCommand_Success(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactsFunc: func(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { + mock := &MockContactsClient{ + ListContactsFunc: func(_ context.Context, _ string, _ int64) (*people.ListConnectionsResponse, error) { return &people.ListConnectionsResponse{ Connections: []*people.Person{ testutil.SamplePerson("people/c123"), @@ -68,20 +41,20 @@ func TestListCommand_Success(t *testing.T) { cmd := newListCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "people/c123") - assert.Contains(t, output, "John Doe") - assert.Contains(t, output, "2 contact(s)") + testutil.Contains(t, output, "people/c123") + testutil.Contains(t, output, "John Doe") + testutil.Contains(t, output, "2 contact(s)") }) } func TestListCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactsFunc: func(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { + mock := &MockContactsClient{ + ListContactsFunc: func(_ context.Context, _ string, _ int64) (*people.ListConnectionsResponse, error) { return &people.ListConnectionsResponse{ Connections: []*people.Person{ testutil.SamplePerson("people/c123"), @@ -94,21 +67,21 @@ func TestListCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var contacts []*contactsapi.Contact err := json.Unmarshal([]byte(output), &contacts) - assert.NoError(t, err) - assert.Len(t, contacts, 1) + testutil.NoError(t, err) + testutil.Len(t, contacts, 1) }) } func TestListCommand_Empty(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactsFunc: func(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { + mock := &MockContactsClient{ + ListContactsFunc: func(_ context.Context, _ string, _ int64) (*people.ListConnectionsResponse, error) { return &people.ListConnectionsResponse{ Connections: []*people.Person{}, }, nil @@ -118,18 +91,43 @@ func TestListCommand_Empty(t *testing.T) { cmd := newListCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No contacts found") + testutil.Contains(t, output, "No contacts found") + }) +} + +func TestListCommand_Empty_JSON(t *testing.T) { + mock := &MockContactsClient{ + ListContactsFunc: func(_ context.Context, _ string, _ int64) (*people.ListConnectionsResponse, error) { + return &people.ListConnectionsResponse{ + Connections: []*people.Person{}, + }, nil + }, + } + + cmd := newListCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var contacts []any + err := json.Unmarshal([]byte(output), &contacts) + testutil.NoError(t, err) + testutil.Len(t, contacts, 0) }) } func TestListCommand_APIError(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactsFunc: func(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { + mock := &MockContactsClient{ + ListContactsFunc: func(_ context.Context, _ string, _ int64) (*people.ListConnectionsResponse, error) { return nil, errors.New("API error") }, } @@ -138,8 +136,8 @@ func TestListCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to list contacts") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "listing contacts") }) } @@ -148,15 +146,15 @@ func TestListCommand_ClientCreationError(t *testing.T) { withFailingClientFactory(func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create Contacts client") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating Contacts client") }) } func TestSearchCommand_Success(t *testing.T) { - mock := &testutil.MockContactsClient{ - SearchContactsFunc: func(query string, pageSize int64) (*people.SearchResponse, error) { - assert.Equal(t, "John", query) + mock := &MockContactsClient{ + SearchContactsFunc: func(_ context.Context, query string, _ int64) (*people.SearchResponse, error) { + testutil.Equal(t, query, "John") return &people.SearchResponse{ Results: []*people.SearchResult{ {Person: testutil.SamplePerson("people/c123")}, @@ -169,19 +167,19 @@ func TestSearchCommand_Success(t *testing.T) { cmd.SetArgs([]string{"John"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "John Doe") - assert.Contains(t, output, "1 contact(s)") + testutil.Contains(t, output, "John Doe") + testutil.Contains(t, output, "1 contact(s)") }) } func TestSearchCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockContactsClient{ - SearchContactsFunc: func(query string, pageSize int64) (*people.SearchResponse, error) { + mock := &MockContactsClient{ + SearchContactsFunc: func(_ context.Context, _ string, _ int64) (*people.SearchResponse, error) { return &people.SearchResponse{ Results: []*people.SearchResult{ {Person: testutil.SamplePerson("people/c123")}, @@ -194,21 +192,21 @@ func TestSearchCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"John", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var contacts []*contactsapi.Contact err := json.Unmarshal([]byte(output), &contacts) - assert.NoError(t, err) - assert.Len(t, contacts, 1) + testutil.NoError(t, err) + testutil.Len(t, contacts, 1) }) } func TestSearchCommand_NoResults(t *testing.T) { - mock := &testutil.MockContactsClient{ - SearchContactsFunc: func(query string, pageSize int64) (*people.SearchResponse, error) { + mock := &MockContactsClient{ + SearchContactsFunc: func(_ context.Context, _ string, _ int64) (*people.SearchResponse, error) { return &people.SearchResponse{ Results: []*people.SearchResult{}, }, nil @@ -219,18 +217,43 @@ func TestSearchCommand_NoResults(t *testing.T) { cmd.SetArgs([]string{"nonexistent"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + testutil.Contains(t, output, "No contacts found") + }) +} + +func TestSearchCommand_NoResults_JSON(t *testing.T) { + mock := &MockContactsClient{ + SearchContactsFunc: func(_ context.Context, _ string, _ int64) (*people.SearchResponse, error) { + return &people.SearchResponse{ + Results: []*people.SearchResult{}, + }, nil + }, + } + + cmd := newSearchCommand() + cmd.SetArgs([]string{"nonexistent", "--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No contacts found") + var contacts []any + err := json.Unmarshal([]byte(output), &contacts) + testutil.NoError(t, err) + testutil.Len(t, contacts, 0) }) } func TestSearchCommand_APIError(t *testing.T) { - mock := &testutil.MockContactsClient{ - SearchContactsFunc: func(query string, pageSize int64) (*people.SearchResponse, error) { + mock := &MockContactsClient{ + SearchContactsFunc: func(_ context.Context, _ string, _ int64) (*people.SearchResponse, error) { return nil, errors.New("API error") }, } @@ -240,15 +263,15 @@ func TestSearchCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to search contacts") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "searching contacts") }) } func TestGetCommand_Success(t *testing.T) { - mock := &testutil.MockContactsClient{ - GetContactFunc: func(resourceName string) (*people.Person, error) { - assert.Equal(t, "people/c123", resourceName) + mock := &MockContactsClient{ + GetContactFunc: func(_ context.Context, resourceName string) (*people.Person, error) { + testutil.Equal(t, resourceName, "people/c123") return testutil.SamplePerson("people/c123"), nil }, } @@ -257,20 +280,20 @@ func TestGetCommand_Success(t *testing.T) { cmd.SetArgs([]string{"people/c123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "people/c123") - assert.Contains(t, output, "John Doe") - assert.Contains(t, output, "john@example.com") + testutil.Contains(t, output, "people/c123") + testutil.Contains(t, output, "John Doe") + testutil.Contains(t, output, "john@example.com") }) } func TestGetCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockContactsClient{ - GetContactFunc: func(resourceName string) (*people.Person, error) { + mock := &MockContactsClient{ + GetContactFunc: func(_ context.Context, _ string) (*people.Person, error) { return testutil.SamplePerson("people/c123"), nil }, } @@ -279,21 +302,21 @@ func TestGetCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"people/c123", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var contact contactsapi.Contact err := json.Unmarshal([]byte(output), &contact) - assert.NoError(t, err) - assert.Equal(t, "people/c123", contact.ResourceName) + testutil.NoError(t, err) + testutil.Equal(t, contact.ResourceName, "people/c123") }) } func TestGetCommand_NotFound(t *testing.T) { - mock := &testutil.MockContactsClient{ - GetContactFunc: func(resourceName string) (*people.Person, error) { + mock := &MockContactsClient{ + GetContactFunc: func(_ context.Context, _ string) (*people.Person, error) { return nil, errors.New("contact not found") }, } @@ -303,14 +326,14 @@ func TestGetCommand_NotFound(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get contact") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "getting contact") }) } func TestGroupsCommand_Success(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactGroupsFunc: func(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { + mock := &MockContactsClient{ + ListContactGroupsFunc: func(_ context.Context, _ string, _ int64) (*people.ListContactGroupsResponse, error) { return &people.ListContactGroupsResponse{ ContactGroups: []*people.ContactGroup{ { @@ -333,20 +356,20 @@ func TestGroupsCommand_Success(t *testing.T) { cmd := newGroupsCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Friends") - assert.Contains(t, output, "Family") - assert.Contains(t, output, "2 contact group(s)") + testutil.Contains(t, output, "Friends") + testutil.Contains(t, output, "Family") + testutil.Contains(t, output, "2 contact group(s)") }) } func TestGroupsCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactGroupsFunc: func(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { + mock := &MockContactsClient{ + ListContactGroupsFunc: func(_ context.Context, _ string, _ int64) (*people.ListContactGroupsResponse, error) { return &people.ListContactGroupsResponse{ ContactGroups: []*people.ContactGroup{ { @@ -364,22 +387,22 @@ func TestGroupsCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var groups []*contactsapi.ContactGroup err := json.Unmarshal([]byte(output), &groups) - assert.NoError(t, err) - assert.Len(t, groups, 1) - assert.Equal(t, "Friends", groups[0].Name) + testutil.NoError(t, err) + testutil.Len(t, groups, 1) + testutil.Equal(t, groups[0].Name, "Friends") }) } func TestGroupsCommand_Empty(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactGroupsFunc: func(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { + mock := &MockContactsClient{ + ListContactGroupsFunc: func(_ context.Context, _ string, _ int64) (*people.ListContactGroupsResponse, error) { return &people.ListContactGroupsResponse{ ContactGroups: []*people.ContactGroup{}, }, nil @@ -389,18 +412,43 @@ func TestGroupsCommand_Empty(t *testing.T) { cmd := newGroupsCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No contact groups found") + testutil.Contains(t, output, "No contact groups found") + }) +} + +func TestGroupsCommand_Empty_JSON(t *testing.T) { + mock := &MockContactsClient{ + ListContactGroupsFunc: func(_ context.Context, _ string, _ int64) (*people.ListContactGroupsResponse, error) { + return &people.ListContactGroupsResponse{ + ContactGroups: []*people.ContactGroup{}, + }, nil + }, + } + + cmd := newGroupsCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var groups []any + err := json.Unmarshal([]byte(output), &groups) + testutil.NoError(t, err) + testutil.Len(t, groups, 0) }) } func TestGroupsCommand_APIError(t *testing.T) { - mock := &testutil.MockContactsClient{ - ListContactGroupsFunc: func(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { + mock := &MockContactsClient{ + ListContactGroupsFunc: func(_ context.Context, _ string, _ int64) (*people.ListContactGroupsResponse, error) { return nil, errors.New("API error") }, } @@ -409,7 +457,7 @@ func TestGroupsCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to list contact groups") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "listing contact groups") }) } diff --git a/internal/cmd/contacts/list.go b/internal/cmd/contacts/list.go index 1ae3f2d..9a4ed0d 100644 --- a/internal/cmd/contacts/list.go +++ b/internal/cmd/contacts/list.go @@ -26,18 +26,22 @@ Examples: gro contacts list --max 50 gro ppl list --json`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newContactsClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newContactsClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Contacts client: %w", err) + return fmt.Errorf("creating Contacts client: %w", err) } - resp, err := client.ListContacts("", maxResults) + resp, err := client.ListContacts(cmd.Context(), "", maxResults) if err != nil { - return fmt.Errorf("failed to list contacts: %w", err) + return fmt.Errorf("listing contacts: %w", err) } if len(resp.Connections) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No contacts found.") return nil } diff --git a/internal/cmd/contacts/mock_test.go b/internal/cmd/contacts/mock_test.go new file mode 100644 index 0000000..d0299d4 --- /dev/null +++ b/internal/cmd/contacts/mock_test.go @@ -0,0 +1,46 @@ +package contacts + +import ( + "context" + + "google.golang.org/api/people/v1" +) + +// MockContactsClient is a configurable mock for ContactsClient. +type MockContactsClient struct { + ListContactsFunc func(ctx context.Context, pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) + SearchContactsFunc func(ctx context.Context, query string, pageSize int64) (*people.SearchResponse, error) + GetContactFunc func(ctx context.Context, resourceName string) (*people.Person, error) + ListContactGroupsFunc func(ctx context.Context, pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) +} + +// Verify MockContactsClient implements ContactsClient +var _ ContactsClient = (*MockContactsClient)(nil) + +func (m *MockContactsClient) ListContacts(ctx context.Context, pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { + if m.ListContactsFunc != nil { + return m.ListContactsFunc(ctx, pageToken, pageSize) + } + return nil, nil +} + +func (m *MockContactsClient) SearchContacts(ctx context.Context, query string, pageSize int64) (*people.SearchResponse, error) { + if m.SearchContactsFunc != nil { + return m.SearchContactsFunc(ctx, query, pageSize) + } + return nil, nil +} + +func (m *MockContactsClient) GetContact(ctx context.Context, resourceName string) (*people.Person, error) { + if m.GetContactFunc != nil { + return m.GetContactFunc(ctx, resourceName) + } + return nil, nil +} + +func (m *MockContactsClient) ListContactGroups(ctx context.Context, pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { + if m.ListContactGroupsFunc != nil { + return m.ListContactGroupsFunc(ctx, pageToken, pageSize) + } + return nil, nil +} diff --git a/internal/cmd/contacts/output.go b/internal/cmd/contacts/output.go index a3651f4..38fd86c 100644 --- a/internal/cmd/contacts/output.go +++ b/internal/cmd/contacts/output.go @@ -4,19 +4,29 @@ import ( "context" "fmt" + "google.golang.org/api/people/v1" + "github.com/open-cli-collective/google-readonly/internal/contacts" "github.com/open-cli-collective/google-readonly/internal/output" ) +// ContactsClient defines the interface for Contacts client operations used by contacts commands. +type ContactsClient interface { + ListContacts(ctx context.Context, pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) + SearchContacts(ctx context.Context, query string, pageSize int64) (*people.SearchResponse, error) + GetContact(ctx context.Context, resourceName string) (*people.Person, error) + ListContactGroups(ctx context.Context, pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) +} + // ClientFactory is the function used to create Contacts clients. // Override in tests to inject mocks. -var ClientFactory = func() (contacts.ContactsClientInterface, error) { - return contacts.NewClient(context.Background()) +var ClientFactory = func(ctx context.Context) (ContactsClient, error) { + return contacts.NewClient(ctx) } // newContactsClient creates a new contacts client -func newContactsClient() (contacts.ContactsClientInterface, error) { - return ClientFactory() +func newContactsClient(ctx context.Context) (ContactsClient, error) { + return ClientFactory(ctx) } // printJSON outputs data as indented JSON diff --git a/internal/cmd/contacts/output_test.go b/internal/cmd/contacts/output_test.go index 1c691b8..2b3f995 100644 --- a/internal/cmd/contacts/output_test.go +++ b/internal/cmd/contacts/output_test.go @@ -7,10 +7,8 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/open-cli-collective/google-readonly/internal/contacts" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestPrintJSON(t *testing.T) { @@ -50,7 +48,7 @@ func TestPrintJSON(t *testing.T) { os.Stdout = w err := printJSON(tt.data) - require.NoError(t, err) + testutil.NoError(t, err) w.Close() os.Stdout = oldStdout @@ -59,12 +57,12 @@ func TestPrintJSON(t *testing.T) { io.Copy(&buf, r) output := buf.String() - assert.NotEmpty(t, output) + testutil.NotEmpty(t, output) // Verify it's valid JSON var parsed any err = json.Unmarshal([]byte(output), &parsed) - assert.NoError(t, err, "output should be valid JSON") + testutil.NoError(t, err) }) } } @@ -239,10 +237,10 @@ func TestPrintContact(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } for _, notWant := range tt.wantNotContains { - assert.NotContains(t, output, notWant) + testutil.NotContains(t, output, notWant) } }) } @@ -303,7 +301,7 @@ func TestPrintContactSummary(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } }) } @@ -377,7 +375,7 @@ func TestPrintContactGroup(t *testing.T) { output := buf.String() for _, want := range tt.wantContains { - assert.Contains(t, output, want) + testutil.Contains(t, output, want) } }) } diff --git a/internal/cmd/contacts/search.go b/internal/cmd/contacts/search.go index 739172e..36ac7d2 100644 --- a/internal/cmd/contacts/search.go +++ b/internal/cmd/contacts/search.go @@ -35,17 +35,21 @@ Examples: RunE: func(cmd *cobra.Command, args []string) error { query := args[0] - client, err := newContactsClient() + client, err := newContactsClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Contacts client: %w", err) + return fmt.Errorf("creating Contacts client: %w", err) } - resp, err := client.SearchContacts(query, maxResults) + resp, err := client.SearchContacts(cmd.Context(), query, maxResults) if err != nil { - return fmt.Errorf("failed to search contacts: %w", err) + return fmt.Errorf("searching contacts: %w", err) } if len(resp.Results) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Printf("No contacts found matching \"%s\".\n", query) return nil } diff --git a/internal/cmd/drive/download.go b/internal/cmd/drive/download.go index 95139d5..53b492e 100644 --- a/internal/cmd/drive/download.go +++ b/internal/cmd/drive/download.go @@ -42,17 +42,19 @@ Export formats: Drawings: pdf, png, svg, jpg`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newDriveClient() + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } fileID := args[0] + ctx := cmd.Context() + // Get file metadata first - file, err := client.GetFile(fileID) + file, err := client.GetFile(ctx, fileID) if err != nil { - return fmt.Errorf("failed to get file info: %w", err) + return fmt.Errorf("getting file info: %w", err) } var data []byte @@ -67,7 +69,7 @@ Export formats: exportMime, err := drive.GetExportMimeType(file.MimeType, format) if err != nil { - return fmt.Errorf("failed to get export type: %w", err) + return fmt.Errorf("getting export type: %w", err) } if !stdout { @@ -75,9 +77,9 @@ Export formats: fmt.Printf("Format: %s\n", format) } - data, err = client.ExportFile(fileID, exportMime) + data, err = client.ExportFile(ctx, fileID, exportMime) if err != nil { - return fmt.Errorf("failed to export file: %w", err) + return fmt.Errorf("exporting file: %w", err) } } else { // Regular file - download directly @@ -90,9 +92,9 @@ Export formats: fmt.Printf("Downloading: %s\n", file.Name) } - data, err = client.DownloadFile(fileID) + data, err = client.DownloadFile(ctx, fileID) if err != nil { - return fmt.Errorf("failed to download file: %w", err) + return fmt.Errorf("downloading file: %w", err) } } @@ -100,7 +102,7 @@ Export formats: if stdout { _, err = os.Stdout.Write(data) if err != nil { - return fmt.Errorf("failed to write to stdout: %w", err) + return fmt.Errorf("writing to stdout: %w", err) } return nil } @@ -108,7 +110,7 @@ Export formats: outputPath := determineOutputPath(file.Name, format, output) if err := os.WriteFile(outputPath, data, config.OutputFilePerm); err != nil { - return fmt.Errorf("failed to write file: %w", err) + return fmt.Errorf("writing file: %w", err) } fmt.Printf("Size: %s\n", formatpkg.Size(int64(len(data)))) diff --git a/internal/cmd/drive/download_test.go b/internal/cmd/drive/download_test.go index 551e337..69583e5 100644 --- a/internal/cmd/drive/download_test.go +++ b/internal/cmd/drive/download_test.go @@ -3,68 +3,68 @@ package drive import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestDownloadCommand(t *testing.T) { cmd := newDownloadCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "download ", cmd.Use) + testutil.Equal(t, cmd.Use, "download ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"file-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"file-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has output flag", func(t *testing.T) { flag := cmd.Flags().Lookup("output") - assert.NotNil(t, flag) - assert.Equal(t, "o", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "o") }) t.Run("has format flag", func(t *testing.T) { flag := cmd.Flags().Lookup("format") - assert.NotNil(t, flag) - assert.Equal(t, "f", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "f") }) t.Run("has stdout flag", func(t *testing.T) { flag := cmd.Flags().Lookup("stdout") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "Download") + testutil.Contains(t, cmd.Short, "Download") }) } func TestDetermineOutputPath(t *testing.T) { t.Run("uses user-specified output path", func(t *testing.T) { result := determineOutputPath("original.doc", "pdf", "/custom/path.pdf") - assert.Equal(t, "/custom/path.pdf", result) + testutil.Equal(t, result, "/custom/path.pdf") }) t.Run("uses original name when no format or output", func(t *testing.T) { result := determineOutputPath("document.pdf", "", "") - assert.Equal(t, "document.pdf", result) + testutil.Equal(t, result, "document.pdf") }) t.Run("replaces extension when format specified", func(t *testing.T) { result := determineOutputPath("Report", "pdf", "") - assert.Equal(t, "Report.pdf", result) + testutil.Equal(t, result, "Report.pdf") }) t.Run("replaces existing extension when format specified", func(t *testing.T) { result := determineOutputPath("Report.gdoc", "docx", "") - assert.Equal(t, "Report.docx", result) + testutil.Equal(t, result, "Report.docx") }) t.Run("handles various export formats", func(t *testing.T) { @@ -86,7 +86,7 @@ func TestDetermineOutputPath(t *testing.T) { for _, tt := range tests { t.Run(tt.format, func(t *testing.T) { result := determineOutputPath(tt.name, tt.format, "") - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } }) diff --git a/internal/cmd/drive/drive.go b/internal/cmd/drive/drive.go index 6d08766..9080df4 100644 --- a/internal/cmd/drive/drive.go +++ b/internal/cmd/drive/drive.go @@ -1,3 +1,4 @@ +// Package drive implements the gro drive command and subcommands. package drive import ( diff --git a/internal/cmd/drive/drives.go b/internal/cmd/drive/drives.go index 17be98f..11d3f91 100644 --- a/internal/cmd/drive/drives.go +++ b/internal/cmd/drive/drives.go @@ -1,6 +1,7 @@ package drive import ( + "context" "fmt" "os" "strings" @@ -31,17 +32,17 @@ Examples: gro drive drives --refresh # Force refresh from API gro drive drives --json # Output as JSON`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newDriveClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } // Initialize cache ttl := config.GetCacheTTLHours() c, err := cache.New(ttl) if err != nil { - return fmt.Errorf("failed to initialize cache: %w", err) + return fmt.Errorf("initializing cache: %w", err) } var drives []*drive.SharedDrive @@ -50,7 +51,7 @@ Examples: if !refresh { cached, err := c.GetDrives() if err != nil { - return fmt.Errorf("failed to read cache: %w", err) + return fmt.Errorf("reading cache: %w", err) } if cached != nil { // Convert from cache type to drive type @@ -66,9 +67,9 @@ Examples: // Fetch from API if no cache hit if drives == nil { - drives, err = client.ListSharedDrives(100) + drives, err = client.ListSharedDrives(cmd.Context(), 100) if err != nil { - return fmt.Errorf("failed to list shared drives: %w", err) + return fmt.Errorf("listing shared drives: %w", err) } // Update cache @@ -86,6 +87,10 @@ Examples: } if len(drives) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No shared drives found.") return nil } @@ -118,7 +123,7 @@ func printSharedDrives(drives []*drive.SharedDrive) { } // resolveDriveScope converts command flags to a DriveScope, resolving drive names via cache -func resolveDriveScope(client drive.DriveClientInterface, myDrive bool, driveFlag string) (drive.DriveScope, error) { +func resolveDriveScope(ctx context.Context, client DriveClient, myDrive bool, driveFlag string) (drive.DriveScope, error) { // --my-drive flag if myDrive { return drive.DriveScope{MyDriveOnly: true}, nil @@ -139,15 +144,15 @@ func resolveDriveScope(client drive.DriveClientInterface, myDrive bool, driveFla ttl := config.GetCacheTTLHours() c, err := cache.New(ttl) if err != nil { - return drive.DriveScope{}, fmt.Errorf("failed to initialize cache: %w", err) + return drive.DriveScope{}, fmt.Errorf("initializing cache: %w", err) } cached, _ := c.GetDrives() if cached == nil { // Cache miss - fetch from API - drives, err := client.ListSharedDrives(100) + drives, err := client.ListSharedDrives(ctx, 100) if err != nil { - return drive.DriveScope{}, fmt.Errorf("failed to list shared drives: %w", err) + return drive.DriveScope{}, fmt.Errorf("listing shared drives: %w", err) } // Update cache diff --git a/internal/cmd/drive/drives_test.go b/internal/cmd/drive/drives_test.go index 6a2d2bc..977d717 100644 --- a/internal/cmd/drive/drives_test.go +++ b/internal/cmd/drive/drives_test.go @@ -1,10 +1,9 @@ package drive import ( + "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/open-cli-collective/google-readonly/internal/drive" "github.com/open-cli-collective/google-readonly/internal/testutil" ) @@ -13,37 +12,37 @@ func TestDrivesCommand(t *testing.T) { cmd := newDrivesCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "drives", cmd.Use) + testutil.Equal(t, cmd.Use, "drives") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has refresh flag", func(t *testing.T) { flag := cmd.Flags().Lookup("refresh") - assert.NotNil(t, flag) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "shared drives") + testutil.Contains(t, cmd.Short, "shared drives") }) t.Run("has long description", func(t *testing.T) { - assert.Contains(t, cmd.Long, "Shared Drives") - assert.Contains(t, cmd.Long, "cache") + testutil.Contains(t, cmd.Long, "Shared Drives") + testutil.Contains(t, cmd.Long, "cache") }) } @@ -98,48 +97,48 @@ func TestLooksLikeDriveID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := looksLikeDriveID(tt.input) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } func TestResolveDriveScope(t *testing.T) { t.Run("returns MyDriveOnly when myDrive flag is true", func(t *testing.T) { - mock := &testutil.MockDriveClient{} + mock := &MockDriveClient{} - scope, err := resolveDriveScope(mock, true, "") + scope, err := resolveDriveScope(context.Background(), mock, true, "") - assert.NoError(t, err) - assert.True(t, scope.MyDriveOnly) - assert.False(t, scope.AllDrives) - assert.Empty(t, scope.DriveID) + testutil.NoError(t, err) + testutil.True(t, scope.MyDriveOnly) + testutil.False(t, scope.AllDrives) + testutil.Empty(t, scope.DriveID) }) t.Run("returns AllDrives when no flags provided", func(t *testing.T) { - mock := &testutil.MockDriveClient{} + mock := &MockDriveClient{} - scope, err := resolveDriveScope(mock, false, "") + scope, err := resolveDriveScope(context.Background(), mock, false, "") - assert.NoError(t, err) - assert.True(t, scope.AllDrives) - assert.False(t, scope.MyDriveOnly) - assert.Empty(t, scope.DriveID) + testutil.NoError(t, err) + testutil.True(t, scope.AllDrives) + testutil.False(t, scope.MyDriveOnly) + testutil.Empty(t, scope.DriveID) }) t.Run("returns DriveID directly when input looks like ID", func(t *testing.T) { - mock := &testutil.MockDriveClient{} + mock := &MockDriveClient{} - scope, err := resolveDriveScope(mock, false, "0ALengineering123456") + scope, err := resolveDriveScope(context.Background(), mock, false, "0ALengineering123456") - assert.NoError(t, err) - assert.Equal(t, "0ALengineering123456", scope.DriveID) - assert.False(t, scope.AllDrives) - assert.False(t, scope.MyDriveOnly) + testutil.NoError(t, err) + testutil.Equal(t, scope.DriveID, "0ALengineering123456") + testutil.False(t, scope.AllDrives) + testutil.False(t, scope.MyDriveOnly) }) t.Run("resolves drive name to ID via API", func(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListSharedDrivesFunc: func(pageSize int64) ([]*drive.SharedDrive, error) { + mock := &MockDriveClient{ + ListSharedDrivesFunc: func(_ context.Context, _ int64) ([]*drive.SharedDrive, error) { return []*drive.SharedDrive{ {ID: "0ALeng123", Name: "Engineering"}, {ID: "0ALfin456", Name: "Finance"}, @@ -147,40 +146,40 @@ func TestResolveDriveScope(t *testing.T) { }, } - scope, err := resolveDriveScope(mock, false, "Engineering") + scope, err := resolveDriveScope(context.Background(), mock, false, "Engineering") - assert.NoError(t, err) - assert.Equal(t, "0ALeng123", scope.DriveID) + testutil.NoError(t, err) + testutil.Equal(t, scope.DriveID, "0ALeng123") }) t.Run("resolves drive name case-insensitively", func(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListSharedDrivesFunc: func(pageSize int64) ([]*drive.SharedDrive, error) { + mock := &MockDriveClient{ + ListSharedDrivesFunc: func(_ context.Context, _ int64) ([]*drive.SharedDrive, error) { return []*drive.SharedDrive{ {ID: "0ALeng123", Name: "Engineering"}, }, nil }, } - scope, err := resolveDriveScope(mock, false, "ENGINEERING") + scope, err := resolveDriveScope(context.Background(), mock, false, "ENGINEERING") - assert.NoError(t, err) - assert.Equal(t, "0ALeng123", scope.DriveID) + testutil.NoError(t, err) + testutil.Equal(t, scope.DriveID, "0ALeng123") }) t.Run("returns error when drive name not found", func(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListSharedDrivesFunc: func(pageSize int64) ([]*drive.SharedDrive, error) { + mock := &MockDriveClient{ + ListSharedDrivesFunc: func(_ context.Context, _ int64) ([]*drive.SharedDrive, error) { return []*drive.SharedDrive{ {ID: "0ALeng123", Name: "Engineering"}, }, nil }, } - _, err := resolveDriveScope(mock, false, "NonExistent") + _, err := resolveDriveScope(context.Background(), mock, false, "NonExistent") - assert.Error(t, err) - assert.Contains(t, err.Error(), "shared drive not found") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "shared drive not found") }) } @@ -191,8 +190,8 @@ func TestSearchCommand_MutualExclusivity(t *testing.T) { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "mutually exclusive") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "mutually exclusive") }) } @@ -203,8 +202,8 @@ func TestListCommand_MutualExclusivity(t *testing.T) { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "mutually exclusive") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "mutually exclusive") }) } @@ -215,7 +214,7 @@ func TestTreeCommand_MutualExclusivity(t *testing.T) { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "mutually exclusive") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "mutually exclusive") }) } diff --git a/internal/cmd/drive/get.go b/internal/cmd/drive/get.go index 92c597e..f06bb85 100644 --- a/internal/cmd/drive/get.go +++ b/internal/cmd/drive/get.go @@ -23,15 +23,15 @@ Examples: gro drive get --json # Output as JSON`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newDriveClient() + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } fileID := args[0] - file, err := client.GetFile(fileID) + file, err := client.GetFile(cmd.Context(), fileID) if err != nil { - return fmt.Errorf("failed to get file %s: %w", fileID, err) + return fmt.Errorf("getting file %s: %w", fileID, err) } if jsonOutput { diff --git a/internal/cmd/drive/get_test.go b/internal/cmd/drive/get_test.go index 9177481..b9a6598 100644 --- a/internal/cmd/drive/get_test.go +++ b/internal/cmd/drive/get_test.go @@ -1,62 +1,46 @@ package drive import ( - "bytes" - "io" - "os" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/open-cli-collective/google-readonly/internal/drive" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestGetCommand(t *testing.T) { cmd := newGetCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "get ", cmd.Use) + testutil.Equal(t, cmd.Use, "get ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"file-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"file-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "Get") + testutil.Contains(t, cmd.Short, "Get") }) } func TestPrintFileDetails(t *testing.T) { - // Capture stdout for testing captureOutput := func(fn func()) string { - old := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - fn() - - w.Close() - os.Stdout = old - - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() + return testutil.CaptureStdout(t, fn) } t.Run("prints all fields for complete file", func(t *testing.T) { @@ -77,17 +61,17 @@ func TestPrintFileDetails(t *testing.T) { printFileDetails(f) }) - assert.Contains(t, output, "File Details") - assert.Contains(t, output, "ID: abc123") - assert.Contains(t, output, "Name: Test Document") - assert.Contains(t, output, "Type: Document") - assert.Contains(t, output, "Size: -") - assert.Contains(t, output, "Created: 2024-01-10 09:30:00") - assert.Contains(t, output, "Modified: 2024-01-15 14:22:00") - assert.Contains(t, output, "Owner: owner@example.com") - assert.Contains(t, output, "Shared: Yes") - assert.Contains(t, output, "Web Link: https://docs.google.com/document/d/abc123/edit") - assert.Contains(t, output, "Parent: parent123") + testutil.Contains(t, output, "File Details") + testutil.Contains(t, output, "ID: abc123") + testutil.Contains(t, output, "Name: Test Document") + testutil.Contains(t, output, "Type: Document") + testutil.Contains(t, output, "Size: -") + testutil.Contains(t, output, "Created: 2024-01-10 09:30:00") + testutil.Contains(t, output, "Modified: 2024-01-15 14:22:00") + testutil.Contains(t, output, "Owner: owner@example.com") + testutil.Contains(t, output, "Shared: Yes") + testutil.Contains(t, output, "Web Link: https://docs.google.com/document/d/abc123/edit") + testutil.Contains(t, output, "Parent: parent123") }) t.Run("prints size for regular files", func(t *testing.T) { @@ -102,7 +86,7 @@ func TestPrintFileDetails(t *testing.T) { printFileDetails(f) }) - assert.Contains(t, output, "Size: 1.5 MB") + testutil.Contains(t, output, "Size: 1.5 MB") }) t.Run("handles unshared file", func(t *testing.T) { @@ -116,7 +100,7 @@ func TestPrintFileDetails(t *testing.T) { printFileDetails(f) }) - assert.Contains(t, output, "Shared: No") + testutil.Contains(t, output, "Shared: No") }) t.Run("handles multiple owners", func(t *testing.T) { @@ -130,7 +114,7 @@ func TestPrintFileDetails(t *testing.T) { printFileDetails(f) }) - assert.Contains(t, output, "Owner: owner1@example.com, owner2@example.com") + testutil.Contains(t, output, "Owner: owner1@example.com, owner2@example.com") }) t.Run("omits missing fields gracefully", func(t *testing.T) { @@ -143,13 +127,13 @@ func TestPrintFileDetails(t *testing.T) { printFileDetails(f) }) - assert.Contains(t, output, "ID: minimal123") - assert.Contains(t, output, "Name: minimal.txt") + testutil.Contains(t, output, "ID: minimal123") + testutil.Contains(t, output, "Name: minimal.txt") // Should not contain empty values or crash - assert.NotContains(t, output, "Created:") - assert.NotContains(t, output, "Modified:") - assert.NotContains(t, output, "Owner:") - assert.NotContains(t, output, "Web Link:") - assert.NotContains(t, output, "Parent:") + testutil.NotContains(t, output, "Created:") + testutil.NotContains(t, output, "Modified:") + testutil.NotContains(t, output, "Owner:") + testutil.NotContains(t, output, "Web Link:") + testutil.NotContains(t, output, "Parent:") }) } diff --git a/internal/cmd/drive/handlers_test.go b/internal/cmd/drive/handlers_test.go index b615e24..9a02b4a 100644 --- a/internal/cmd/drive/handlers_test.go +++ b/internal/cmd/drive/handlers_test.go @@ -1,61 +1,34 @@ package drive import ( - "bytes" + "context" "encoding/json" "errors" - "io" "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - driveapi "github.com/open-cli-collective/google-readonly/internal/drive" "github.com/open-cli-collective/google-readonly/internal/testutil" ) -// captureOutput captures stdout during test execution -func captureOutput(t *testing.T, f func()) string { - t.Helper() - old := os.Stdout - r, w, err := os.Pipe() - require.NoError(t, err) - os.Stdout = w - - f() - - w.Close() - os.Stdout = old - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() -} - // withMockClient sets up a mock client factory for tests -func withMockClient(mock driveapi.DriveClientInterface, f func()) { - originalFactory := ClientFactory - ClientFactory = func() (driveapi.DriveClientInterface, error) { +func withMockClient(mock DriveClient, f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (DriveClient, error) { return mock, nil - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } // withFailingClientFactory sets up a factory that returns an error func withFailingClientFactory(f func()) { - originalFactory := ClientFactory - ClientFactory = func() (driveapi.DriveClientInterface, error) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (DriveClient, error) { return nil, errors.New("connection failed") - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } func TestListCommand_Success(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { - assert.Contains(t, query, "'root' in parents") + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, query string, _ int64) ([]*driveapi.File, error) { + testutil.Contains(t, query, "'root' in parents") return testutil.SampleDriveFiles(2), nil }, } @@ -63,19 +36,19 @@ func TestListCommand_Success(t *testing.T) { cmd := newListCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file_a") - assert.Contains(t, output, "test-document.pdf") + testutil.Contains(t, output, "file_a") + testutil.Contains(t, output, "test-document.pdf") }) } func TestListCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return testutil.SampleDriveFiles(1), nil }, } @@ -84,41 +57,64 @@ func TestListCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var files []*driveapi.File err := json.Unmarshal([]byte(output), &files) - assert.NoError(t, err) - assert.Len(t, files, 1) + testutil.NoError(t, err) + testutil.Len(t, files, 1) }) } func TestListCommand_Empty(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { + return []*driveapi.File{}, nil + }, + } + + cmd := newListCommand() + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + testutil.Contains(t, output, "No files found") + }) +} + +func TestListCommand_Empty_JSON(t *testing.T) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return []*driveapi.File{}, nil }, } cmd := newListCommand() + cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No files found") + var files []any + err := json.Unmarshal([]byte(output), &files) + testutil.NoError(t, err) + testutil.Len(t, files, 0) }) } func TestListCommand_WithFolder(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { - assert.Contains(t, query, "'folder123' in parents") + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, query string, _ int64) ([]*driveapi.File, error) { + testutil.Contains(t, query, "'folder123' in parents") return testutil.SampleDriveFiles(1), nil }, } @@ -127,19 +123,19 @@ func TestListCommand_WithFolder(t *testing.T) { cmd.SetArgs([]string{"folder123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file_a") + testutil.Contains(t, output, "file_a") }) } func TestListCommand_WithTypeFilter(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { - assert.Contains(t, query, "mimeType") + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, query string, _ int64) ([]*driveapi.File, error) { + testutil.Contains(t, query, "mimeType") return testutil.SampleDriveFiles(1), nil }, } @@ -148,12 +144,12 @@ func TestListCommand_WithTypeFilter(t *testing.T) { cmd.SetArgs([]string{"--type", "document"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file_a") + testutil.Contains(t, output, "file_a") }) } @@ -161,16 +157,16 @@ func TestListCommand_InvalidType(t *testing.T) { cmd := newListCommand() cmd.SetArgs([]string{"--type", "invalid"}) - withMockClient(&testutil.MockDriveClient{}, func() { + withMockClient(&MockDriveClient{}, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown file type") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "unknown file type") }) } func TestListCommand_APIError(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return nil, errors.New("API error") }, } @@ -179,8 +175,8 @@ func TestListCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to list files") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "listing files") }) } @@ -189,15 +185,15 @@ func TestListCommand_ClientCreationError(t *testing.T) { withFailingClientFactory(func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create Drive client") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating Drive client") }) } func TestSearchCommand_Success(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { - assert.Contains(t, query, "fullText contains 'report'") + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, query string, _ int64) ([]*driveapi.File, error) { + testutil.Contains(t, query, "fullText contains 'report'") return testutil.SampleDriveFiles(2), nil }, } @@ -206,20 +202,20 @@ func TestSearchCommand_Success(t *testing.T) { cmd.SetArgs([]string{"report"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file_a") - assert.Contains(t, output, "2 file(s)") + testutil.Contains(t, output, "file_a") + testutil.Contains(t, output, "2 file(s)") }) } func TestSearchCommand_NameOnly(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { - assert.Contains(t, query, "name contains 'budget'") + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, query string, _ int64) ([]*driveapi.File, error) { + testutil.Contains(t, query, "name contains 'budget'") return testutil.SampleDriveFiles(1), nil }, } @@ -228,18 +224,18 @@ func TestSearchCommand_NameOnly(t *testing.T) { cmd.SetArgs([]string{"budget", "--name"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file_a") + testutil.Contains(t, output, "file_a") }) } func TestSearchCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return testutil.SampleDriveFiles(1), nil }, } @@ -248,21 +244,21 @@ func TestSearchCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"test", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var files []*driveapi.File err := json.Unmarshal([]byte(output), &files) - assert.NoError(t, err) - assert.Len(t, files, 1) + testutil.NoError(t, err) + testutil.Len(t, files, 1) }) } func TestSearchCommand_NoResults(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return []*driveapi.File{}, nil }, } @@ -271,18 +267,41 @@ func TestSearchCommand_NoResults(t *testing.T) { cmd.SetArgs([]string{"nonexistent"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No files found") + testutil.Contains(t, output, "No files found") + }) +} + +func TestSearchCommand_NoResults_JSON(t *testing.T) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { + return []*driveapi.File{}, nil + }, + } + + cmd := newSearchCommand() + cmd.SetArgs([]string{"nonexistent", "--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var files []any + err := json.Unmarshal([]byte(output), &files) + testutil.NoError(t, err) + testutil.Len(t, files, 0) }) } func TestSearchCommand_APIError(t *testing.T) { - mock := &testutil.MockDriveClient{ - ListFilesFunc: func(query string, pageSize int64) ([]*driveapi.File, error) { + mock := &MockDriveClient{ + ListFilesFunc: func(_ context.Context, _ string, _ int64) ([]*driveapi.File, error) { return nil, errors.New("API error") }, } @@ -292,15 +311,15 @@ func TestSearchCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to search files") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "searching files") }) } func TestGetCommand_Success(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { - assert.Equal(t, "file123", fileID) + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, fileID string) (*driveapi.File, error) { + testutil.Equal(t, fileID, "file123") return testutil.SampleDriveFile("file123"), nil }, } @@ -309,20 +328,20 @@ func TestGetCommand_Success(t *testing.T) { cmd.SetArgs([]string{"file123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "file123") - assert.Contains(t, output, "test-document.pdf") - assert.Contains(t, output, "owner@example.com") + testutil.Contains(t, output, "file123") + testutil.Contains(t, output, "test-document.pdf") + testutil.Contains(t, output, "owner@example.com") }) } func TestGetCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleDriveFile("file123"), nil }, } @@ -331,21 +350,21 @@ func TestGetCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"file123", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var file driveapi.File err := json.Unmarshal([]byte(output), &file) - assert.NoError(t, err) - assert.Equal(t, "file123", file.ID) + testutil.NoError(t, err) + testutil.Equal(t, file.ID, "file123") }) } func TestGetCommand_NotFound(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return nil, errors.New("file not found") }, } @@ -355,8 +374,8 @@ func TestGetCommand_NotFound(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get file") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "getting file") }) } @@ -367,12 +386,12 @@ func TestDownloadCommand_RegularFile(t *testing.T) { os.Chdir(tmpDir) defer os.Chdir(origDir) - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleDriveFile("file123"), nil }, - DownloadFileFunc: func(fileID string) ([]byte, error) { - assert.Equal(t, "file123", fileID) + DownloadFileFunc: func(_ context.Context, fileID string) ([]byte, error) { + testutil.Equal(t, fileID, "file123") return []byte("test content"), nil }, } @@ -381,22 +400,22 @@ func TestDownloadCommand_RegularFile(t *testing.T) { cmd.SetArgs([]string{"file123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Downloading") - assert.Contains(t, output, "Saved to") + testutil.Contains(t, output, "Downloading") + testutil.Contains(t, output, "Saved to") }) } func TestDownloadCommand_ToStdout(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleDriveFile("file123"), nil }, - DownloadFileFunc: func(fileID string) ([]byte, error) { + DownloadFileFunc: func(_ context.Context, _ string) ([]byte, error) { return []byte("test content"), nil }, } @@ -405,18 +424,18 @@ func TestDownloadCommand_ToStdout(t *testing.T) { cmd.SetArgs([]string{"file123", "--stdout"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Equal(t, "test content", output) + testutil.Equal(t, output, "test content") }) } func TestDownloadCommand_GoogleDocRequiresFormat(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleGoogleDoc("doc123"), nil }, } @@ -426,8 +445,8 @@ func TestDownloadCommand_GoogleDocRequiresFormat(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "requires --format flag") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "requires --format flag") }) } @@ -438,13 +457,13 @@ func TestDownloadCommand_ExportGoogleDoc(t *testing.T) { os.Chdir(tmpDir) defer os.Chdir(origDir) - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleGoogleDoc("doc123"), nil }, - ExportFileFunc: func(fileID, mimeType string) ([]byte, error) { - assert.Equal(t, "doc123", fileID) - assert.Contains(t, mimeType, "pdf") + ExportFileFunc: func(_ context.Context, fileID, mimeType string) ([]byte, error) { + testutil.Equal(t, fileID, "doc123") + testutil.Contains(t, mimeType, "pdf") return []byte("pdf content"), nil }, } @@ -453,19 +472,19 @@ func TestDownloadCommand_ExportGoogleDoc(t *testing.T) { cmd.SetArgs([]string{"doc123", "--format", "pdf"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Exporting") - assert.Contains(t, output, "Saved to") + testutil.Contains(t, output, "Exporting") + testutil.Contains(t, output, "Saved to") }) } func TestDownloadCommand_RegularFileCannotUseFormat(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleDriveFile("file123"), nil }, } @@ -475,17 +494,17 @@ func TestDownloadCommand_RegularFileCannotUseFormat(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "--format flag is only for Google Workspace files") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "--format flag is only for Google Workspace files") }) } func TestDownloadCommand_APIError(t *testing.T) { - mock := &testutil.MockDriveClient{ - GetFileFunc: func(fileID string) (*driveapi.File, error) { + mock := &MockDriveClient{ + GetFileFunc: func(_ context.Context, _ string) (*driveapi.File, error) { return testutil.SampleDriveFile("file123"), nil }, - DownloadFileFunc: func(fileID string) ([]byte, error) { + DownloadFileFunc: func(_ context.Context, _ string) ([]byte, error) { return nil, errors.New("download failed") }, } @@ -495,8 +514,8 @@ func TestDownloadCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to download file") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "downloading file") }) } @@ -506,7 +525,7 @@ func TestDownloadCommand_ClientCreationError(t *testing.T) { withFailingClientFactory(func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create Drive client") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating Drive client") }) } diff --git a/internal/cmd/drive/list.go b/internal/cmd/drive/list.go index 554f43e..cf3c9d4 100644 --- a/internal/cmd/drive/list.go +++ b/internal/cmd/drive/list.go @@ -1,6 +1,7 @@ package drive import ( + "context" "fmt" "os" "strings" @@ -45,9 +46,9 @@ File types: document, spreadsheet, presentation, folder, pdf, image, video, audi return fmt.Errorf("--my-drive and --drive are mutually exclusive") } - client, err := newDriveClient() + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } folderID := "" @@ -56,22 +57,27 @@ File types: document, spreadsheet, presentation, folder, pdf, image, video, audi } // Resolve drive scope for listing - scope, err := resolveDriveScopeForList(client, myDrive, driveFlag, folderID) + ctx := cmd.Context() + scope, err := resolveDriveScopeForList(ctx, client, myDrive, driveFlag, folderID) if err != nil { - return fmt.Errorf("failed to resolve drive scope: %w", err) + return fmt.Errorf("resolving drive scope: %w", err) } query, err := buildListQueryWithScope(folderID, fileType, scope) if err != nil { - return fmt.Errorf("failed to build query: %w", err) + return fmt.Errorf("building query: %w", err) } - files, err := client.ListFilesWithScope(query, maxResults, scope) + files, err := client.ListFilesWithScope(ctx, query, maxResults, scope) if err != nil { - return fmt.Errorf("failed to list files: %w", err) + return fmt.Errorf("listing files: %w", err) } if len(files) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No files found.") return nil } @@ -141,14 +147,14 @@ func buildListQueryWithScope(folderID, fileType string, scope drive.DriveScope) // resolveDriveScopeForList resolves the scope for list operations // List has slightly different behavior - defaults to My Drive root if no flags -func resolveDriveScopeForList(client drive.DriveClientInterface, myDrive bool, driveFlag, folderID string) (drive.DriveScope, error) { +func resolveDriveScopeForList(ctx context.Context, client DriveClient, myDrive bool, driveFlag, folderID string) (drive.DriveScope, error) { // If a folder ID is provided, we need to support all drives to access it if folderID != "" && !myDrive && driveFlag == "" { return drive.DriveScope{AllDrives: true}, nil } // Otherwise use the standard resolution - return resolveDriveScope(client, myDrive, driveFlag) + return resolveDriveScope(ctx, client, myDrive, driveFlag) } // getMimeTypeFilter returns the Drive API query filter for a file type diff --git a/internal/cmd/drive/list_test.go b/internal/cmd/drive/list_test.go index 1a9c76e..162cba1 100644 --- a/internal/cmd/drive/list_test.go +++ b/internal/cmd/drive/list_test.go @@ -3,139 +3,139 @@ package drive import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestListCommand(t *testing.T) { cmd := newListCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "list [folder-id]", cmd.Use) + testutil.Equal(t, cmd.Use, "list [folder-id]") }) t.Run("accepts zero or one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"folder-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"folder-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "25", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "25") }) t.Run("has type flag", func(t *testing.T) { flag := cmd.Flags().Lookup("type") - assert.NotNil(t, flag) - assert.Equal(t, "t", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "t") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "List") + testutil.Contains(t, cmd.Short, "List") }) } func TestBuildListQuery(t *testing.T) { t.Run("builds query for root folder", func(t *testing.T) { query, err := buildListQuery("", "") - assert.NoError(t, err) - assert.Contains(t, query, "trashed = false") - assert.Contains(t, query, "'root' in parents") + testutil.NoError(t, err) + testutil.Contains(t, query, "trashed = false") + testutil.Contains(t, query, "'root' in parents") }) t.Run("builds query for specific folder", func(t *testing.T) { query, err := buildListQuery("folder123", "") - assert.NoError(t, err) - assert.Contains(t, query, "'folder123' in parents") - assert.NotContains(t, query, "'root' in parents") + testutil.NoError(t, err) + testutil.Contains(t, query, "'folder123' in parents") + testutil.NotContains(t, query, "'root' in parents") }) t.Run("adds type filter for document", func(t *testing.T) { query, err := buildListQuery("", "document") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") }) t.Run("adds type filter for spreadsheet", func(t *testing.T) { query, err := buildListQuery("", "spreadsheet") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.spreadsheet'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.spreadsheet'") }) t.Run("adds type filter for presentation", func(t *testing.T) { query, err := buildListQuery("", "presentation") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.presentation'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.presentation'") }) t.Run("adds type filter for folder", func(t *testing.T) { query, err := buildListQuery("", "folder") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.folder'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.folder'") }) t.Run("adds type filter for pdf", func(t *testing.T) { query, err := buildListQuery("", "pdf") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/pdf'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/pdf'") }) t.Run("adds type filter for image", func(t *testing.T) { query, err := buildListQuery("", "image") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType contains 'image/'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType contains 'image/'") }) t.Run("adds type filter for video", func(t *testing.T) { query, err := buildListQuery("", "video") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType contains 'video/'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType contains 'video/'") }) t.Run("adds type filter for audio", func(t *testing.T) { query, err := buildListQuery("", "audio") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType contains 'audio/'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType contains 'audio/'") }) t.Run("returns error for unknown type", func(t *testing.T) { _, err := buildListQuery("", "unknown") - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown file type") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "unknown file type") }) t.Run("accepts type aliases", func(t *testing.T) { query, err := buildListQuery("", "doc") - assert.NoError(t, err) - assert.Contains(t, query, "application/vnd.google-apps.document") + testutil.NoError(t, err) + testutil.Contains(t, query, "application/vnd.google-apps.document") query, err = buildListQuery("", "sheet") - assert.NoError(t, err) - assert.Contains(t, query, "application/vnd.google-apps.spreadsheet") + testutil.NoError(t, err) + testutil.Contains(t, query, "application/vnd.google-apps.spreadsheet") query, err = buildListQuery("", "slides") - assert.NoError(t, err) - assert.Contains(t, query, "application/vnd.google-apps.presentation") + testutil.NoError(t, err) + testutil.Contains(t, query, "application/vnd.google-apps.presentation") }) t.Run("is case insensitive for type", func(t *testing.T) { query, err := buildListQuery("", "DOCUMENT") - assert.NoError(t, err) - assert.Contains(t, query, "application/vnd.google-apps.document") + testutil.NoError(t, err) + testutil.Contains(t, query, "application/vnd.google-apps.document") }) } @@ -163,10 +163,10 @@ func TestGetMimeTypeFilter(t *testing.T) { t.Run(tt.fileType, func(t *testing.T) { result, err := getMimeTypeFilter(tt.fileType) if tt.hasError { - assert.Error(t, err) + testutil.Error(t, err) } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) + testutil.NoError(t, err) + testutil.Equal(t, result, tt.expected) } }) } diff --git a/internal/cmd/drive/mock_test.go b/internal/cmd/drive/mock_test.go new file mode 100644 index 0000000..005a789 --- /dev/null +++ b/internal/cmd/drive/mock_test.go @@ -0,0 +1,66 @@ +package drive + +import ( + "context" + + driveapi "github.com/open-cli-collective/google-readonly/internal/drive" +) + +// MockDriveClient is a configurable mock for DriveClient. +type MockDriveClient struct { + ListFilesFunc func(ctx context.Context, query string, pageSize int64) ([]*driveapi.File, error) + ListFilesWithScopeFunc func(ctx context.Context, query string, pageSize int64, scope driveapi.DriveScope) ([]*driveapi.File, error) + GetFileFunc func(ctx context.Context, fileID string) (*driveapi.File, error) + DownloadFileFunc func(ctx context.Context, fileID string) ([]byte, error) + ExportFileFunc func(ctx context.Context, fileID, mimeType string) ([]byte, error) + ListSharedDrivesFunc func(ctx context.Context, pageSize int64) ([]*driveapi.SharedDrive, error) +} + +// Verify MockDriveClient implements DriveClient +var _ DriveClient = (*MockDriveClient)(nil) + +func (m *MockDriveClient) ListFiles(ctx context.Context, query string, pageSize int64) ([]*driveapi.File, error) { + if m.ListFilesFunc != nil { + return m.ListFilesFunc(ctx, query, pageSize) + } + return nil, nil +} + +func (m *MockDriveClient) ListFilesWithScope(ctx context.Context, query string, pageSize int64, scope driveapi.DriveScope) ([]*driveapi.File, error) { + if m.ListFilesWithScopeFunc != nil { + return m.ListFilesWithScopeFunc(ctx, query, pageSize, scope) + } + // Fall back to ListFiles if no scope function defined + if m.ListFilesFunc != nil { + return m.ListFilesFunc(ctx, query, pageSize) + } + return nil, nil +} + +func (m *MockDriveClient) GetFile(ctx context.Context, fileID string) (*driveapi.File, error) { + if m.GetFileFunc != nil { + return m.GetFileFunc(ctx, fileID) + } + return nil, nil +} + +func (m *MockDriveClient) DownloadFile(ctx context.Context, fileID string) ([]byte, error) { + if m.DownloadFileFunc != nil { + return m.DownloadFileFunc(ctx, fileID) + } + return nil, nil +} + +func (m *MockDriveClient) ExportFile(ctx context.Context, fileID, mimeType string) ([]byte, error) { + if m.ExportFileFunc != nil { + return m.ExportFileFunc(ctx, fileID, mimeType) + } + return nil, nil +} + +func (m *MockDriveClient) ListSharedDrives(ctx context.Context, pageSize int64) ([]*driveapi.SharedDrive, error) { + if m.ListSharedDrivesFunc != nil { + return m.ListSharedDrivesFunc(ctx, pageSize) + } + return nil, nil +} diff --git a/internal/cmd/drive/output.go b/internal/cmd/drive/output.go index 4c7ceb4..ea0d36b 100644 --- a/internal/cmd/drive/output.go +++ b/internal/cmd/drive/output.go @@ -7,15 +7,25 @@ import ( "github.com/open-cli-collective/google-readonly/internal/output" ) +// DriveClient defines the interface for Drive client operations used by drive commands. +type DriveClient interface { + ListFiles(ctx context.Context, query string, pageSize int64) ([]*drive.File, error) + ListFilesWithScope(ctx context.Context, query string, pageSize int64, scope drive.DriveScope) ([]*drive.File, error) + GetFile(ctx context.Context, fileID string) (*drive.File, error) + DownloadFile(ctx context.Context, fileID string) ([]byte, error) + ExportFile(ctx context.Context, fileID string, mimeType string) ([]byte, error) + ListSharedDrives(ctx context.Context, pageSize int64) ([]*drive.SharedDrive, error) +} + // ClientFactory is the function used to create Drive clients. // Override in tests to inject mocks. -var ClientFactory = func() (drive.DriveClientInterface, error) { - return drive.NewClient(context.Background()) +var ClientFactory = func(ctx context.Context) (DriveClient, error) { + return drive.NewClient(ctx) } // newDriveClient creates and returns a new Drive client -func newDriveClient() (drive.DriveClientInterface, error) { - return ClientFactory() +func newDriveClient(ctx context.Context) (DriveClient, error) { + return ClientFactory(ctx) } // printJSON encodes data as indented JSON to stdout diff --git a/internal/cmd/drive/search.go b/internal/cmd/drive/search.go index 625a64c..0d9fa47 100644 --- a/internal/cmd/drive/search.go +++ b/internal/cmd/drive/search.go @@ -50,9 +50,9 @@ File types: document, spreadsheet, presentation, folder, pdf, image, video, audi return fmt.Errorf("--my-drive and --drive are mutually exclusive") } - client, err := newDriveClient() + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } query := "" @@ -62,21 +62,26 @@ File types: document, spreadsheet, presentation, folder, pdf, image, video, audi searchQuery, err := buildSearchQuery(query, nameOnly, fileType, owner, modAfter, modBefore, inFolder) if err != nil { - return fmt.Errorf("failed to build search query: %w", err) + return fmt.Errorf("building search query: %w", err) } // Resolve drive scope - scope, err := resolveDriveScope(client, myDrive, driveFlag) + ctx := cmd.Context() + scope, err := resolveDriveScope(ctx, client, myDrive, driveFlag) if err != nil { - return fmt.Errorf("failed to resolve drive scope: %w", err) + return fmt.Errorf("resolving drive scope: %w", err) } - files, err := client.ListFilesWithScope(searchQuery, maxResults, scope) + files, err := client.ListFilesWithScope(ctx, searchQuery, maxResults, scope) if err != nil { - return fmt.Errorf("failed to search files: %w", err) + return fmt.Errorf("searching files: %w", err) } if len(files) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } if query != "" { fmt.Printf("No files found matching \"%s\".\n", query) } else { diff --git a/internal/cmd/drive/search_test.go b/internal/cmd/drive/search_test.go index a7f8a30..a9b82f0 100644 --- a/internal/cmd/drive/search_test.go +++ b/internal/cmd/drive/search_test.go @@ -3,173 +3,173 @@ package drive import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestSearchCommand(t *testing.T) { cmd := newSearchCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "search [query]", cmd.Use) + testutil.Equal(t, cmd.Use, "search [query]") }) t.Run("accepts zero or one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"query"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"query1", "query2"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "25", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "25") }) t.Run("has name flag", func(t *testing.T) { flag := cmd.Flags().Lookup("name") - assert.NotNil(t, flag) - assert.Equal(t, "n", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "n") }) t.Run("has type flag", func(t *testing.T) { flag := cmd.Flags().Lookup("type") - assert.NotNil(t, flag) - assert.Equal(t, "t", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "t") }) t.Run("has owner flag", func(t *testing.T) { flag := cmd.Flags().Lookup("owner") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has modified-after flag", func(t *testing.T) { flag := cmd.Flags().Lookup("modified-after") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has modified-before flag", func(t *testing.T) { flag := cmd.Flags().Lookup("modified-before") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has in-folder flag", func(t *testing.T) { flag := cmd.Flags().Lookup("in-folder") - assert.NotNil(t, flag) + testutil.NotNil(t, flag) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "Search") + testutil.Contains(t, cmd.Short, "Search") }) } func TestBuildSearchQuery(t *testing.T) { t.Run("builds full-text search query", func(t *testing.T) { query, err := buildSearchQuery("quarterly report", false, "", "", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "trashed = false") - assert.Contains(t, query, "fullText contains 'quarterly report'") + testutil.NoError(t, err) + testutil.Contains(t, query, "trashed = false") + testutil.Contains(t, query, "fullText contains 'quarterly report'") }) t.Run("builds name-only search query", func(t *testing.T) { query, err := buildSearchQuery("budget", true, "", "", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "name contains 'budget'") - assert.NotContains(t, query, "fullText") + testutil.NoError(t, err) + testutil.Contains(t, query, "name contains 'budget'") + testutil.NotContains(t, query, "fullText") }) t.Run("adds type filter", func(t *testing.T) { query, err := buildSearchQuery("test", false, "document", "", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") + testutil.NoError(t, err) + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") }) t.Run("returns error for invalid type", func(t *testing.T) { _, err := buildSearchQuery("test", false, "invalid", "", "", "", "") - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown file type") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "unknown file type") }) t.Run("adds owner filter with 'me'", func(t *testing.T) { query, err := buildSearchQuery("", false, "", "me", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "'me' in owners") + testutil.NoError(t, err) + testutil.Contains(t, query, "'me' in owners") }) t.Run("adds owner filter with email", func(t *testing.T) { query, err := buildSearchQuery("", false, "", "john@example.com", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "'john@example.com' in owners") + testutil.NoError(t, err) + testutil.Contains(t, query, "'john@example.com' in owners") }) t.Run("adds modified-after filter", func(t *testing.T) { query, err := buildSearchQuery("", false, "", "", "2024-01-01", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "modifiedTime > '2024-01-01T00:00:00'") + testutil.NoError(t, err) + testutil.Contains(t, query, "modifiedTime > '2024-01-01T00:00:00'") }) t.Run("adds modified-before filter", func(t *testing.T) { query, err := buildSearchQuery("", false, "", "", "", "2024-12-31", "") - assert.NoError(t, err) - assert.Contains(t, query, "modifiedTime < '2024-12-31T23:59:59'") + testutil.NoError(t, err) + testutil.Contains(t, query, "modifiedTime < '2024-12-31T23:59:59'") }) t.Run("adds folder scope", func(t *testing.T) { query, err := buildSearchQuery("", false, "", "", "", "", "folder123") - assert.NoError(t, err) - assert.Contains(t, query, "'folder123' in parents") + testutil.NoError(t, err) + testutil.Contains(t, query, "'folder123' in parents") }) t.Run("combines multiple filters", func(t *testing.T) { query, err := buildSearchQuery("report", false, "document", "me", "2024-01-01", "", "folder123") - assert.NoError(t, err) - assert.Contains(t, query, "trashed = false") - assert.Contains(t, query, "fullText contains 'report'") - assert.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") - assert.Contains(t, query, "'me' in owners") - assert.Contains(t, query, "modifiedTime > '2024-01-01T00:00:00'") - assert.Contains(t, query, "'folder123' in parents") + testutil.NoError(t, err) + testutil.Contains(t, query, "trashed = false") + testutil.Contains(t, query, "fullText contains 'report'") + testutil.Contains(t, query, "mimeType = 'application/vnd.google-apps.document'") + testutil.Contains(t, query, "'me' in owners") + testutil.Contains(t, query, "modifiedTime > '2024-01-01T00:00:00'") + testutil.Contains(t, query, "'folder123' in parents") }) t.Run("builds query with no search term", func(t *testing.T) { query, err := buildSearchQuery("", false, "document", "", "", "", "") - assert.NoError(t, err) - assert.Contains(t, query, "trashed = false") - assert.Contains(t, query, "mimeType") - assert.NotContains(t, query, "fullText") - assert.NotContains(t, query, "name contains") + testutil.NoError(t, err) + testutil.Contains(t, query, "trashed = false") + testutil.Contains(t, query, "mimeType") + testutil.NotContains(t, query, "fullText") + testutil.NotContains(t, query, "name contains") }) } func TestEscapeQueryString(t *testing.T) { t.Run("escapes single quotes", func(t *testing.T) { result := escapeQueryString("it's a test") - assert.Equal(t, "it\\'s a test", result) + testutil.Equal(t, result, "it\\'s a test") }) t.Run("handles string without quotes", func(t *testing.T) { result := escapeQueryString("simple query") - assert.Equal(t, "simple query", result) + testutil.Equal(t, result, "simple query") }) t.Run("handles multiple quotes", func(t *testing.T) { result := escapeQueryString("don't won't can't") - assert.Equal(t, "don\\'t won\\'t can\\'t", result) + testutil.Equal(t, result, "don\\'t won\\'t can\\'t") }) t.Run("handles empty string", func(t *testing.T) { result := escapeQueryString("") - assert.Equal(t, "", result) + testutil.Equal(t, result, "") }) } diff --git a/internal/cmd/drive/tree.go b/internal/cmd/drive/tree.go index c92fc1e..a55f45d 100644 --- a/internal/cmd/drive/tree.go +++ b/internal/cmd/drive/tree.go @@ -1,6 +1,7 @@ package drive import ( + "context" "fmt" "sort" @@ -48,9 +49,9 @@ Examples: return fmt.Errorf("--my-drive and --drive are mutually exclusive") } - client, err := newDriveClient() + client, err := newDriveClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Drive client: %w", err) + return fmt.Errorf("creating Drive client: %w", err) } folderID := "root" @@ -61,18 +62,18 @@ Examples: rootName = "" // Will be fetched from folder info } else if driveFlag != "" { // Resolve shared drive - scope, err := resolveDriveScope(client, false, driveFlag) + scope, err := resolveDriveScope(cmd.Context(), client, false, driveFlag) if err != nil { - return fmt.Errorf("failed to resolve drive: %w", err) + return fmt.Errorf("resolving drive: %w", err) } folderID = scope.DriveID rootName = driveFlag // Use the provided name } // Build the tree - tree, err := buildTreeWithScope(client, folderID, rootName, depth, files) + tree, err := buildTreeWithScope(cmd.Context(), client, folderID, rootName, depth, files) if err != nil { - return fmt.Errorf("failed to build folder tree: %w", err) + return fmt.Errorf("building folder tree: %w", err) } if jsonOutput { @@ -94,12 +95,12 @@ Examples: } // buildTree recursively builds the folder tree structure -func buildTree(client drive.DriveClientInterface, folderID string, depth int, includeFiles bool) (*TreeNode, error) { - return buildTreeWithScope(client, folderID, "", depth, includeFiles) +func buildTree(ctx context.Context, client DriveClient, folderID string, depth int, includeFiles bool) (*TreeNode, error) { + return buildTreeWithScope(ctx, client, folderID, "", depth, includeFiles) } // buildTreeWithScope builds folder tree with optional root name override -func buildTreeWithScope(client drive.DriveClientInterface, folderID, rootName string, depth int, includeFiles bool) (*TreeNode, error) { +func buildTreeWithScope(ctx context.Context, client DriveClient, folderID, rootName string, depth int, includeFiles bool) (*TreeNode, error) { // Get folder info var folderName string var folderType string @@ -111,9 +112,9 @@ func buildTreeWithScope(client drive.DriveClientInterface, folderID, rootName st folderName = rootName folderType = "Shared Drive" } else { - folder, err := client.GetFile(folderID) + folder, err := client.GetFile(ctx, folderID) if err != nil { - return nil, fmt.Errorf("failed to get folder info: %w", err) + return nil, fmt.Errorf("getting folder info: %w", err) } folderName = folder.Name folderType = drive.GetTypeName(folder.MimeType) @@ -138,9 +139,9 @@ func buildTreeWithScope(client drive.DriveClientInterface, folderID, rootName st // Use ListFilesWithScope to support shared drives scope := drive.DriveScope{AllDrives: true} - children, err := client.ListFilesWithScope(query, 100, scope) + children, err := client.ListFilesWithScope(ctx, query, 100, scope) if err != nil { - return nil, fmt.Errorf("failed to list children: %w", err) + return nil, fmt.Errorf("listing children: %w", err) } // Sort children: folders first, then by name @@ -157,7 +158,7 @@ func buildTreeWithScope(client drive.DriveClientInterface, folderID, rootName st for _, child := range children { if child.MimeType == drive.MimeTypeFolder { // Recursively build subtree for folders (don't pass rootName on recursion) - childNode, err := buildTreeWithScope(client, child.ID, "", depth-1, includeFiles) + childNode, err := buildTreeWithScope(ctx, client, child.ID, "", depth-1, includeFiles) if err != nil { // Log error but continue with other children continue diff --git a/internal/cmd/drive/tree_test.go b/internal/cmd/drive/tree_test.go index 40c0673..6f3b538 100644 --- a/internal/cmd/drive/tree_test.go +++ b/internal/cmd/drive/tree_test.go @@ -1,76 +1,61 @@ package drive import ( - "bytes" + "context" "fmt" - "io" - "os" "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/open-cli-collective/google-readonly/internal/drive" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestTreeCommand(t *testing.T) { cmd := newTreeCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "tree [folder-id]", cmd.Use) + testutil.Equal(t, cmd.Use, "tree [folder-id]") }) t.Run("accepts zero or one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"folder-id"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"folder-id", "extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has depth flag", func(t *testing.T) { flag := cmd.Flags().Lookup("depth") - assert.NotNil(t, flag) - assert.Equal(t, "d", flag.Shorthand) - assert.Equal(t, "2", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "d") + testutil.Equal(t, flag.DefValue, "2") }) t.Run("has files flag", func(t *testing.T) { flag := cmd.Flags().Lookup("files") - assert.NotNil(t, flag) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.Contains(t, cmd.Short, "folder structure") + testutil.Contains(t, cmd.Short, "folder structure") }) } func TestPrintTree(t *testing.T) { - // Capture stdout for testing captureOutput := func(fn func()) string { - old := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - fn() - - w.Close() - os.Stdout = old - - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() + return testutil.CaptureStdout(t, fn) } t.Run("prints single node", func(t *testing.T) { @@ -84,7 +69,7 @@ func TestPrintTree(t *testing.T) { printTree(node, "", true) }) - assert.Equal(t, "My Drive\n", output) + testutil.Equal(t, output, "My Drive\n") }) t.Run("prints tree with children", func(t *testing.T) { @@ -102,9 +87,9 @@ func TestPrintTree(t *testing.T) { printTree(node, "", true) }) - assert.Contains(t, output, "My Drive") - assert.Contains(t, output, "├── Documents") - assert.Contains(t, output, "└── Photos") + testutil.Contains(t, output, "My Drive") + testutil.Contains(t, output, "├── Documents") + testutil.Contains(t, output, "└── Photos") }) t.Run("prints nested tree", func(t *testing.T) { @@ -130,11 +115,11 @@ func TestPrintTree(t *testing.T) { printTree(node, "", true) }) - assert.Contains(t, output, "My Drive") - assert.Contains(t, output, "├── Projects") - assert.Contains(t, output, "│ ├── Project A") - assert.Contains(t, output, "│ └── Project B") - assert.Contains(t, output, "└── Documents") + testutil.Contains(t, output, "My Drive") + testutil.Contains(t, output, "├── Projects") + testutil.Contains(t, output, "│ ├── Project A") + testutil.Contains(t, output, "│ └── Project B") + testutil.Contains(t, output, "└── Documents") }) t.Run("prints deeply nested tree", func(t *testing.T) { @@ -165,10 +150,10 @@ func TestPrintTree(t *testing.T) { printTree(node, "", true) }) - assert.Contains(t, output, "Root") - assert.Contains(t, output, "└── Level1") - assert.Contains(t, output, " └── Level2") - assert.Contains(t, output, " └── Level3") + testutil.Contains(t, output, "Root") + testutil.Contains(t, output, "└── Level1") + testutil.Contains(t, output, " └── Level2") + testutil.Contains(t, output, " └── Level3") }) t.Run("handles empty children", func(t *testing.T) { @@ -183,7 +168,7 @@ func TestPrintTree(t *testing.T) { printTree(node, "", true) }) - assert.Equal(t, "Empty Folder\n", output) + testutil.Equal(t, output, "Empty Folder\n") }) } @@ -198,10 +183,10 @@ func TestTreeNode(t *testing.T) { }, } - assert.Equal(t, "abc123", node.ID) - assert.Equal(t, "Test", node.Name) - assert.Equal(t, "Folder", node.Type) - assert.Len(t, node.Children, 1) + testutil.Equal(t, node.ID, "abc123") + testutil.Equal(t, node.Name, "Test") + testutil.Equal(t, node.Type, "Folder") + testutil.Len(t, node.Children, 1) }) t.Run("handles nil children", func(t *testing.T) { @@ -212,11 +197,11 @@ func TestTreeNode(t *testing.T) { Children: nil, } - assert.Nil(t, node.Children) + testutil.Nil(t, node.Children) }) } -// mockDriveClient implements drive.DriveClientInterface for testing +// mockDriveClient implements DriveClient for testing type mockDriveClient struct { files map[string]*drive.File // fileID -> File children map[string][]*drive.File // folderID -> children @@ -229,14 +214,14 @@ func newMockDriveClient() *mockDriveClient { } } -func (m *mockDriveClient) GetFile(fileID string) (*drive.File, error) { +func (m *mockDriveClient) GetFile(_ context.Context, fileID string) (*drive.File, error) { if f, ok := m.files[fileID]; ok { return f, nil } return nil, fmt.Errorf("file not found: %s", fileID) } -func (m *mockDriveClient) ListFiles(query string, _ int64) ([]*drive.File, error) { +func (m *mockDriveClient) ListFiles(_ context.Context, query string, _ int64) ([]*drive.File, error) { // Extract folderID from query like "'folder123' in parents and trashed = false" // The query format is: "'' in parents and trashed = false" for folderID, files := range m.children { @@ -248,20 +233,20 @@ func (m *mockDriveClient) ListFiles(query string, _ int64) ([]*drive.File, error return []*drive.File{}, nil } -func (m *mockDriveClient) ListFilesWithScope(query string, pageSize int64, _ drive.DriveScope) ([]*drive.File, error) { +func (m *mockDriveClient) ListFilesWithScope(ctx context.Context, query string, pageSize int64, _ drive.DriveScope) ([]*drive.File, error) { // Delegate to ListFiles for testing purposes - return m.ListFiles(query, pageSize) + return m.ListFiles(ctx, query, pageSize) } -func (m *mockDriveClient) DownloadFile(_ string) ([]byte, error) { +func (m *mockDriveClient) DownloadFile(_ context.Context, _ string) ([]byte, error) { return nil, fmt.Errorf("not implemented") } -func (m *mockDriveClient) ExportFile(_ string, _ string) ([]byte, error) { +func (m *mockDriveClient) ExportFile(_ context.Context, _ string, _ string) ([]byte, error) { return nil, fmt.Errorf("not implemented") } -func (m *mockDriveClient) ListSharedDrives(_ int64) ([]*drive.SharedDrive, error) { +func (m *mockDriveClient) ListSharedDrives(_ context.Context, _ int64) ([]*drive.SharedDrive, error) { return nil, fmt.Errorf("not implemented") } @@ -277,13 +262,13 @@ func TestBuildTree(t *testing.T) { mock.files["folder1"] = &drive.File{ID: "folder1", Name: "Documents", MimeType: drive.MimeTypeFolder} mock.files["folder2"] = &drive.File{ID: "folder2", Name: "Photos", MimeType: drive.MimeTypeFolder} - tree, err := buildTree(mock, "root", 1, false) + tree, err := buildTree(context.Background(), mock, "root", 1, false) - assert.NoError(t, err) - assert.Equal(t, "root", tree.ID) - assert.Equal(t, "My Drive", tree.Name) - assert.Equal(t, "Folder", tree.Type) - assert.Len(t, tree.Children, 2) + testutil.NoError(t, err) + testutil.Equal(t, tree.ID, "root") + testutil.Equal(t, tree.Name, "My Drive") + testutil.Equal(t, tree.Type, "Folder") + testutil.Len(t, tree.Children, 2) }) t.Run("builds tree for specific folder", func(t *testing.T) { @@ -297,13 +282,13 @@ func TestBuildTree(t *testing.T) { {ID: "doc1", Name: "Notes.txt", MimeType: "text/plain"}, } - tree, err := buildTree(mock, "folder123", 1, true) + tree, err := buildTree(context.Background(), mock, "folder123", 1, true) - assert.NoError(t, err) - assert.Equal(t, "folder123", tree.ID) - assert.Equal(t, "My Folder", tree.Name) - assert.Len(t, tree.Children, 1) - assert.Equal(t, "Notes.txt", tree.Children[0].Name) + testutil.NoError(t, err) + testutil.Equal(t, tree.ID, "folder123") + testutil.Equal(t, tree.Name, "My Folder") + testutil.Len(t, tree.Children, 1) + testutil.Equal(t, tree.Children[0].Name, "Notes.txt") }) t.Run("respects depth limit", func(t *testing.T) { @@ -318,13 +303,13 @@ func TestBuildTree(t *testing.T) { mock.files["folder2"] = &drive.File{ID: "folder2", Name: "Level2", MimeType: drive.MimeTypeFolder} // With depth 1, should not recurse into Level1 - tree, err := buildTree(mock, "root", 1, false) + tree, err := buildTree(context.Background(), mock, "root", 1, false) - assert.NoError(t, err) - assert.Len(t, tree.Children, 1) - assert.Equal(t, "Level1", tree.Children[0].Name) + testutil.NoError(t, err) + testutil.Len(t, tree.Children, 1) + testutil.Equal(t, tree.Children[0].Name, "Level1") // Children of Level1 should be empty due to depth limit - assert.Empty(t, tree.Children[0].Children) + testutil.Len(t, tree.Children[0].Children, 0) }) t.Run("returns node with no children at depth 0", func(t *testing.T) { @@ -333,11 +318,11 @@ func TestBuildTree(t *testing.T) { {ID: "folder1", Name: "Folder", MimeType: drive.MimeTypeFolder}, } - tree, err := buildTree(mock, "root", 0, false) + tree, err := buildTree(context.Background(), mock, "root", 0, false) - assert.NoError(t, err) - assert.Equal(t, "My Drive", tree.Name) - assert.Nil(t, tree.Children) + testutil.NoError(t, err) + testutil.Equal(t, tree.Name, "My Drive") + testutil.Nil(t, tree.Children) }) t.Run("includes files when includeFiles is true", func(t *testing.T) { @@ -348,10 +333,10 @@ func TestBuildTree(t *testing.T) { } mock.files["folder1"] = &drive.File{ID: "folder1", Name: "Docs", MimeType: drive.MimeTypeFolder} - tree, err := buildTree(mock, "root", 1, true) + tree, err := buildTree(context.Background(), mock, "root", 1, true) - assert.NoError(t, err) - assert.Len(t, tree.Children, 2) + testutil.NoError(t, err) + testutil.Len(t, tree.Children, 2) }) t.Run("sorts folders before files", func(t *testing.T) { @@ -362,12 +347,12 @@ func TestBuildTree(t *testing.T) { } mock.files["folder1"] = &drive.File{ID: "folder1", Name: "zzz-folder", MimeType: drive.MimeTypeFolder} - tree, err := buildTree(mock, "root", 1, true) + tree, err := buildTree(context.Background(), mock, "root", 1, true) - assert.NoError(t, err) - assert.Len(t, tree.Children, 2) + testutil.NoError(t, err) + testutil.Len(t, tree.Children, 2) // Folder should come first despite alphabetical order - assert.Equal(t, "zzz-folder", tree.Children[0].Name) - assert.Equal(t, "aaa.txt", tree.Children[1].Name) + testutil.Equal(t, tree.Children[0].Name, "zzz-folder") + testutil.Equal(t, tree.Children[1].Name, "aaa.txt") }) } diff --git a/internal/cmd/initcmd/init.go b/internal/cmd/initcmd/init.go index d05524e..5d7ab5f 100644 --- a/internal/cmd/initcmd/init.go +++ b/internal/cmd/initcmd/init.go @@ -1,8 +1,10 @@ +// Package initcmd implements the gro init command for OAuth setup. package initcmd import ( "bufio" "context" + "errors" "fmt" "net/http" "net/url" @@ -13,6 +15,7 @@ import ( "github.com/spf13/cobra" "google.golang.org/api/googleapi" + "github.com/open-cli-collective/google-readonly/internal/auth" "github.com/open-cli-collective/google-readonly/internal/config" "github.com/open-cli-collective/google-readonly/internal/gmail" "github.com/open-cli-collective/google-readonly/internal/keychain" @@ -45,14 +48,14 @@ Prerequisites: return cmd } -func runInit(cmd *cobra.Command, args []string) error { +func runInit(cmd *cobra.Command, _ []string) error { // Step 1: Check for credentials.json - credPath, err := gmail.GetCredentialsPath() + credPath, err := auth.GetCredentialsPath() if err != nil { - return fmt.Errorf("failed to get credentials path: %w", err) + return fmt.Errorf("getting credentials path: %w", err) } - shortPath := gmail.ShortenPath(credPath) + shortPath := auth.ShortenPath(credPath) if _, err := os.Stat(credPath); os.IsNotExist(err) { fmt.Println("Credentials file not found.") fmt.Println() @@ -62,9 +65,9 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Printf("Credentials: %s\n", shortPath) // Step 2: Load OAuth config - config, err := gmail.GetOAuthConfig() + config, err := auth.GetOAuthConfig() if err != nil { - return fmt.Errorf("failed to load OAuth config: %w", err) + return fmt.Errorf("loading OAuth config: %w", err) } // Step 3: Check if already authenticated @@ -73,7 +76,7 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Println() if !noVerify { - err := verifyConnectivity() + err := verifyConnectivity(cmd.Context()) if err == nil { promptCacheTTL() return nil @@ -88,7 +91,7 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Println() fmt.Println("Clearing old token...") if delErr := keychain.DeleteToken(); delErr != nil { - return fmt.Errorf("failed to clear token: %w", delErr) + return fmt.Errorf("clearing token: %w", delErr) } // Fall through to OAuth flow below } else { @@ -111,7 +114,7 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Println("Token: Not found - starting OAuth flow") fmt.Println() - authURL := gmail.GetAuthURL(config) + authURL := auth.GetAuthURL(config) fmt.Println("Open this URL in your browser:") fmt.Println() fmt.Println(authURL) @@ -128,7 +131,7 @@ func runInit(cmd *cobra.Command, args []string) error { reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { - return fmt.Errorf("failed to read input: %w", err) + return fmt.Errorf("reading input: %w", err) } code := extractAuthCode(input) @@ -140,22 +143,21 @@ func runInit(cmd *cobra.Command, args []string) error { fmt.Println() fmt.Println("Exchanging authorization code...") - ctx := context.Background() - token, err := gmail.ExchangeAuthCode(ctx, config, code) + token, err := auth.ExchangeAuthCode(cmd.Context(), config, code) if err != nil { - return fmt.Errorf("failed to exchange authorization code: %w", err) + return fmt.Errorf("exchanging authorization code: %w", err) } // Step 6: Save token if err := keychain.SetToken(token); err != nil { - return fmt.Errorf("failed to save token: %w", err) + return fmt.Errorf("saving token: %w", err) } fmt.Printf("Token saved to: %s\n", keychain.GetStorageBackend()) // Step 7: Verify connectivity (unless --no-verify) if !noVerify { fmt.Println() - if err := verifyConnectivity(); err != nil { + if err := verifyConnectivity(cmd.Context()); err != nil { return err } promptCacheTTL() @@ -189,21 +191,21 @@ func extractAuthCode(input string) string { } // verifyConnectivity tests the Gmail API connection -func verifyConnectivity() error { +func verifyConnectivity(ctx context.Context) error { fmt.Println("Verifying Gmail API connection...") - client, err := gmail.NewClient(context.Background()) + client, err := gmail.NewClient(ctx) if err != nil { fmt.Println(" OAuth token: FAILED") - return fmt.Errorf("failed to create client: %w", err) + return fmt.Errorf("creating client: %w", err) } fmt.Println(" OAuth token: OK") // Get profile to verify connectivity and get email address - profile, err := client.GetProfile() + profile, err := client.GetProfile(ctx) if err != nil { fmt.Println(" Gmail API: FAILED") - return fmt.Errorf("failed to access Gmail API: %w", err) + return fmt.Errorf("accessing Gmail API: %w", err) } fmt.Println(" Gmail API: OK") fmt.Printf(" Messages: %d total\n", profile.MessagesTotal) @@ -259,23 +261,7 @@ func isAuthError(err error) bool { } // errorAs is a wrapper for errors.As to make testing easier -var errorAs = func(err error, target interface{}) bool { - switch t := target.(type) { - case **googleapi.Error: - for e := err; e != nil; { - if apiErr, ok := e.(*googleapi.Error); ok { - *t = apiErr - return true - } - if unwrapper, ok := e.(interface{ Unwrap() error }); ok { - e = unwrapper.Unwrap() - } else { - break - } - } - } - return false -} +var errorAs = errors.As // promptReauth asks the user if they want to re-authenticate func promptReauth() bool { @@ -334,6 +320,6 @@ func promptCacheTTL() { } if err := config.SaveConfig(cfg); err != nil { - fmt.Printf("Warning: failed to save config: %v\n", err) + fmt.Printf("Warning: saving config: %v\n", err) } } diff --git a/internal/cmd/initcmd/init_test.go b/internal/cmd/initcmd/init_test.go index f6a9f9b..2ab041e 100644 --- a/internal/cmd/initcmd/init_test.go +++ b/internal/cmd/initcmd/init_test.go @@ -5,42 +5,50 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/googleapi" + + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestInitCommand(t *testing.T) { + t.Parallel() cmd := NewCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "init", cmd.Use) + t.Parallel() + testutil.Equal(t, cmd.Use, "init") }) t.Run("requires no arguments", func(t *testing.T) { + t.Parallel() err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has no-verify flag", func(t *testing.T) { + t.Parallel() flag := cmd.Flags().Lookup("no-verify") - assert.NotNil(t, flag) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + t.Parallel() + testutil.NotEmpty(t, cmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Long) - assert.Contains(t, cmd.Long, "OAuth") + t.Parallel() + testutil.NotEmpty(t, cmd.Long) + testutil.Contains(t, cmd.Long, "OAuth") }) } func TestExtractAuthCode(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -110,13 +118,15 @@ func TestExtractAuthCode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := extractAuthCode(tt.input) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } func TestIsAuthError(t *testing.T) { + t.Parallel() tests := []struct { name string err error @@ -176,8 +186,9 @@ func TestIsAuthError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := isAuthError(tt.err) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } diff --git a/internal/cmd/mail/attachments_download.go b/internal/cmd/mail/attachments_download.go index bea9b12..40337cf 100644 --- a/internal/cmd/mail/attachments_download.go +++ b/internal/cmd/mail/attachments_download.go @@ -1,6 +1,7 @@ package mail import ( + "context" "fmt" "os" "path/filepath" @@ -43,15 +44,15 @@ Examples: return fmt.Errorf("must specify --filename or --all") } - client, err := newGmailClient() + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } messageID := args[0] - attachments, err := client.GetAttachments(messageID) + attachments, err := client.GetAttachments(cmd.Context(), messageID) if err != nil { - return fmt.Errorf("failed to get attachments: %w", err) + return fmt.Errorf("getting attachments: %w", err) } if len(attachments) == 0 { @@ -73,13 +74,13 @@ Examples: // Create output directory if needed if err := os.MkdirAll(outputDir, config.OutputDirPerm); err != nil { - return fmt.Errorf("failed to create output directory: %w", err) + return fmt.Errorf("creating output directory: %w", err) } // Get absolute path of download directory for path validation absOutputDir, err := filepath.Abs(outputDir) if err != nil { - return fmt.Errorf("failed to resolve download directory: %w", err) + return fmt.Errorf("resolving download directory: %w", err) } // Download each attachment @@ -94,7 +95,7 @@ Examples: continue } - data, err := downloadAttachment(client, messageID, att) + data, err := downloadAttachment(cmd.Context(), client, messageID, att) if err != nil { fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n", safeFilename, err) continue @@ -135,11 +136,11 @@ Examples: return cmd } -func downloadAttachment(client gmail.GmailClientInterface, messageID string, att *gmail.Attachment) ([]byte, error) { +func downloadAttachment(ctx context.Context, client MailClient, messageID string, att *gmail.Attachment) ([]byte, error) { if att.AttachmentID != "" { - return client.DownloadAttachment(messageID, att.AttachmentID) + return client.DownloadAttachment(ctx, messageID, att.AttachmentID) } - return client.DownloadInlineAttachment(messageID, att.PartID) + return client.DownloadInlineAttachment(ctx, messageID, att.PartID) } func saveAttachment(path string, data []byte) error { diff --git a/internal/cmd/mail/attachments_download_test.go b/internal/cmd/mail/attachments_download_test.go index 46afaee..6f67c48 100644 --- a/internal/cmd/mail/attachments_download_test.go +++ b/internal/cmd/mail/attachments_download_test.go @@ -4,8 +4,7 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestSafeOutputPath(t *testing.T) { @@ -84,14 +83,14 @@ func TestSafeOutputPath(t *testing.T) { result, err := safeOutputPath(destDir, tt.filename) if tt.expectError { - assert.Error(t, err) + testutil.Error(t, err) if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) + testutil.Contains(t, err.Error(), tt.errorMsg) } } else { - require.NoError(t, err) + testutil.NoError(t, err) // Verify the result is within destDir - assert.True(t, filepath.IsAbs(result) || result == filepath.Join(destDir, filepath.Clean(tt.filename))) + testutil.True(t, filepath.IsAbs(result) || result == filepath.Join(destDir, filepath.Clean(tt.filename))) } }) } @@ -110,11 +109,11 @@ func TestSafeOutputPath_StaysWithinDestDir(t *testing.T) { for _, filename := range validCases { t.Run(filename, func(t *testing.T) { result, err := safeOutputPath(destDir, filename) - require.NoError(t, err) + testutil.NoError(t, err) // Result must start with destDir - assert.True(t, len(result) >= len(destDir)) - assert.Equal(t, destDir, result[:len(destDir)]) + testutil.True(t, len(result) >= len(destDir)) + testutil.Equal(t, result[:len(destDir)], destDir) }) } } diff --git a/internal/cmd/mail/attachments_list.go b/internal/cmd/mail/attachments_list.go index 30bc76a..658e741 100644 --- a/internal/cmd/mail/attachments_list.go +++ b/internal/cmd/mail/attachments_list.go @@ -23,17 +23,21 @@ Examples: gro mail attachments list 18abc123def456 --json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } - attachments, err := client.GetAttachments(args[0]) + attachments, err := client.GetAttachments(cmd.Context(), args[0]) if err != nil { - return fmt.Errorf("failed to get attachments: %w", err) + return fmt.Errorf("getting attachments: %w", err) } if len(attachments) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No attachments found for message.") return nil } diff --git a/internal/cmd/mail/attachments_test.go b/internal/cmd/mail/attachments_test.go index 99cc497..e498fe1 100644 --- a/internal/cmd/mail/attachments_test.go +++ b/internal/cmd/mail/attachments_test.go @@ -3,7 +3,7 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestIsZipFile(t *testing.T) { @@ -28,7 +28,7 @@ func TestIsZipFile(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isZipFile(tt.filename, tt.mimeType) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } @@ -39,19 +39,19 @@ func TestAttachmentsCommand(t *testing.T) { cmd := newAttachmentsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "attachments", cmd.Use) + testutil.Equal(t, cmd.Use, "attachments") }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 2) + testutil.GreaterOrEqual(t, len(subcommands), 2) var names []string for _, cmd := range subcommands { names = append(names, cmd.Name()) } - assert.Contains(t, names, "list") - assert.Contains(t, names, "download") + testutil.SliceContains(t, names, "list") + testutil.SliceContains(t, names, "download") }) } @@ -59,21 +59,21 @@ func TestListAttachmentsCommand(t *testing.T) { cmd := newListAttachmentsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "list ", cmd.Use) + testutil.Equal(t, cmd.Use, "list ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"msg123"}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") }) } @@ -81,15 +81,15 @@ func TestDownloadAttachmentsCommand(t *testing.T) { cmd := newDownloadAttachmentsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "download ", cmd.Use) + testutil.Equal(t, cmd.Use, "download ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"msg123"}) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("has required flags", func(t *testing.T) { @@ -105,8 +105,8 @@ func TestDownloadAttachmentsCommand(t *testing.T) { for _, f := range flags { flag := cmd.Flags().Lookup(f.name) - assert.NotNil(t, flag, "flag %s should exist", f.name) - assert.Equal(t, f.shorthand, flag.Shorthand, "flag %s should have shorthand %s", f.name, f.shorthand) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, f.shorthand) } }) } diff --git a/internal/cmd/mail/handlers_test.go b/internal/cmd/mail/handlers_test.go index 2f23986..22e265e 100644 --- a/internal/cmd/mail/handlers_test.go +++ b/internal/cmd/mail/handlers_test.go @@ -1,63 +1,36 @@ package mail import ( - "bytes" + "context" "encoding/json" "errors" - "io" - "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "google.golang.org/api/gmail/v1" gmailapi "github.com/open-cli-collective/google-readonly/internal/gmail" "github.com/open-cli-collective/google-readonly/internal/testutil" ) -// captureOutput captures stdout during test execution -func captureOutput(t *testing.T, f func()) string { - t.Helper() - old := os.Stdout - r, w, err := os.Pipe() - require.NoError(t, err) - os.Stdout = w - - f() - - w.Close() - os.Stdout = old - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() -} - // withMockClient sets up a mock client factory for tests -func withMockClient(mock gmailapi.GmailClientInterface, f func()) { - originalFactory := ClientFactory - ClientFactory = func() (gmailapi.GmailClientInterface, error) { +func withMockClient(mock MailClient, f func()) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (MailClient, error) { return mock, nil - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } // withFailingClientFactory sets up a factory that returns an error func withFailingClientFactory(f func()) { - originalFactory := ClientFactory - ClientFactory = func() (gmailapi.GmailClientInterface, error) { + testutil.WithFactory(&ClientFactory, func(_ context.Context) (MailClient, error) { return nil, errors.New("connection failed") - } - defer func() { ClientFactory = originalFactory }() - f() + }, f) } func TestSearchCommand_Success(t *testing.T) { - mock := &testutil.MockGmailClient{ - SearchMessagesFunc: func(query string, maxResults int64) ([]*gmailapi.Message, int, error) { - assert.Equal(t, "is:unread", query) - assert.Equal(t, int64(10), maxResults) + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, query string, maxResults int64) ([]*gmailapi.Message, int, error) { + testutil.Equal(t, query, "is:unread") + testutil.Equal(t, maxResults, int64(10)) return testutil.SampleMessages(2), 0, nil }, } @@ -66,21 +39,21 @@ func TestSearchCommand_Success(t *testing.T) { cmd.SetArgs([]string{"is:unread"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) // Verify output contains expected message data - assert.Contains(t, output, "ID: msg_a") - assert.Contains(t, output, "ID: msg_b") - assert.Contains(t, output, "Test Subject") + testutil.Contains(t, output, "ID: msg_a") + testutil.Contains(t, output, "ID: msg_b") + testutil.Contains(t, output, "Test Subject") }) } func TestSearchCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockGmailClient{ - SearchMessagesFunc: func(query string, maxResults int64) ([]*gmailapi.Message, int, error) { + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, _ string, _ int64) ([]*gmailapi.Message, int, error) { return testutil.SampleMessages(1), 0, nil }, } @@ -89,23 +62,23 @@ func TestSearchCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"is:unread", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) // Verify JSON output is valid var messages []*gmailapi.Message err := json.Unmarshal([]byte(output), &messages) - assert.NoError(t, err) - assert.Len(t, messages, 1) - assert.Equal(t, "msg_a", messages[0].ID) + testutil.NoError(t, err) + testutil.Len(t, messages, 1) + testutil.Equal(t, messages[0].ID, "msg_a") }) } func TestSearchCommand_NoResults(t *testing.T) { - mock := &testutil.MockGmailClient{ - SearchMessagesFunc: func(query string, maxResults int64) ([]*gmailapi.Message, int, error) { + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, _ string, _ int64) ([]*gmailapi.Message, int, error) { return []*gmailapi.Message{}, 0, nil }, } @@ -114,18 +87,41 @@ func TestSearchCommand_NoResults(t *testing.T) { cmd.SetArgs([]string{"nonexistent"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + testutil.Contains(t, output, "No messages found") + }) +} + +func TestSearchCommand_NoResults_JSON(t *testing.T) { + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, _ string, _ int64) ([]*gmailapi.Message, int, error) { + return []*gmailapi.Message{}, 0, nil + }, + } + + cmd := newSearchCommand() + cmd.SetArgs([]string{"nonexistent", "--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No messages found") + var messages []any + err := json.Unmarshal([]byte(output), &messages) + testutil.NoError(t, err) + testutil.Len(t, messages, 0) }) } func TestSearchCommand_APIError(t *testing.T) { - mock := &testutil.MockGmailClient{ - SearchMessagesFunc: func(query string, maxResults int64) ([]*gmailapi.Message, int, error) { + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, _ string, _ int64) ([]*gmailapi.Message, int, error) { return nil, 0, errors.New("API quota exceeded") }, } @@ -135,8 +131,8 @@ func TestSearchCommand_APIError(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to search messages") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "searching messages") }) } @@ -146,14 +142,14 @@ func TestSearchCommand_ClientCreationError(t *testing.T) { withFailingClientFactory(func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create Gmail client") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating Gmail client") }) } func TestSearchCommand_SkippedMessages(t *testing.T) { - mock := &testutil.MockGmailClient{ - SearchMessagesFunc: func(query string, maxResults int64) ([]*gmailapi.Message, int, error) { + mock := &MockGmailClient{ + SearchMessagesFunc: func(_ context.Context, _ string, _ int64) ([]*gmailapi.Message, int, error) { return testutil.SampleMessages(2), 3, nil // 3 messages skipped }, } @@ -162,20 +158,20 @@ func TestSearchCommand_SkippedMessages(t *testing.T) { cmd.SetArgs([]string{"is:unread"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "3 message(s) could not be retrieved") + testutil.Contains(t, output, "3 message(s) could not be retrieved") }) } func TestReadCommand_Success(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetMessageFunc: func(messageID string, includeBody bool) (*gmailapi.Message, error) { - assert.Equal(t, "msg123", messageID) - assert.True(t, includeBody) + mock := &MockGmailClient{ + GetMessageFunc: func(_ context.Context, messageID string, includeBody bool) (*gmailapi.Message, error) { + testutil.Equal(t, messageID, "msg123") + testutil.True(t, includeBody) return testutil.SampleMessage("msg123"), nil }, } @@ -184,20 +180,20 @@ func TestReadCommand_Success(t *testing.T) { cmd.SetArgs([]string{"msg123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "ID: msg123") - assert.Contains(t, output, "Test Subject") - assert.Contains(t, output, "--- Body ---") + testutil.Contains(t, output, "ID: msg123") + testutil.Contains(t, output, "Test Subject") + testutil.Contains(t, output, "--- Body ---") }) } func TestReadCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetMessageFunc: func(messageID string, includeBody bool) (*gmailapi.Message, error) { + mock := &MockGmailClient{ + GetMessageFunc: func(_ context.Context, _ string, _ bool) (*gmailapi.Message, error) { return testutil.SampleMessage("msg123"), nil }, } @@ -206,21 +202,21 @@ func TestReadCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"msg123", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var msg gmailapi.Message err := json.Unmarshal([]byte(output), &msg) - assert.NoError(t, err) - assert.Equal(t, "msg123", msg.ID) + testutil.NoError(t, err) + testutil.Equal(t, msg.ID, "msg123") }) } func TestReadCommand_NotFound(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetMessageFunc: func(messageID string, includeBody bool) (*gmailapi.Message, error) { + mock := &MockGmailClient{ + GetMessageFunc: func(_ context.Context, _ string, _ bool) (*gmailapi.Message, error) { return nil, errors.New("message not found") }, } @@ -230,15 +226,15 @@ func TestReadCommand_NotFound(t *testing.T) { withMockClient(mock, func() { err := cmd.Execute() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read message") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "reading message") }) } func TestThreadCommand_Success(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetThreadFunc: func(id string) ([]*gmailapi.Message, error) { - assert.Equal(t, "thread123", id) + mock := &MockGmailClient{ + GetThreadFunc: func(_ context.Context, id string) ([]*gmailapi.Message, error) { + testutil.Equal(t, id, "thread123") return testutil.SampleMessages(3), nil }, } @@ -247,21 +243,21 @@ func TestThreadCommand_Success(t *testing.T) { cmd.SetArgs([]string{"thread123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "Thread contains 3 message(s)") - assert.Contains(t, output, "Message 1 of 3") - assert.Contains(t, output, "Message 2 of 3") - assert.Contains(t, output, "Message 3 of 3") + testutil.Contains(t, output, "Thread contains 3 message(s)") + testutil.Contains(t, output, "Message 1 of 3") + testutil.Contains(t, output, "Message 2 of 3") + testutil.Contains(t, output, "Message 3 of 3") }) } func TestThreadCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetThreadFunc: func(id string) ([]*gmailapi.Message, error) { + mock := &MockGmailClient{ + GetThreadFunc: func(_ context.Context, _ string) ([]*gmailapi.Message, error) { return testutil.SampleMessages(2), nil }, } @@ -270,21 +266,21 @@ func TestThreadCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"thread123", "--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var messages []*gmailapi.Message err := json.Unmarshal([]byte(output), &messages) - assert.NoError(t, err) - assert.Len(t, messages, 2) + testutil.NoError(t, err) + testutil.Len(t, messages, 2) }) } func TestLabelsCommand_Success(t *testing.T) { - mock := &testutil.MockGmailClient{ - FetchLabelsFunc: func() error { + mock := &MockGmailClient{ + FetchLabelsFunc: func(_ context.Context) error { return nil }, GetLabelsFunc: func() []*gmail.Label { @@ -295,21 +291,21 @@ func TestLabelsCommand_Success(t *testing.T) { cmd := newLabelsCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "NAME") - assert.Contains(t, output, "TYPE") - assert.Contains(t, output, "Work") - assert.Contains(t, output, "user") + testutil.Contains(t, output, "NAME") + testutil.Contains(t, output, "TYPE") + testutil.Contains(t, output, "Work") + testutil.Contains(t, output, "user") }) } func TestLabelsCommand_JSONOutput(t *testing.T) { - mock := &testutil.MockGmailClient{ - FetchLabelsFunc: func() error { + mock := &MockGmailClient{ + FetchLabelsFunc: func(_ context.Context) error { return nil }, GetLabelsFunc: func() []*gmail.Label { @@ -321,21 +317,44 @@ func TestLabelsCommand_JSONOutput(t *testing.T) { cmd.SetArgs([]string{"--json"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) var labels []Label err := json.Unmarshal([]byte(output), &labels) - assert.NoError(t, err) - assert.Greater(t, len(labels), 0) + testutil.NoError(t, err) + testutil.Greater(t, len(labels), 0) + }) +} + +func TestThreadCommand_NoResults_JSON(t *testing.T) { + mock := &MockGmailClient{ + GetThreadFunc: func(_ context.Context, _ string) ([]*gmailapi.Message, error) { + return []*gmailapi.Message{}, nil + }, + } + + cmd := newThreadCommand() + cmd.SetArgs([]string{"thread123", "--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var messages []any + err := json.Unmarshal([]byte(output), &messages) + testutil.NoError(t, err) + testutil.Len(t, messages, 0) }) } func TestLabelsCommand_Empty(t *testing.T) { - mock := &testutil.MockGmailClient{ - FetchLabelsFunc: func() error { + mock := &MockGmailClient{ + FetchLabelsFunc: func(_ context.Context) error { return nil }, GetLabelsFunc: func() []*gmail.Label { @@ -346,18 +365,44 @@ func TestLabelsCommand_Empty(t *testing.T) { cmd := newLabelsCommand() withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No labels found") + testutil.Contains(t, output, "No labels found") + }) +} + +func TestLabelsCommand_Empty_JSON(t *testing.T) { + mock := &MockGmailClient{ + FetchLabelsFunc: func(_ context.Context) error { + return nil + }, + GetLabelsFunc: func() []*gmail.Label { + return []*gmail.Label{} + }, + } + + cmd := newLabelsCommand() + cmd.SetArgs([]string{"--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + var labels []any + err := json.Unmarshal([]byte(output), &labels) + testutil.NoError(t, err) + testutil.Len(t, labels, 0) }) } func TestListAttachmentsCommand_Success(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetAttachmentsFunc: func(messageID string) ([]*gmailapi.Attachment, error) { + mock := &MockGmailClient{ + GetAttachmentsFunc: func(_ context.Context, _ string) ([]*gmailapi.Attachment, error) { return []*gmailapi.Attachment{ testutil.SampleAttachment("report.pdf"), testutil.SampleAttachment("data.xlsx"), @@ -369,20 +414,20 @@ func TestListAttachmentsCommand_Success(t *testing.T) { cmd.SetArgs([]string{"msg123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "2 attachment(s)") - assert.Contains(t, output, "report.pdf") - assert.Contains(t, output, "data.xlsx") + testutil.Contains(t, output, "2 attachment(s)") + testutil.Contains(t, output, "report.pdf") + testutil.Contains(t, output, "data.xlsx") }) } func TestListAttachmentsCommand_NoAttachments(t *testing.T) { - mock := &testutil.MockGmailClient{ - GetAttachmentsFunc: func(messageID string) ([]*gmailapi.Attachment, error) { + mock := &MockGmailClient{ + GetAttachmentsFunc: func(_ context.Context, _ string) ([]*gmailapi.Attachment, error) { return []*gmailapi.Attachment{}, nil }, } @@ -391,11 +436,34 @@ func TestListAttachmentsCommand_NoAttachments(t *testing.T) { cmd.SetArgs([]string{"msg123"}) withMockClient(mock, func() { - output := captureOutput(t, func() { + output := testutil.CaptureStdout(t, func() { + err := cmd.Execute() + testutil.NoError(t, err) + }) + + testutil.Contains(t, output, "No attachments found") + }) +} + +func TestListAttachmentsCommand_NoAttachments_JSON(t *testing.T) { + mock := &MockGmailClient{ + GetAttachmentsFunc: func(_ context.Context, _ string) ([]*gmailapi.Attachment, error) { + return []*gmailapi.Attachment{}, nil + }, + } + + cmd := newListAttachmentsCommand() + cmd.SetArgs([]string{"msg123", "--json"}) + + withMockClient(mock, func() { + output := testutil.CaptureStdout(t, func() { err := cmd.Execute() - assert.NoError(t, err) + testutil.NoError(t, err) }) - assert.Contains(t, output, "No attachments found") + var attachments []any + err := json.Unmarshal([]byte(output), &attachments) + testutil.NoError(t, err) + testutil.Len(t, attachments, 0) }) } diff --git a/internal/cmd/mail/labels.go b/internal/cmd/mail/labels.go index f2e1839..d9b5563 100644 --- a/internal/cmd/mail/labels.go +++ b/internal/cmd/mail/labels.go @@ -34,18 +34,22 @@ Examples: gro mail labels gro mail labels --json`, Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() + RunE: func(cmd *cobra.Command, _ []string) error { + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } - if err := client.FetchLabels(); err != nil { - return fmt.Errorf("failed to fetch labels: %w", err) + if err := client.FetchLabels(cmd.Context()); err != nil { + return fmt.Errorf("fetching labels: %w", err) } gmailLabels := client.GetLabels() if len(gmailLabels) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No labels found.") return nil } diff --git a/internal/cmd/mail/labels_test.go b/internal/cmd/mail/labels_test.go index 6fe9af7..f3de5db 100644 --- a/internal/cmd/mail/labels_test.go +++ b/internal/cmd/mail/labels_test.go @@ -3,90 +3,91 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" gmailapi "google.golang.org/api/gmail/v1" + + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestLabelsCommand(t *testing.T) { cmd := newLabelsCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "labels", cmd.Use) + testutil.Equal(t, cmd.Use, "labels") }) t.Run("requires no arguments", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"extra"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "label") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "label") }) } func TestGetLabelType(t *testing.T) { t.Run("returns category for CATEGORY_ prefix", func(t *testing.T) { label := &gmailapi.Label{Id: "CATEGORY_UPDATES", Type: "system"} - assert.Equal(t, "category", getLabelType(label)) + testutil.Equal(t, getLabelType(label), "category") }) t.Run("returns category for all category types", func(t *testing.T) { categories := []string{"CATEGORY_SOCIAL", "CATEGORY_PROMOTIONS", "CATEGORY_FORUMS", "CATEGORY_PERSONAL"} for _, id := range categories { label := &gmailapi.Label{Id: id, Type: "system"} - assert.Equal(t, "category", getLabelType(label), "expected category for %s", id) + testutil.Equal(t, getLabelType(label), "category") } }) t.Run("returns system for system type", func(t *testing.T) { label := &gmailapi.Label{Id: "INBOX", Type: "system"} - assert.Equal(t, "system", getLabelType(label)) + testutil.Equal(t, getLabelType(label), "system") }) t.Run("returns user for user type", func(t *testing.T) { label := &gmailapi.Label{Id: "Label_123", Type: "user"} - assert.Equal(t, "user", getLabelType(label)) + testutil.Equal(t, getLabelType(label), "user") }) t.Run("returns user for empty type", func(t *testing.T) { label := &gmailapi.Label{Id: "Label_456", Type: ""} - assert.Equal(t, "user", getLabelType(label)) + testutil.Equal(t, getLabelType(label), "user") }) } func TestLabelTypePriority(t *testing.T) { t.Run("user has highest priority (lowest value)", func(t *testing.T) { - assert.Equal(t, 0, labelTypePriority("user")) + testutil.Equal(t, labelTypePriority("user"), 0) }) t.Run("category is second priority", func(t *testing.T) { - assert.Equal(t, 1, labelTypePriority("category")) + testutil.Equal(t, labelTypePriority("category"), 1) }) t.Run("system is third priority", func(t *testing.T) { - assert.Equal(t, 2, labelTypePriority("system")) + testutil.Equal(t, labelTypePriority("system"), 2) }) t.Run("unknown types have lowest priority", func(t *testing.T) { - assert.Equal(t, 3, labelTypePriority("unknown")) - assert.Equal(t, 3, labelTypePriority("")) + testutil.Equal(t, labelTypePriority("unknown"), 3) + testutil.Equal(t, labelTypePriority(""), 3) }) t.Run("priorities maintain correct sort order", func(t *testing.T) { - assert.Less(t, labelTypePriority("user"), labelTypePriority("category")) - assert.Less(t, labelTypePriority("category"), labelTypePriority("system")) - assert.Less(t, labelTypePriority("system"), labelTypePriority("unknown")) + testutil.Less(t, labelTypePriority("user"), labelTypePriority("category")) + testutil.Less(t, labelTypePriority("category"), labelTypePriority("system")) + testutil.Less(t, labelTypePriority("system"), labelTypePriority("unknown")) }) } diff --git a/internal/cmd/mail/mail.go b/internal/cmd/mail/mail.go index 414ab79..ffccfda 100644 --- a/internal/cmd/mail/mail.go +++ b/internal/cmd/mail/mail.go @@ -1,3 +1,4 @@ +// Package mail implements the gro mail command and subcommands. package mail import ( diff --git a/internal/cmd/mail/mail_test.go b/internal/cmd/mail/mail_test.go index c2688b3..c4c35e7 100644 --- a/internal/cmd/mail/mail_test.go +++ b/internal/cmd/mail/mail_test.go @@ -3,32 +3,32 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestMailCommand(t *testing.T) { cmd := NewCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "mail", cmd.Use) + testutil.Equal(t, cmd.Use, "mail") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) + testutil.NotEmpty(t, cmd.Short) }) t.Run("has subcommands", func(t *testing.T) { subcommands := cmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 5) + testutil.GreaterOrEqual(t, len(subcommands), 5) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "search") - assert.Contains(t, names, "read") - assert.Contains(t, names, "thread") - assert.Contains(t, names, "labels") - assert.Contains(t, names, "attachments") + testutil.SliceContains(t, names, "search") + testutil.SliceContains(t, names, "read") + testutil.SliceContains(t, names, "thread") + testutil.SliceContains(t, names, "labels") + testutil.SliceContains(t, names, "attachments") }) } diff --git a/internal/cmd/mail/mock_test.go b/internal/cmd/mail/mock_test.go new file mode 100644 index 0000000..75c3a83 --- /dev/null +++ b/internal/cmd/mail/mock_test.go @@ -0,0 +1,97 @@ +package mail + +import ( + "context" + + "google.golang.org/api/gmail/v1" + + gmailapi "github.com/open-cli-collective/google-readonly/internal/gmail" +) + +// MockGmailClient is a configurable mock for MailClient. +// Set the function fields to control behavior in tests. +type MockGmailClient struct { + GetMessageFunc func(ctx context.Context, messageID string, includeBody bool) (*gmailapi.Message, error) + SearchMessagesFunc func(ctx context.Context, query string, maxResults int64) ([]*gmailapi.Message, int, error) + GetThreadFunc func(ctx context.Context, id string) ([]*gmailapi.Message, error) + FetchLabelsFunc func(ctx context.Context) error + GetLabelNameFunc func(labelID string) string + GetLabelsFunc func() []*gmail.Label + GetAttachmentsFunc func(ctx context.Context, messageID string) ([]*gmailapi.Attachment, error) + DownloadAttachmentFunc func(ctx context.Context, messageID, attachmentID string) ([]byte, error) + DownloadInlineAttachmentFunc func(ctx context.Context, messageID, partID string) ([]byte, error) + GetProfileFunc func(ctx context.Context) (*gmailapi.Profile, error) +} + +// Verify MockGmailClient implements MailClient +var _ MailClient = (*MockGmailClient)(nil) + +func (m *MockGmailClient) GetMessage(ctx context.Context, messageID string, includeBody bool) (*gmailapi.Message, error) { + if m.GetMessageFunc != nil { + return m.GetMessageFunc(ctx, messageID, includeBody) + } + return nil, nil +} + +func (m *MockGmailClient) SearchMessages(ctx context.Context, query string, maxResults int64) ([]*gmailapi.Message, int, error) { + if m.SearchMessagesFunc != nil { + return m.SearchMessagesFunc(ctx, query, maxResults) + } + return nil, 0, nil +} + +func (m *MockGmailClient) GetThread(ctx context.Context, id string) ([]*gmailapi.Message, error) { + if m.GetThreadFunc != nil { + return m.GetThreadFunc(ctx, id) + } + return nil, nil +} + +func (m *MockGmailClient) FetchLabels(ctx context.Context) error { + if m.FetchLabelsFunc != nil { + return m.FetchLabelsFunc(ctx) + } + return nil +} + +func (m *MockGmailClient) GetLabelName(labelID string) string { + if m.GetLabelNameFunc != nil { + return m.GetLabelNameFunc(labelID) + } + return labelID +} + +func (m *MockGmailClient) GetLabels() []*gmail.Label { + if m.GetLabelsFunc != nil { + return m.GetLabelsFunc() + } + return nil +} + +func (m *MockGmailClient) GetAttachments(ctx context.Context, messageID string) ([]*gmailapi.Attachment, error) { + if m.GetAttachmentsFunc != nil { + return m.GetAttachmentsFunc(ctx, messageID) + } + return nil, nil +} + +func (m *MockGmailClient) DownloadAttachment(ctx context.Context, messageID, attachmentID string) ([]byte, error) { + if m.DownloadAttachmentFunc != nil { + return m.DownloadAttachmentFunc(ctx, messageID, attachmentID) + } + return nil, nil +} + +func (m *MockGmailClient) DownloadInlineAttachment(ctx context.Context, messageID, partID string) ([]byte, error) { + if m.DownloadInlineAttachmentFunc != nil { + return m.DownloadInlineAttachmentFunc(ctx, messageID, partID) + } + return nil, nil +} + +func (m *MockGmailClient) GetProfile(ctx context.Context) (*gmailapi.Profile, error) { + if m.GetProfileFunc != nil { + return m.GetProfileFunc(ctx) + } + return nil, nil +} diff --git a/internal/cmd/mail/output.go b/internal/cmd/mail/output.go index 732b293..741d738 100644 --- a/internal/cmd/mail/output.go +++ b/internal/cmd/mail/output.go @@ -5,19 +5,35 @@ import ( "fmt" "strings" + gmailv1 "google.golang.org/api/gmail/v1" + "github.com/open-cli-collective/google-readonly/internal/gmail" "github.com/open-cli-collective/google-readonly/internal/output" ) +// MailClient defines the interface for Gmail client operations used by mail commands. +type MailClient interface { + GetMessage(ctx context.Context, messageID string, includeBody bool) (*gmail.Message, error) + SearchMessages(ctx context.Context, query string, maxResults int64) ([]*gmail.Message, int, error) + GetThread(ctx context.Context, id string) ([]*gmail.Message, error) + FetchLabels(ctx context.Context) error + GetLabelName(labelID string) string + GetLabels() []*gmailv1.Label + GetAttachments(ctx context.Context, messageID string) ([]*gmail.Attachment, error) + DownloadAttachment(ctx context.Context, messageID string, attachmentID string) ([]byte, error) + DownloadInlineAttachment(ctx context.Context, messageID string, partID string) ([]byte, error) + GetProfile(ctx context.Context) (*gmail.Profile, error) +} + // ClientFactory is the function used to create Gmail clients. // Override in tests to inject mocks. -var ClientFactory = func() (gmail.GmailClientInterface, error) { - return gmail.NewClient(context.Background()) +var ClientFactory = func(ctx context.Context) (MailClient, error) { + return gmail.NewClient(ctx) } // newGmailClient creates and returns a new Gmail client -func newGmailClient() (gmail.GmailClientInterface, error) { - return ClientFactory() +func newGmailClient(ctx context.Context) (MailClient, error) { + return ClientFactory(ctx) } // printJSON encodes data as indented JSON to stdout diff --git a/internal/cmd/mail/output_test.go b/internal/cmd/mail/output_test.go index f059ec3..87fb1e1 100644 --- a/internal/cmd/mail/output_test.go +++ b/internal/cmd/mail/output_test.go @@ -3,26 +3,29 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestMessagePrintOptions(t *testing.T) { + t.Parallel() t.Run("default options are all false", func(t *testing.T) { + t.Parallel() opts := MessagePrintOptions{} - assert.False(t, opts.IncludeThreadID) - assert.False(t, opts.IncludeTo) - assert.False(t, opts.IncludeSnippet) - assert.False(t, opts.IncludeBody) + testutil.False(t, opts.IncludeThreadID) + testutil.False(t, opts.IncludeTo) + testutil.False(t, opts.IncludeSnippet) + testutil.False(t, opts.IncludeBody) }) t.Run("options can be set individually", func(t *testing.T) { + t.Parallel() opts := MessagePrintOptions{ IncludeThreadID: true, IncludeBody: true, } - assert.True(t, opts.IncludeThreadID) - assert.False(t, opts.IncludeTo) - assert.False(t, opts.IncludeSnippet) - assert.True(t, opts.IncludeBody) + testutil.True(t, opts.IncludeThreadID) + testutil.False(t, opts.IncludeTo) + testutil.False(t, opts.IncludeSnippet) + testutil.True(t, opts.IncludeBody) }) } diff --git a/internal/cmd/mail/read.go b/internal/cmd/mail/read.go index 362241d..1c05c85 100644 --- a/internal/cmd/mail/read.go +++ b/internal/cmd/mail/read.go @@ -21,14 +21,14 @@ Examples: gro mail read 18abc123def456 --json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } - msg, err := client.GetMessage(args[0], true) + msg, err := client.GetMessage(cmd.Context(), args[0], true) if err != nil { - return fmt.Errorf("failed to read message: %w", err) + return fmt.Errorf("reading message: %w", err) } if jsonOutput { diff --git a/internal/cmd/mail/read_test.go b/internal/cmd/mail/read_test.go index 8f741cf..27bb594 100644 --- a/internal/cmd/mail/read_test.go +++ b/internal/cmd/mail/read_test.go @@ -3,40 +3,40 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestReadCommand(t *testing.T) { cmd := newReadCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "read ", cmd.Use) + testutil.Equal(t, cmd.Use, "read ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"msg123"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"msg1", "msg2"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "message") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "message") }) t.Run("long description mentions message ID source", func(t *testing.T) { - assert.Contains(t, cmd.Long, "search") + testutil.Contains(t, cmd.Long, "search") }) } diff --git a/internal/cmd/mail/sanitize_test.go b/internal/cmd/mail/sanitize_test.go index e662ff0..9025f92 100644 --- a/internal/cmd/mail/sanitize_test.go +++ b/internal/cmd/mail/sanitize_test.go @@ -3,10 +3,11 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestSanitizeOutput(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -121,13 +122,15 @@ func TestSanitizeOutput(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := SanitizeOutput(tt.input) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } func TestSanitizeFilename(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -182,13 +185,15 @@ func TestSanitizeFilename(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := SanitizeFilename(tt.input) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } func TestSanitizeOutput_RealWorldExamples(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -213,8 +218,9 @@ func TestSanitizeOutput_RealWorldExamples(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := SanitizeOutput(tt.input) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } diff --git a/internal/cmd/mail/search.go b/internal/cmd/mail/search.go index b0ac845..66baf1e 100644 --- a/internal/cmd/mail/search.go +++ b/internal/cmd/mail/search.go @@ -26,17 +26,21 @@ Examples: For more query operators, see: https://support.google.com/mail/answer/7190`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } - messages, skipped, err := client.SearchMessages(args[0], maxResults) + messages, skipped, err := client.SearchMessages(cmd.Context(), args[0], maxResults) if err != nil { - return fmt.Errorf("failed to search messages: %w", err) + return fmt.Errorf("searching messages: %w", err) } if len(messages) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No messages found.") return nil } diff --git a/internal/cmd/mail/search_test.go b/internal/cmd/mail/search_test.go index 15a1f52..a00526b 100644 --- a/internal/cmd/mail/search_test.go +++ b/internal/cmd/mail/search_test.go @@ -3,44 +3,44 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestSearchCommand(t *testing.T) { cmd := newSearchCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "search ", cmd.Use) + testutil.Equal(t, cmd.Use, "search ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"query"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"query1", "query2"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has max flag", func(t *testing.T) { flag := cmd.Flags().Lookup("max") - assert.NotNil(t, flag) - assert.Equal(t, "m", flag.Shorthand) - assert.Equal(t, "10", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "m") + testutil.Equal(t, flag.DefValue, "10") }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has examples in long description", func(t *testing.T) { - assert.Contains(t, cmd.Long, "from:") - assert.Contains(t, cmd.Long, "subject:") - assert.Contains(t, cmd.Long, "is:unread") + testutil.Contains(t, cmd.Long, "from:") + testutil.Contains(t, cmd.Long, "subject:") + testutil.Contains(t, cmd.Long, "is:unread") }) } diff --git a/internal/cmd/mail/thread.go b/internal/cmd/mail/thread.go index 2dad55e..6742b28 100644 --- a/internal/cmd/mail/thread.go +++ b/internal/cmd/mail/thread.go @@ -24,17 +24,21 @@ Examples: gro mail thread 18abc123def456 --json`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - client, err := newGmailClient() + client, err := newGmailClient(cmd.Context()) if err != nil { - return fmt.Errorf("failed to create Gmail client: %w", err) + return fmt.Errorf("creating Gmail client: %w", err) } - messages, err := client.GetThread(args[0]) + messages, err := client.GetThread(cmd.Context(), args[0]) if err != nil { - return fmt.Errorf("failed to get thread: %w", err) + return fmt.Errorf("getting thread: %w", err) } if len(messages) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } fmt.Println("No messages found in thread.") return nil } diff --git a/internal/cmd/mail/thread_test.go b/internal/cmd/mail/thread_test.go index 47f9749..b608f0e 100644 --- a/internal/cmd/mail/thread_test.go +++ b/internal/cmd/mail/thread_test.go @@ -3,41 +3,41 @@ package mail import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestThreadCommand(t *testing.T) { cmd := newThreadCommand() t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "thread ", cmd.Use) + testutil.Equal(t, cmd.Use, "thread ") }) t.Run("requires exactly one argument", func(t *testing.T) { err := cmd.Args(cmd, []string{}) - assert.Error(t, err) + testutil.Error(t, err) err = cmd.Args(cmd, []string{"thread123"}) - assert.NoError(t, err) + testutil.NoError(t, err) err = cmd.Args(cmd, []string{"thread1", "thread2"}) - assert.Error(t, err) + testutil.Error(t, err) }) t.Run("has json flag", func(t *testing.T) { flag := cmd.Flags().Lookup("json") - assert.NotNil(t, flag) - assert.Equal(t, "j", flag.Shorthand) - assert.Equal(t, "false", flag.DefValue) + testutil.NotNil(t, flag) + testutil.Equal(t, flag.Shorthand, "j") + testutil.Equal(t, flag.DefValue, "false") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, cmd.Short) - assert.Contains(t, cmd.Short, "thread") + testutil.NotEmpty(t, cmd.Short) + testutil.Contains(t, cmd.Short, "thread") }) t.Run("long description explains thread ID", func(t *testing.T) { - assert.Contains(t, cmd.Long, "thread ID") - assert.Contains(t, cmd.Long, "message ID") + testutil.Contains(t, cmd.Long, "thread ID") + testutil.Contains(t, cmd.Long, "message ID") }) } diff --git a/internal/cmd/root/root.go b/internal/cmd/root/root.go index f2636d4..0dd06d9 100644 --- a/internal/cmd/root/root.go +++ b/internal/cmd/root/root.go @@ -1,6 +1,8 @@ +// Package root provides the top-level gro command and global flags. package root import ( + "context" "fmt" "os" @@ -31,14 +33,19 @@ To get started, run: This will guide you through OAuth setup for Google API access.`, Version: version.Version, - PersistentPreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRun: func(_ *cobra.Command, _ []string) { log.Verbose = verbose }, } -// Execute runs the root command +// Execute runs the root command with a background context func Execute() { - if err := rootCmd.Execute(); err != nil { + ExecuteContext(context.Background()) +} + +// ExecuteContext runs the root command with the given context +func ExecuteContext(ctx context.Context) { + if err := rootCmd.ExecuteContext(ctx); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } diff --git a/internal/cmd/root/root_test.go b/internal/cmd/root/root_test.go index e051192..50f7d92 100644 --- a/internal/cmd/root/root_test.go +++ b/internal/cmd/root/root_test.go @@ -3,39 +3,39 @@ package root import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestRootCommand(t *testing.T) { t.Run("has correct use", func(t *testing.T) { - assert.Equal(t, "gro", rootCmd.Use) + testutil.Equal(t, rootCmd.Use, "gro") }) t.Run("has short description", func(t *testing.T) { - assert.NotEmpty(t, rootCmd.Short) + testutil.NotEmpty(t, rootCmd.Short) }) t.Run("has long description", func(t *testing.T) { - assert.NotEmpty(t, rootCmd.Long) - assert.Contains(t, rootCmd.Long, "read-only") + testutil.NotEmpty(t, rootCmd.Long) + testutil.Contains(t, rootCmd.Long, "read-only") }) t.Run("has version set", func(t *testing.T) { - assert.NotEmpty(t, rootCmd.Version) + testutil.NotEmpty(t, rootCmd.Version) }) t.Run("has subcommands", func(t *testing.T) { subcommands := rootCmd.Commands() - assert.GreaterOrEqual(t, len(subcommands), 5) + testutil.GreaterOrEqual(t, len(subcommands), 5) var names []string for _, sub := range subcommands { names = append(names, sub.Name()) } - assert.Contains(t, names, "init") - assert.Contains(t, names, "config") - assert.Contains(t, names, "mail") - assert.Contains(t, names, "calendar") - assert.Contains(t, names, "contacts") + testutil.SliceContains(t, names, "init") + testutil.SliceContains(t, names, "config") + testutil.SliceContains(t, names, "mail") + testutil.SliceContains(t, names, "calendar") + testutil.SliceContains(t, names, "contacts") }) } diff --git a/internal/config/config.go b/internal/config/config.go index 6d30e8e..58c3d22 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -110,7 +110,7 @@ func LoadConfig() (*Config, error) { return nil, err } - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // Path from user config directory if err != nil { if os.IsNotExist(err) { // Return default config diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f602e51..74079e8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,11 +3,9 @@ package config import ( "os" "path/filepath" + "strings" "testing" "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetConfigDir(t *testing.T) { @@ -16,24 +14,36 @@ func TestGetConfigDir(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) dir, err := GetConfigDir() - require.NoError(t, err) - assert.Equal(t, filepath.Join(tmpDir, DirName), dir) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dir != filepath.Join(tmpDir, DirName) { + t.Errorf("got %v, want %v", dir, filepath.Join(tmpDir, DirName)) + } // Verify directory was created info, err := os.Stat(dir) - require.NoError(t, err) - assert.True(t, info.IsDir()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !info.IsDir() { + t.Error("got false, want true") + } }) t.Run("uses ~/.config if XDG_CONFIG_HOME not set", func(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", "") dir, err := GetConfigDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } home, _ := os.UserHomeDir() expected := filepath.Join(home, ".config", DirName) - assert.Equal(t, expected, dir) + if dir != expected { + t.Errorf("got %v, want %v", dir, expected) + } }) t.Run("creates directory with correct permissions", func(t *testing.T) { @@ -41,11 +51,17 @@ func TestGetConfigDir(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) dir, err := GetConfigDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } info, err := os.Stat(dir) - require.NoError(t, err) - assert.Equal(t, os.FileMode(0700), info.Mode().Perm()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.Mode().Perm() != os.FileMode(0700) { + t.Errorf("got %v, want %v", info.Mode().Perm(), os.FileMode(0700)) + } }) } @@ -54,8 +70,12 @@ func TestGetCredentialsPath(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) path, err := GetCredentialsPath() - require.NoError(t, err) - assert.Equal(t, filepath.Join(tmpDir, DirName, CredentialsFile), path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if path != filepath.Join(tmpDir, DirName, CredentialsFile) { + t.Errorf("got %v, want %v", path, filepath.Join(tmpDir, DirName, CredentialsFile)) + } } func TestGetTokenPath(t *testing.T) { @@ -63,13 +83,20 @@ func TestGetTokenPath(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) path, err := GetTokenPath() - require.NoError(t, err) - assert.Equal(t, filepath.Join(tmpDir, DirName, TokenFile), path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if path != filepath.Join(tmpDir, DirName, TokenFile) { + t.Errorf("got %v, want %v", path, filepath.Join(tmpDir, DirName, TokenFile)) + } } func TestShortenPath(t *testing.T) { + t.Parallel() home, err := os.UserHomeDir() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } tests := []struct { name string @@ -105,18 +132,32 @@ func TestShortenPath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := ShortenPath(tt.input) - assert.Equal(t, tt.expected, result) + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } }) } } func TestConstants(t *testing.T) { - assert.Equal(t, "google-readonly", DirName) - assert.Equal(t, "credentials.json", CredentialsFile) - assert.Equal(t, "token.json", TokenFile) - assert.Equal(t, "config.json", ConfigFile) - assert.Equal(t, 24, DefaultCacheTTLHours) + t.Parallel() + if DirName != "google-readonly" { + t.Errorf("got %v, want %v", DirName, "google-readonly") + } + if CredentialsFile != "credentials.json" { + t.Errorf("got %v, want %v", CredentialsFile, "credentials.json") + } + if TokenFile != "token.json" { + t.Errorf("got %v, want %v", TokenFile, "token.json") + } + if ConfigFile != "config.json" { + t.Errorf("got %v, want %v", ConfigFile, "config.json") + } + if DefaultCacheTTLHours != 24 { + t.Errorf("got %v, want %v", DefaultCacheTTLHours, 24) + } } func TestGetConfigPath(t *testing.T) { @@ -124,8 +165,12 @@ func TestGetConfigPath(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) path, err := GetConfigPath() - require.NoError(t, err) - assert.Equal(t, filepath.Join(tmpDir, DirName, ConfigFile), path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if path != filepath.Join(tmpDir, DirName, ConfigFile) { + t.Errorf("got %v, want %v", path, filepath.Join(tmpDir, DirName, ConfigFile)) + } } func TestLoadConfig(t *testing.T) { @@ -134,8 +179,12 @@ func TestLoadConfig(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) cfg, err := LoadConfig() - require.NoError(t, err) - assert.Equal(t, DefaultCacheTTLHours, cfg.CacheTTLHours) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.CacheTTLHours != DefaultCacheTTLHours { + t.Errorf("got %v, want %v", cfg.CacheTTLHours, DefaultCacheTTLHours) + } }) t.Run("loads config from file", func(t *testing.T) { @@ -144,14 +193,22 @@ func TestLoadConfig(t *testing.T) { // Create config directory and file configDir := filepath.Join(tmpDir, DirName) - require.NoError(t, os.MkdirAll(configDir, DirPerm)) + if err := os.MkdirAll(configDir, DirPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } configData := `{"cache_ttl_hours": 48}` - require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFile), []byte(configData), TokenPerm)) + if err := os.WriteFile(filepath.Join(configDir, ConfigFile), []byte(configData), TokenPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } cfg, err := LoadConfig() - require.NoError(t, err) - assert.Equal(t, 48, cfg.CacheTTLHours) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.CacheTTLHours != 48 { + t.Errorf("got %v, want %v", cfg.CacheTTLHours, 48) + } }) t.Run("applies default for zero or negative TTL", func(t *testing.T) { @@ -159,14 +216,22 @@ func TestLoadConfig(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) configDir := filepath.Join(tmpDir, DirName) - require.NoError(t, os.MkdirAll(configDir, DirPerm)) + if err := os.MkdirAll(configDir, DirPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } configData := `{"cache_ttl_hours": 0}` - require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFile), []byte(configData), TokenPerm)) + if err := os.WriteFile(filepath.Join(configDir, ConfigFile), []byte(configData), TokenPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } cfg, err := LoadConfig() - require.NoError(t, err) - assert.Equal(t, DefaultCacheTTLHours, cfg.CacheTTLHours) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.CacheTTLHours != DefaultCacheTTLHours { + t.Errorf("got %v, want %v", cfg.CacheTTLHours, DefaultCacheTTLHours) + } }) t.Run("returns error for invalid JSON", func(t *testing.T) { @@ -174,12 +239,18 @@ func TestLoadConfig(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) configDir := filepath.Join(tmpDir, DirName) - require.NoError(t, os.MkdirAll(configDir, DirPerm)) + if err := os.MkdirAll(configDir, DirPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } - require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFile), []byte("not json"), TokenPerm)) + if err := os.WriteFile(filepath.Join(configDir, ConfigFile), []byte("not json"), TokenPerm); err != nil { + t.Fatalf("unexpected error: %v", err) + } _, err := LoadConfig() - assert.Error(t, err) + if err == nil { + t.Fatal("expected error, got nil") + } }) } @@ -190,13 +261,19 @@ func TestSaveConfig(t *testing.T) { cfg := &Config{CacheTTLHours: 12} err := SaveConfig(cfg) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify file was created path, _ := GetConfigPath() data, err := os.ReadFile(path) - require.NoError(t, err) - assert.Contains(t, string(data), `"cache_ttl_hours": 12`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(string(data), `"cache_ttl_hours": 12`) { + t.Errorf("expected %q to contain %q", string(data), `"cache_ttl_hours": 12`) + } }) t.Run("overwrites existing config", func(t *testing.T) { @@ -205,16 +282,24 @@ func TestSaveConfig(t *testing.T) { // Save initial config cfg1 := &Config{CacheTTLHours: 12} - require.NoError(t, SaveConfig(cfg1)) + if err := SaveConfig(cfg1); err != nil { + t.Fatalf("unexpected error: %v", err) + } // Save new config cfg2 := &Config{CacheTTLHours: 36} - require.NoError(t, SaveConfig(cfg2)) + if err := SaveConfig(cfg2); err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify new value loaded, err := LoadConfig() - require.NoError(t, err) - assert.Equal(t, 36, loaded.CacheTTLHours) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loaded.CacheTTLHours != 36 { + t.Errorf("got %v, want %v", loaded.CacheTTLHours, 36) + } }) } @@ -224,10 +309,14 @@ func TestGetCacheTTL(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) cfg := &Config{CacheTTLHours: 12} - require.NoError(t, SaveConfig(cfg)) + if err := SaveConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } ttl := GetCacheTTL() - assert.Equal(t, 12*time.Hour, ttl) + if ttl != 12*time.Hour { + t.Errorf("got %v, want %v", ttl, 12*time.Hour) + } }) t.Run("returns default TTL when no config exists", func(t *testing.T) { @@ -235,7 +324,9 @@ func TestGetCacheTTL(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) ttl := GetCacheTTL() - assert.Equal(t, time.Duration(DefaultCacheTTLHours)*time.Hour, ttl) + if ttl != time.Duration(DefaultCacheTTLHours)*time.Hour { + t.Errorf("got %v, want %v", ttl, time.Duration(DefaultCacheTTLHours)*time.Hour) + } }) } @@ -245,10 +336,14 @@ func TestGetCacheTTLHours(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) cfg := &Config{CacheTTLHours: 48} - require.NoError(t, SaveConfig(cfg)) + if err := SaveConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } hours := GetCacheTTLHours() - assert.Equal(t, 48, hours) + if hours != 48 { + t.Errorf("got %v, want %v", hours, 48) + } }) t.Run("returns default TTL when no config exists", func(t *testing.T) { @@ -256,6 +351,8 @@ func TestGetCacheTTLHours(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) hours := GetCacheTTLHours() - assert.Equal(t, DefaultCacheTTLHours, hours) + if hours != DefaultCacheTTLHours { + t.Errorf("got %v, want %v", hours, DefaultCacheTTLHours) + } }) } diff --git a/internal/contacts/client.go b/internal/contacts/client.go index f7982e5..dfbfb0b 100644 --- a/internal/contacts/client.go +++ b/internal/contacts/client.go @@ -1,3 +1,4 @@ +// Package contacts provides a client for the Google People API. package contacts import ( @@ -19,12 +20,12 @@ type Client struct { func NewClient(ctx context.Context) (*Client, error) { client, err := auth.GetHTTPClient(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("loading OAuth client: %w", err) } srv, err := people.NewService(ctx, option.WithHTTPClient(client)) if err != nil { - return nil, fmt.Errorf("unable to create People service: %w", err) + return nil, fmt.Errorf("creating People service: %w", err) } return &Client{ @@ -33,7 +34,7 @@ func NewClient(ctx context.Context) (*Client, error) { } // ListContacts retrieves contacts from the user's account -func (c *Client) ListContacts(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { +func (c *Client) ListContacts(ctx context.Context, pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { call := c.service.People.Connections.List("people/me"). PersonFields("names,emailAddresses,phoneNumbers,organizations,addresses,biographies,photos"). PageSize(pageSize). @@ -43,42 +44,44 @@ func (c *Client) ListContacts(pageToken string, pageSize int64) (*people.ListCon call = call.PageToken(pageToken) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list contacts: %w", err) + return nil, fmt.Errorf("listing contacts: %w", err) } return resp, nil } // SearchContacts searches for contacts matching a query -func (c *Client) SearchContacts(query string, pageSize int64) (*people.SearchResponse, error) { +func (c *Client) SearchContacts(ctx context.Context, query string, pageSize int64) (*people.SearchResponse, error) { resp, err := c.service.People.SearchContacts(). Query(query). ReadMask("names,emailAddresses,phoneNumbers,organizations,addresses,biographies,photos"). PageSize(int64(pageSize)). + Context(ctx). Do() if err != nil { - return nil, fmt.Errorf("failed to search contacts: %w", err) + return nil, fmt.Errorf("searching contacts: %w", err) } return resp, nil } // GetContact retrieves a specific contact by resource name -func (c *Client) GetContact(resourceName string) (*people.Person, error) { +func (c *Client) GetContact(ctx context.Context, resourceName string) (*people.Person, error) { resp, err := c.service.People.Get(resourceName). PersonFields("names,emailAddresses,phoneNumbers,organizations,addresses,biographies,urls,birthdays,events,relations,photos,metadata"). + Context(ctx). Do() if err != nil { - return nil, fmt.Errorf("failed to get contact: %w", err) + return nil, fmt.Errorf("getting contact: %w", err) } return resp, nil } // ListContactGroups retrieves all contact groups -func (c *Client) ListContactGroups(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { +func (c *Client) ListContactGroups(ctx context.Context, pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { call := c.service.ContactGroups.List(). PageSize(pageSize). GroupFields("name,groupType,memberCount") @@ -87,9 +90,9 @@ func (c *Client) ListContactGroups(pageToken string, pageSize int64) (*people.Li call = call.PageToken(pageToken) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list contact groups: %w", err) + return nil, fmt.Errorf("listing contact groups: %w", err) } return resp, nil diff --git a/internal/contacts/client_test.go b/internal/contacts/client_test.go index eb60418..883a700 100644 --- a/internal/contacts/client_test.go +++ b/internal/contacts/client_test.go @@ -2,13 +2,15 @@ package contacts import ( "testing" - - "github.com/stretchr/testify/assert" ) func TestClientStructure(t *testing.T) { + t.Parallel() t.Run("Client has private service field", func(t *testing.T) { + t.Parallel() client := &Client{} - assert.Nil(t, client.service) + if client.service != nil { + t.Errorf("got %v, want nil", client.service) + } }) } diff --git a/internal/contacts/contacts_test.go b/internal/contacts/contacts_test.go index 6d420cb..614a8b9 100644 --- a/internal/contacts/contacts_test.go +++ b/internal/contacts/contacts_test.go @@ -3,12 +3,13 @@ package contacts import ( "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/people/v1" ) func TestParseContact(t *testing.T) { + t.Parallel() t.Run("parses basic contact", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c123", Names: []*people.Name{ @@ -22,14 +23,25 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Equal(t, "people/c123", contact.ResourceName) - assert.Equal(t, "John Doe", contact.DisplayName) - assert.Len(t, contact.Names, 1) - assert.Equal(t, "John", contact.Names[0].GivenName) - assert.Equal(t, "Doe", contact.Names[0].FamilyName) + if contact.ResourceName != "people/c123" { + t.Errorf("got %v, want %v", contact.ResourceName, "people/c123") + } + if contact.DisplayName != "John Doe" { + t.Errorf("got %v, want %v", contact.DisplayName, "John Doe") + } + if len(contact.Names) != 1 { + t.Errorf("got length %d, want %d", len(contact.Names), 1) + } + if contact.Names[0].GivenName != "John" { + t.Errorf("got %v, want %v", contact.Names[0].GivenName, "John") + } + if contact.Names[0].FamilyName != "Doe" { + t.Errorf("got %v, want %v", contact.Names[0].FamilyName, "Doe") + } }) t.Run("parses contact with email", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c456", Names: []*people.Name{ @@ -50,15 +62,28 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Len(t, contact.Emails, 2) - assert.Equal(t, "jane@example.com", contact.Emails[0].Value) - assert.Equal(t, "work", contact.Emails[0].Type) - assert.True(t, contact.Emails[0].Primary) - assert.Equal(t, "jane.personal@example.com", contact.Emails[1].Value) - assert.False(t, contact.Emails[1].Primary) + if len(contact.Emails) != 2 { + t.Errorf("got length %d, want %d", len(contact.Emails), 2) + } + if contact.Emails[0].Value != "jane@example.com" { + t.Errorf("got %v, want %v", contact.Emails[0].Value, "jane@example.com") + } + if contact.Emails[0].Type != "work" { + t.Errorf("got %v, want %v", contact.Emails[0].Type, "work") + } + if !contact.Emails[0].Primary { + t.Error("got false, want true") + } + if contact.Emails[1].Value != "jane.personal@example.com" { + t.Errorf("got %v, want %v", contact.Emails[1].Value, "jane.personal@example.com") + } + if contact.Emails[1].Primary { + t.Error("got true, want false") + } }) t.Run("parses contact with phone numbers", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c789", PhoneNumbers: []*people.PhoneNumber{ @@ -69,12 +94,19 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Len(t, contact.Phones, 2) - assert.Equal(t, "+1-555-123-4567", contact.Phones[0].Value) - assert.Equal(t, "mobile", contact.Phones[0].Type) + if len(contact.Phones) != 2 { + t.Errorf("got length %d, want %d", len(contact.Phones), 2) + } + if contact.Phones[0].Value != "+1-555-123-4567" { + t.Errorf("got %v, want %v", contact.Phones[0].Value, "+1-555-123-4567") + } + if contact.Phones[0].Type != "mobile" { + t.Errorf("got %v, want %v", contact.Phones[0].Type, "mobile") + } }) t.Run("parses contact with organization", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c101", Organizations: []*people.Organization{ @@ -88,13 +120,22 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Len(t, contact.Organizations, 1) - assert.Equal(t, "Acme Corp", contact.Organizations[0].Name) - assert.Equal(t, "Software Engineer", contact.Organizations[0].Title) - assert.Equal(t, "Engineering", contact.Organizations[0].Department) + if len(contact.Organizations) != 1 { + t.Errorf("got length %d, want %d", len(contact.Organizations), 1) + } + if contact.Organizations[0].Name != "Acme Corp" { + t.Errorf("got %v, want %v", contact.Organizations[0].Name, "Acme Corp") + } + if contact.Organizations[0].Title != "Software Engineer" { + t.Errorf("got %v, want %v", contact.Organizations[0].Title, "Software Engineer") + } + if contact.Organizations[0].Department != "Engineering" { + t.Errorf("got %v, want %v", contact.Organizations[0].Department, "Engineering") + } }) t.Run("parses contact with address", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c102", Addresses: []*people.Address{ @@ -111,13 +152,22 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Len(t, contact.Addresses, 1) - assert.Equal(t, "home", contact.Addresses[0].Type) - assert.Equal(t, "San Francisco", contact.Addresses[0].City) - assert.Equal(t, "94102", contact.Addresses[0].PostalCode) + if len(contact.Addresses) != 1 { + t.Errorf("got length %d, want %d", len(contact.Addresses), 1) + } + if contact.Addresses[0].Type != "home" { + t.Errorf("got %v, want %v", contact.Addresses[0].Type, "home") + } + if contact.Addresses[0].City != "San Francisco" { + t.Errorf("got %v, want %v", contact.Addresses[0].City, "San Francisco") + } + if contact.Addresses[0].PostalCode != "94102" { + t.Errorf("got %v, want %v", contact.Addresses[0].PostalCode, "94102") + } }) t.Run("parses contact with URLs", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c103", Urls: []*people.Url{ @@ -128,11 +178,16 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Len(t, contact.URLs, 2) - assert.Equal(t, "https://linkedin.com/in/johndoe", contact.URLs[0].Value) + if len(contact.URLs) != 2 { + t.Errorf("got length %d, want %d", len(contact.URLs), 2) + } + if contact.URLs[0].Value != "https://linkedin.com/in/johndoe" { + t.Errorf("got %v, want %v", contact.URLs[0].Value, "https://linkedin.com/in/johndoe") + } }) t.Run("parses contact with biography", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c104", Biographies: []*people.Biography{ @@ -142,10 +197,13 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Equal(t, "A passionate software developer.", contact.Biography) + if contact.Biography != "A passionate software developer." { + t.Errorf("got %v, want %v", contact.Biography, "A passionate software developer.") + } }) t.Run("parses contact with birthday including year", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c105", Birthdays: []*people.Birthday{ @@ -155,10 +213,13 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Equal(t, "1990-06-15", contact.Birthday) + if contact.Birthday != "1990-06-15" { + t.Errorf("got %v, want %v", contact.Birthday, "1990-06-15") + } }) t.Run("parses contact with birthday month/day only", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c106", Birthdays: []*people.Birthday{ @@ -168,10 +229,13 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Equal(t, "12-25", contact.Birthday) + if contact.Birthday != "12-25" { + t.Errorf("got %v, want %v", contact.Birthday, "12-25") + } }) t.Run("parses contact with photo", func(t *testing.T) { + t.Parallel() p := &people.Person{ ResourceName: "people/c107", Photos: []*people.Photo{ @@ -181,17 +245,24 @@ func TestParseContact(t *testing.T) { contact := ParseContact(p) - assert.Equal(t, "https://example.com/photo.jpg", contact.PhotoURL) + if contact.PhotoURL != "https://example.com/photo.jpg" { + t.Errorf("got %v, want %v", contact.PhotoURL, "https://example.com/photo.jpg") + } }) t.Run("handles nil person", func(t *testing.T) { + t.Parallel() contact := ParseContact(nil) - assert.Nil(t, contact) + if contact != nil { + t.Errorf("got %v, want nil", contact) + } }) } func TestParseContactGroup(t *testing.T) { + t.Parallel() t.Run("parses contact group", func(t *testing.T) { + t.Parallel() g := &people.ContactGroup{ ResourceName: "contactGroups/abc123", Name: "Work", @@ -201,125 +272,177 @@ func TestParseContactGroup(t *testing.T) { group := ParseContactGroup(g) - assert.Equal(t, "contactGroups/abc123", group.ResourceName) - assert.Equal(t, "Work", group.Name) - assert.Equal(t, "USER_CONTACT_GROUP", group.GroupType) - assert.Equal(t, int64(42), group.MemberCount) + if group.ResourceName != "contactGroups/abc123" { + t.Errorf("got %v, want %v", group.ResourceName, "contactGroups/abc123") + } + if group.Name != "Work" { + t.Errorf("got %v, want %v", group.Name, "Work") + } + if group.GroupType != "USER_CONTACT_GROUP" { + t.Errorf("got %v, want %v", group.GroupType, "USER_CONTACT_GROUP") + } + if group.MemberCount != int64(42) { + t.Errorf("got %v, want %v", group.MemberCount, int64(42)) + } }) t.Run("handles nil group", func(t *testing.T) { + t.Parallel() group := ParseContactGroup(nil) - assert.Nil(t, group) + if group != nil { + t.Errorf("got %v, want nil", group) + } }) } func TestContactGetDisplayName(t *testing.T) { + t.Parallel() t.Run("returns display name when set", func(t *testing.T) { + t.Parallel() c := &Contact{ ResourceName: "people/c1", DisplayName: "John Doe", } - assert.Equal(t, "John Doe", c.GetDisplayName()) + if c.GetDisplayName() != "John Doe" { + t.Errorf("got %v, want %v", c.GetDisplayName(), "John Doe") + } }) t.Run("falls back to names array", func(t *testing.T) { + t.Parallel() c := &Contact{ ResourceName: "people/c2", Names: []Name{ {DisplayName: "Jane Smith"}, }, } - assert.Equal(t, "Jane Smith", c.GetDisplayName()) + if c.GetDisplayName() != "Jane Smith" { + t.Errorf("got %v, want %v", c.GetDisplayName(), "Jane Smith") + } }) t.Run("falls back to email", func(t *testing.T) { + t.Parallel() c := &Contact{ ResourceName: "people/c3", Emails: []Email{ {Value: "test@example.com"}, }, } - assert.Equal(t, "test@example.com", c.GetDisplayName()) + if c.GetDisplayName() != "test@example.com" { + t.Errorf("got %v, want %v", c.GetDisplayName(), "test@example.com") + } }) t.Run("falls back to resource name", func(t *testing.T) { + t.Parallel() c := &Contact{ ResourceName: "people/c4", } - assert.Equal(t, "people/c4", c.GetDisplayName()) + if c.GetDisplayName() != "people/c4" { + t.Errorf("got %v, want %v", c.GetDisplayName(), "people/c4") + } }) } func TestContactGetPrimaryEmail(t *testing.T) { + t.Parallel() t.Run("returns primary email when marked", func(t *testing.T) { + t.Parallel() c := &Contact{ Emails: []Email{ {Value: "work@example.com", Primary: false}, {Value: "primary@example.com", Primary: true}, }, } - assert.Equal(t, "primary@example.com", c.GetPrimaryEmail()) + if c.GetPrimaryEmail() != "primary@example.com" { + t.Errorf("got %v, want %v", c.GetPrimaryEmail(), "primary@example.com") + } }) t.Run("returns first email when no primary", func(t *testing.T) { + t.Parallel() c := &Contact{ Emails: []Email{ {Value: "first@example.com"}, {Value: "second@example.com"}, }, } - assert.Equal(t, "first@example.com", c.GetPrimaryEmail()) + if c.GetPrimaryEmail() != "first@example.com" { + t.Errorf("got %v, want %v", c.GetPrimaryEmail(), "first@example.com") + } }) t.Run("returns empty string when no emails", func(t *testing.T) { + t.Parallel() c := &Contact{} - assert.Equal(t, "", c.GetPrimaryEmail()) + if c.GetPrimaryEmail() != "" { + t.Errorf("got %v, want %v", c.GetPrimaryEmail(), "") + } }) } func TestContactGetPrimaryPhone(t *testing.T) { + t.Parallel() t.Run("returns first phone", func(t *testing.T) { + t.Parallel() c := &Contact{ Phones: []Phone{ {Value: "+1-555-123-4567"}, {Value: "+1-555-987-6543"}, }, } - assert.Equal(t, "+1-555-123-4567", c.GetPrimaryPhone()) + if c.GetPrimaryPhone() != "+1-555-123-4567" { + t.Errorf("got %v, want %v", c.GetPrimaryPhone(), "+1-555-123-4567") + } }) t.Run("returns empty string when no phones", func(t *testing.T) { + t.Parallel() c := &Contact{} - assert.Equal(t, "", c.GetPrimaryPhone()) + if c.GetPrimaryPhone() != "" { + t.Errorf("got %v, want %v", c.GetPrimaryPhone(), "") + } }) } func TestContactGetOrganization(t *testing.T) { + t.Parallel() t.Run("returns organization name", func(t *testing.T) { + t.Parallel() c := &Contact{ Organizations: []Organization{ {Name: "Acme Corp", Title: "Engineer"}, }, } - assert.Equal(t, "Acme Corp", c.GetOrganization()) + if c.GetOrganization() != "Acme Corp" { + t.Errorf("got %v, want %v", c.GetOrganization(), "Acme Corp") + } }) t.Run("returns title when no name", func(t *testing.T) { + t.Parallel() c := &Contact{ Organizations: []Organization{ {Title: "Freelance Developer"}, }, } - assert.Equal(t, "Freelance Developer", c.GetOrganization()) + if c.GetOrganization() != "Freelance Developer" { + t.Errorf("got %v, want %v", c.GetOrganization(), "Freelance Developer") + } }) t.Run("returns empty string when no organizations", func(t *testing.T) { + t.Parallel() c := &Contact{} - assert.Equal(t, "", c.GetOrganization()) + if c.GetOrganization() != "" { + t.Errorf("got %v, want %v", c.GetOrganization(), "") + } }) } func TestFormatDate(t *testing.T) { + t.Parallel() tests := []struct { name string year int64 @@ -334,13 +457,17 @@ func TestFormatDate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := formatDate(tt.year, tt.month, tt.day) - assert.Equal(t, tt.expect, result) + if result != tt.expect { + t.Errorf("got %v, want %v", result, tt.expect) + } }) } } func TestFormatMonthDay(t *testing.T) { + t.Parallel() tests := []struct { name string month int64 @@ -354,8 +481,11 @@ func TestFormatMonthDay(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := formatMonthDay(tt.month, tt.day) - assert.Equal(t, tt.expect, result) + if result != tt.expect { + t.Errorf("got %v, want %v", result, tt.expect) + } }) } } diff --git a/internal/contacts/interfaces.go b/internal/contacts/interfaces.go deleted file mode 100644 index 95ed9e3..0000000 --- a/internal/contacts/interfaces.go +++ /dev/null @@ -1,24 +0,0 @@ -package contacts - -import ( - "google.golang.org/api/people/v1" -) - -// ContactsClientInterface defines the interface for Contacts client operations. -// This enables unit testing through mock implementations. -type ContactsClientInterface interface { - // ListContacts retrieves contacts from the user's account - ListContacts(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) - - // SearchContacts searches for contacts matching a query - SearchContacts(query string, pageSize int64) (*people.SearchResponse, error) - - // GetContact retrieves a specific contact by resource name - GetContact(resourceName string) (*people.Person, error) - - // ListContactGroups retrieves all contact groups - ListContactGroups(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) -} - -// Verify that Client implements ContactsClientInterface -var _ ContactsClientInterface = (*Client)(nil) diff --git a/internal/drive/client.go b/internal/drive/client.go index 53a8871..fdc69ba 100644 --- a/internal/drive/client.go +++ b/internal/drive/client.go @@ -1,3 +1,4 @@ +// Package drive provides a client for the Google Drive API. package drive import ( @@ -20,12 +21,12 @@ type Client struct { func NewClient(ctx context.Context) (*Client, error) { client, err := auth.GetHTTPClient(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("loading OAuth client: %w", err) } srv, err := drive.NewService(ctx, option.WithHTTPClient(client)) if err != nil { - return nil, fmt.Errorf("unable to create Drive service: %w", err) + return nil, fmt.Errorf("creating Drive service: %w", err) } return &Client{ @@ -37,7 +38,7 @@ func NewClient(ctx context.Context) (*Client, error) { const fileFields = "id,name,mimeType,size,createdTime,modifiedTime,parents,owners,webViewLink,shared,driveId" // ListFiles returns files matching the query (searches My Drive only for backwards compatibility) -func (c *Client) ListFiles(query string, pageSize int64) ([]*File, error) { +func (c *Client) ListFiles(ctx context.Context, query string, pageSize int64) ([]*File, error) { call := c.service.Files.List(). Fields("files(" + fileFields + ")"). OrderBy("modifiedTime desc") @@ -49,9 +50,9 @@ func (c *Client) ListFiles(query string, pageSize int64) ([]*File, error) { call = call.PageSize(pageSize) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list files: %w", err) + return nil, fmt.Errorf("listing files: %w", err) } files := make([]*File, 0, len(resp.Files)) @@ -62,7 +63,7 @@ func (c *Client) ListFiles(query string, pageSize int64) ([]*File, error) { } // ListFilesWithScope returns files matching the query within the specified scope -func (c *Client) ListFilesWithScope(query string, pageSize int64, scope DriveScope) ([]*File, error) { +func (c *Client) ListFilesWithScope(ctx context.Context, query string, pageSize int64, scope DriveScope) ([]*File, error) { call := c.service.Files.List(). Fields("files(" + fileFields + ")"). OrderBy("modifiedTime desc"). @@ -89,9 +90,9 @@ func (c *Client) ListFilesWithScope(query string, pageSize int64, scope DriveSco call = call.PageSize(pageSize) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list files: %w", err) + return nil, fmt.Errorf("listing files: %w", err) } files := make([]*File, 0, len(resp.Files)) @@ -102,51 +103,53 @@ func (c *Client) ListFilesWithScope(query string, pageSize int64, scope DriveSco } // GetFile retrieves a single file by ID (supports files in shared drives) -func (c *Client) GetFile(fileID string) (*File, error) { +func (c *Client) GetFile(ctx context.Context, fileID string) (*File, error) { f, err := c.service.Files.Get(fileID). Fields(fileFields). SupportsAllDrives(true). + Context(ctx). Do() if err != nil { - return nil, fmt.Errorf("failed to get file: %w", err) + return nil, fmt.Errorf("getting file: %w", err) } return ParseFile(f), nil } // DownloadFile downloads a regular (non-Google Workspace) file -func (c *Client) DownloadFile(fileID string) ([]byte, error) { +func (c *Client) DownloadFile(ctx context.Context, fileID string) ([]byte, error) { resp, err := c.service.Files.Get(fileID). SupportsAllDrives(true). + Context(ctx). Download() if err != nil { - return nil, fmt.Errorf("failed to download file: %w", err) + return nil, fmt.Errorf("downloading file: %w", err) } defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read file content: %w", err) + return nil, fmt.Errorf("reading file content: %w", err) } return data, nil } // ExportFile exports a Google Workspace file to the specified MIME type -func (c *Client) ExportFile(fileID string, mimeType string) ([]byte, error) { - resp, err := c.service.Files.Export(fileID, mimeType).Download() +func (c *Client) ExportFile(ctx context.Context, fileID string, mimeType string) ([]byte, error) { + resp, err := c.service.Files.Export(fileID, mimeType).Context(ctx).Download() if err != nil { - return nil, fmt.Errorf("failed to export file: %w", err) + return nil, fmt.Errorf("exporting file: %w", err) } defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read exported content: %w", err) + return nil, fmt.Errorf("reading exported content: %w", err) } return data, nil } // ListSharedDrives returns all shared drives accessible to the user -func (c *Client) ListSharedDrives(pageSize int64) ([]*SharedDrive, error) { +func (c *Client) ListSharedDrives(ctx context.Context, pageSize int64) ([]*SharedDrive, error) { var allDrives []*SharedDrive pageToken := "" @@ -161,9 +164,9 @@ func (c *Client) ListSharedDrives(pageSize int64) ([]*SharedDrive, error) { call = call.PageToken(pageToken) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to list shared drives: %w", err) + return nil, fmt.Errorf("listing shared drives: %w", err) } for _, d := range resp.Drives { diff --git a/internal/drive/files_test.go b/internal/drive/files_test.go index 698c83f..6990c73 100644 --- a/internal/drive/files_test.go +++ b/internal/drive/files_test.go @@ -1,14 +1,18 @@ package drive import ( + "reflect" + "slices" + "strings" "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/drive/v3" ) func TestParseFile(t *testing.T) { + t.Parallel() t.Run("parses basic file", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "test.txt", @@ -23,18 +27,37 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.Equal(t, "123", result.ID) - assert.Equal(t, "test.txt", result.Name) - assert.Equal(t, "text/plain", result.MimeType) - assert.Equal(t, int64(1024), result.Size) - assert.Equal(t, 2024, result.CreatedTime.Year()) - assert.Equal(t, 2024, result.ModifiedTime.Year()) - assert.Equal(t, []string{"parent1"}, result.Parents) - assert.Equal(t, "https://drive.google.com/file/d/123", result.WebViewLink) - assert.True(t, result.Shared) + if result.ID != "123" { + t.Errorf("got %v, want %v", result.ID, "123") + } + if result.Name != "test.txt" { + t.Errorf("got %v, want %v", result.Name, "test.txt") + } + if result.MimeType != "text/plain" { + t.Errorf("got %v, want %v", result.MimeType, "text/plain") + } + if result.Size != int64(1024) { + t.Errorf("got %v, want %v", result.Size, int64(1024)) + } + if result.CreatedTime.Year() != 2024 { + t.Errorf("got %v, want %v", result.CreatedTime.Year(), 2024) + } + if result.ModifiedTime.Year() != 2024 { + t.Errorf("got %v, want %v", result.ModifiedTime.Year(), 2024) + } + if !reflect.DeepEqual(result.Parents, []string{"parent1"}) { + t.Errorf("got %v, want %v", result.Parents, []string{"parent1"}) + } + if result.WebViewLink != "https://drive.google.com/file/d/123" { + t.Errorf("got %v, want %v", result.WebViewLink, "https://drive.google.com/file/d/123") + } + if !result.Shared { + t.Error("got false, want true") + } }) t.Run("parses file with owners", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "shared.txt", @@ -47,10 +70,14 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.Equal(t, []string{"owner1@example.com", "owner2@example.com"}, result.Owners) + expected := []string{"owner1@example.com", "owner2@example.com"} + if !reflect.DeepEqual(result.Owners, expected) { + t.Errorf("got %v, want %v", result.Owners, expected) + } }) t.Run("handles empty timestamps", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "no-times.txt", @@ -61,11 +88,16 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.True(t, result.CreatedTime.IsZero()) - assert.True(t, result.ModifiedTime.IsZero()) + if !result.CreatedTime.IsZero() { + t.Error("got false, want true") + } + if !result.ModifiedTime.IsZero() { + t.Error("got false, want true") + } }) t.Run("handles malformed timestamps", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "bad-times.txt", @@ -76,11 +108,16 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.True(t, result.CreatedTime.IsZero()) - assert.True(t, result.ModifiedTime.IsZero()) + if !result.CreatedTime.IsZero() { + t.Error("got false, want true") + } + if !result.ModifiedTime.IsZero() { + t.Error("got false, want true") + } }) t.Run("handles nil owners", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "no-owners.txt", @@ -90,10 +127,13 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.Nil(t, result.Owners) + if result.Owners != nil { + t.Errorf("got %v, want nil", result.Owners) + } }) t.Run("handles empty owners slice", func(t *testing.T) { + t.Parallel() f := &drive.File{ Id: "123", Name: "empty-owners.txt", @@ -103,11 +143,14 @@ func TestParseFile(t *testing.T) { result := ParseFile(f) - assert.Nil(t, result.Owners) + if result.Owners != nil { + t.Errorf("got %v, want nil", result.Owners) + } }) } func TestGetTypeName(t *testing.T) { + t.Parallel() tests := []struct { mimeType string expected string @@ -141,14 +184,19 @@ func TestGetTypeName(t *testing.T) { for _, tt := range tests { t.Run(tt.mimeType, func(t *testing.T) { + t.Parallel() result := GetTypeName(tt.mimeType) - assert.Equal(t, tt.expected, result) + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } }) } } func TestIsGoogleWorkspaceFile(t *testing.T) { + t.Parallel() t.Run("returns true for Google Workspace files", func(t *testing.T) { + t.Parallel() workspaceTypes := []string{ MimeTypeDocument, MimeTypeSpreadsheet, @@ -159,11 +207,14 @@ func TestIsGoogleWorkspaceFile(t *testing.T) { } for _, mimeType := range workspaceTypes { - assert.True(t, IsGoogleWorkspaceFile(mimeType), "expected true for %s", mimeType) + if !IsGoogleWorkspaceFile(mimeType) { + t.Errorf("got false, want true for %s", mimeType) + } } }) t.Run("returns false for non-Workspace files", func(t *testing.T) { + t.Parallel() nonWorkspaceTypes := []string{ MimeTypeFolder, MimeTypeShortcut, @@ -175,13 +226,17 @@ func TestIsGoogleWorkspaceFile(t *testing.T) { } for _, mimeType := range nonWorkspaceTypes { - assert.False(t, IsGoogleWorkspaceFile(mimeType), "expected false for %s", mimeType) + if IsGoogleWorkspaceFile(mimeType) { + t.Errorf("got true, want false for %s", mimeType) + } } }) } func TestGetExportMimeType(t *testing.T) { + t.Parallel() t.Run("returns correct MIME type for Document exports", func(t *testing.T) { + t.Parallel() tests := []struct { format string expected string @@ -195,14 +250,20 @@ func TestGetExportMimeType(t *testing.T) { for _, tt := range tests { t.Run(tt.format, func(t *testing.T) { + t.Parallel() result, err := GetExportMimeType(MimeTypeDocument, tt.format) - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } }) } }) t.Run("returns correct MIME type for Spreadsheet exports", func(t *testing.T) { + t.Parallel() tests := []struct { format string expected string @@ -214,66 +275,113 @@ func TestGetExportMimeType(t *testing.T) { for _, tt := range tests { t.Run(tt.format, func(t *testing.T) { + t.Parallel() result, err := GetExportMimeType(MimeTypeSpreadsheet, tt.format) - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } }) } }) t.Run("returns correct MIME type for Presentation exports", func(t *testing.T) { + t.Parallel() result, err := GetExportMimeType(MimeTypePresentation, "pptx") - assert.NoError(t, err) - assert.Equal(t, "application/vnd.openxmlformats-officedocument.presentationml.presentation", result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "application/vnd.openxmlformats-officedocument.presentationml.presentation" { + t.Errorf("got %v, want %v", result, "application/vnd.openxmlformats-officedocument.presentationml.presentation") + } }) t.Run("returns correct MIME type for Drawing exports", func(t *testing.T) { + t.Parallel() result, err := GetExportMimeType(MimeTypeDrawing, "png") - assert.NoError(t, err) - assert.Equal(t, "image/png", result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "image/png" { + t.Errorf("got %v, want %v", result, "image/png") + } }) t.Run("returns error for unsupported format", func(t *testing.T) { + t.Parallel() _, err := GetExportMimeType(MimeTypeDocument, "xyz") - assert.Error(t, err) - assert.Contains(t, err.Error(), "not supported") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "not supported") { + t.Errorf("expected %q to contain %q", err.Error(), "not supported") + } }) t.Run("returns error for non-exportable file type", func(t *testing.T) { + t.Parallel() _, err := GetExportMimeType("application/pdf", "docx") - assert.Error(t, err) - assert.Contains(t, err.Error(), "does not support export") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "does not support export") { + t.Errorf("expected %q to contain %q", err.Error(), "does not support export") + } }) t.Run("returns error for format not matching file type", func(t *testing.T) { + t.Parallel() // csv is valid for spreadsheets but not documents _, err := GetExportMimeType(MimeTypeDocument, "csv") - assert.Error(t, err) - assert.Contains(t, err.Error(), "not supported for Google Document") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "not supported for Google Document") { + t.Errorf("expected %q to contain %q", err.Error(), "not supported for Google Document") + } }) } func TestGetSupportedExportFormats(t *testing.T) { + t.Parallel() t.Run("returns formats for Document", func(t *testing.T) { + t.Parallel() formats := GetSupportedExportFormats(MimeTypeDocument) - assert.Contains(t, formats, "pdf") - assert.Contains(t, formats, "docx") - assert.Contains(t, formats, "txt") + if !slices.Contains(formats, "pdf") { + t.Errorf("expected formats to contain %q", "pdf") + } + if !slices.Contains(formats, "docx") { + t.Errorf("expected formats to contain %q", "docx") + } + if !slices.Contains(formats, "txt") { + t.Errorf("expected formats to contain %q", "txt") + } }) t.Run("returns formats for Spreadsheet", func(t *testing.T) { + t.Parallel() formats := GetSupportedExportFormats(MimeTypeSpreadsheet) - assert.Contains(t, formats, "xlsx") - assert.Contains(t, formats, "csv") + if !slices.Contains(formats, "xlsx") { + t.Errorf("expected formats to contain %q", "xlsx") + } + if !slices.Contains(formats, "csv") { + t.Errorf("expected formats to contain %q", "csv") + } }) t.Run("returns nil for non-exportable file", func(t *testing.T) { + t.Parallel() formats := GetSupportedExportFormats("application/pdf") - assert.Nil(t, formats) + if formats != nil { + t.Errorf("got %v, want nil", formats) + } }) } func TestGetFileExtension(t *testing.T) { + t.Parallel() tests := []struct { format string expected string @@ -294,8 +402,11 @@ func TestGetFileExtension(t *testing.T) { for _, tt := range tests { t.Run(tt.format, func(t *testing.T) { + t.Parallel() result := GetFileExtension(tt.format) - assert.Equal(t, tt.expected, result) + if result != tt.expected { + t.Errorf("got %v, want %v", result, tt.expected) + } }) } } diff --git a/internal/drive/interfaces.go b/internal/drive/interfaces.go deleted file mode 100644 index 8b55972..0000000 --- a/internal/drive/interfaces.go +++ /dev/null @@ -1,26 +0,0 @@ -package drive - -// DriveClientInterface defines the interface for Drive client operations. -// This enables unit testing through mock implementations. -type DriveClientInterface interface { - // ListFiles returns files matching the query (searches My Drive only for backwards compatibility) - ListFiles(query string, pageSize int64) ([]*File, error) - - // ListFilesWithScope returns files matching the query within the specified scope - ListFilesWithScope(query string, pageSize int64, scope DriveScope) ([]*File, error) - - // GetFile retrieves a single file by ID (supports all drives) - GetFile(fileID string) (*File, error) - - // DownloadFile downloads a regular (non-Google Workspace) file - DownloadFile(fileID string) ([]byte, error) - - // ExportFile exports a Google Workspace file to the specified MIME type - ExportFile(fileID string, mimeType string) ([]byte, error) - - // ListSharedDrives returns all shared drives accessible to the user - ListSharedDrives(pageSize int64) ([]*SharedDrive, error) -} - -// Verify that Client implements DriveClientInterface -var _ DriveClientInterface = (*Client)(nil) diff --git a/internal/errors/errors.go b/internal/errors/errors.go index de8cfc0..726e836 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -2,7 +2,10 @@ // and system errors, allowing commands to provide appropriate guidance. package errors -import "fmt" +import ( + "errors" + "fmt" +) // UserError represents an error caused by invalid user input or action. // These errors are actionable - the user can fix them. @@ -49,7 +52,8 @@ func NewSystemError(message string, cause error, retryable bool) SystemError { // IsRetryable returns true if the error is a retryable SystemError. func IsRetryable(err error) bool { - if sysErr, ok := err.(SystemError); ok { + var sysErr SystemError + if errors.As(err, &sysErr) { return sysErr.Retryable } return false diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go index ce1bae4..b69536a 100644 --- a/internal/errors/errors_test.go +++ b/internal/errors/errors_test.go @@ -4,10 +4,11 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestUserError(t *testing.T) { + t.Parallel() tests := []struct { name string err UserError @@ -27,17 +28,20 @@ func TestUserError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, tt.err.Error()) + t.Parallel() + testutil.Equal(t, tt.err.Error(), tt.expected) }) } } func TestNewUserError(t *testing.T) { + t.Parallel() err := NewUserError("invalid value: %d", 42) - assert.Equal(t, "invalid value: 42", err.Error()) + testutil.Equal(t, err.Error(), "invalid value: 42") } func TestSystemError(t *testing.T) { + t.Parallel() tests := []struct { name string err SystemError @@ -64,32 +68,36 @@ func TestSystemError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, tt.err.Error()) + t.Parallel() + testutil.Equal(t, tt.err.Error(), tt.expected) }) } } func TestSystemErrorUnwrap(t *testing.T) { + t.Parallel() cause := errors.New("underlying error") err := SystemError{ Message: "wrapper", Cause: cause, } - assert.Equal(t, cause, err.Unwrap()) - assert.True(t, errors.Is(err, cause)) + testutil.Equal(t, err.Unwrap(), cause) + testutil.True(t, errors.Is(err, cause)) } func TestNewSystemError(t *testing.T) { + t.Parallel() cause := errors.New("network timeout") err := NewSystemError("API call failed", cause, true) - assert.Equal(t, "API call failed", err.Message) - assert.Equal(t, cause, err.Cause) - assert.True(t, err.Retryable) + testutil.Equal(t, err.Message, "API call failed") + testutil.Equal(t, err.Cause, cause) + testutil.True(t, err.Retryable) } func TestIsRetryable(t *testing.T) { + t.Parallel() tests := []struct { name string err error @@ -119,7 +127,8 @@ func TestIsRetryable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, IsRetryable(tt.err)) + t.Parallel() + testutil.Equal(t, IsRetryable(tt.err), tt.expected) }) } } diff --git a/internal/format/format_test.go b/internal/format/format_test.go index 7add5bf..eae57a7 100644 --- a/internal/format/format_test.go +++ b/internal/format/format_test.go @@ -3,10 +3,11 @@ package format import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestTruncate(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -23,13 +24,15 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := Truncate(tt.input, tt.maxLen) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } func TestSize(t *testing.T) { + t.Parallel() tests := []struct { name string bytes int64 @@ -47,8 +50,9 @@ func TestSize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() result := Size(tt.bytes) - assert.Equal(t, tt.expected, result) + testutil.Equal(t, result, tt.expected) }) } } diff --git a/internal/gmail/attachments.go b/internal/gmail/attachments.go index a6375f9..d12b6a6 100644 --- a/internal/gmail/attachments.go +++ b/internal/gmail/attachments.go @@ -1,6 +1,7 @@ package gmail import ( + "context" "encoding/base64" "fmt" "strconv" @@ -10,35 +11,35 @@ import ( ) // GetAttachments retrieves attachment metadata for a message -func (c *Client) GetAttachments(messageID string) ([]*Attachment, error) { - msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format("full").Do() +func (c *Client) GetAttachments(ctx context.Context, messageID string) ([]*Attachment, error) { + msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format("full").Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get message: %w", err) + return nil, fmt.Errorf("getting message: %w", err) } return extractAttachments(msg.Payload, ""), nil } // DownloadAttachment downloads a single attachment by message ID and attachment ID -func (c *Client) DownloadAttachment(messageID string, attachmentID string) ([]byte, error) { - att, err := c.service.Users.Messages.Attachments.Get(c.userID, messageID, attachmentID).Do() +func (c *Client) DownloadAttachment(ctx context.Context, messageID string, attachmentID string) ([]byte, error) { + att, err := c.service.Users.Messages.Attachments.Get(c.userID, messageID, attachmentID).Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to download attachment: %w", err) + return nil, fmt.Errorf("downloading attachment: %w", err) } data, err := base64.URLEncoding.DecodeString(att.Data) if err != nil { - return nil, fmt.Errorf("failed to decode attachment data: %w", err) + return nil, fmt.Errorf("decoding attachment data: %w", err) } return data, nil } // DownloadInlineAttachment downloads an attachment that has inline data -func (c *Client) DownloadInlineAttachment(messageID string, partID string) ([]byte, error) { - msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format("full").Do() +func (c *Client) DownloadInlineAttachment(ctx context.Context, messageID string, partID string) ([]byte, error) { + msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format("full").Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get message: %w", err) + return nil, fmt.Errorf("getting message: %w", err) } part := findPart(msg.Payload, partID) @@ -52,7 +53,7 @@ func (c *Client) DownloadInlineAttachment(messageID string, partID string) ([]by data, err := base64.URLEncoding.DecodeString(part.Body.Data) if err != nil { - return nil, fmt.Errorf("failed to decode inline attachment: %w", err) + return nil, fmt.Errorf("decoding inline attachment: %w", err) } return data, nil diff --git a/internal/gmail/attachments_test.go b/internal/gmail/attachments_test.go index efb95a1..aa4e901 100644 --- a/internal/gmail/attachments_test.go +++ b/internal/gmail/attachments_test.go @@ -3,30 +3,37 @@ package gmail import ( "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/gmail/v1" ) func TestFindPart(t *testing.T) { + t.Parallel() t.Run("returns payload for empty path", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "text/plain", } result := findPart(payload, "") - assert.Equal(t, payload, result) + if result != payload { + t.Errorf("got %v, want %v", result, payload) + } }) t.Run("finds part at index 0", func(t *testing.T) { + t.Parallel() child := &gmail.MessagePart{MimeType: "text/plain", Filename: "file.txt"} payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{child}, } result := findPart(payload, "0") - assert.Equal(t, child, result) + if result != child { + t.Errorf("got %v, want %v", result, child) + } }) t.Run("finds nested part", func(t *testing.T) { + t.Parallel() deepChild := &gmail.MessagePart{MimeType: "application/pdf", Filename: "nested.pdf"} payload := &gmail.MessagePart{ MimeType: "multipart/mixed", @@ -41,37 +48,49 @@ func TestFindPart(t *testing.T) { }, } result := findPart(payload, "0.1") - assert.Equal(t, deepChild, result) + if result != deepChild { + t.Errorf("got %v, want %v", result, deepChild) + } }) t.Run("returns nil for invalid index", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{{MimeType: "text/plain"}}, } result := findPart(payload, "5") - assert.Nil(t, result) + if result != nil { + t.Errorf("got %v, want nil", result) + } }) t.Run("returns nil for negative index", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{{MimeType: "text/plain"}}, } result := findPart(payload, "-1") - assert.Nil(t, result) + if result != nil { + t.Errorf("got %v, want nil", result) + } }) t.Run("returns nil for non-numeric path", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{{MimeType: "text/plain"}}, } result := findPart(payload, "abc") - assert.Nil(t, result) + if result != nil { + t.Errorf("got %v, want nil", result) + } }) t.Run("returns nil for out of bounds nested path", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{ @@ -82,10 +101,13 @@ func TestFindPart(t *testing.T) { }, } result := findPart(payload, "0.5") - assert.Nil(t, result) + if result != nil { + t.Errorf("got %v, want nil", result) + } }) t.Run("handles deeply nested path", func(t *testing.T) { + t.Parallel() deepest := &gmail.MessagePart{Filename: "deep.txt"} payload := &gmail.MessagePart{ Parts: []*gmail.MessagePart{ @@ -101,6 +123,8 @@ func TestFindPart(t *testing.T) { }, } result := findPart(payload, "0.0.0") - assert.Equal(t, deepest, result) + if result != deepest { + t.Errorf("got %v, want %v", result, deepest) + } }) } diff --git a/internal/gmail/client.go b/internal/gmail/client.go index b59dd85..b3c2b97 100644 --- a/internal/gmail/client.go +++ b/internal/gmail/client.go @@ -1,3 +1,4 @@ +// Package gmail provides a client for the Gmail API. package gmail import ( @@ -5,7 +6,6 @@ import ( "fmt" "sync" - "golang.org/x/oauth2" "google.golang.org/api/gmail/v1" "google.golang.org/api/option" @@ -25,12 +25,12 @@ type Client struct { func NewClient(ctx context.Context) (*Client, error) { client, err := auth.GetHTTPClient(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("loading OAuth client: %w", err) } srv, err := gmail.NewService(ctx, option.WithHTTPClient(client)) if err != nil { - return nil, fmt.Errorf("unable to create Gmail service: %w", err) + return nil, fmt.Errorf("creating Gmail service: %w", err) } return &Client{ @@ -40,7 +40,7 @@ func NewClient(ctx context.Context) (*Client, error) { } // FetchLabels retrieves and caches all labels from the Gmail account -func (c *Client) FetchLabels() error { +func (c *Client) FetchLabels(ctx context.Context) error { // Check with read lock first to avoid unnecessary API calls c.labelsMu.RLock() if c.labelsLoaded { @@ -58,9 +58,9 @@ func (c *Client) FetchLabels() error { return nil } - resp, err := c.service.Users.Labels.List(c.userID).Do() + resp, err := c.service.Users.Labels.List(c.userID).Context(ctx).Do() if err != nil { - return fmt.Errorf("failed to fetch labels: %w", err) + return fmt.Errorf("fetching labels: %w", err) } c.labels = make(map[string]*gmail.Label) @@ -98,11 +98,18 @@ func (c *Client) GetLabels() []*gmail.Label { return labels } +// Profile represents a Gmail user profile. +type Profile struct { + EmailAddress string + MessagesTotal int64 + ThreadsTotal int64 +} + // GetProfile retrieves the authenticated user's profile -func (c *Client) GetProfile() (*Profile, error) { - profile, err := c.service.Users.GetProfile(c.userID).Do() +func (c *Client) GetProfile(ctx context.Context) (*Profile, error) { + profile, err := c.service.Users.GetProfile(c.userID).Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get profile: %w", err) + return nil, fmt.Errorf("getting profile: %w", err) } return &Profile{ EmailAddress: profile.EmailAddress, @@ -110,39 +117,3 @@ func (c *Client) GetProfile() (*Profile, error) { ThreadsTotal: profile.ThreadsTotal, }, nil } - -// GetConfigDir returns the configuration directory path -// Deprecated: Use auth.GetConfigDir() instead -func GetConfigDir() (string, error) { - return auth.GetConfigDir() -} - -// GetCredentialsPath returns the path to credentials.json -// Deprecated: Use auth.GetCredentialsPath() instead -func GetCredentialsPath() (string, error) { - return auth.GetCredentialsPath() -} - -// GetOAuthConfig loads OAuth config from credentials file -// Deprecated: Use auth.GetOAuthConfig() instead -func GetOAuthConfig() (*oauth2.Config, error) { - return auth.GetOAuthConfig() -} - -// ExchangeAuthCode exchanges an authorization code for a token -// Deprecated: Use auth.ExchangeAuthCode() instead -func ExchangeAuthCode(ctx context.Context, config *oauth2.Config, code string) (*oauth2.Token, error) { - return auth.ExchangeAuthCode(ctx, config, code) -} - -// GetAuthURL returns the OAuth authorization URL -// Deprecated: Use auth.GetAuthURL() instead -func GetAuthURL(config *oauth2.Config) string { - return auth.GetAuthURL(config) -} - -// ShortenPath replaces the home directory prefix with ~ for display purposes. -// Deprecated: Use auth.ShortenPath() instead -func ShortenPath(path string) string { - return auth.ShortenPath(path) -} diff --git a/internal/gmail/client_test.go b/internal/gmail/client_test.go index 7cf77d7..00aef2f 100644 --- a/internal/gmail/client_test.go +++ b/internal/gmail/client_test.go @@ -1,19 +1,15 @@ package gmail import ( - "os" - "path/filepath" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" gmailapi "google.golang.org/api/gmail/v1" - - "github.com/open-cli-collective/google-readonly/internal/auth" ) func TestGetLabelName(t *testing.T) { + t.Parallel() t.Run("returns name for cached label", func(t *testing.T) { + t.Parallel() client := &Client{ labels: map[string]*gmailapi.Label{ "Label_123": {Id: "Label_123", Name: "Work"}, @@ -22,41 +18,56 @@ func TestGetLabelName(t *testing.T) { labelsLoaded: true, } - assert.Equal(t, "Work", client.GetLabelName("Label_123")) - assert.Equal(t, "Personal", client.GetLabelName("Label_456")) + if got := client.GetLabelName("Label_123"); got != "Work" { + t.Errorf("got %v, want %v", got, "Work") + } + if got := client.GetLabelName("Label_456"); got != "Personal" { + t.Errorf("got %v, want %v", got, "Personal") + } }) t.Run("returns ID for uncached label", func(t *testing.T) { + t.Parallel() client := &Client{ labels: map[string]*gmailapi.Label{}, labelsLoaded: true, } - assert.Equal(t, "Unknown_Label", client.GetLabelName("Unknown_Label")) + if got := client.GetLabelName("Unknown_Label"); got != "Unknown_Label" { + t.Errorf("got %v, want %v", got, "Unknown_Label") + } }) t.Run("returns ID when labels not loaded", func(t *testing.T) { + t.Parallel() client := &Client{ labels: nil, labelsLoaded: false, } - assert.Equal(t, "Label_123", client.GetLabelName("Label_123")) + if got := client.GetLabelName("Label_123"); got != "Label_123" { + t.Errorf("got %v, want %v", got, "Label_123") + } }) } func TestGetLabels(t *testing.T) { + t.Parallel() t.Run("returns nil when labels not loaded", func(t *testing.T) { + t.Parallel() client := &Client{ labels: nil, labelsLoaded: false, } result := client.GetLabels() - assert.Nil(t, result) + if result != nil { + t.Errorf("got %v, want nil", result) + } }) t.Run("returns all cached labels", func(t *testing.T) { + t.Parallel() label1 := &gmailapi.Label{Id: "Label_1", Name: "Work"} label2 := &gmailapi.Label{Id: "Label_2", Name: "Personal"} @@ -69,60 +80,39 @@ func TestGetLabels(t *testing.T) { } result := client.GetLabels() - assert.Len(t, result, 2) - assert.Contains(t, result, label1) - assert.Contains(t, result, label2) + if len(result) != 2 { + t.Errorf("got length %d, want %d", len(result), 2) + } + found1, found2 := false, false + for _, l := range result { + if l == label1 { + found1 = true + } + if l == label2 { + found2 = true + } + } + if !found1 { + t.Errorf("expected result to contain label1 (Work)") + } + if !found2 { + t.Errorf("expected result to contain label2 (Personal)") + } }) t.Run("returns empty slice for empty cache", func(t *testing.T) { + t.Parallel() client := &Client{ labels: map[string]*gmailapi.Label{}, labelsLoaded: true, } result := client.GetLabels() - assert.NotNil(t, result) - assert.Empty(t, result) - }) -} - -// TestDeprecatedWrappers verifies that the deprecated wrappers delegate correctly to the auth package -func TestDeprecatedWrappers(t *testing.T) { - t.Run("GetConfigDir delegates to auth package", func(t *testing.T) { - tmpDir := t.TempDir() - t.Setenv("XDG_CONFIG_HOME", tmpDir) - - gmailDir, err := GetConfigDir() - require.NoError(t, err) - - authDir, err := auth.GetConfigDir() - require.NoError(t, err) - - assert.Equal(t, authDir, gmailDir) - }) - - t.Run("GetCredentialsPath delegates to auth package", func(t *testing.T) { - tmpDir := t.TempDir() - t.Setenv("XDG_CONFIG_HOME", tmpDir) - - gmailPath, err := GetCredentialsPath() - require.NoError(t, err) - - authPath, err := auth.GetCredentialsPath() - require.NoError(t, err) - - assert.Equal(t, authPath, gmailPath) - }) - - t.Run("ShortenPath delegates to auth package", func(t *testing.T) { - home, err := os.UserHomeDir() - require.NoError(t, err) - - testPath := filepath.Join(home, ".config", "test") - - gmailResult := ShortenPath(testPath) - authResult := auth.ShortenPath(testPath) - - assert.Equal(t, authResult, gmailResult) + if result == nil { + t.Fatal("expected non-nil, got nil") + } + if len(result) != 0 { + t.Errorf("got length %d, want 0", len(result)) + } }) } diff --git a/internal/gmail/interfaces.go b/internal/gmail/interfaces.go deleted file mode 100644 index 3fea340..0000000 --- a/internal/gmail/interfaces.go +++ /dev/null @@ -1,49 +0,0 @@ -package gmail - -import ( - "google.golang.org/api/gmail/v1" -) - -// Profile represents a Gmail user profile. -type Profile struct { - EmailAddress string - MessagesTotal int64 - ThreadsTotal int64 -} - -// GmailClientInterface defines the interface for Gmail client operations. -// This enables unit testing through mock implementations. -type GmailClientInterface interface { - // GetMessage retrieves a single message by ID - GetMessage(messageID string, includeBody bool) (*Message, error) - - // SearchMessages searches for messages matching the query - SearchMessages(query string, maxResults int64) ([]*Message, int, error) - - // GetThread retrieves all messages in a thread - GetThread(id string) ([]*Message, error) - - // FetchLabels retrieves and caches all labels from the Gmail account - FetchLabels() error - - // GetLabelName resolves a label ID to its display name - GetLabelName(labelID string) string - - // GetLabels returns all cached labels - GetLabels() []*gmail.Label - - // GetAttachments retrieves attachment metadata for a message - GetAttachments(messageID string) ([]*Attachment, error) - - // DownloadAttachment downloads a single attachment by message ID and attachment ID - DownloadAttachment(messageID string, attachmentID string) ([]byte, error) - - // DownloadInlineAttachment downloads an attachment that has inline data - DownloadInlineAttachment(messageID string, partID string) ([]byte, error) - - // GetProfile retrieves the authenticated user's profile - GetProfile() (*Profile, error) -} - -// Verify that Client implements GmailClientInterface -var _ GmailClientInterface = (*Client)(nil) diff --git a/internal/gmail/messages.go b/internal/gmail/messages.go index a948a4b..e611e8e 100644 --- a/internal/gmail/messages.go +++ b/internal/gmail/messages.go @@ -1,6 +1,7 @@ package gmail import ( + "context" "encoding/base64" "fmt" "strings" @@ -37,21 +38,21 @@ type Attachment struct { // SearchMessages searches for messages matching the query. // Returns messages, the count of messages that failed to fetch, and any error. -func (c *Client) SearchMessages(query string, maxResults int64) ([]*Message, int, error) { +func (c *Client) SearchMessages(ctx context.Context, query string, maxResults int64) ([]*Message, int, error) { call := c.service.Users.Messages.List(c.userID).Q(query) if maxResults > 0 { call = call.MaxResults(maxResults) } - resp, err := call.Do() + resp, err := call.Context(ctx).Do() if err != nil { - return nil, 0, fmt.Errorf("failed to search messages: %w", err) + return nil, 0, fmt.Errorf("searching messages: %w", err) } var messages []*Message var skipped int for _, msg := range resp.Messages { - m, err := c.GetMessage(msg.Id, false) + m, err := c.GetMessage(ctx, msg.Id, false) if err != nil { skipped++ log.Debug("skipped message %s: %v", msg.Id, err) @@ -68,20 +69,20 @@ func (c *Client) SearchMessages(query string, maxResults int64) ([]*Message, int } // GetMessage retrieves a single message by ID -func (c *Client) GetMessage(messageID string, includeBody bool) (*Message, error) { +func (c *Client) GetMessage(ctx context.Context, messageID string, includeBody bool) (*Message, error) { format := "metadata" if includeBody { format = "full" } // Fetch labels for resolution - if err := c.FetchLabels(); err != nil { + if err := c.FetchLabels(ctx); err != nil { return nil, err } - msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format(format).Do() + msg, err := c.service.Users.Messages.Get(c.userID, messageID).Format(format).Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get message: %w", err) + return nil, fmt.Errorf("getting message: %w", err) } return parseMessage(msg, includeBody, c.GetLabelName), nil @@ -90,24 +91,24 @@ func (c *Client) GetMessage(messageID string, includeBody bool) (*Message, error // GetThread retrieves all messages in a thread. // The id parameter can be either a thread ID or a message ID. // If a message ID is provided, the thread ID is resolved automatically. -func (c *Client) GetThread(id string) ([]*Message, error) { +func (c *Client) GetThread(ctx context.Context, id string) ([]*Message, error) { // Fetch labels for resolution - if err := c.FetchLabels(); err != nil { + if err := c.FetchLabels(ctx); err != nil { return nil, err } - thread, err := c.service.Users.Threads.Get(c.userID, id).Format("full").Do() + thread, err := c.service.Users.Threads.Get(c.userID, id).Format("full").Context(ctx).Do() if err != nil { // If the ID wasn't found as a thread ID, try treating it as a message ID - msg, msgErr := c.service.Users.Messages.Get(c.userID, id).Format("minimal").Do() + msg, msgErr := c.service.Users.Messages.Get(c.userID, id).Format("minimal").Context(ctx).Do() if msgErr != nil { // Return the original thread error if message lookup also fails - return nil, fmt.Errorf("failed to get thread: %w", err) + return nil, fmt.Errorf("getting thread: %w", err) } // Use the thread ID from the message - thread, err = c.service.Users.Threads.Get(c.userID, msg.ThreadId).Format("full").Do() + thread, err = c.service.Users.Threads.Get(c.userID, msg.ThreadId).Format("full").Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get thread: %w", err) + return nil, fmt.Errorf("getting thread: %w", err) } } @@ -160,7 +161,7 @@ func parseMessage(msg *gmail.Message, includeBody bool, resolver LabelResolver) } // extractLabelsAndCategories separates label IDs into user labels and Gmail categories -func extractLabelsAndCategories(labelIds []string, resolver LabelResolver) ([]string, []string) { +func extractLabelsAndCategories(labelIDs []string, resolver LabelResolver) ([]string, []string) { var labels, categories []string // System labels to exclude from display @@ -170,7 +171,7 @@ func extractLabelsAndCategories(labelIds []string, resolver LabelResolver) ([]st "CHAT": true, "CATEGORY_PERSONAL": true, } - for _, labelID := range labelIds { + for _, labelID := range labelIDs { // Check if it's a category if strings.HasPrefix(labelID, "CATEGORY_") { // Convert CATEGORY_UPDATES -> updates diff --git a/internal/gmail/messages_test.go b/internal/gmail/messages_test.go index 8f6e952..f8c96a1 100644 --- a/internal/gmail/messages_test.go +++ b/internal/gmail/messages_test.go @@ -2,14 +2,16 @@ package gmail import ( "encoding/base64" + "sort" "testing" - "github.com/stretchr/testify/assert" "google.golang.org/api/gmail/v1" ) func TestParseMessage(t *testing.T) { + t.Parallel() t.Run("extracts headers correctly", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", ThreadId: "thread456", @@ -26,16 +28,31 @@ func TestParseMessage(t *testing.T) { result := parseMessage(msg, false, nil) - assert.Equal(t, "msg123", result.ID) - assert.Equal(t, "thread456", result.ThreadID) - assert.Equal(t, "Test Subject", result.Subject) - assert.Equal(t, "alice@example.com", result.From) - assert.Equal(t, "bob@example.com", result.To) - assert.Equal(t, "Mon, 1 Jan 2024 12:00:00 +0000", result.Date) - assert.Equal(t, "This is a test...", result.Snippet) + if result.ID != "msg123" { + t.Errorf("got %v, want %v", result.ID, "msg123") + } + if result.ThreadID != "thread456" { + t.Errorf("got %v, want %v", result.ThreadID, "thread456") + } + if result.Subject != "Test Subject" { + t.Errorf("got %v, want %v", result.Subject, "Test Subject") + } + if result.From != "alice@example.com" { + t.Errorf("got %v, want %v", result.From, "alice@example.com") + } + if result.To != "bob@example.com" { + t.Errorf("got %v, want %v", result.To, "bob@example.com") + } + if result.Date != "Mon, 1 Jan 2024 12:00:00 +0000" { + t.Errorf("got %v, want %v", result.Date, "Mon, 1 Jan 2024 12:00:00 +0000") + } + if result.Snippet != "This is a test..." { + t.Errorf("got %v, want %v", result.Snippet, "This is a test...") + } }) t.Run("extracts thread ID", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", ThreadId: "thread789", @@ -46,11 +63,16 @@ func TestParseMessage(t *testing.T) { result := parseMessage(msg, false, nil) - assert.Equal(t, "msg123", result.ID) - assert.Equal(t, "thread789", result.ThreadID) + if result.ID != "msg123" { + t.Errorf("got %v, want %v", result.ID, "msg123") + } + if result.ThreadID != "thread789" { + t.Errorf("got %v, want %v", result.ThreadID, "thread789") + } }) t.Run("handles nil payload", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", ThreadId: "thread456", @@ -61,15 +83,26 @@ func TestParseMessage(t *testing.T) { result := parseMessage(msg, true, nil) // Should not panic, basic fields populated - assert.Equal(t, "msg123", result.ID) - assert.Equal(t, "thread456", result.ThreadID) - assert.Equal(t, "Preview text", result.Snippet) + if result.ID != "msg123" { + t.Errorf("got %v, want %v", result.ID, "msg123") + } + if result.ThreadID != "thread456" { + t.Errorf("got %v, want %v", result.ThreadID, "thread456") + } + if result.Snippet != "Preview text" { + t.Errorf("got %v, want %v", result.Snippet, "Preview text") + } // Headers won't be extracted - assert.Empty(t, result.Subject) - assert.Empty(t, result.Body) + if result.Subject != "" { + t.Errorf("got %q, want empty", result.Subject) + } + if result.Body != "" { + t.Errorf("got %q, want empty", result.Body) + } }) t.Run("handles case-insensitive headers", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -83,12 +116,19 @@ func TestParseMessage(t *testing.T) { result := parseMessage(msg, false, nil) - assert.Equal(t, "Upper Case", result.Subject) - assert.Equal(t, "lower@example.com", result.From) - assert.Equal(t, "mixed@example.com", result.To) + if result.Subject != "Upper Case" { + t.Errorf("got %v, want %v", result.Subject, "Upper Case") + } + if result.From != "lower@example.com" { + t.Errorf("got %v, want %v", result.From, "lower@example.com") + } + if result.To != "mixed@example.com" { + t.Errorf("got %v, want %v", result.To, "mixed@example.com") + } }) t.Run("handles missing headers gracefully", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -98,16 +138,28 @@ func TestParseMessage(t *testing.T) { result := parseMessage(msg, false, nil) - assert.Equal(t, "msg123", result.ID) - assert.Empty(t, result.Subject) - assert.Empty(t, result.From) - assert.Empty(t, result.To) - assert.Empty(t, result.Date) + if result.ID != "msg123" { + t.Errorf("got %v, want %v", result.ID, "msg123") + } + if result.Subject != "" { + t.Errorf("got %q, want empty", result.Subject) + } + if result.From != "" { + t.Errorf("got %q, want empty", result.From) + } + if result.To != "" { + t.Errorf("got %q, want empty", result.To) + } + if result.Date != "" { + t.Errorf("got %q, want empty", result.Date) + } }) } func TestExtractBody(t *testing.T) { + t.Parallel() t.Run("extracts plain text body", func(t *testing.T) { + t.Parallel() bodyText := "Hello, this is the message body." encoded := base64.URLEncoding.EncodeToString([]byte(bodyText)) @@ -119,10 +171,13 @@ func TestExtractBody(t *testing.T) { } result := extractBody(payload) - assert.Equal(t, bodyText, result) + if result != bodyText { + t.Errorf("got %v, want %v", result, bodyText) + } }) t.Run("extracts plain text from multipart message", func(t *testing.T) { + t.Parallel() bodyText := "Plain text content" encoded := base64.URLEncoding.EncodeToString([]byte(bodyText)) @@ -145,10 +200,13 @@ func TestExtractBody(t *testing.T) { } result := extractBody(payload) - assert.Equal(t, bodyText, result) + if result != bodyText { + t.Errorf("got %v, want %v", result, bodyText) + } }) t.Run("falls back to HTML if no plain text", func(t *testing.T) { + t.Parallel() htmlContent := "

HTML only

" encoded := base64.URLEncoding.EncodeToString([]byte(htmlContent)) @@ -160,10 +218,13 @@ func TestExtractBody(t *testing.T) { } result := extractBody(payload) - assert.Equal(t, htmlContent, result) + if result != htmlContent { + t.Errorf("got %v, want %v", result, htmlContent) + } }) t.Run("handles nested multipart", func(t *testing.T) { + t.Parallel() bodyText := "Nested plain text" encoded := base64.URLEncoding.EncodeToString([]byte(bodyText)) @@ -185,29 +246,38 @@ func TestExtractBody(t *testing.T) { } result := extractBody(payload) - assert.Equal(t, bodyText, result) + if result != bodyText { + t.Errorf("got %v, want %v", result, bodyText) + } }) t.Run("returns empty string for empty body", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "text/plain", Body: &gmail.MessagePartBody{}, } result := extractBody(payload) - assert.Empty(t, result) + if result != "" { + t.Errorf("got %q, want empty", result) + } }) t.Run("returns empty string for nil body", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "text/plain", } result := extractBody(payload) - assert.Empty(t, result) + if result != "" { + t.Errorf("got %q, want empty", result) + } }) t.Run("handles invalid base64 gracefully", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "text/plain", Body: &gmail.MessagePartBody{ @@ -216,12 +286,16 @@ func TestExtractBody(t *testing.T) { } result := extractBody(payload) - assert.Empty(t, result) + if result != "" { + t.Errorf("got %q, want empty", result) + } }) } func TestMessageStruct(t *testing.T) { + t.Parallel() t.Run("message struct has all fields", func(t *testing.T) { + t.Parallel() msg := &Message{ ID: "test-id", ThreadID: "thread-id", @@ -233,19 +307,37 @@ func TestMessageStruct(t *testing.T) { Body: "Full body content", } - assert.Equal(t, "test-id", msg.ID) - assert.Equal(t, "thread-id", msg.ThreadID) - assert.Equal(t, "Test Subject", msg.Subject) - assert.Equal(t, "from@example.com", msg.From) - assert.Equal(t, "to@example.com", msg.To) - assert.Equal(t, "2024-01-01", msg.Date) - assert.Equal(t, "Preview...", msg.Snippet) - assert.Equal(t, "Full body content", msg.Body) + if msg.ID != "test-id" { + t.Errorf("got %v, want %v", msg.ID, "test-id") + } + if msg.ThreadID != "thread-id" { + t.Errorf("got %v, want %v", msg.ThreadID, "thread-id") + } + if msg.Subject != "Test Subject" { + t.Errorf("got %v, want %v", msg.Subject, "Test Subject") + } + if msg.From != "from@example.com" { + t.Errorf("got %v, want %v", msg.From, "from@example.com") + } + if msg.To != "to@example.com" { + t.Errorf("got %v, want %v", msg.To, "to@example.com") + } + if msg.Date != "2024-01-01" { + t.Errorf("got %v, want %v", msg.Date, "2024-01-01") + } + if msg.Snippet != "Preview..." { + t.Errorf("got %v, want %v", msg.Snippet, "Preview...") + } + if msg.Body != "Full body content" { + t.Errorf("got %v, want %v", msg.Body, "Full body content") + } }) } func TestParseMessageWithBody(t *testing.T) { + t.Parallel() t.Run("includes body when requested", func(t *testing.T) { + t.Parallel() bodyText := "This is the full body" encoded := base64.URLEncoding.EncodeToString([]byte(bodyText)) @@ -263,10 +355,13 @@ func TestParseMessageWithBody(t *testing.T) { } result := parseMessage(msg, true, nil) - assert.Equal(t, bodyText, result.Body) + if result.Body != bodyText { + t.Errorf("got %v, want %v", result.Body, bodyText) + } }) t.Run("excludes body when not requested", func(t *testing.T) { + t.Parallel() bodyText := "This should not appear" encoded := base64.URLEncoding.EncodeToString([]byte(bodyText)) @@ -284,12 +379,16 @@ func TestParseMessageWithBody(t *testing.T) { } result := parseMessage(msg, false, nil) - assert.Empty(t, result.Body) + if result.Body != "" { + t.Errorf("got %q, want empty", result.Body) + } }) } func TestExtractAttachments(t *testing.T) { + t.Parallel() t.Run("detects attachment by filename", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{ @@ -309,15 +408,28 @@ func TestExtractAttachments(t *testing.T) { } attachments := extractAttachments(payload, "") - assert.Len(t, attachments, 1) - assert.Equal(t, "report.pdf", attachments[0].Filename) - assert.Equal(t, "application/pdf", attachments[0].MimeType) - assert.Equal(t, int64(12345), attachments[0].Size) - assert.Equal(t, "att123", attachments[0].AttachmentID) - assert.Equal(t, "1", attachments[0].PartID) + if len(attachments) != 1 { + t.Errorf("got length %d, want %d", len(attachments), 1) + } + if attachments[0].Filename != "report.pdf" { + t.Errorf("got %v, want %v", attachments[0].Filename, "report.pdf") + } + if attachments[0].MimeType != "application/pdf" { + t.Errorf("got %v, want %v", attachments[0].MimeType, "application/pdf") + } + if attachments[0].Size != int64(12345) { + t.Errorf("got %v, want %v", attachments[0].Size, int64(12345)) + } + if attachments[0].AttachmentID != "att123" { + t.Errorf("got %v, want %v", attachments[0].AttachmentID, "att123") + } + if attachments[0].PartID != "1" { + t.Errorf("got %v, want %v", attachments[0].PartID, "1") + } }) t.Run("detects attachment by Content-Disposition header", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{ @@ -333,12 +445,19 @@ func TestExtractAttachments(t *testing.T) { } attachments := extractAttachments(payload, "") - assert.Len(t, attachments, 1) - assert.Equal(t, "data.csv", attachments[0].Filename) - assert.False(t, attachments[0].IsInline) + if len(attachments) != 1 { + t.Errorf("got length %d, want %d", len(attachments), 1) + } + if attachments[0].Filename != "data.csv" { + t.Errorf("got %v, want %v", attachments[0].Filename, "data.csv") + } + if attachments[0].IsInline { + t.Error("got true, want false") + } }) t.Run("detects inline attachment", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/related", Parts: []*gmail.MessagePart{ @@ -354,12 +473,19 @@ func TestExtractAttachments(t *testing.T) { } attachments := extractAttachments(payload, "") - assert.Len(t, attachments, 1) - assert.Equal(t, "image.png", attachments[0].Filename) - assert.True(t, attachments[0].IsInline) + if len(attachments) != 1 { + t.Errorf("got length %d, want %d", len(attachments), 1) + } + if attachments[0].Filename != "image.png" { + t.Errorf("got %v, want %v", attachments[0].Filename, "image.png") + } + if !attachments[0].IsInline { + t.Error("got false, want true") + } }) t.Run("handles nested multipart with multiple attachments", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{ @@ -384,24 +510,38 @@ func TestExtractAttachments(t *testing.T) { } attachments := extractAttachments(payload, "") - assert.Len(t, attachments, 2) - assert.Equal(t, "doc1.pdf", attachments[0].Filename) - assert.Equal(t, "1", attachments[0].PartID) - assert.Equal(t, "doc2.pdf", attachments[1].Filename) - assert.Equal(t, "2", attachments[1].PartID) + if len(attachments) != 2 { + t.Errorf("got length %d, want %d", len(attachments), 2) + } + if attachments[0].Filename != "doc1.pdf" { + t.Errorf("got %v, want %v", attachments[0].Filename, "doc1.pdf") + } + if attachments[0].PartID != "1" { + t.Errorf("got %v, want %v", attachments[0].PartID, "1") + } + if attachments[1].Filename != "doc2.pdf" { + t.Errorf("got %v, want %v", attachments[1].Filename, "doc2.pdf") + } + if attachments[1].PartID != "2" { + t.Errorf("got %v, want %v", attachments[1].PartID, "2") + } }) t.Run("handles message with no attachments", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "text/plain", Body: &gmail.MessagePartBody{Data: "simple message"}, } attachments := extractAttachments(payload, "") - assert.Empty(t, attachments) + if len(attachments) != 0 { + t.Errorf("got length %d, want 0", len(attachments)) + } }) t.Run("generates correct part paths for deeply nested", func(t *testing.T) { + t.Parallel() payload := &gmail.MessagePart{ MimeType: "multipart/mixed", Parts: []*gmail.MessagePart{ @@ -425,74 +565,105 @@ func TestExtractAttachments(t *testing.T) { } attachments := extractAttachments(payload, "") - assert.Len(t, attachments, 1) - assert.Equal(t, "nested.png", attachments[0].Filename) - assert.Equal(t, "0.1", attachments[0].PartID) + if len(attachments) != 1 { + t.Errorf("got length %d, want %d", len(attachments), 1) + } + if attachments[0].Filename != "nested.png" { + t.Errorf("got %v, want %v", attachments[0].Filename, "nested.png") + } + if attachments[0].PartID != "0.1" { + t.Errorf("got %v, want %v", attachments[0].PartID, "0.1") + } }) } func TestIsAttachment(t *testing.T) { + t.Parallel() t.Run("returns true for part with filename", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{Filename: "test.pdf"} - assert.True(t, isAttachment(part)) + if !isAttachment(part) { + t.Error("got false, want true") + } }) t.Run("returns true for Content-Disposition attachment", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{ Headers: []*gmail.MessagePartHeader{ {Name: "Content-Disposition", Value: "attachment; filename=\"test.pdf\""}, }, } - assert.True(t, isAttachment(part)) + if !isAttachment(part) { + t.Error("got false, want true") + } }) t.Run("returns false for plain text part", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{ MimeType: "text/plain", Body: &gmail.MessagePartBody{Data: "text"}, } - assert.False(t, isAttachment(part)) + if isAttachment(part) { + t.Error("got true, want false") + } }) t.Run("handles case-insensitive Content-Disposition", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{ Headers: []*gmail.MessagePartHeader{ {Name: "CONTENT-DISPOSITION", Value: "ATTACHMENT"}, }, } - assert.True(t, isAttachment(part)) + if !isAttachment(part) { + t.Error("got false, want true") + } }) } func TestIsInlineAttachment(t *testing.T) { + t.Parallel() t.Run("returns true for inline disposition", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{ Filename: "image.png", Headers: []*gmail.MessagePartHeader{ {Name: "Content-Disposition", Value: "inline; filename=\"image.png\""}, }, } - assert.True(t, isInlineAttachment(part)) + if !isInlineAttachment(part) { + t.Error("got false, want true") + } }) t.Run("returns false for attachment disposition", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{ Filename: "doc.pdf", Headers: []*gmail.MessagePartHeader{ {Name: "Content-Disposition", Value: "attachment; filename=\"doc.pdf\""}, }, } - assert.False(t, isInlineAttachment(part)) + if isInlineAttachment(part) { + t.Error("got true, want false") + } }) t.Run("returns false for no disposition header", func(t *testing.T) { + t.Parallel() part := &gmail.MessagePart{Filename: "file.txt"} - assert.False(t, isInlineAttachment(part)) + if isInlineAttachment(part) { + t.Error("got true, want false") + } }) } func TestParseMessageWithAttachments(t *testing.T) { + t.Parallel() t.Run("extracts attachments when body is requested", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -517,12 +688,19 @@ func TestParseMessageWithAttachments(t *testing.T) { } result := parseMessage(msg, true, nil) - assert.Equal(t, "body text", result.Body) - assert.Len(t, result.Attachments, 1) - assert.Equal(t, "attachment.pdf", result.Attachments[0].Filename) + if result.Body != "body text" { + t.Errorf("got %v, want %v", result.Body, "body text") + } + if len(result.Attachments) != 1 { + t.Errorf("got length %d, want %d", len(result.Attachments), 1) + } + if result.Attachments[0].Filename != "attachment.pdf" { + t.Errorf("got %v, want %v", result.Attachments[0].Filename, "attachment.pdf") + } }) t.Run("does not extract attachments when body not requested", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -538,43 +716,80 @@ func TestParseMessageWithAttachments(t *testing.T) { } result := parseMessage(msg, false, nil) - assert.Empty(t, result.Attachments) + if len(result.Attachments) != 0 { + t.Errorf("got length %d, want 0", len(result.Attachments)) + } }) } func TestExtractLabelsAndCategories(t *testing.T) { + t.Parallel() t.Run("separates user labels from categories", func(t *testing.T) { - labelIds := []string{"Label_1", "CATEGORY_UPDATES", "Label_2", "CATEGORY_SOCIAL"} + t.Parallel() + labelIDs := []string{"Label_1", "CATEGORY_UPDATES", "Label_2", "CATEGORY_SOCIAL"} resolver := func(id string) string { return id } - labels, categories := extractLabelsAndCategories(labelIds, resolver) + labels, categories := extractLabelsAndCategories(labelIDs, resolver) - assert.ElementsMatch(t, []string{"Label_1", "Label_2"}, labels) - assert.ElementsMatch(t, []string{"updates", "social"}, categories) + sort.Strings(labels) + sort.Strings(categories) + expectedLabels := []string{"Label_1", "Label_2"} + expectedCategories := []string{"social", "updates"} + sort.Strings(expectedLabels) + sort.Strings(expectedCategories) + + if len(labels) != len(expectedLabels) { + t.Fatalf("got labels length %d, want %d", len(labels), len(expectedLabels)) + } + for i := range labels { + if labels[i] != expectedLabels[i] { + t.Errorf("labels[%d]: got %v, want %v", i, labels[i], expectedLabels[i]) + } + } + + if len(categories) != len(expectedCategories) { + t.Fatalf("got categories length %d, want %d", len(categories), len(expectedCategories)) + } + for i := range categories { + if categories[i] != expectedCategories[i] { + t.Errorf("categories[%d]: got %v, want %v", i, categories[i], expectedCategories[i]) + } + } }) t.Run("filters out system labels", func(t *testing.T) { - labelIds := []string{"INBOX", "Label_1", "UNREAD", "STARRED", "IMPORTANT"} + t.Parallel() + labelIDs := []string{"INBOX", "Label_1", "UNREAD", "STARRED", "IMPORTANT"} resolver := func(id string) string { return id } - labels, categories := extractLabelsAndCategories(labelIds, resolver) + labels, categories := extractLabelsAndCategories(labelIDs, resolver) - assert.Equal(t, []string{"Label_1"}, labels) - assert.Empty(t, categories) + if len(labels) != 1 || labels[0] != "Label_1" { + t.Errorf("got labels %v, want %v", labels, []string{"Label_1"}) + } + if len(categories) != 0 { + t.Errorf("got categories length %d, want 0", len(categories)) + } }) t.Run("filters out CATEGORY_PERSONAL", func(t *testing.T) { - labelIds := []string{"CATEGORY_PERSONAL", "CATEGORY_UPDATES"} + t.Parallel() + labelIDs := []string{"CATEGORY_PERSONAL", "CATEGORY_UPDATES"} resolver := func(id string) string { return id } - labels, categories := extractLabelsAndCategories(labelIds, resolver) + labels, categories := extractLabelsAndCategories(labelIDs, resolver) - assert.Empty(t, labels) - assert.Equal(t, []string{"updates"}, categories) + if len(labels) != 0 { + t.Errorf("got labels length %d, want 0", len(labels)) + } + if len(categories) != 1 || categories[0] != "updates" { + t.Errorf("got categories %v, want %v", categories, []string{"updates"}) + } }) t.Run("uses resolver to translate label IDs", func(t *testing.T) { - labelIds := []string{"Label_123", "Label_456"} + t.Parallel() + labelIDs := []string{"Label_123", "Label_456"} resolver := func(id string) string { if id == "Label_123" { return "Work" @@ -585,38 +800,68 @@ func TestExtractLabelsAndCategories(t *testing.T) { return id } - labels, categories := extractLabelsAndCategories(labelIds, resolver) + labels, categories := extractLabelsAndCategories(labelIDs, resolver) - assert.ElementsMatch(t, []string{"Work", "Personal"}, labels) - assert.Empty(t, categories) + sort.Strings(labels) + expectedLabels := []string{"Personal", "Work"} + sort.Strings(expectedLabels) + + if len(labels) != len(expectedLabels) { + t.Fatalf("got labels length %d, want %d", len(labels), len(expectedLabels)) + } + for i := range labels { + if labels[i] != expectedLabels[i] { + t.Errorf("labels[%d]: got %v, want %v", i, labels[i], expectedLabels[i]) + } + } + if len(categories) != 0 { + t.Errorf("got categories length %d, want 0", len(categories)) + } }) t.Run("handles nil resolver", func(t *testing.T) { - labelIds := []string{"Label_1", "CATEGORY_SOCIAL"} + t.Parallel() + labelIDs := []string{"Label_1", "CATEGORY_SOCIAL"} - labels, categories := extractLabelsAndCategories(labelIds, nil) + labels, categories := extractLabelsAndCategories(labelIDs, nil) - assert.Equal(t, []string{"Label_1"}, labels) - assert.Equal(t, []string{"social"}, categories) + if len(labels) != 1 || labels[0] != "Label_1" { + t.Errorf("got labels %v, want %v", labels, []string{"Label_1"}) + } + if len(categories) != 1 || categories[0] != "social" { + t.Errorf("got categories %v, want %v", categories, []string{"social"}) + } }) t.Run("handles empty label IDs", func(t *testing.T) { + t.Parallel() labels, categories := extractLabelsAndCategories([]string{}, nil) - assert.Empty(t, labels) - assert.Empty(t, categories) + if len(labels) != 0 { + t.Errorf("got labels length %d, want 0", len(labels)) + } + if len(categories) != 0 { + t.Errorf("got categories length %d, want 0", len(categories)) + } }) t.Run("handles nil label IDs", func(t *testing.T) { + t.Parallel() labels, categories := extractLabelsAndCategories(nil, nil) - assert.Empty(t, labels) - assert.Empty(t, categories) + if len(labels) != 0 { + t.Errorf("got labels length %d, want 0", len(labels)) + } + if len(categories) != 0 { + t.Errorf("got categories length %d, want 0", len(categories)) + } }) } func TestParseMessageWithLabels(t *testing.T) { + t.Parallel() t.Run("extracts labels and categories from message", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -635,11 +880,16 @@ func TestParseMessageWithLabels(t *testing.T) { result := parseMessage(msg, false, resolver) - assert.Equal(t, []string{"Work"}, result.Labels) - assert.Equal(t, []string{"updates"}, result.Categories) + if len(result.Labels) != 1 || result.Labels[0] != "Work" { + t.Errorf("got labels %v, want %v", result.Labels, []string{"Work"}) + } + if len(result.Categories) != 1 || result.Categories[0] != "updates" { + t.Errorf("got categories %v, want %v", result.Categories, []string{"updates"}) + } }) t.Run("handles message with no labels", func(t *testing.T) { + t.Parallel() msg := &gmail.Message{ Id: "msg123", Payload: &gmail.MessagePart{ @@ -650,7 +900,11 @@ func TestParseMessageWithLabels(t *testing.T) { result := parseMessage(msg, false, nil) - assert.Empty(t, result.Labels) - assert.Empty(t, result.Categories) + if len(result.Labels) != 0 { + t.Errorf("got labels length %d, want 0", len(result.Labels)) + } + if len(result.Categories) != 0 { + t.Errorf("got categories length %d, want 0", len(result.Categories)) + } }) } diff --git a/internal/keychain/keychain.go b/internal/keychain/keychain.go index 0e4589e..ca92f5d 100644 --- a/internal/keychain/keychain.go +++ b/internal/keychain/keychain.go @@ -16,12 +16,13 @@ import ( const ( serviceName = config.DirName - tokenKey = "oauth_token" + tokenKey = "oauth_token" //nolint:gosec // Not a credential; key name for keychain lookup ) // StorageBackend represents where tokens are stored type StorageBackend string +// StorageBackend constants define where OAuth tokens are persisted. const ( BackendKeychain StorageBackend = "Keychain" // macOS Keychain BackendSecretTool StorageBackend = "secret-tool" // Linux libsecret @@ -83,20 +84,20 @@ func MigrateFromFile(tokenFilePath string) error { } // Read token from file - f, err := os.Open(tokenFilePath) + f, err := os.Open(tokenFilePath) //nolint:gosec // Path from user config directory if err != nil { - return fmt.Errorf("failed to open token file: %w", err) + return fmt.Errorf("opening token file: %w", err) } defer f.Close() var token oauth2.Token if err := json.NewDecoder(f).Decode(&token); err != nil { - return fmt.Errorf("failed to parse token file: %w", err) + return fmt.Errorf("parsing token file: %w", err) } // Store in secure storage if err := SetToken(&token); err != nil { - return fmt.Errorf("failed to store token in secure storage: %w", err) + return fmt.Errorf("storing token in secure storage: %w", err) } // Securely delete old token file (overwrite with zeros before removal) @@ -123,7 +124,7 @@ func secureDelete(path string) error { } // Overwrite with zeros - f, err := os.OpenFile(path, os.O_WRONLY, 0) + f, err := os.OpenFile(path, os.O_WRONLY, 0) //nolint:gosec // Path from user config directory if err != nil { // If we can't open for writing, try to delete anyway return os.Remove(path) @@ -145,18 +146,18 @@ func getFromConfigFile() (*oauth2.Token, error) { return nil, err } - f, err := os.Open(path) + f, err := os.Open(path) //nolint:gosec // Path from user config directory if err != nil { if os.IsNotExist(err) { return nil, ErrTokenNotFound } - return nil, fmt.Errorf("failed to open token file: %w", err) + return nil, fmt.Errorf("opening token file: %w", err) } defer f.Close() var token oauth2.Token if err := json.NewDecoder(f).Decode(&token); err != nil { - return nil, fmt.Errorf("failed to parse token file: %w", err) + return nil, fmt.Errorf("parsing token file: %w", err) } return &token, nil @@ -171,18 +172,18 @@ func setInConfigFile(token *oauth2.Token) error { // Ensure directory exists dir := filepath.Dir(path) if err := os.MkdirAll(dir, config.DirPerm); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) + return fmt.Errorf("creating config directory: %w", err) } // Write token with restricted permissions - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, config.TokenPerm) + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, config.TokenPerm) //nolint:gosec // Path from user config directory if err != nil { - return fmt.Errorf("failed to create token file: %w", err) + return fmt.Errorf("creating token file: %w", err) } defer f.Close() if err := json.NewEncoder(f).Encode(token); err != nil { - return fmt.Errorf("failed to write token: %w", err) + return fmt.Errorf("writing token: %w", err) } return nil @@ -195,7 +196,7 @@ func deleteFromConfigFile() error { } if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to delete token file: %w", err) + return fmt.Errorf("deleting token file: %w", err) } return nil diff --git a/internal/keychain/keychain_darwin.go b/internal/keychain/keychain_darwin.go index 4e5a035..33497a0 100644 --- a/internal/keychain/keychain_darwin.go +++ b/internal/keychain/keychain_darwin.go @@ -75,12 +75,12 @@ func getFromKeychain() (*oauth2.Token, error) { output, err := cmd.Output() if err != nil { - return nil, fmt.Errorf("failed to read from keychain: %w", err) + return nil, fmt.Errorf("reading from keychain: %w", err) } var token oauth2.Token if err := json.Unmarshal([]byte(strings.TrimSpace(string(output))), &token); err != nil { - return nil, fmt.Errorf("failed to parse token from keychain: %w", err) + return nil, fmt.Errorf("parsing token from keychain: %w", err) } return &token, nil @@ -89,7 +89,7 @@ func getFromKeychain() (*oauth2.Token, error) { func setInKeychain(token *oauth2.Token) error { data, err := json.Marshal(token) if err != nil { - return fmt.Errorf("failed to serialize token: %w", err) + return fmt.Errorf("serializing token: %w", err) } // Delete existing entry (ignore error if not exists) @@ -106,7 +106,7 @@ func setInKeychain(token *oauth2.Token) error { cmd.Stdin = strings.NewReader(stdinCmd) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to store in keychain: %w", err) + return fmt.Errorf("storing in keychain: %w", err) } return nil @@ -118,7 +118,7 @@ func deleteFromKeychain() error { "-a", tokenKey) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to delete from keychain: %w", err) + return fmt.Errorf("deleting from keychain: %w", err) } return nil diff --git a/internal/keychain/keychain_linux.go b/internal/keychain/keychain_linux.go index 31d3755..109d8d8 100644 --- a/internal/keychain/keychain_linux.go +++ b/internal/keychain/keychain_linux.go @@ -92,12 +92,12 @@ func getFromSecretTool() (*oauth2.Token, error) { output, err := cmd.Output() if err != nil { - return nil, fmt.Errorf("failed to read from secret-tool: %w", err) + return nil, fmt.Errorf("reading from secret-tool: %w", err) } var token oauth2.Token if err := json.Unmarshal([]byte(strings.TrimSpace(string(output))), &token); err != nil { - return nil, fmt.Errorf("failed to parse token from secret-tool: %w", err) + return nil, fmt.Errorf("parsing token from secret-tool: %w", err) } return &token, nil @@ -106,7 +106,7 @@ func getFromSecretTool() (*oauth2.Token, error) { func setInSecretTool(token *oauth2.Token) error { data, err := json.Marshal(token) if err != nil { - return fmt.Errorf("failed to serialize token: %w", err) + return fmt.Errorf("serializing token: %w", err) } // Delete existing entry (ignore error if not exists) @@ -119,7 +119,7 @@ func setInSecretTool(token *oauth2.Token) error { cmd.Stdin = strings.NewReader(string(data)) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to store in secret-tool: %w", err) + return fmt.Errorf("storing in secret-tool: %w", err) } return nil @@ -131,7 +131,7 @@ func deleteFromSecretTool() error { "account", tokenKey) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to delete from secret-tool: %w", err) + return fmt.Errorf("deleting from secret-tool: %w", err) } return nil diff --git a/internal/keychain/keychain_test.go b/internal/keychain/keychain_test.go index 6adcb30..a964bcc 100644 --- a/internal/keychain/keychain_test.go +++ b/internal/keychain/keychain_test.go @@ -2,13 +2,14 @@ package keychain import ( "encoding/json" + "errors" + "fmt" "os" "path/filepath" + "strings" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/oauth2" "github.com/open-cli-collective/google-readonly/internal/config" @@ -33,17 +34,29 @@ func TestConfigFile_TokenRoundTrip(t *testing.T) { // Store token err := setInConfigFile(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Retrieve token retrieved, err := getFromConfigFile() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - assert.Equal(t, token.AccessToken, retrieved.AccessToken) - assert.Equal(t, token.RefreshToken, retrieved.RefreshToken) - assert.Equal(t, token.TokenType, retrieved.TokenType) + if retrieved.AccessToken != token.AccessToken { + t.Errorf("got %v, want %v", retrieved.AccessToken, token.AccessToken) + } + if retrieved.RefreshToken != token.RefreshToken { + t.Errorf("got %v, want %v", retrieved.RefreshToken, token.RefreshToken) + } + if retrieved.TokenType != token.TokenType { + t.Errorf("got %v, want %v", retrieved.TokenType, token.TokenType) + } // Compare times with tolerance for JSON marshaling - assert.WithinDuration(t, token.Expiry, retrieved.Expiry, time.Second) + if diff := token.Expiry.Sub(retrieved.Expiry); diff < -time.Second || diff > time.Second { + t.Errorf("times differ by %v, max allowed %v", diff, time.Second) + } } func TestConfigFile_Permissions(t *testing.T) { @@ -61,15 +74,21 @@ func TestConfigFile_Permissions(t *testing.T) { } err := setInConfigFile(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Check file permissions path := filepath.Join(tmpDir, serviceName, config.TokenFile) info, err := os.Stat(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify 0600 permissions (read/write for owner only) - assert.Equal(t, os.FileMode(0600), info.Mode().Perm()) + if info.Mode().Perm() != os.FileMode(0600) { + t.Errorf("got %v, want %v", info.Mode().Perm(), os.FileMode(0600)) + } } func TestConfigFile_DirectoryPermissions(t *testing.T) { @@ -87,15 +106,21 @@ func TestConfigFile_DirectoryPermissions(t *testing.T) { } err := setInConfigFile(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Check directory permissions dir := filepath.Join(tmpDir, serviceName) info, err := os.Stat(dir) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify 0700 permissions (read/write/execute for owner only) - assert.Equal(t, os.FileMode(0700), info.Mode().Perm()) + if info.Mode().Perm() != os.FileMode(0700) { + t.Errorf("got %v, want %v", info.Mode().Perm(), os.FileMode(0700)) + } } func TestConfigFile_NotFound(t *testing.T) { @@ -108,7 +133,9 @@ func TestConfigFile_NotFound(t *testing.T) { defer os.Setenv("XDG_CONFIG_HOME", originalXDG) _, err := getFromConfigFile() - assert.ErrorIs(t, err, ErrTokenNotFound) + if !errors.Is(err, ErrTokenNotFound) { + t.Errorf("got %v, want %v", err, ErrTokenNotFound) + } } func TestConfigFile_InvalidJSON(t *testing.T) { @@ -123,16 +150,24 @@ func TestConfigFile_InvalidJSON(t *testing.T) { // Create config directory configDir := filepath.Join(tmpDir, serviceName) err := os.MkdirAll(configDir, 0700) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Write invalid JSON path := filepath.Join(configDir, config.TokenFile) err = os.WriteFile(path, []byte("invalid json"), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } _, err = getFromConfigFile() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse token file") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "parsing token file") { + t.Errorf("expected %q to contain %q", err.Error(), "parsing token file") + } } func TestConfigFile_Overwrite(t *testing.T) { @@ -150,7 +185,9 @@ func TestConfigFile_Overwrite(t *testing.T) { TokenType: "Bearer", } err := setInConfigFile(token1) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Store second token token2 := &oauth2.Token{ @@ -158,12 +195,18 @@ func TestConfigFile_Overwrite(t *testing.T) { TokenType: "Bearer", } err = setInConfigFile(token2) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Retrieve should return second token retrieved, err := getFromConfigFile() - require.NoError(t, err) - assert.Equal(t, "second-token", retrieved.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if retrieved.AccessToken != "second-token" { + t.Errorf("got %v, want %v", retrieved.AccessToken, "second-token") + } } func TestConfigFile_Delete(t *testing.T) { @@ -181,15 +224,21 @@ func TestConfigFile_Delete(t *testing.T) { TokenType: "Bearer", } err := setInConfigFile(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Delete token err = deleteFromConfigFile() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Should be gone _, err = getFromConfigFile() - assert.ErrorIs(t, err, ErrTokenNotFound) + if !errors.Is(err, ErrTokenNotFound) { + t.Errorf("got %v, want %v", err, ErrTokenNotFound) + } } func TestConfigFile_DeleteNonExistent(t *testing.T) { @@ -203,13 +252,17 @@ func TestConfigFile_DeleteNonExistent(t *testing.T) { // Delete should not error on non-existent file err := deleteFromConfigFile() - assert.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } } func TestMigrateFromFile_NoFile(t *testing.T) { // Migration should succeed (no-op) when file doesn't exist err := MigrateFromFile("/nonexistent/path/token.json") - assert.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } } func TestMigrateFromFile_InvalidJSON(t *testing.T) { @@ -232,7 +285,9 @@ func TestMigrateFromFile_InvalidJSON(t *testing.T) { // Create temp file with invalid JSON tokenPath := filepath.Join(tmpDir, "token.json") err := os.WriteFile(tokenPath, []byte("invalid json"), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // If secure storage has a token (e.g., from real keychain), migration is skipped // In that case, we test the direct file parsing instead @@ -241,8 +296,12 @@ func TestMigrateFromFile_InvalidJSON(t *testing.T) { } err = MigrateFromFile(tokenPath) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse token file") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "parsing token file") { + t.Errorf("expected %q to contain %q", err.Error(), "parsing token file") + } } func TestMigrateFromFile_Success(t *testing.T) { @@ -273,30 +332,44 @@ func TestMigrateFromFile_Success(t *testing.T) { } tokenPath := filepath.Join(tmpDir, "token.json") data, err := json.Marshal(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } err = os.WriteFile(tokenPath, data, 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Run migration err = MigrateFromFile(tokenPath) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify token was stored (uses GetToken to check all backends) retrieved, err := GetToken() - require.NoError(t, err) - assert.Equal(t, "migrated-token", retrieved.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if retrieved.AccessToken != "migrated-token" { + t.Errorf("got %v, want %v", retrieved.AccessToken, "migrated-token") + } // Clean up: delete the token we just stored defer DeleteToken() // Verify original file was securely deleted (not renamed to backup) _, err = os.Stat(tokenPath) - assert.True(t, os.IsNotExist(err), "original token file should be deleted") + if !os.IsNotExist(err) { + t.Error("original token file should be deleted") + } // Verify no backup file was created (secure delete, not rename) backupPath := tokenPath + ".backup" _, err = os.Stat(backupPath) - assert.True(t, os.IsNotExist(err), "backup file should not exist (secure delete)") + if !os.IsNotExist(err) { + t.Error("backup file should not exist (secure delete)") + } } func TestHasStoredToken_ConfigFile(t *testing.T) { @@ -314,7 +387,9 @@ func TestHasStoredToken_ConfigFile(t *testing.T) { // Should return error when no token file _, err := getFromConfigFile() - assert.ErrorIs(t, err, ErrTokenNotFound) + if !errors.Is(err, ErrTokenNotFound) { + t.Errorf("got %v, want %v", err, ErrTokenNotFound) + } // Store a token in config file token := &oauth2.Token{ @@ -322,24 +397,40 @@ func TestHasStoredToken_ConfigFile(t *testing.T) { TokenType: "Bearer", } err = setInConfigFile(token) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Should successfully retrieve from config file retrieved, err := getFromConfigFile() - require.NoError(t, err) - assert.Equal(t, "test-token", retrieved.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if retrieved.AccessToken != "test-token" { + t.Errorf("got %v, want %v", retrieved.AccessToken, "test-token") + } } func TestGetStorageBackend(t *testing.T) { // Just verify it returns a valid backend backend := GetStorageBackend() - assert.Contains(t, []StorageBackend{BackendKeychain, BackendSecretTool, BackendFile}, backend) + validBackends := []StorageBackend{BackendKeychain, BackendSecretTool, BackendFile} + found := false + for _, v := range validBackends { + if v == backend { + found = true + break + } + } + if !found { + t.Errorf("got %v, want one of %v", backend, validBackends) + } } -func TestIsSecureStorage(t *testing.T) { +func TestIsSecureStorage(_ *testing.T) { // This will vary by platform - just verify it returns a bool - result := IsSecureStorage() - assert.IsType(t, true, result) + // Go enforces the type at compile time, so no runtime check needed + _ = IsSecureStorage() } func TestTokenFilePath(t *testing.T) { @@ -348,17 +439,25 @@ func TestTokenFilePath(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", tmpDir) path, err := tokenFilePath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } configPath, err := config.GetTokenPath() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - assert.Equal(t, configPath, path) + if path != configPath { + t.Errorf("got %v, want %v", path, configPath) + } } func TestServiceNameConstant(t *testing.T) { // Verify serviceName matches config.DirName - assert.Equal(t, config.DirName, serviceName) + if serviceName != config.DirName { + t.Errorf("got %v, want %v", serviceName, config.DirName) + } } func TestSecureDelete(t *testing.T) { @@ -369,25 +468,35 @@ func TestSecureDelete(t *testing.T) { // Create file with sensitive data sensitiveData := []byte("super secret token data") err := os.WriteFile(path, sensitiveData, 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify file exists _, err = os.Stat(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Secure delete err = secureDelete(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Verify file is gone _, err = os.Stat(path) - assert.True(t, os.IsNotExist(err)) + if !os.IsNotExist(err) { + t.Error("got false, want true") + } }) t.Run("handles non-existent file", func(t *testing.T) { // Should not error on non-existent file err := secureDelete("/nonexistent/path/file.txt") - assert.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } }) t.Run("overwrites file content before deletion", func(t *testing.T) { @@ -397,50 +506,70 @@ func TestSecureDelete(t *testing.T) { // Create file with known content sensitiveData := []byte("secret123456") err := os.WriteFile(path, sensitiveData, 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Get file size before deletion info, err := os.Stat(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } originalSize := info.Size() // Create a copy to verify overwrite behavior // We'll use a custom path that we keep open to observe the overwrite copyPath := filepath.Join(tmpDir, "observe.txt") err = os.WriteFile(copyPath, sensitiveData, 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Open the file to observe content after overwrite but before unlink // This simulates what forensic tools would see f, err := os.OpenFile(copyPath, os.O_RDWR, 0) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Overwrite with zeros (simulating what secureDelete does) zeros := make([]byte, originalSize) _, err = f.Write(zeros) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } _ = f.Sync() // Read back - should be all zeros _, err = f.Seek(0, 0) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } content := make([]byte, originalSize) _, err = f.Read(content) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } f.Close() // Verify content is all zeros for i, b := range content { - assert.Equal(t, byte(0), b, "byte %d should be zero", i) + if b != byte(0) { + t.Errorf("byte %d: got %v, want %v", i, b, byte(0)) + } } // Now test actual secureDelete err = secureDelete(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // File should be gone _, err = os.Stat(path) - assert.True(t, os.IsNotExist(err)) + if !os.IsNotExist(err) { + t.Error("got false, want true") + } }) } @@ -476,12 +605,20 @@ func TestPersistentTokenSource_NoChange(t *testing.T) { // Call Token() token, err := pts.Token() - require.NoError(t, err) - assert.Equal(t, "initial-token", token.AccessToken) - assert.Equal(t, 1, mock.calls) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "initial-token" { + t.Errorf("got %v, want %v", token.AccessToken, "initial-token") + } + if mock.calls != 1 { + t.Errorf("got %v, want %v", mock.calls, 1) + } // current should remain the same (same pointer) - assert.Same(t, initialToken, pts.current) + if pts.current != initialToken { + t.Errorf("expected same pointer, got different") + } } func TestPersistentTokenSource_RefreshUpdatesCurrent(t *testing.T) { @@ -512,13 +649,23 @@ func TestPersistentTokenSource_RefreshUpdatesCurrent(t *testing.T) { // Call Token() - should detect change and update current token, err := pts.Token() - require.NoError(t, err) - assert.Equal(t, "refreshed-token", token.AccessToken) - assert.Equal(t, 1, mock.calls) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "refreshed-token" { + t.Errorf("got %v, want %v", token.AccessToken, "refreshed-token") + } + if mock.calls != 1 { + t.Errorf("got %v, want %v", mock.calls, 1) + } // Verify current was updated to the refreshed token - assert.Equal(t, "refreshed-token", pts.current.AccessToken) - assert.Equal(t, "new-refresh-token", pts.current.RefreshToken) + if pts.current.AccessToken != "refreshed-token" { + t.Errorf("got %v, want %v", pts.current.AccessToken, "refreshed-token") + } + if pts.current.RefreshToken != "new-refresh-token" { + t.Errorf("got %v, want %v", pts.current.RefreshToken, "new-refresh-token") + } } func TestPersistentTokenSource_NilCurrentUpdatesCurrent(t *testing.T) { @@ -541,19 +688,27 @@ func TestPersistentTokenSource_NilCurrentUpdatesCurrent(t *testing.T) { // Call Token() - should detect as change (nil -> token) and update current token, err := pts.Token() - require.NoError(t, err) - assert.Equal(t, "new-token", token.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "new-token" { + t.Errorf("got %v, want %v", token.AccessToken, "new-token") + } // Verify current was set - require.NotNil(t, pts.current) - assert.Equal(t, "new-token", pts.current.AccessToken) + if pts.current == nil { + t.Fatal("expected non-nil, got nil") + } + if pts.current.AccessToken != "new-token" { + t.Errorf("got %v, want %v", pts.current.AccessToken, "new-token") + } } func TestPersistentTokenSource_BaseError(t *testing.T) { // Mock returns an error mock := &mockTokenSource{ token: nil, - err: assert.AnError, + err: fmt.Errorf("mock error"), } initialToken := &oauth2.Token{ @@ -569,12 +724,20 @@ func TestPersistentTokenSource_BaseError(t *testing.T) { // Call Token() - should propagate error token, err := pts.Token() - assert.Error(t, err) - assert.Nil(t, token) - assert.Equal(t, 1, mock.calls) + if err == nil { + t.Fatal("expected error, got nil") + } + if token != nil { + t.Errorf("got %v, want nil", token) + } + if mock.calls != 1 { + t.Errorf("got %v, want %v", mock.calls, 1) + } // current should remain unchanged on error - assert.Equal(t, "initial-token", pts.current.AccessToken) + if pts.current.AccessToken != "initial-token" { + t.Errorf("got %v, want %v", pts.current.AccessToken, "initial-token") + } } func TestPersistentTokenSource_MultipleCalls_NoChange(t *testing.T) { @@ -598,15 +761,23 @@ func TestPersistentTokenSource_MultipleCalls_NoChange(t *testing.T) { // Multiple calls should all succeed for i := 0; i < 3; i++ { token, err := pts.Token() - require.NoError(t, err) - assert.Equal(t, "stable-token", token.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "stable-token" { + t.Errorf("got %v, want %v", token.AccessToken, "stable-token") + } } // Verify mock was called 3 times - assert.Equal(t, 3, mock.calls) + if mock.calls != 3 { + t.Errorf("got %v, want %v", mock.calls, 3) + } // current should still be the same - assert.Same(t, stableToken, pts.current) + if pts.current != stableToken { + t.Errorf("expected same pointer, got different") + } } func TestPersistentTokenSource_ChangeDetection(t *testing.T) { @@ -627,24 +798,38 @@ func TestPersistentTokenSource_ChangeDetection(t *testing.T) { // First call: nil -> token1 (change detected) _, err := pts.Token() - require.NoError(t, err) - assert.Equal(t, "token-1", pts.current.AccessToken) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pts.current.AccessToken != "token-1" { + t.Errorf("got %v, want %v", pts.current.AccessToken, "token-1") + } originalCurrent := pts.current // Second call: token1 -> token2 (change detected) mock.token = token2 _, err = pts.Token() - require.NoError(t, err) - assert.Equal(t, "token-2", pts.current.AccessToken) - assert.NotSame(t, originalCurrent, pts.current) // current was updated + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pts.current.AccessToken != "token-2" { + t.Errorf("got %v, want %v", pts.current.AccessToken, "token-2") + } + if pts.current == originalCurrent { + t.Errorf("expected different pointers, got same") + } // Third call: token2 -> token3 (same AccessToken, no change) secondCurrent := pts.current mock.token = token3 _, err = pts.Token() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // current should not have changed since AccessToken is the same - assert.Same(t, secondCurrent, pts.current) + if pts.current != secondCurrent { + t.Errorf("expected same pointer, got different") + } } func TestPersistentTokenSource_ReturnsCorrectToken(t *testing.T) { @@ -660,8 +845,12 @@ func TestPersistentTokenSource_ReturnsCorrectToken(t *testing.T) { } token, err := pts.Token() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Should return the token from base, not current - assert.Equal(t, "from-base", token.AccessToken) + if token.AccessToken != "from-base" { + t.Errorf("got %v, want %v", token.AccessToken, "from-base") + } } diff --git a/internal/keychain/token_source.go b/internal/keychain/token_source.go index 586edee..f9e2985 100644 --- a/internal/keychain/token_source.go +++ b/internal/keychain/token_source.go @@ -21,9 +21,9 @@ type PersistentTokenSource struct { // NewPersistentTokenSource creates a TokenSource that persists refreshed tokens. // When the underlying oauth2 package refreshes an expired token, this wrapper // detects the change and saves the new token to secure storage. -func NewPersistentTokenSource(config *oauth2.Config, initial *oauth2.Token) oauth2.TokenSource { +func NewPersistentTokenSource(ctx context.Context, config *oauth2.Config, initial *oauth2.Token) oauth2.TokenSource { // Create base token source that handles refresh - base := config.TokenSource(context.Background(), initial) + base := config.TokenSource(ctx, initial) return &PersistentTokenSource{ base: base, diff --git a/internal/log/log_test.go b/internal/log/log_test.go index 8804071..bffd607 100644 --- a/internal/log/log_test.go +++ b/internal/log/log_test.go @@ -5,8 +5,6 @@ import ( "os" "strings" "testing" - - "github.com/stretchr/testify/assert" ) func TestDebug_WhenVerboseTrue(t *testing.T) { @@ -29,8 +27,12 @@ func TestDebug_WhenVerboseTrue(t *testing.T) { buf.ReadFrom(r) output := buf.String() - assert.Contains(t, output, "[DEBUG]") - assert.Contains(t, output, "test message 42") + if !strings.Contains(output, "[DEBUG]") { + t.Errorf("expected %q to contain %q", output, "[DEBUG]") + } + if !strings.Contains(output, "test message 42") { + t.Errorf("expected %q to contain %q", output, "test message 42") + } } func TestDebug_WhenVerboseFalse(t *testing.T) { @@ -53,7 +55,9 @@ func TestDebug_WhenVerboseFalse(t *testing.T) { buf.ReadFrom(r) output := buf.String() - assert.Empty(t, output) + if output != "" { + t.Errorf("got %q, want empty string", output) + } } func TestInfo(t *testing.T) { @@ -70,8 +74,12 @@ func TestInfo(t *testing.T) { buf.ReadFrom(r) output := buf.String() - assert.Equal(t, "info message test\n", output) - assert.False(t, strings.Contains(output, "[INFO]")) // No prefix for info + if output != "info message test\n" { + t.Errorf("got %v, want %v", output, "info message test\n") + } + if strings.Contains(output, "[INFO]") { + t.Error("got true, want false") + } // No prefix for info } func TestWarn(t *testing.T) { @@ -88,8 +96,12 @@ func TestWarn(t *testing.T) { buf.ReadFrom(r) output := buf.String() - assert.Contains(t, output, "[WARN]") - assert.Contains(t, output, "warning: something") + if !strings.Contains(output, "[WARN]") { + t.Errorf("expected %q to contain %q", output, "[WARN]") + } + if !strings.Contains(output, "warning: something") { + t.Errorf("expected %q to contain %q", output, "warning: something") + } } func TestError(t *testing.T) { @@ -106,6 +118,10 @@ func TestError(t *testing.T) { buf.ReadFrom(r) output := buf.String() - assert.Contains(t, output, "[ERROR]") - assert.Contains(t, output, "error occurred: failure") + if !strings.Contains(output, "[ERROR]") { + t.Errorf("expected %q to contain %q", output, "[ERROR]") + } + if !strings.Contains(output, "error occurred: failure") { + t.Errorf("expected %q to contain %q", output, "error occurred: failure") + } } diff --git a/internal/output/output_test.go b/internal/output/output_test.go index eb662a8..766afe8 100644 --- a/internal/output/output_test.go +++ b/internal/output/output_test.go @@ -5,11 +5,11 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func TestJSON(t *testing.T) { + t.Parallel() tests := []struct { name string data any @@ -39,15 +39,17 @@ func TestJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var buf bytes.Buffer err := JSON(&buf, tt.data) - require.NoError(t, err) - assert.Equal(t, tt.expected, buf.String()) + testutil.NoError(t, err) + testutil.Equal(t, buf.String(), tt.expected) }) } } func TestJSON_indentation(t *testing.T) { + t.Parallel() data := struct { Nested struct { Value string @@ -58,19 +60,20 @@ func TestJSON_indentation(t *testing.T) { var buf bytes.Buffer err := JSON(&buf, data) - require.NoError(t, err) + testutil.NoError(t, err) // Check that indentation uses 2 spaces lines := strings.Split(buf.String(), "\n") - assert.True(t, strings.HasPrefix(lines[1], " "), "expected 2-space indentation") - assert.True(t, strings.HasPrefix(lines[2], " "), "expected 4-space indentation for nested") + testutil.True(t, strings.HasPrefix(lines[1], " ")) + testutil.True(t, strings.HasPrefix(lines[2], " ")) } func TestJSON_error(t *testing.T) { + t.Parallel() // Channels cannot be encoded to JSON data := make(chan int) var buf bytes.Buffer err := JSON(&buf, data) - assert.Error(t, err) + testutil.Error(t, err) } diff --git a/internal/testutil/assert.go b/internal/testutil/assert.go new file mode 100644 index 0000000..24c3c44 --- /dev/null +++ b/internal/testutil/assert.go @@ -0,0 +1,167 @@ +// Package testutil provides test assertion helpers and sample data fixtures. +package testutil + +import ( + "errors" + "reflect" + "strings" + "testing" +) + +// Equal checks that got equals want using comparable constraint. +func Equal[T comparable](t testing.TB, got, want T) { + t.Helper() + if got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +// NoError fails the test immediately if err is not nil. +func NoError(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// Error checks that err is not nil. +func Error(t testing.TB, err error) { + t.Helper() + if err == nil { + t.Fatal("expected error, got nil") + } +} + +// ErrorIs checks that err matches target using errors.Is. +func ErrorIs(t testing.TB, err, target error) { + t.Helper() + if !errors.Is(err, target) { + t.Errorf("got error %v, want error matching %v", err, target) + } +} + +// Contains checks that s contains substr. +func Contains(t testing.TB, s, substr string) { + t.Helper() + if !strings.Contains(s, substr) { + t.Errorf("expected %q to contain %q", s, substr) + } +} + +// NotContains checks that s does not contain substr. +func NotContains(t testing.TB, s, substr string) { + t.Helper() + if strings.Contains(s, substr) { + t.Errorf("expected %q to not contain %q", s, substr) + } +} + +// Len checks that the slice has the expected length. +func Len[T any](t testing.TB, slice []T, want int) { + t.Helper() + if len(slice) != want { + t.Errorf("got length %d, want %d", len(slice), want) + } +} + +// Nil checks that val is nil. +// Uses reflection to handle nil slices, maps, pointers, channels, and functions +// that appear non-nil when boxed into an interface. +func Nil(t testing.TB, val any) { + t.Helper() + if val == nil { + return + } + v := reflect.ValueOf(val) + switch v.Kind() { //nolint:exhaustive // only nillable kinds are relevant + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface: + if v.IsNil() { + return + } + } + t.Errorf("got %v, want nil", val) +} + +// NotNil fails the test immediately if val is nil. +func NotNil(t testing.TB, val any) { + t.Helper() + if val == nil { + t.Fatal("got nil, want non-nil") + } +} + +// True checks that condition is true. +func True(t testing.TB, condition bool) { + t.Helper() + if !condition { + t.Error("got false, want true") + } +} + +// False checks that condition is false. +func False(t testing.TB, condition bool) { + t.Helper() + if condition { + t.Error("got true, want false") + } +} + +// Empty checks that s is the empty string. +func Empty(t testing.TB, s string) { + t.Helper() + if s != "" { + t.Errorf("got %q, want empty string", s) + } +} + +// NotEmpty checks that s is not the empty string. +func NotEmpty(t testing.TB, s string) { + t.Helper() + if s == "" { + t.Error("got empty string, want non-empty") + } +} + +// Greater checks that a > b. +func Greater(t testing.TB, a, b int) { + t.Helper() + if a <= b { + t.Errorf("got %d, want greater than %d", a, b) + } +} + +// GreaterOrEqual checks that a >= b. +func GreaterOrEqual(t testing.TB, a, b int) { + t.Helper() + if a < b { + t.Errorf("got %d, want >= %d", a, b) + } +} + +// Less checks that a < b. +func Less(t testing.TB, a, b int) { + t.Helper() + if a >= b { + t.Errorf("got %d, want less than %d", a, b) + } +} + +// SliceContains checks that the slice contains the target value. +func SliceContains[T comparable](t testing.TB, slice []T, target T) { + t.Helper() + for _, v := range slice { + if v == target { + return + } + } + t.Errorf("slice %v does not contain %v", slice, target) +} + +// LenSlice checks that an arbitrary slice has the expected length. +// Use this when Len's type parameter cannot be inferred. +func LenSlice(t testing.TB, length, want int) { + t.Helper() + if length != want { + t.Errorf("got length %d, want %d", length, want) + } +} diff --git a/internal/testutil/assert_test.go b/internal/testutil/assert_test.go new file mode 100644 index 0000000..ffc6ea9 --- /dev/null +++ b/internal/testutil/assert_test.go @@ -0,0 +1,401 @@ +package testutil + +import ( + "errors" + "testing" +) + +// mockT captures test failures without stopping the outer test. +type mockT struct { + testing.TB + failed bool +} + +func (m *mockT) Helper() {} +func (m *mockT) Errorf(_ string, _ ...any) { m.failed = true } +func (m *mockT) Error(_ ...any) { m.failed = true } +func (m *mockT) Fatalf(_ string, _ ...any) { m.failed = true } +func (m *mockT) Fatal(_ ...any) { m.failed = true } + +func TestEqual(t *testing.T) { + t.Parallel() + t.Run("passes on equal values", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Equal(mt, 42, 42) + if mt.failed { + t.Error("Equal should not fail for equal values") + } + }) + + t.Run("fails on unequal values", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Equal(mt, 1, 2) + if !mt.failed { + t.Error("Equal should fail for unequal values") + } + }) + + t.Run("works with strings", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Equal(mt, "hello", "hello") + if mt.failed { + t.Error("Equal should not fail for equal strings") + } + }) +} + +func TestNoError(t *testing.T) { + t.Parallel() + t.Run("passes on nil error", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NoError(mt, nil) + if mt.failed { + t.Error("NoError should not fail for nil error") + } + }) + + t.Run("fails on non-nil error", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NoError(mt, errors.New("boom")) + if !mt.failed { + t.Error("NoError should fail for non-nil error") + } + }) +} + +func TestError(t *testing.T) { + t.Parallel() + t.Run("passes on non-nil error", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Error(mt, errors.New("boom")) + if mt.failed { + t.Error("Error should not fail for non-nil error") + } + }) + + t.Run("fails on nil error", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Error(mt, nil) + if !mt.failed { + t.Error("Error should fail for nil error") + } + }) +} + +func TestErrorIs(t *testing.T) { + t.Parallel() + sentinel := errors.New("sentinel") + + t.Run("passes when errors match", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + ErrorIs(mt, sentinel, sentinel) + if mt.failed { + t.Error("ErrorIs should not fail for matching errors") + } + }) + + t.Run("fails when errors don't match", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + ErrorIs(mt, errors.New("other"), sentinel) + if !mt.failed { + t.Error("ErrorIs should fail for non-matching errors") + } + }) +} + +func TestContains(t *testing.T) { + t.Parallel() + t.Run("passes when string contains substr", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Contains(mt, "hello world", "world") + if mt.failed { + t.Error("Contains should not fail when substr is present") + } + }) + + t.Run("fails when string doesn't contain substr", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Contains(mt, "hello world", "xyz") + if !mt.failed { + t.Error("Contains should fail when substr is absent") + } + }) +} + +func TestNotContains(t *testing.T) { + t.Parallel() + t.Run("passes when string doesn't contain substr", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotContains(mt, "hello world", "xyz") + if mt.failed { + t.Error("NotContains should not fail when substr is absent") + } + }) + + t.Run("fails when string contains substr", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotContains(mt, "hello world", "world") + if !mt.failed { + t.Error("NotContains should fail when substr is present") + } + }) +} + +func TestLen(t *testing.T) { + t.Parallel() + t.Run("passes on correct length", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Len(mt, []int{1, 2, 3}, 3) + if mt.failed { + t.Error("Len should not fail for correct length") + } + }) + + t.Run("fails on wrong length", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Len(mt, []int{1, 2}, 3) + if !mt.failed { + t.Error("Len should fail for wrong length") + } + }) + + t.Run("works with empty slice", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Len(mt, []string{}, 0) + if mt.failed { + t.Error("Len should not fail for empty slice with want 0") + } + }) +} + +func TestNil(t *testing.T) { + t.Parallel() + t.Run("passes on nil", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Nil(mt, nil) + if mt.failed { + t.Error("Nil should not fail for nil") + } + }) + + t.Run("fails on non-nil", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Nil(mt, "something") + if !mt.failed { + t.Error("Nil should fail for non-nil") + } + }) +} + +func TestNotNil(t *testing.T) { + t.Parallel() + t.Run("passes on non-nil", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotNil(mt, "something") + if mt.failed { + t.Error("NotNil should not fail for non-nil") + } + }) + + t.Run("fails on nil", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotNil(mt, nil) + if !mt.failed { + t.Error("NotNil should fail for nil") + } + }) +} + +func TestTrue(t *testing.T) { + t.Parallel() + t.Run("passes on true", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + True(mt, true) + if mt.failed { + t.Error("True should not fail for true") + } + }) + + t.Run("fails on false", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + True(mt, false) + if !mt.failed { + t.Error("True should fail for false") + } + }) +} + +func TestFalse(t *testing.T) { + t.Parallel() + t.Run("passes on false", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + False(mt, false) + if mt.failed { + t.Error("False should not fail for false") + } + }) + + t.Run("fails on true", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + False(mt, true) + if !mt.failed { + t.Error("False should fail for true") + } + }) +} + +func TestEmpty(t *testing.T) { + t.Parallel() + t.Run("passes on empty string", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Empty(mt, "") + if mt.failed { + t.Error("Empty should not fail for empty string") + } + }) + + t.Run("fails on non-empty string", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Empty(mt, "hello") + if !mt.failed { + t.Error("Empty should fail for non-empty string") + } + }) +} + +func TestNotEmpty(t *testing.T) { + t.Parallel() + t.Run("passes on non-empty string", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotEmpty(mt, "hello") + if mt.failed { + t.Error("NotEmpty should not fail for non-empty string") + } + }) + + t.Run("fails on empty string", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + NotEmpty(mt, "") + if !mt.failed { + t.Error("NotEmpty should fail for empty string") + } + }) +} + +func TestGreater(t *testing.T) { + t.Parallel() + t.Run("passes when a > b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Greater(mt, 5, 3) + if mt.failed { + t.Error("Greater should not fail when a > b") + } + }) + + t.Run("fails when a == b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Greater(mt, 3, 3) + if !mt.failed { + t.Error("Greater should fail when a == b") + } + }) + + t.Run("fails when a < b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Greater(mt, 2, 3) + if !mt.failed { + t.Error("Greater should fail when a < b") + } + }) +} + +func TestGreaterOrEqual(t *testing.T) { + t.Parallel() + t.Run("passes when a > b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + GreaterOrEqual(mt, 5, 3) + if mt.failed { + t.Error("GreaterOrEqual should not fail when a > b") + } + }) + + t.Run("passes when a == b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + GreaterOrEqual(mt, 3, 3) + if mt.failed { + t.Error("GreaterOrEqual should not fail when a == b") + } + }) + + t.Run("fails when a < b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + GreaterOrEqual(mt, 2, 3) + if !mt.failed { + t.Error("GreaterOrEqual should fail when a < b") + } + }) +} + +func TestLess(t *testing.T) { + t.Parallel() + t.Run("passes when a < b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Less(mt, 2, 5) + if mt.failed { + t.Error("Less should not fail when a < b") + } + }) + + t.Run("fails when a == b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Less(mt, 3, 3) + if !mt.failed { + t.Error("Less should fail when a == b") + } + }) + + t.Run("fails when a > b", func(t *testing.T) { + t.Parallel() + mt := &mockT{} + Less(mt, 5, 3) + if !mt.failed { + t.Error("Less should fail when a > b") + } + }) +} diff --git a/internal/testutil/helpers.go b/internal/testutil/helpers.go new file mode 100644 index 0000000..85bc072 --- /dev/null +++ b/internal/testutil/helpers.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "bytes" + "io" + "os" + "testing" +) + +// CaptureStdout captures everything written to os.Stdout during the execution +// of f and returns it as a string. This is useful for testing commands that +// print output directly to stdout. +func CaptureStdout(t testing.TB, f func()) string { + t.Helper() + old := os.Stdout + r, w, err := os.Pipe() + NoError(t, err) + os.Stdout = w + + f() + + // Close error is non-fatal for pipe operations in tests + _ = w.Close() + os.Stdout = old + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + return buf.String() +} + +// WithFactory temporarily replaces a factory function variable with a +// replacement value, executes f, then restores the original. This is the +// generic building block for per-package withMockClient helpers. +// +// Usage: +// +// testutil.WithFactory(&ClientFactory, mockFactory, func() { +// // ClientFactory now returns the mock +// }) +func WithFactory[T any](factoryPtr *T, replacement T, f func()) { + original := *factoryPtr + *factoryPtr = replacement + defer func() { *factoryPtr = original }() + f() +} diff --git a/internal/testutil/mocks.go b/internal/testutil/mocks.go deleted file mode 100644 index 4a10737..0000000 --- a/internal/testutil/mocks.go +++ /dev/null @@ -1,231 +0,0 @@ -// Package testutil provides test utilities including mock implementations -// of client interfaces for unit testing command handlers. -package testutil - -import ( - "google.golang.org/api/calendar/v3" - "google.golang.org/api/gmail/v1" - "google.golang.org/api/people/v1" - - calendarapi "github.com/open-cli-collective/google-readonly/internal/calendar" - contactsapi "github.com/open-cli-collective/google-readonly/internal/contacts" - driveapi "github.com/open-cli-collective/google-readonly/internal/drive" - gmailapi "github.com/open-cli-collective/google-readonly/internal/gmail" -) - -// MockGmailClient is a configurable mock for GmailClientInterface. -// Set the function fields to control behavior in tests. -type MockGmailClient struct { - GetMessageFunc func(messageID string, includeBody bool) (*gmailapi.Message, error) - SearchMessagesFunc func(query string, maxResults int64) ([]*gmailapi.Message, int, error) - GetThreadFunc func(id string) ([]*gmailapi.Message, error) - FetchLabelsFunc func() error - GetLabelNameFunc func(labelID string) string - GetLabelsFunc func() []*gmail.Label - GetAttachmentsFunc func(messageID string) ([]*gmailapi.Attachment, error) - DownloadAttachmentFunc func(messageID, attachmentID string) ([]byte, error) - DownloadInlineAttachmentFunc func(messageID, partID string) ([]byte, error) - GetProfileFunc func() (*gmailapi.Profile, error) -} - -// Verify MockGmailClient implements GmailClientInterface -var _ gmailapi.GmailClientInterface = (*MockGmailClient)(nil) - -func (m *MockGmailClient) GetMessage(messageID string, includeBody bool) (*gmailapi.Message, error) { - if m.GetMessageFunc != nil { - return m.GetMessageFunc(messageID, includeBody) - } - return nil, nil -} - -func (m *MockGmailClient) SearchMessages(query string, maxResults int64) ([]*gmailapi.Message, int, error) { - if m.SearchMessagesFunc != nil { - return m.SearchMessagesFunc(query, maxResults) - } - return nil, 0, nil -} - -func (m *MockGmailClient) GetThread(id string) ([]*gmailapi.Message, error) { - if m.GetThreadFunc != nil { - return m.GetThreadFunc(id) - } - return nil, nil -} - -func (m *MockGmailClient) FetchLabels() error { - if m.FetchLabelsFunc != nil { - return m.FetchLabelsFunc() - } - return nil -} - -func (m *MockGmailClient) GetLabelName(labelID string) string { - if m.GetLabelNameFunc != nil { - return m.GetLabelNameFunc(labelID) - } - return labelID -} - -func (m *MockGmailClient) GetLabels() []*gmail.Label { - if m.GetLabelsFunc != nil { - return m.GetLabelsFunc() - } - return nil -} - -func (m *MockGmailClient) GetAttachments(messageID string) ([]*gmailapi.Attachment, error) { - if m.GetAttachmentsFunc != nil { - return m.GetAttachmentsFunc(messageID) - } - return nil, nil -} - -func (m *MockGmailClient) DownloadAttachment(messageID, attachmentID string) ([]byte, error) { - if m.DownloadAttachmentFunc != nil { - return m.DownloadAttachmentFunc(messageID, attachmentID) - } - return nil, nil -} - -func (m *MockGmailClient) DownloadInlineAttachment(messageID, partID string) ([]byte, error) { - if m.DownloadInlineAttachmentFunc != nil { - return m.DownloadInlineAttachmentFunc(messageID, partID) - } - return nil, nil -} - -func (m *MockGmailClient) GetProfile() (*gmailapi.Profile, error) { - if m.GetProfileFunc != nil { - return m.GetProfileFunc() - } - return nil, nil -} - -// MockCalendarClient is a configurable mock for CalendarClientInterface. -type MockCalendarClient struct { - ListCalendarsFunc func() ([]*calendar.CalendarListEntry, error) - ListEventsFunc func(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) - GetEventFunc func(calendarID, eventID string) (*calendar.Event, error) -} - -// Verify MockCalendarClient implements CalendarClientInterface -var _ calendarapi.CalendarClientInterface = (*MockCalendarClient)(nil) - -func (m *MockCalendarClient) ListCalendars() ([]*calendar.CalendarListEntry, error) { - if m.ListCalendarsFunc != nil { - return m.ListCalendarsFunc() - } - return nil, nil -} - -func (m *MockCalendarClient) ListEvents(calendarID, timeMin, timeMax string, maxResults int64) ([]*calendar.Event, error) { - if m.ListEventsFunc != nil { - return m.ListEventsFunc(calendarID, timeMin, timeMax, maxResults) - } - return nil, nil -} - -func (m *MockCalendarClient) GetEvent(calendarID, eventID string) (*calendar.Event, error) { - if m.GetEventFunc != nil { - return m.GetEventFunc(calendarID, eventID) - } - return nil, nil -} - -// MockContactsClient is a configurable mock for ContactsClientInterface. -type MockContactsClient struct { - ListContactsFunc func(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) - SearchContactsFunc func(query string, pageSize int64) (*people.SearchResponse, error) - GetContactFunc func(resourceName string) (*people.Person, error) - ListContactGroupsFunc func(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) -} - -// Verify MockContactsClient implements ContactsClientInterface -var _ contactsapi.ContactsClientInterface = (*MockContactsClient)(nil) - -func (m *MockContactsClient) ListContacts(pageToken string, pageSize int64) (*people.ListConnectionsResponse, error) { - if m.ListContactsFunc != nil { - return m.ListContactsFunc(pageToken, pageSize) - } - return nil, nil -} - -func (m *MockContactsClient) SearchContacts(query string, pageSize int64) (*people.SearchResponse, error) { - if m.SearchContactsFunc != nil { - return m.SearchContactsFunc(query, pageSize) - } - return nil, nil -} - -func (m *MockContactsClient) GetContact(resourceName string) (*people.Person, error) { - if m.GetContactFunc != nil { - return m.GetContactFunc(resourceName) - } - return nil, nil -} - -func (m *MockContactsClient) ListContactGroups(pageToken string, pageSize int64) (*people.ListContactGroupsResponse, error) { - if m.ListContactGroupsFunc != nil { - return m.ListContactGroupsFunc(pageToken, pageSize) - } - return nil, nil -} - -// MockDriveClient is a configurable mock for DriveClientInterface. -type MockDriveClient struct { - ListFilesFunc func(query string, pageSize int64) ([]*driveapi.File, error) - ListFilesWithScopeFunc func(query string, pageSize int64, scope driveapi.DriveScope) ([]*driveapi.File, error) - GetFileFunc func(fileID string) (*driveapi.File, error) - DownloadFileFunc func(fileID string) ([]byte, error) - ExportFileFunc func(fileID, mimeType string) ([]byte, error) - ListSharedDrivesFunc func(pageSize int64) ([]*driveapi.SharedDrive, error) -} - -// Verify MockDriveClient implements DriveClientInterface -var _ driveapi.DriveClientInterface = (*MockDriveClient)(nil) - -func (m *MockDriveClient) ListFiles(query string, pageSize int64) ([]*driveapi.File, error) { - if m.ListFilesFunc != nil { - return m.ListFilesFunc(query, pageSize) - } - return nil, nil -} - -func (m *MockDriveClient) ListFilesWithScope(query string, pageSize int64, scope driveapi.DriveScope) ([]*driveapi.File, error) { - if m.ListFilesWithScopeFunc != nil { - return m.ListFilesWithScopeFunc(query, pageSize, scope) - } - // Fall back to ListFiles if no scope function defined - if m.ListFilesFunc != nil { - return m.ListFilesFunc(query, pageSize) - } - return nil, nil -} - -func (m *MockDriveClient) GetFile(fileID string) (*driveapi.File, error) { - if m.GetFileFunc != nil { - return m.GetFileFunc(fileID) - } - return nil, nil -} - -func (m *MockDriveClient) DownloadFile(fileID string) ([]byte, error) { - if m.DownloadFileFunc != nil { - return m.DownloadFileFunc(fileID) - } - return nil, nil -} - -func (m *MockDriveClient) ExportFile(fileID, mimeType string) ([]byte, error) { - if m.ExportFileFunc != nil { - return m.ExportFileFunc(fileID, mimeType) - } - return nil, nil -} - -func (m *MockDriveClient) ListSharedDrives(pageSize int64) ([]*driveapi.SharedDrive, error) { - if m.ListSharedDrivesFunc != nil { - return m.ListSharedDrivesFunc(pageSize) - } - return nil, nil -} diff --git a/internal/zip/extract.go b/internal/zip/extract.go index cb48215..618db6d 100644 --- a/internal/zip/extract.go +++ b/internal/zip/extract.go @@ -1,3 +1,4 @@ +// Package zip provides secure zip archive extraction with path traversal protection. package zip import ( @@ -42,7 +43,7 @@ func DefaultOptions() Options { func Extract(zipPath, destDir string, opts Options) error { r, err := zip.OpenReader(zipPath) if err != nil { - return fmt.Errorf("failed to open zip: %w", err) + return fmt.Errorf("opening zip: %w", err) } defer r.Close() @@ -54,10 +55,10 @@ func Extract(zipPath, destDir string, opts Options) error { // Create destination directory destDir, err = filepath.Abs(destDir) if err != nil { - return fmt.Errorf("failed to resolve destination path: %w", err) + return fmt.Errorf("resolving destination path: %w", err) } if err := fs.MkdirAll(destDir, 0755); err != nil { - return fmt.Errorf("failed to create destination: %w", err) + return fmt.Errorf("creating destination: %w", err) } var totalSize int64 @@ -79,14 +80,14 @@ func validateZip(r *zip.Reader, opts Options) error { var totalSize uint64 for _, f := range r.File { // Check for zip bomb (compression ratio attack) - if f.UncompressedSize64 > uint64(opts.MaxFileSize) { + if f.UncompressedSize64 > uint64(opts.MaxFileSize) { //nolint:gosec // MaxFileSize is always positive return fmt.Errorf("file %s exceeds max size: %d bytes", f.Name, f.UncompressedSize64) } totalSize += f.UncompressedSize64 } - if totalSize > uint64(opts.MaxTotalSize) { + if totalSize > uint64(opts.MaxTotalSize) { //nolint:gosec // MaxTotalSize is always positive return fmt.Errorf("total extracted size exceeds limit: %d bytes (max %d)", totalSize, opts.MaxTotalSize) } diff --git a/internal/zip/extract_test.go b/internal/zip/extract_test.go index 57914d6..bbb1bd1 100644 --- a/internal/zip/extract_test.go +++ b/internal/zip/extract_test.go @@ -8,25 +8,24 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/open-cli-collective/google-readonly/internal/testutil" ) func createTestZip(t *testing.T, files map[string][]byte) string { t.Helper() tmpFile, err := os.CreateTemp("", "test-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) defer tmpFile.Close() w := zip.NewWriter(tmpFile) for name, content := range files { f, err := w.Create(name) - require.NoError(t, err) + testutil.NoError(t, err) _, err = f.Write(content) - require.NoError(t, err) + testutil.NoError(t, err) } - require.NoError(t, w.Close()) + testutil.NoError(t, w.Close()) return tmpFile.Name() } @@ -41,15 +40,15 @@ func TestExtract(t *testing.T) { destDir := t.TempDir() err := Extract(zipPath, destDir, DefaultOptions()) - require.NoError(t, err) + testutil.NoError(t, err) content1, err := os.ReadFile(filepath.Join(destDir, "file1.txt")) - require.NoError(t, err) - assert.Equal(t, "content 1", string(content1)) + testutil.NoError(t, err) + testutil.Equal(t, string(content1), "content 1") content2, err := os.ReadFile(filepath.Join(destDir, "file2.txt")) - require.NoError(t, err) - assert.Equal(t, "content 2", string(content2)) + testutil.NoError(t, err) + testutil.Equal(t, string(content2), "content 2") }) t.Run("extracts nested directories", func(t *testing.T) { @@ -61,11 +60,11 @@ func TestExtract(t *testing.T) { destDir := t.TempDir() err := Extract(zipPath, destDir, DefaultOptions()) - require.NoError(t, err) + testutil.NoError(t, err) content, err := os.ReadFile(filepath.Join(destDir, "dir1", "dir2", "file2.txt")) - require.NoError(t, err) - assert.Equal(t, "nested 2", string(content)) + testutil.NoError(t, err) + testutil.Equal(t, string(content), "nested 2") }) t.Run("creates destination directory if not exists", func(t *testing.T) { @@ -76,23 +75,23 @@ func TestExtract(t *testing.T) { destDir := filepath.Join(t.TempDir(), "new", "nested", "dir") err := Extract(zipPath, destDir, DefaultOptions()) - require.NoError(t, err) + testutil.NoError(t, err) _, err = os.Stat(filepath.Join(destDir, "test.txt")) - assert.NoError(t, err) + testutil.NoError(t, err) }) t.Run("rejects invalid zip file", func(t *testing.T) { tmpFile, err := os.CreateTemp("", "invalid-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) tmpFile.WriteString("not a zip file") tmpFile.Close() defer os.Remove(tmpFile.Name()) destDir := t.TempDir() err = Extract(tmpFile.Name(), destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to open zip") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "opening zip") }) } @@ -100,7 +99,7 @@ func TestExtractSecurityPathTraversal(t *testing.T) { t.Run("rejects path with leading ..", func(t *testing.T) { // Create a malicious zip with path traversal tmpFile, err := os.CreateTemp("", "malicious-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) defer os.Remove(tmpFile.Name()) w := zip.NewWriter(tmpFile) @@ -110,20 +109,20 @@ func TestExtractSecurityPathTraversal(t *testing.T) { Method: zip.Store, } f, err := w.CreateHeader(header) - require.NoError(t, err) + testutil.NoError(t, err) f.Write([]byte("malicious")) w.Close() tmpFile.Close() destDir := t.TempDir() err = Extract(tmpFile.Name(), destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid file path") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "invalid file path") }) t.Run("rejects absolute paths", func(t *testing.T) { tmpFile, err := os.CreateTemp("", "malicious-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) defer os.Remove(tmpFile.Name()) w := zip.NewWriter(tmpFile) @@ -132,14 +131,14 @@ func TestExtractSecurityPathTraversal(t *testing.T) { Method: zip.Store, } f, err := w.CreateHeader(header) - require.NoError(t, err) + testutil.NoError(t, err) f.Write([]byte("malicious")) w.Close() tmpFile.Close() destDir := t.TempDir() err = Extract(tmpFile.Name(), destDir, DefaultOptions()) - assert.Error(t, err) + testutil.Error(t, err) }) } @@ -160,8 +159,8 @@ func TestExtractSecurityLimits(t *testing.T) { MaxDepth: MaxDepth, } err := Extract(zipPath, destDir, opts) - assert.Error(t, err) - assert.Contains(t, err.Error(), "too many files") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "too many files") }) t.Run("rejects file exceeding max size", func(t *testing.T) { @@ -178,8 +177,8 @@ func TestExtractSecurityLimits(t *testing.T) { MaxDepth: MaxDepth, } err := Extract(zipPath, destDir, opts) - assert.Error(t, err) - assert.Contains(t, err.Error(), "exceeds max size") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "exceeds max size") }) t.Run("rejects total size exceeding limit", func(t *testing.T) { @@ -197,8 +196,8 @@ func TestExtractSecurityLimits(t *testing.T) { MaxDepth: MaxDepth, } err := Extract(zipPath, destDir, opts) - assert.Error(t, err) - assert.Contains(t, err.Error(), "exceeds limit") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "exceeds limit") }) t.Run("rejects path too deep", func(t *testing.T) { @@ -215,17 +214,17 @@ func TestExtractSecurityLimits(t *testing.T) { MaxDepth: 3, // Less than actual depth } err := Extract(zipPath, destDir, opts) - assert.Error(t, err) - assert.Contains(t, err.Error(), "too deep") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "too deep") }) } func TestDefaultOptions(t *testing.T) { opts := DefaultOptions() - assert.Equal(t, int64(MaxFileSize), opts.MaxFileSize) - assert.Equal(t, int64(MaxTotalSize), opts.MaxTotalSize) - assert.Equal(t, MaxFiles, opts.MaxFiles) - assert.Equal(t, MaxDepth, opts.MaxDepth) + testutil.Equal(t, opts.MaxFileSize, int64(MaxFileSize)) + testutil.Equal(t, opts.MaxTotalSize, int64(MaxTotalSize)) + testutil.Equal(t, opts.MaxFiles, MaxFiles) + testutil.Equal(t, opts.MaxDepth, MaxDepth) } func TestValidateZip(t *testing.T) { @@ -236,11 +235,11 @@ func TestValidateZip(t *testing.T) { defer os.Remove(zipPath) r, err := zip.OpenReader(zipPath) - require.NoError(t, err) + testutil.NoError(t, err) defer r.Close() err = validateZip(&r.Reader, DefaultOptions()) - assert.NoError(t, err) + testutil.NoError(t, err) }) } @@ -252,7 +251,7 @@ type mockFS struct { failAfterN int // fail MkdirAll after N calls (0 = fail immediately) } -func (m *mockFS) MkdirAll(path string, perm os.FileMode) error { +func (m *mockFS) MkdirAll(_ string, _ os.FileMode) error { m.mkdirCalls++ if m.failAfterN > 0 && m.mkdirCalls <= m.failAfterN { return nil @@ -295,15 +294,15 @@ type mockFSWithErrorWriter struct { writer *errorWriter } -func (m *mockFSWithErrorWriter) MkdirAll(path string, perm os.FileMode) error { +func (m *mockFSWithErrorWriter) MkdirAll(_ string, _ os.FileMode) error { return nil } -func (m *mockFSWithErrorWriter) OpenFile(name string, flag int, perm os.FileMode) (io.WriteCloser, error) { +func (m *mockFSWithErrorWriter) OpenFile(_ string, _ int, _ os.FileMode) (io.WriteCloser, error) { return m.writer, nil } -func (m *mockFSWithErrorWriter) Remove(name string) error { +func (m *mockFSWithErrorWriter) Remove(_ string) error { return nil } @@ -323,8 +322,8 @@ func TestExtractFileSystemErrors(t *testing.T) { defer os.Remove(zipPath) err := Extract(zipPath, "/tmp/test-dest", DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to create destination") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "creating destination") }) t.Run("returns error when MkdirAll fails for parent directory", func(t *testing.T) { @@ -340,8 +339,8 @@ func TestExtractFileSystemErrors(t *testing.T) { destDir := t.TempDir() err := Extract(zipPath, destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "disk full") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "disk full") }) t.Run("returns error when OpenFile fails", func(t *testing.T) { @@ -356,8 +355,8 @@ func TestExtractFileSystemErrors(t *testing.T) { destDir := t.TempDir() err := Extract(zipPath, destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "too many open files") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "too many open files") }) t.Run("returns error when io.Copy fails", func(t *testing.T) { @@ -375,8 +374,8 @@ func TestExtractFileSystemErrors(t *testing.T) { destDir := t.TempDir() err := Extract(zipPath, destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "write error") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "write error") }) } @@ -388,7 +387,7 @@ func TestExtractDirectoryEntry(t *testing.T) { t.Run("extracts directory entries", func(t *testing.T) { // Create zip with explicit directory entry tmpFile, err := os.CreateTemp("", "dir-test-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) defer os.Remove(tmpFile.Name()) w := zip.NewWriter(tmpFile) @@ -399,18 +398,18 @@ func TestExtractDirectoryEntry(t *testing.T) { } header.SetMode(os.ModeDir | 0755) _, err = w.CreateHeader(header) - require.NoError(t, err) + testutil.NoError(t, err) w.Close() tmpFile.Close() destDir := t.TempDir() err = Extract(tmpFile.Name(), destDir, DefaultOptions()) - require.NoError(t, err) + testutil.NoError(t, err) // Verify directory was created info, err := os.Stat(filepath.Join(destDir, "mydir")) - require.NoError(t, err) - assert.True(t, info.IsDir()) + testutil.NoError(t, err) + testutil.True(t, info.IsDir()) }) t.Run("returns error when MkdirAll fails for directory entry", func(t *testing.T) { @@ -421,18 +420,18 @@ func TestExtractDirectoryEntry(t *testing.T) { // Create zip with explicit directory entry tmpFile, err := os.CreateTemp("", "dir-test-*.zip") - require.NoError(t, err) + testutil.NoError(t, err) defer os.Remove(tmpFile.Name()) w := zip.NewWriter(tmpFile) _, err = w.Create("mydir/") - require.NoError(t, err) + testutil.NoError(t, err) w.Close() tmpFile.Close() destDir := t.TempDir() err = Extract(tmpFile.Name(), destDir, DefaultOptions()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "cannot create directory") + testutil.Error(t, err) + testutil.Contains(t, err.Error(), "cannot create directory") }) } diff --git a/internal/zip/fs.go b/internal/zip/fs.go index 25646b9..c2fb497 100644 --- a/internal/zip/fs.go +++ b/internal/zip/fs.go @@ -20,7 +20,7 @@ func (osFS) MkdirAll(path string, perm os.FileMode) error { } func (osFS) OpenFile(name string, flag int, perm os.FileMode) (io.WriteCloser, error) { - return os.OpenFile(name, flag, perm) + return os.OpenFile(name, flag, perm) //nolint:gosec // Path validated by caller in extractFile } func (osFS) Remove(name string) error {