From 4beb6f245ab3ac000257987506037195bf22b0de Mon Sep 17 00:00:00 2001 From: Brian Le Date: Fri, 27 Feb 2026 14:56:31 -0500 Subject: [PATCH 1/3] feat(api): add notion rest client primitives Introduce a dedicated official Notion REST API client package and config loader scaffolding. Adds API base URL, Notion-Version, and token config with env overrides, plus tests for request behavior and config normalization. --- internal/api/client.go | 119 +++++++++++++++++++++++++++++++++ internal/api/client_test.go | 108 ++++++++++++++++++++++++++++++ internal/config/config.go | 96 ++++++++++++++++++++++++++ internal/config/config_test.go | 47 +++++++++++++ 4 files changed, 370 insertions(+) create mode 100644 internal/api/client.go create mode 100644 internal/api/client_test.go create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go diff --git a/internal/api/client.go b/internal/api/client.go new file mode 100644 index 0000000..acd1a0d --- /dev/null +++ b/internal/api/client.go @@ -0,0 +1,119 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/lox/notion-cli/internal/config" +) + +const ( + defaultBaseURL = "https://api.notion.com/v1" + defaultNotionAPIRev = "2022-06-28" +) + +type Client struct { + httpClient *http.Client + baseURL string + notionVersion string + token string +} + +func NewClient(cfg config.APIConfig, token string) (*Client, error) { + token = strings.TrimSpace(token) + if token == "" { + return nil, fmt.Errorf("official API token is required") + } + + baseURL := strings.TrimSpace(cfg.BaseURL) + if baseURL == "" { + baseURL = defaultBaseURL + } + baseURL = strings.TrimRight(baseURL, "/") + + notionVersion := strings.TrimSpace(cfg.NotionVersion) + if notionVersion == "" { + notionVersion = defaultNotionAPIRev + } + + return &Client{ + httpClient: &http.Client{Timeout: 20 * time.Second}, + baseURL: baseURL, + notionVersion: notionVersion, + token: token, + }, nil +} + +func (c *Client) PatchPage(ctx context.Context, pageID string, patch map[string]any) error { + pageID = strings.TrimSpace(pageID) + if pageID == "" { + return fmt.Errorf("page ID is required") + } + if len(patch) == 0 { + return fmt.Errorf("patch payload is required") + } + + return c.doJSON(ctx, http.MethodPatch, "/pages/"+pageID, patch, nil) +} + +func (c *Client) doJSON(ctx context.Context, method, path string, payload any, out any) error { + var bodyReader io.Reader + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return err + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader) + if err != nil { + return err + } + req.Header.Set("accept", "application/json") + req.Header.Set("authorization", "Bearer "+c.token) + req.Header.Set("notion-version", c.notionVersion) + if payload != nil { + req.Header.Set("content-type", "application/json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode >= 400 { + message := strings.TrimSpace(string(respBody)) + if message == "" { + message = http.StatusText(resp.StatusCode) + } else { + var errResp struct { + Message string `json:"message"` + } + if err := json.Unmarshal(respBody, &errResp); err == nil && strings.TrimSpace(errResp.Message) != "" { + message = strings.TrimSpace(errResp.Message) + } + } + return fmt.Errorf("official API %s %s failed (%d): %s", method, path, resp.StatusCode, message) + } + + if out == nil || len(respBody) == 0 { + return nil + } + if err := json.Unmarshal(respBody, out); err != nil { + return fmt.Errorf("parse official API response for %s %s: %w", method, path, err) + } + return nil +} diff --git a/internal/api/client_test.go b/internal/api/client_test.go new file mode 100644 index 0000000..075081f --- /dev/null +++ b/internal/api/client_test.go @@ -0,0 +1,108 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/lox/notion-cli/internal/config" +) + +func TestNewClientRequiresToken(t *testing.T) { + t.Parallel() + + _, err := NewClient(config.APIConfig{}, "") + if err == nil { + t.Fatal("expected token error") + } +} + +func TestPatchPageSendsPatchRequest(t *testing.T) { + t.Parallel() + + var gotMethod string + var gotPath string + var gotAuth string + var gotVersion string + var gotContentType string + var gotBody map[string]any + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotVersion = r.Header.Get("Notion-Version") + gotContentType = r.Header.Get("Content-Type") + + defer func() { _ = r.Body.Close() }() + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page-id","object":"page"}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{ + BaseURL: srv.URL, + NotionVersion: "2022-06-28", + }, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + patch := map[string]any{ + "archived": true, + } + + if err := client.PatchPage(context.Background(), "page-id", patch); err != nil { + t.Fatalf("patch page: %v", err) + } + + if gotMethod != http.MethodPatch { + t.Fatalf("method mismatch: got %s", gotMethod) + } + if gotPath != "/pages/page-id" { + t.Fatalf("path mismatch: got %s", gotPath) + } + if gotAuth != "Bearer secret-token" { + t.Fatalf("auth mismatch: got %s", gotAuth) + } + if gotVersion != "2022-06-28" { + t.Fatalf("notion-version mismatch: got %s", gotVersion) + } + if gotContentType != "application/json" { + t.Fatalf("content-type mismatch: got %s", gotContentType) + } + + if gotBody["archived"] != true { + t.Fatalf("archived mismatch: %v", gotBody["archived"]) + } +} + +func TestPatchPageReturnsAPIErrorMessage(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"object":"error","message":"unauthorized"}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + err = client.PatchPage(context.Background(), "page-id", map[string]any{"archived": true}) + if err == nil { + t.Fatal("expected API error") + } + if !strings.Contains(err.Error(), "unauthorized") { + t.Fatalf("expected unauthorized message, got: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..bdece23 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,96 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" +) + +const ( + configDirName = ".config/notion-cli" + configFileName = "config.json" +) + +type Config struct { + ActiveAccount string `json:"active_account,omitempty"` + API APIConfig `json:"api,omitempty"` +} + +type APIConfig struct { + BaseURL string `json:"base_url,omitempty"` + NotionVersion string `json:"notion_version,omitempty"` + Token string `json:"token,omitempty"` +} + +func Default() Config { + return Config{ + API: APIConfig{ + BaseURL: "https://api.notion.com/v1", + NotionVersion: "2022-06-28", + }, + } +} + +func Load() (Config, error) { + cfg := Default() + + path, err := Path() + if err != nil { + return cfg, err + } + + if data, err := os.ReadFile(path); err == nil { + if err := json.Unmarshal(data, &cfg); err != nil { + return cfg, err + } + } else if !errors.Is(err, os.ErrNotExist) { + return cfg, err + } + + applyEnvOverrides(&cfg) + normalize(&cfg) + return cfg, nil +} + +func Path() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, configDirName, configFileName), nil +} + +func applyEnvOverrides(cfg *Config) { + if cfg == nil { + return + } + + if s := os.Getenv("NOTION_API_BASE_URL"); s != "" { + cfg.API.BaseURL = s + } + if s := os.Getenv("NOTION_API_NOTION_VERSION"); s != "" { + cfg.API.NotionVersion = s + } + if s := os.Getenv("NOTION_API_TOKEN"); s != "" { + cfg.API.Token = s + } +} + +func normalize(cfg *Config) { + if cfg == nil { + return + } + + cfg.API.BaseURL = strings.TrimSpace(cfg.API.BaseURL) + if cfg.API.BaseURL == "" { + cfg.API.BaseURL = "https://api.notion.com/v1" + } + cfg.API.BaseURL = strings.TrimRight(cfg.API.BaseURL, "/") + cfg.API.NotionVersion = strings.TrimSpace(cfg.API.NotionVersion) + if cfg.API.NotionVersion == "" { + cfg.API.NotionVersion = "2022-06-28" + } + cfg.API.Token = strings.TrimSpace(cfg.API.Token) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..47d39c4 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,47 @@ +package config + +import "testing" + +func TestApplyEnvOverrides(t *testing.T) { + t.Setenv("NOTION_API_BASE_URL", "https://api.example.com/v1/") + t.Setenv("NOTION_API_NOTION_VERSION", "2022-06-28") + t.Setenv("NOTION_API_TOKEN", "api-token") + + cfg := Default() + applyEnvOverrides(&cfg) + normalize(&cfg) + + if cfg.API.BaseURL != "https://api.example.com/v1" { + t.Fatalf("unexpected api.base_url normalization: %q", cfg.API.BaseURL) + } + if cfg.API.NotionVersion != "2022-06-28" { + t.Fatalf("unexpected api.notion_version: %q", cfg.API.NotionVersion) + } + if cfg.API.Token != "api-token" { + t.Fatalf("unexpected api.token: %q", cfg.API.Token) + } +} + +func TestNormalizeAppliesAPIDefaults(t *testing.T) { + cfg := Config{} + normalize(&cfg) + + if cfg.API.BaseURL != "https://api.notion.com/v1" { + t.Fatalf("unexpected api.base_url default: %q", cfg.API.BaseURL) + } + if cfg.API.NotionVersion != "2022-06-28" { + t.Fatalf("unexpected api.notion_version default: %q", cfg.API.NotionVersion) + } +} + +func TestPathUsesHome(t *testing.T) { + t.Setenv("HOME", "/tmp/example-home") + + path, err := Path() + if err != nil { + t.Fatal(err) + } + if path != "/tmp/example-home/.config/notion-cli/config.json" { + t.Fatalf("unexpected path: %s", path) + } +} From b3b67bf9ed6f4a8f90d2f25b7d9ee3d6b0324aa8 Mon Sep 17 00:00:00 2001 From: Brian Le Date: Sat, 28 Feb 2026 14:55:09 -0500 Subject: [PATCH 2/3] feat(api): add official API client loader for CLI commands --- internal/cli/official_api.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 internal/cli/official_api.go diff --git a/internal/cli/official_api.go b/internal/cli/official_api.go new file mode 100644 index 0000000..48906d9 --- /dev/null +++ b/internal/cli/official_api.go @@ -0,0 +1,22 @@ +package cli + +import ( + "fmt" + + "github.com/lox/notion-cli/internal/api" + "github.com/lox/notion-cli/internal/config" +) + +func RequireOfficialAPIClient() (*api.Client, error) { + cfg, err := config.Load() + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + client, err := api.NewClient(cfg.API, cfg.API.Token) + if err != nil { + return nil, fmt.Errorf("create official API client: %w (set api.token in ~/.config/notion-cli/config.json or NOTION_API_TOKEN)", err) + } + + return client, nil +} From f6d6333dc44c5ab436d85edbe73783b3a48a4f6b Mon Sep 17 00:00:00 2001 From: Brian Le Date: Sat, 18 Apr 2026 12:49:03 -0400 Subject: [PATCH 3/3] refactor(api): typed errors, URL escape, reuse defaults, 429 retry --- internal/api/client.go | 222 ++++++++++++++++++++++++++++-------- internal/api/client_test.go | 123 ++++++++++++++++++-- internal/config/config.go | 15 ++- 3 files changed, 298 insertions(+), 62 deletions(-) diff --git a/internal/api/client.go b/internal/api/client.go index acd1a0d..363375f 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -5,18 +5,83 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" + "net/url" + "strconv" "strings" "time" "github.com/lox/notion-cli/internal/config" ) -const ( - defaultBaseURL = "https://api.notion.com/v1" - defaultNotionAPIRev = "2022-06-28" -) +// APIError is a structured error returned by the Notion API. +type APIError struct { + StatusCode int + Code string + Message string + RequestID string +} + +func (e *APIError) Error() string { + parts := make([]string, 0, 3) + parts = append(parts, fmt.Sprintf("notion API error (status %d)", e.StatusCode)) + if e.Code != "" { + parts = append(parts, fmt.Sprintf("code=%s", e.Code)) + } + if e.RequestID != "" { + parts = append(parts, fmt.Sprintf("request_id=%s", e.RequestID)) + } + msg := strings.Join(parts, " ") + if e.Message != "" { + msg = msg + ": " + e.Message + } + return msg +} + +// Page represents a Notion page response (minimal fields the client surfaces today). +type Page struct { + ID string `json:"id"` + Object string `json:"object"` + Archived bool `json:"archived"` +} + +// Icon represents a Notion page icon. Left opaque for callers that pass it through. +type Icon map[string]any + +// Cover represents a Notion page cover. Left opaque for callers that pass it through. +type Cover map[string]any + +// PropertyValue represents a Notion property value. Left opaque for callers that pass it through. +type PropertyValue map[string]any + +// PageUpdate is the typed payload for PatchPage. +// All fields are optional; only non-nil fields are sent. +type PageUpdate struct { + Archived *bool + Icon *Icon + Cover *Cover + Properties map[string]PropertyValue +} + +func (u PageUpdate) payload() (map[string]any, error) { + out := make(map[string]any) + if u.Archived != nil { + out["archived"] = *u.Archived + } + if u.Icon != nil { + out["icon"] = *u.Icon + } + if u.Cover != nil { + out["cover"] = *u.Cover + } + if len(u.Properties) > 0 { + out["properties"] = u.Properties + } + if len(out) == 0 { + return nil, fmt.Errorf("page update is empty") + } + return out, nil +} type Client struct { httpClient *http.Client @@ -28,18 +93,18 @@ type Client struct { func NewClient(cfg config.APIConfig, token string) (*Client, error) { token = strings.TrimSpace(token) if token == "" { - return nil, fmt.Errorf("official API token is required") + return nil, fmt.Errorf("notion API token is required") } baseURL := strings.TrimSpace(cfg.BaseURL) if baseURL == "" { - baseURL = defaultBaseURL + baseURL = config.DefaultAPIBaseURL } baseURL = strings.TrimRight(baseURL, "/") notionVersion := strings.TrimSpace(cfg.NotionVersion) if notionVersion == "" { - notionVersion = defaultNotionAPIRev + notionVersion = config.DefaultNotionAPIVersion } return &Client{ @@ -50,70 +115,135 @@ func NewClient(cfg config.APIConfig, token string) (*Client, error) { }, nil } -func (c *Client) PatchPage(ctx context.Context, pageID string, patch map[string]any) error { +func (c *Client) PatchPage(ctx context.Context, pageID string, update PageUpdate) (*Page, error) { pageID = strings.TrimSpace(pageID) if pageID == "" { - return fmt.Errorf("page ID is required") + return nil, fmt.Errorf("page ID is required") } - if len(patch) == 0 { - return fmt.Errorf("patch payload is required") + + payload, err := update.payload() + if err != nil { + return nil, err + } + + pagePath, err := url.JoinPath(c.baseURL, "pages", url.PathEscape(pageID)) + if err != nil { + return nil, fmt.Errorf("build page URL: %w", err) } - return c.doJSON(ctx, http.MethodPatch, "/pages/"+pageID, patch, nil) + var page Page + if err := c.doJSON(ctx, http.MethodPatch, pagePath, payload, &page); err != nil { + return nil, err + } + return &page, nil } -func (c *Client) doJSON(ctx context.Context, method, path string, payload any, out any) error { - var bodyReader io.Reader +func (c *Client) doJSON(ctx context.Context, method, fullURL string, payload any, out any) error { + var body []byte if payload != nil { data, err := json.Marshal(payload) if err != nil { return err } - bodyReader = bytes.NewReader(data) + body = data } - req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader) + resp, err := c.sendOnce(ctx, method, fullURL, body, payload != nil) if err != nil { return err } - req.Header.Set("accept", "application/json") - req.Header.Set("authorization", "Bearer "+c.token) - req.Header.Set("notion-version", c.notionVersion) - if payload != nil { - req.Header.Set("content-type", "application/json") - } - resp, err := c.httpClient.Do(req) - if err != nil { - return err + // Minimal 429 handling: honor Retry-After once, retry once. + if resp.StatusCode == http.StatusTooManyRequests { + retryAfter := parseRetryAfter(resp.Header.Get("Retry-After")) + _ = resp.Body.Close() + if retryAfter > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryAfter): + } + } + resp, err = c.sendOnce(ctx, method, fullURL, body, payload != nil) + if err != nil { + return err + } } - defer func() { _ = resp.Body.Close() }() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return err - } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - message := strings.TrimSpace(string(respBody)) - if message == "" { - message = http.StatusText(resp.StatusCode) - } else { - var errResp struct { - Message string `json:"message"` - } - if err := json.Unmarshal(respBody, &errResp); err == nil && strings.TrimSpace(errResp.Message) != "" { - message = strings.TrimSpace(errResp.Message) - } - } - return fmt.Errorf("official API %s %s failed (%d): %s", method, path, resp.StatusCode, message) + return parseAPIError(resp) } - if out == nil || len(respBody) == 0 { + if out == nil { return nil } - if err := json.Unmarshal(respBody, out); err != nil { - return fmt.Errorf("parse official API response for %s %s: %w", method, path, err) + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("parse notion API response for %s %s: %w", method, fullURL, err) } return nil } + +func (c *Client) sendOnce(ctx context.Context, method, fullURL string, body []byte, hasPayload bool) (*http.Response, error) { + var bodyReader *bytes.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + var req *http.Request + var err error + if bodyReader != nil { + req, err = http.NewRequestWithContext(ctx, method, fullURL, bodyReader) + } else { + req, err = http.NewRequestWithContext(ctx, method, fullURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Notion-Version", c.notionVersion) + if hasPayload { + req.Header.Set("Content-Type", "application/json") + } + + return c.httpClient.Do(req) +} + +func parseAPIError(resp *http.Response) error { + apiErr := &APIError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("X-Notion-Request-Id"), + } + + var parsed struct { + Code string `json:"code"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&parsed); err == nil { + apiErr.Code = strings.TrimSpace(parsed.Code) + apiErr.Message = strings.TrimSpace(parsed.Message) + } + if apiErr.Message == "" { + apiErr.Message = http.StatusText(resp.StatusCode) + } + return apiErr +} + +func parseRetryAfter(h string) time.Duration { + h = strings.TrimSpace(h) + if h == "" { + return 0 + } + if secs, err := strconv.Atoi(h); err == nil && secs >= 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(h); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + return 0 +} diff --git a/internal/api/client_test.go b/internal/api/client_test.go index 075081f..f635e6b 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -3,10 +3,12 @@ package api import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" - "strings" + "sync/atomic" "testing" + "time" "github.com/lox/notion-cli/internal/config" ) @@ -28,6 +30,7 @@ func TestPatchPageSendsPatchRequest(t *testing.T) { var gotAuth string var gotVersion string var gotContentType string + var gotAccept string var gotBody map[string]any srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -36,6 +39,7 @@ func TestPatchPageSendsPatchRequest(t *testing.T) { gotAuth = r.Header.Get("Authorization") gotVersion = r.Header.Get("Notion-Version") gotContentType = r.Header.Get("Content-Type") + gotAccept = r.Header.Get("Accept") defer func() { _ = r.Body.Close() }() if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { @@ -43,7 +47,7 @@ func TestPatchPageSendsPatchRequest(t *testing.T) { } w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"id":"page-id","object":"page"}`)) + _, _ = w.Write([]byte(`{"id":"page-id","object":"page","archived":true}`)) })) defer srv.Close() @@ -55,11 +59,9 @@ func TestPatchPageSendsPatchRequest(t *testing.T) { t.Fatalf("new client: %v", err) } - patch := map[string]any{ - "archived": true, - } - - if err := client.PatchPage(context.Background(), "page-id", patch); err != nil { + archived := true + page, err := client.PatchPage(context.Background(), "page-id", PageUpdate{Archived: &archived}) + if err != nil { t.Fatalf("patch page: %v", err) } @@ -78,18 +80,50 @@ func TestPatchPageSendsPatchRequest(t *testing.T) { if gotContentType != "application/json" { t.Fatalf("content-type mismatch: got %s", gotContentType) } + if gotAccept != "application/json" { + t.Fatalf("accept mismatch: got %s", gotAccept) + } if gotBody["archived"] != true { t.Fatalf("archived mismatch: %v", gotBody["archived"]) } + if page == nil || page.ID != "page-id" || !page.Archived { + t.Fatalf("unexpected page: %+v", page) + } } -func TestPatchPageReturnsAPIErrorMessage(t *testing.T) { +func TestPatchPageEscapesPageID(t *testing.T) { t.Parallel() + var gotPath string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.EscapedPath() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page id"}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + if _, err := client.PatchPage(context.Background(), "page id/with slash", PageUpdate{Archived: &archived}); err != nil { + t.Fatalf("patch page: %v", err) + } + if gotPath != "/pages/page%20id%2Fwith%20slash" { + t.Fatalf("unexpected escaped path: %s", gotPath) + } +} + +func TestPatchPageReturnsTypedAPIError(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Notion-Request-Id", "req-123") w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"object":"error","message":"unauthorized"}`)) + _, _ = w.Write([]byte(`{"object":"error","code":"unauthorized","message":"invalid token"}`)) })) defer srv.Close() @@ -98,11 +132,76 @@ func TestPatchPageReturnsAPIErrorMessage(t *testing.T) { t.Fatalf("new client: %v", err) } - err = client.PatchPage(context.Background(), "page-id", map[string]any{"archived": true}) + archived := true + _, err = client.PatchPage(context.Background(), "page-id", PageUpdate{Archived: &archived}) if err == nil { t.Fatal("expected API error") } - if !strings.Contains(err.Error(), "unauthorized") { - t.Fatalf("expected unauthorized message, got: %v", err) + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected *APIError, got %T: %v", err, err) + } + if apiErr.StatusCode != http.StatusUnauthorized { + t.Fatalf("status: got %d", apiErr.StatusCode) + } + if apiErr.Code != "unauthorized" { + t.Fatalf("code: got %q", apiErr.Code) + } + if apiErr.Message != "invalid token" { + t.Fatalf("message: got %q", apiErr.Message) + } + if apiErr.RequestID != "req-123" { + t.Fatalf("request id: got %q", apiErr.RequestID) + } +} + +func TestPatchPageRetriesOn429(t *testing.T) { + t.Parallel() + + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"code":"rate_limited","message":"slow down"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page-id","archived":true}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + page, err := client.PatchPage(ctx, "page-id", PageUpdate{Archived: &archived}) + if err != nil { + t.Fatalf("patch page: %v", err) + } + if page.ID != "page-id" { + t.Fatalf("unexpected page: %+v", page) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Fatalf("expected exactly one retry (2 calls), got %d", got) + } +} + +func TestPatchPageEmptyUpdateIsRejected(t *testing.T) { + t.Parallel() + + client, err := NewClient(config.APIConfig{BaseURL: "https://example.invalid"}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + if _, err := client.PatchPage(context.Background(), "page-id", PageUpdate{}); err == nil { + t.Fatal("expected empty-update error") } } diff --git a/internal/config/config.go b/internal/config/config.go index bdece23..17b0452 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,11 @@ import ( const ( configDirName = ".config/notion-cli" configFileName = "config.json" + + // DefaultAPIBaseURL is the default Notion REST API base URL. + DefaultAPIBaseURL = "https://api.notion.com/v1" + // DefaultNotionAPIVersion is the default Notion-Version header value. + DefaultNotionAPIVersion = "2022-06-28" ) type Config struct { @@ -27,8 +32,8 @@ type APIConfig struct { func Default() Config { return Config{ API: APIConfig{ - BaseURL: "https://api.notion.com/v1", - NotionVersion: "2022-06-28", + BaseURL: DefaultAPIBaseURL, + NotionVersion: DefaultNotionAPIVersion, }, } } @@ -83,14 +88,16 @@ func normalize(cfg *Config) { return } + d := Default() + cfg.API.BaseURL = strings.TrimSpace(cfg.API.BaseURL) if cfg.API.BaseURL == "" { - cfg.API.BaseURL = "https://api.notion.com/v1" + cfg.API.BaseURL = d.API.BaseURL } cfg.API.BaseURL = strings.TrimRight(cfg.API.BaseURL, "/") cfg.API.NotionVersion = strings.TrimSpace(cfg.API.NotionVersion) if cfg.API.NotionVersion == "" { - cfg.API.NotionVersion = "2022-06-28" + cfg.API.NotionVersion = d.API.NotionVersion } cfg.API.Token = strings.TrimSpace(cfg.API.Token) }