diff --git a/internal/api/client.go b/internal/api/client.go new file mode 100644 index 0000000..363375f --- /dev/null +++ b/internal/api/client.go @@ -0,0 +1,249 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/lox/notion-cli/internal/config" +) + +// APIError is a structured error returned by the Notion API. +type APIError struct { + StatusCode int + Code string + Message string + RequestID string +} + +func (e *APIError) Error() string { + parts := make([]string, 0, 3) + parts = append(parts, fmt.Sprintf("notion API error (status %d)", e.StatusCode)) + if e.Code != "" { + parts = append(parts, fmt.Sprintf("code=%s", e.Code)) + } + if e.RequestID != "" { + parts = append(parts, fmt.Sprintf("request_id=%s", e.RequestID)) + } + msg := strings.Join(parts, " ") + if e.Message != "" { + msg = msg + ": " + e.Message + } + return msg +} + +// Page represents a Notion page response (minimal fields the client surfaces today). +type Page struct { + ID string `json:"id"` + Object string `json:"object"` + Archived bool `json:"archived"` +} + +// Icon represents a Notion page icon. Left opaque for callers that pass it through. +type Icon map[string]any + +// Cover represents a Notion page cover. Left opaque for callers that pass it through. +type Cover map[string]any + +// PropertyValue represents a Notion property value. Left opaque for callers that pass it through. +type PropertyValue map[string]any + +// PageUpdate is the typed payload for PatchPage. +// All fields are optional; only non-nil fields are sent. +type PageUpdate struct { + Archived *bool + Icon *Icon + Cover *Cover + Properties map[string]PropertyValue +} + +func (u PageUpdate) payload() (map[string]any, error) { + out := make(map[string]any) + if u.Archived != nil { + out["archived"] = *u.Archived + } + if u.Icon != nil { + out["icon"] = *u.Icon + } + if u.Cover != nil { + out["cover"] = *u.Cover + } + if len(u.Properties) > 0 { + out["properties"] = u.Properties + } + if len(out) == 0 { + return nil, fmt.Errorf("page update is empty") + } + return out, nil +} + +type Client struct { + httpClient *http.Client + baseURL string + notionVersion string + token string +} + +func NewClient(cfg config.APIConfig, token string) (*Client, error) { + token = strings.TrimSpace(token) + if token == "" { + return nil, fmt.Errorf("notion API token is required") + } + + baseURL := strings.TrimSpace(cfg.BaseURL) + if baseURL == "" { + baseURL = config.DefaultAPIBaseURL + } + baseURL = strings.TrimRight(baseURL, "/") + + notionVersion := strings.TrimSpace(cfg.NotionVersion) + if notionVersion == "" { + notionVersion = config.DefaultNotionAPIVersion + } + + return &Client{ + httpClient: &http.Client{Timeout: 20 * time.Second}, + baseURL: baseURL, + notionVersion: notionVersion, + token: token, + }, nil +} + +func (c *Client) PatchPage(ctx context.Context, pageID string, update PageUpdate) (*Page, error) { + pageID = strings.TrimSpace(pageID) + if pageID == "" { + return nil, fmt.Errorf("page ID is required") + } + + payload, err := update.payload() + if err != nil { + return nil, err + } + + pagePath, err := url.JoinPath(c.baseURL, "pages", url.PathEscape(pageID)) + if err != nil { + return nil, fmt.Errorf("build page URL: %w", err) + } + + var page Page + if err := c.doJSON(ctx, http.MethodPatch, pagePath, payload, &page); err != nil { + return nil, err + } + return &page, nil +} + +func (c *Client) doJSON(ctx context.Context, method, fullURL string, payload any, out any) error { + var body []byte + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return err + } + body = data + } + + resp, err := c.sendOnce(ctx, method, fullURL, body, payload != nil) + if err != nil { + return err + } + + // Minimal 429 handling: honor Retry-After once, retry once. + if resp.StatusCode == http.StatusTooManyRequests { + retryAfter := parseRetryAfter(resp.Header.Get("Retry-After")) + _ = resp.Body.Close() + if retryAfter > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryAfter): + } + } + resp, err = c.sendOnce(ctx, method, fullURL, body, payload != nil) + if err != nil { + return err + } + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + + if out == nil { + return nil + } + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("parse notion API response for %s %s: %w", method, fullURL, err) + } + return nil +} + +func (c *Client) sendOnce(ctx context.Context, method, fullURL string, body []byte, hasPayload bool) (*http.Response, error) { + var bodyReader *bytes.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + var req *http.Request + var err error + if bodyReader != nil { + req, err = http.NewRequestWithContext(ctx, method, fullURL, bodyReader) + } else { + req, err = http.NewRequestWithContext(ctx, method, fullURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Notion-Version", c.notionVersion) + if hasPayload { + req.Header.Set("Content-Type", "application/json") + } + + return c.httpClient.Do(req) +} + +func parseAPIError(resp *http.Response) error { + apiErr := &APIError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("X-Notion-Request-Id"), + } + + var parsed struct { + Code string `json:"code"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&parsed); err == nil { + apiErr.Code = strings.TrimSpace(parsed.Code) + apiErr.Message = strings.TrimSpace(parsed.Message) + } + if apiErr.Message == "" { + apiErr.Message = http.StatusText(resp.StatusCode) + } + return apiErr +} + +func parseRetryAfter(h string) time.Duration { + h = strings.TrimSpace(h) + if h == "" { + return 0 + } + if secs, err := strconv.Atoi(h); err == nil && secs >= 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(h); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + return 0 +} diff --git a/internal/api/client_test.go b/internal/api/client_test.go new file mode 100644 index 0000000..f635e6b --- /dev/null +++ b/internal/api/client_test.go @@ -0,0 +1,207 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/lox/notion-cli/internal/config" +) + +func TestNewClientRequiresToken(t *testing.T) { + t.Parallel() + + _, err := NewClient(config.APIConfig{}, "") + if err == nil { + t.Fatal("expected token error") + } +} + +func TestPatchPageSendsPatchRequest(t *testing.T) { + t.Parallel() + + var gotMethod string + var gotPath string + var gotAuth string + var gotVersion string + var gotContentType string + var gotAccept string + var gotBody map[string]any + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotVersion = r.Header.Get("Notion-Version") + gotContentType = r.Header.Get("Content-Type") + gotAccept = r.Header.Get("Accept") + + defer func() { _ = r.Body.Close() }() + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page-id","object":"page","archived":true}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{ + BaseURL: srv.URL, + NotionVersion: "2022-06-28", + }, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + page, err := client.PatchPage(context.Background(), "page-id", PageUpdate{Archived: &archived}) + if err != nil { + t.Fatalf("patch page: %v", err) + } + + if gotMethod != http.MethodPatch { + t.Fatalf("method mismatch: got %s", gotMethod) + } + if gotPath != "/pages/page-id" { + t.Fatalf("path mismatch: got %s", gotPath) + } + if gotAuth != "Bearer secret-token" { + t.Fatalf("auth mismatch: got %s", gotAuth) + } + if gotVersion != "2022-06-28" { + t.Fatalf("notion-version mismatch: got %s", gotVersion) + } + if gotContentType != "application/json" { + t.Fatalf("content-type mismatch: got %s", gotContentType) + } + if gotAccept != "application/json" { + t.Fatalf("accept mismatch: got %s", gotAccept) + } + + if gotBody["archived"] != true { + t.Fatalf("archived mismatch: %v", gotBody["archived"]) + } + if page == nil || page.ID != "page-id" || !page.Archived { + t.Fatalf("unexpected page: %+v", page) + } +} + +func TestPatchPageEscapesPageID(t *testing.T) { + t.Parallel() + + var gotPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.EscapedPath() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page id"}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + if _, err := client.PatchPage(context.Background(), "page id/with slash", PageUpdate{Archived: &archived}); err != nil { + t.Fatalf("patch page: %v", err) + } + if gotPath != "/pages/page%20id%2Fwith%20slash" { + t.Fatalf("unexpected escaped path: %s", gotPath) + } +} + +func TestPatchPageReturnsTypedAPIError(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Notion-Request-Id", "req-123") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"object":"error","code":"unauthorized","message":"invalid token"}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + _, err = client.PatchPage(context.Background(), "page-id", PageUpdate{Archived: &archived}) + if err == nil { + t.Fatal("expected API error") + } + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected *APIError, got %T: %v", err, err) + } + if apiErr.StatusCode != http.StatusUnauthorized { + t.Fatalf("status: got %d", apiErr.StatusCode) + } + if apiErr.Code != "unauthorized" { + t.Fatalf("code: got %q", apiErr.Code) + } + if apiErr.Message != "invalid token" { + t.Fatalf("message: got %q", apiErr.Message) + } + if apiErr.RequestID != "req-123" { + t.Fatalf("request id: got %q", apiErr.RequestID) + } +} + +func TestPatchPageRetriesOn429(t *testing.T) { + t.Parallel() + + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"code":"rate_limited","message":"slow down"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"page-id","archived":true}`)) + })) + defer srv.Close() + + client, err := NewClient(config.APIConfig{BaseURL: srv.URL}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + archived := true + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + page, err := client.PatchPage(ctx, "page-id", PageUpdate{Archived: &archived}) + if err != nil { + t.Fatalf("patch page: %v", err) + } + if page.ID != "page-id" { + t.Fatalf("unexpected page: %+v", page) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Fatalf("expected exactly one retry (2 calls), got %d", got) + } +} + +func TestPatchPageEmptyUpdateIsRejected(t *testing.T) { + t.Parallel() + + client, err := NewClient(config.APIConfig{BaseURL: "https://example.invalid"}, "secret-token") + if err != nil { + t.Fatalf("new client: %v", err) + } + + if _, err := client.PatchPage(context.Background(), "page-id", PageUpdate{}); err == nil { + t.Fatal("expected empty-update error") + } +} diff --git a/internal/cli/official_api.go b/internal/cli/official_api.go new file mode 100644 index 0000000..48906d9 --- /dev/null +++ b/internal/cli/official_api.go @@ -0,0 +1,22 @@ +package cli + +import ( + "fmt" + + "github.com/lox/notion-cli/internal/api" + "github.com/lox/notion-cli/internal/config" +) + +func RequireOfficialAPIClient() (*api.Client, error) { + cfg, err := config.Load() + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + client, err := api.NewClient(cfg.API, cfg.API.Token) + if err != nil { + return nil, fmt.Errorf("create official API client: %w (set api.token in ~/.config/notion-cli/config.json or NOTION_API_TOKEN)", err) + } + + return client, nil +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..17b0452 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,103 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" +) + +const ( + configDirName = ".config/notion-cli" + configFileName = "config.json" + + // DefaultAPIBaseURL is the default Notion REST API base URL. + DefaultAPIBaseURL = "https://api.notion.com/v1" + // DefaultNotionAPIVersion is the default Notion-Version header value. + DefaultNotionAPIVersion = "2022-06-28" +) + +type Config struct { + ActiveAccount string `json:"active_account,omitempty"` + API APIConfig `json:"api,omitempty"` +} + +type APIConfig struct { + BaseURL string `json:"base_url,omitempty"` + NotionVersion string `json:"notion_version,omitempty"` + Token string `json:"token,omitempty"` +} + +func Default() Config { + return Config{ + API: APIConfig{ + BaseURL: DefaultAPIBaseURL, + NotionVersion: DefaultNotionAPIVersion, + }, + } +} + +func Load() (Config, error) { + cfg := Default() + + path, err := Path() + if err != nil { + return cfg, err + } + + if data, err := os.ReadFile(path); err == nil { + if err := json.Unmarshal(data, &cfg); err != nil { + return cfg, err + } + } else if !errors.Is(err, os.ErrNotExist) { + return cfg, err + } + + applyEnvOverrides(&cfg) + normalize(&cfg) + return cfg, nil +} + +func Path() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, configDirName, configFileName), nil +} + +func applyEnvOverrides(cfg *Config) { + if cfg == nil { + return + } + + if s := os.Getenv("NOTION_API_BASE_URL"); s != "" { + cfg.API.BaseURL = s + } + if s := os.Getenv("NOTION_API_NOTION_VERSION"); s != "" { + cfg.API.NotionVersion = s + } + if s := os.Getenv("NOTION_API_TOKEN"); s != "" { + cfg.API.Token = s + } +} + +func normalize(cfg *Config) { + if cfg == nil { + return + } + + d := Default() + + cfg.API.BaseURL = strings.TrimSpace(cfg.API.BaseURL) + if cfg.API.BaseURL == "" { + cfg.API.BaseURL = d.API.BaseURL + } + cfg.API.BaseURL = strings.TrimRight(cfg.API.BaseURL, "/") + cfg.API.NotionVersion = strings.TrimSpace(cfg.API.NotionVersion) + if cfg.API.NotionVersion == "" { + cfg.API.NotionVersion = d.API.NotionVersion + } + cfg.API.Token = strings.TrimSpace(cfg.API.Token) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..47d39c4 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,47 @@ +package config + +import "testing" + +func TestApplyEnvOverrides(t *testing.T) { + t.Setenv("NOTION_API_BASE_URL", "https://api.example.com/v1/") + t.Setenv("NOTION_API_NOTION_VERSION", "2022-06-28") + t.Setenv("NOTION_API_TOKEN", "api-token") + + cfg := Default() + applyEnvOverrides(&cfg) + normalize(&cfg) + + if cfg.API.BaseURL != "https://api.example.com/v1" { + t.Fatalf("unexpected api.base_url normalization: %q", cfg.API.BaseURL) + } + if cfg.API.NotionVersion != "2022-06-28" { + t.Fatalf("unexpected api.notion_version: %q", cfg.API.NotionVersion) + } + if cfg.API.Token != "api-token" { + t.Fatalf("unexpected api.token: %q", cfg.API.Token) + } +} + +func TestNormalizeAppliesAPIDefaults(t *testing.T) { + cfg := Config{} + normalize(&cfg) + + if cfg.API.BaseURL != "https://api.notion.com/v1" { + t.Fatalf("unexpected api.base_url default: %q", cfg.API.BaseURL) + } + if cfg.API.NotionVersion != "2022-06-28" { + t.Fatalf("unexpected api.notion_version default: %q", cfg.API.NotionVersion) + } +} + +func TestPathUsesHome(t *testing.T) { + t.Setenv("HOME", "/tmp/example-home") + + path, err := Path() + if err != nil { + t.Fatal(err) + } + if path != "/tmp/example-home/.config/notion-cli/config.json" { + t.Fatalf("unexpected path: %s", path) + } +}