From 984ff871e5c97aa15ecdd0b72e8012818369f29f Mon Sep 17 00:00:00 2001 From: Rian Stockbower Date: Sun, 1 Feb 2026 06:48:26 -0500 Subject: [PATCH 1/2] feat: add Tooling API for developer workflows Add Tooling API support with apex, log, and coverage commands: - api/tooling: New package for Tooling API client with query, execute anonymous, test running, and log retrieval capabilities - sfdc apex: Apex class operations - list: List Apex classes or triggers - get: Get class/trigger source code - execute: Execute anonymous Apex - test: Run Apex tests asynchronously - sfdc log: Debug log operations - list: List recent debug logs - get: Get log content - tail: Stream new logs continuously - sfdc coverage: Code coverage reporting with minimum threshold Closes #20 --- api/tooling/client.go | 560 ++++++++++++++++++++++ api/tooling/client_test.go | 477 ++++++++++++++++++ api/tooling/types.go | 138 ++++++ cmd/sfdc/main.go | 8 + internal/cmd/apexcmd/apex.go | 36 ++ internal/cmd/apexcmd/apex_test.go | 379 +++++++++++++++ internal/cmd/apexcmd/execute.go | 104 ++++ internal/cmd/apexcmd/get.go | 77 +++ internal/cmd/apexcmd/list.go | 134 ++++++ internal/cmd/apexcmd/test.go | 188 ++++++++ internal/cmd/coveragecmd/coverage.go | 150 ++++++ internal/cmd/coveragecmd/coverage_test.go | 328 +++++++++++++ internal/cmd/logcmd/get.go | 57 +++ internal/cmd/logcmd/list.go | 97 ++++ internal/cmd/logcmd/log.go | 34 ++ internal/cmd/logcmd/log_test.go | 294 ++++++++++++ internal/cmd/logcmd/tail.go | 104 ++++ internal/cmd/root/root.go | 26 + 18 files changed, 3191 insertions(+) create mode 100644 api/tooling/client.go create mode 100644 api/tooling/client_test.go create mode 100644 api/tooling/types.go create mode 100644 internal/cmd/apexcmd/apex.go create mode 100644 internal/cmd/apexcmd/apex_test.go create mode 100644 internal/cmd/apexcmd/execute.go create mode 100644 internal/cmd/apexcmd/get.go create mode 100644 internal/cmd/apexcmd/list.go create mode 100644 internal/cmd/apexcmd/test.go create mode 100644 internal/cmd/coveragecmd/coverage.go create mode 100644 internal/cmd/coveragecmd/coverage_test.go create mode 100644 internal/cmd/logcmd/get.go create mode 100644 internal/cmd/logcmd/list.go create mode 100644 internal/cmd/logcmd/log.go create mode 100644 internal/cmd/logcmd/log_test.go create mode 100644 internal/cmd/logcmd/tail.go diff --git a/api/tooling/client.go b/api/tooling/client.go new file mode 100644 index 0000000..230793d --- /dev/null +++ b/api/tooling/client.go @@ -0,0 +1,560 @@ +package tooling + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// DefaultAPIVersion is the default Salesforce API version. +const DefaultAPIVersion = "v62.0" + +// Client is a Salesforce Tooling API client. +type Client struct { + httpClient *http.Client + instanceURL string + apiVersion string + baseURL string +} + +// ClientConfig contains configuration for creating a new Tooling API client. +type ClientConfig struct { + InstanceURL string + HTTPClient *http.Client + APIVersion string +} + +// New creates a new Tooling API client. +func New(cfg ClientConfig) (*Client, error) { + if cfg.InstanceURL == "" { + return nil, fmt.Errorf("instance URL is required") + } + if cfg.HTTPClient == nil { + return nil, fmt.Errorf("HTTP client is required") + } + + instanceURL := strings.TrimSuffix(cfg.InstanceURL, "/") + apiVersion := cfg.APIVersion + if apiVersion == "" { + apiVersion = DefaultAPIVersion + } + + return &Client{ + httpClient: cfg.HTTPClient, + instanceURL: instanceURL, + apiVersion: apiVersion, + baseURL: fmt.Sprintf("%s/services/data/%s/tooling", instanceURL, apiVersion), + }, nil +} + +// doRequest performs an HTTP request and returns the response body. +func (c *Client) doRequest(ctx context.Context, method, path string, body interface{}) ([]byte, error) { + var bodyReader io.Reader + if body != nil { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + bodyReader = bytes.NewReader(jsonBody) + } + + fullURL := path + if !strings.HasPrefix(path, "http") { + fullURL = c.baseURL + path + } + + req, err := http.NewRequestWithContext(ctx, method, fullURL, bodyReader) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody)) + } + + return respBody, nil +} + +// Get performs a GET request. +func (c *Client) Get(ctx context.Context, path string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, path, nil) +} + +// Post performs a POST request. +func (c *Client) Post(ctx context.Context, path string, body interface{}) ([]byte, error) { + return c.doRequest(ctx, http.MethodPost, path, body) +} + +// Query executes a SOQL query against the Tooling API. +func (c *Client) Query(ctx context.Context, soql string) (*QueryResult, error) { + path := fmt.Sprintf("/query?q=%s", url.QueryEscape(soql)) + body, err := c.Get(ctx, path) + if err != nil { + return nil, err + } + + var result QueryResult + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse query result: %w", err) + } + + return &result, nil +} + +// ListApexClasses returns all Apex classes. +func (c *Client) ListApexClasses(ctx context.Context) ([]ApexClass, error) { + soql := "SELECT Id, Name, Status, IsValid, ApiVersion, LengthWithoutComments, NamespacePrefix FROM ApexClass ORDER BY Name" + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + classes := make([]ApexClass, 0, len(result.Records)) + for _, rec := range result.Records { + class := recordToApexClass(rec) + classes = append(classes, class) + } + + return classes, nil +} + +// ListApexTriggers returns all Apex triggers. +func (c *Client) ListApexTriggers(ctx context.Context) ([]ApexTrigger, error) { + soql := "SELECT Id, Name, Status, IsValid, ApiVersion, TableEnumOrId, NamespacePrefix FROM ApexTrigger ORDER BY Name" + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + triggers := make([]ApexTrigger, 0, len(result.Records)) + for _, rec := range result.Records { + trigger := recordToApexTrigger(rec) + triggers = append(triggers, trigger) + } + + return triggers, nil +} + +// GetApexClass returns an Apex class by name, including body. +func (c *Client) GetApexClass(ctx context.Context, name string) (*ApexClass, error) { + soql := fmt.Sprintf("SELECT Id, Name, Body, Status, IsValid, ApiVersion, LengthWithoutComments, NamespacePrefix FROM ApexClass WHERE Name = '%s'", name) + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + if len(result.Records) == 0 { + return nil, fmt.Errorf("apex class not found: %s", name) + } + + class := recordToApexClass(result.Records[0]) + return &class, nil +} + +// GetApexTrigger returns an Apex trigger by name, including body. +func (c *Client) GetApexTrigger(ctx context.Context, name string) (*ApexTrigger, error) { + soql := fmt.Sprintf("SELECT Id, Name, Body, Status, IsValid, ApiVersion, TableEnumOrId, NamespacePrefix FROM ApexTrigger WHERE Name = '%s'", name) + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + if len(result.Records) == 0 { + return nil, fmt.Errorf("apex trigger not found: %s", name) + } + + trigger := recordToApexTrigger(result.Records[0]) + return &trigger, nil +} + +// ExecuteAnonymous executes anonymous Apex code. +func (c *Client) ExecuteAnonymous(ctx context.Context, code string) (*ExecuteAnonymousResult, error) { + path := fmt.Sprintf("/executeAnonymous?anonymousBody=%s", url.QueryEscape(code)) + body, err := c.Get(ctx, path) + if err != nil { + return nil, err + } + + var result ExecuteAnonymousResult + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse execute result: %w", err) + } + + return &result, nil +} + +// RunTestsAsync enqueues Apex tests to run asynchronously. +func (c *Client) RunTestsAsync(ctx context.Context, classIDs []string) (string, error) { + req := RunTestsRequest{ + ClassIDs: classIDs, + } + body, err := c.Post(ctx, "/runTestsAsynchronous", req) + if err != nil { + return "", err + } + + // Response is a quoted string with the job ID + var jobID string + if err := json.Unmarshal(body, &jobID); err != nil { + // Try unquoting directly + jobID = strings.Trim(string(body), "\"") + } + + return jobID, nil +} + +// GetTestResults returns test results for a given async job. +func (c *Client) GetTestResults(ctx context.Context, asyncJobID string) ([]ApexTestResult, error) { + soql := fmt.Sprintf( + "SELECT Id, ApexClassId, ApexClass.Name, MethodName, Outcome, Message, StackTrace, RunTime, AsyncApexJobId FROM ApexTestResult WHERE AsyncApexJobId = '%s'", + asyncJobID, + ) + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + results := make([]ApexTestResult, 0, len(result.Records)) + for _, rec := range result.Records { + tr := recordToApexTestResult(rec) + results = append(results, tr) + } + + return results, nil +} + +// GetAsyncJobStatus returns the status of an async Apex job. +func (c *Client) GetAsyncJobStatus(ctx context.Context, jobID string) (*AsyncApexJob, error) { + soql := fmt.Sprintf( + "SELECT Id, Status, JobItemsProcessed, TotalJobItems, NumberOfErrors, ExtendedStatus, CompletedDate FROM AsyncApexJob WHERE Id = '%s'", + jobID, + ) + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + if len(result.Records) == 0 { + return nil, fmt.Errorf("async job not found: %s", jobID) + } + + job := recordToAsyncApexJob(result.Records[0]) + return &job, nil +} + +// ListApexLogs returns debug logs. +func (c *Client) ListApexLogs(ctx context.Context, userID string, limit int) ([]ApexLog, error) { + soql := "SELECT Id, LogUserId, Operation, Request, Status, LogLength, DurationMilliseconds, StartTime, Location, Application FROM ApexLog" + if userID != "" { + soql += fmt.Sprintf(" WHERE LogUserId = '%s'", userID) + } + soql += " ORDER BY StartTime DESC" + if limit > 0 { + soql += fmt.Sprintf(" LIMIT %d", limit) + } + + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + logs := make([]ApexLog, 0, len(result.Records)) + for _, rec := range result.Records { + log := recordToApexLog(rec) + logs = append(logs, log) + } + + return logs, nil +} + +// GetApexLogBody returns the body content of a debug log. +func (c *Client) GetApexLogBody(ctx context.Context, logID string) (string, error) { + // Log body is retrieved from the REST API, not Tooling API + path := fmt.Sprintf("%s/services/data/%s/sobjects/ApexLog/%s/Body", c.instanceURL, c.apiVersion, logID) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode >= 400 { + return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + return string(body), nil +} + +// GetCodeCoverage returns aggregate code coverage for the org. +func (c *Client) GetCodeCoverage(ctx context.Context) ([]ApexCodeCoverageAggregate, error) { + soql := "SELECT Id, ApexClassOrTriggerId, ApexClassOrTrigger.Name, NumLinesCovered, NumLinesUncovered FROM ApexCodeCoverageAggregate ORDER BY ApexClassOrTrigger.Name" + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + coverage := make([]ApexCodeCoverageAggregate, 0, len(result.Records)) + for _, rec := range result.Records { + cov := recordToApexCodeCoverageAggregate(rec) + coverage = append(coverage, cov) + } + + return coverage, nil +} + +// GetCodeCoverageForClass returns aggregate code coverage for a specific class. +func (c *Client) GetCodeCoverageForClass(ctx context.Context, className string) (*ApexCodeCoverageAggregate, error) { + soql := fmt.Sprintf( + "SELECT Id, ApexClassOrTriggerId, ApexClassOrTrigger.Name, NumLinesCovered, NumLinesUncovered FROM ApexCodeCoverageAggregate WHERE ApexClassOrTrigger.Name = '%s'", + className, + ) + result, err := c.Query(ctx, soql) + if err != nil { + return nil, err + } + + if len(result.Records) == 0 { + return nil, fmt.Errorf("no coverage data found for: %s", className) + } + + cov := recordToApexCodeCoverageAggregate(result.Records[0]) + return &cov, nil +} + +// GetApexClassID returns the ID of an Apex class by name. +func (c *Client) GetApexClassID(ctx context.Context, className string) (string, error) { + soql := fmt.Sprintf("SELECT Id FROM ApexClass WHERE Name = '%s'", className) + result, err := c.Query(ctx, soql) + if err != nil { + return "", err + } + + if len(result.Records) == 0 { + return "", fmt.Errorf("apex class not found: %s", className) + } + + id, _ := result.Records[0]["Id"].(string) + return id, nil +} + +// Helper functions to convert generic records to typed structs + +func recordToApexClass(rec Record) ApexClass { + class := ApexClass{} + if v, ok := rec["Id"].(string); ok { + class.ID = v + } + if v, ok := rec["Name"].(string); ok { + class.Name = v + } + if v, ok := rec["Body"].(string); ok { + class.Body = v + } + if v, ok := rec["Status"].(string); ok { + class.Status = v + } + if v, ok := rec["IsValid"].(bool); ok { + class.IsValid = v + } + if v, ok := rec["ApiVersion"].(float64); ok { + class.APIVersion = v + } + if v, ok := rec["LengthWithoutComments"].(float64); ok { + class.LengthWithoutComments = int(v) + } + if v, ok := rec["NamespacePrefix"].(string); ok { + class.NamespacePrefix = v + } + return class +} + +func recordToApexTrigger(rec Record) ApexTrigger { + trigger := ApexTrigger{} + if v, ok := rec["Id"].(string); ok { + trigger.ID = v + } + if v, ok := rec["Name"].(string); ok { + trigger.Name = v + } + if v, ok := rec["Body"].(string); ok { + trigger.Body = v + } + if v, ok := rec["Status"].(string); ok { + trigger.Status = v + } + if v, ok := rec["IsValid"].(bool); ok { + trigger.IsValid = v + } + if v, ok := rec["ApiVersion"].(float64); ok { + trigger.APIVersion = v + } + if v, ok := rec["TableEnumOrId"].(string); ok { + trigger.TableEnumOrID = v + } + if v, ok := rec["NamespacePrefix"].(string); ok { + trigger.NamespacePrefix = v + } + return trigger +} + +func recordToApexLog(rec Record) ApexLog { + log := ApexLog{} + if v, ok := rec["Id"].(string); ok { + log.ID = v + } + if v, ok := rec["LogUserId"].(string); ok { + log.LogUserID = v + } + if v, ok := rec["Operation"].(string); ok { + log.Operation = v + } + if v, ok := rec["Request"].(string); ok { + log.Request = v + } + if v, ok := rec["Status"].(string); ok { + log.Status = v + } + if v, ok := rec["LogLength"].(float64); ok { + log.LogLength = int(v) + } + if v, ok := rec["DurationMilliseconds"].(float64); ok { + log.DurationMS = int(v) + } + if v, ok := rec["StartTime"].(string); ok { + log.StartTime, _ = parseTime(v) + } + if v, ok := rec["Location"].(string); ok { + log.Location = v + } + if v, ok := rec["Application"].(string); ok { + log.Application = v + } + return log +} + +func recordToApexTestResult(rec Record) ApexTestResult { + result := ApexTestResult{} + if v, ok := rec["Id"].(string); ok { + result.ID = v + } + if v, ok := rec["ApexClassId"].(string); ok { + result.ApexClassID = v + } + if nested, ok := rec["ApexClass"].(map[string]interface{}); ok { + if v, ok := nested["Name"].(string); ok { + result.ClassName = v + } + } + if v, ok := rec["MethodName"].(string); ok { + result.MethodName = v + } + if v, ok := rec["Outcome"].(string); ok { + result.Outcome = v + } + if v, ok := rec["Message"].(string); ok { + result.Message = v + } + if v, ok := rec["StackTrace"].(string); ok { + result.StackTrace = v + } + if v, ok := rec["RunTime"].(float64); ok { + result.RunTime = int(v) + } + if v, ok := rec["AsyncApexJobId"].(string); ok { + result.AsyncApexJobID = v + } + return result +} + +func recordToAsyncApexJob(rec Record) AsyncApexJob { + job := AsyncApexJob{} + if v, ok := rec["Id"].(string); ok { + job.ID = v + } + if v, ok := rec["Status"].(string); ok { + job.Status = v + } + if v, ok := rec["JobItemsProcessed"].(float64); ok { + job.JobItemsProcessed = int(v) + } + if v, ok := rec["TotalJobItems"].(float64); ok { + job.TotalJobItems = int(v) + } + if v, ok := rec["NumberOfErrors"].(float64); ok { + job.NumberOfErrors = int(v) + } + if v, ok := rec["ExtendedStatus"].(string); ok { + job.ExtendedStatus = v + } + if v, ok := rec["CompletedDate"].(string); ok { + job.CompletedDate = v + } + return job +} + +func recordToApexCodeCoverageAggregate(rec Record) ApexCodeCoverageAggregate { + cov := ApexCodeCoverageAggregate{} + if v, ok := rec["Id"].(string); ok { + cov.ID = v + } + if v, ok := rec["ApexClassOrTriggerId"].(string); ok { + cov.ApexClassOrTriggerID = v + } + if nested, ok := rec["ApexClassOrTrigger"].(map[string]interface{}); ok { + if v, ok := nested["Name"].(string); ok { + cov.ApexClassOrTrigger.Name = v + } + } + if v, ok := rec["NumLinesCovered"].(float64); ok { + cov.NumLinesCovered = int(v) + } + if v, ok := rec["NumLinesUncovered"].(float64); ok { + cov.NumLinesUncovered = int(v) + } + return cov +} + +func parseTime(s string) (time.Time, error) { + // Salesforce datetime format + formats := []string{ + "2006-01-02T15:04:05.000+0000", + "2006-01-02T15:04:05.000Z", + "2006-01-02T15:04:05Z", + } + for _, f := range formats { + if t, err := time.Parse(f, s); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("unable to parse time: %s", s) +} diff --git a/api/tooling/client_test.go b/api/tooling/client_test.go new file mode 100644 index 0000000..5393070 --- /dev/null +++ b/api/tooling/client_test.go @@ -0,0 +1,477 @@ +package tooling + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + cfg ClientConfig + wantErr bool + }{ + { + name: "valid config", + cfg: ClientConfig{ + InstanceURL: "https://test.salesforce.com", + HTTPClient: http.DefaultClient, + }, + wantErr: false, + }, + { + name: "missing instance URL", + cfg: ClientConfig{ + HTTPClient: http.DefaultClient, + }, + wantErr: true, + }, + { + name: "missing HTTP client", + cfg: ClientConfig{ + InstanceURL: "https://test.salesforce.com", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := New(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, client) + } else { + assert.NoError(t, err) + assert.NotNil(t, client) + } + }) + } +} + +func TestListApexClasses(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "/tooling/query") + assert.Contains(t, r.URL.RawQuery, "ApexClass") + + response := QueryResult{ + TotalSize: 2, + Done: true, + Records: []Record{ + { + "Id": "01p000000000001", + "Name": "MyController", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "LengthWithoutComments": float64(500), + }, + { + "Id": "01p000000000002", + "Name": "MyHelper", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "LengthWithoutComments": float64(300), + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + classes, err := client.ListApexClasses(context.Background()) + require.NoError(t, err) + assert.Len(t, classes, 2) + assert.Equal(t, "MyController", classes[0].Name) + assert.Equal(t, "MyHelper", classes[1].Name) +} + +func TestListApexTriggers(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "/tooling/query") + assert.Contains(t, r.URL.RawQuery, "ApexTrigger") + + response := QueryResult{ + TotalSize: 1, + Done: true, + Records: []Record{ + { + "Id": "01q000000000001", + "Name": "AccountTrigger", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "TableEnumOrId": "Account", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + triggers, err := client.ListApexTriggers(context.Background()) + require.NoError(t, err) + assert.Len(t, triggers, 1) + assert.Equal(t, "AccountTrigger", triggers[0].Name) + assert.Equal(t, "Account", triggers[0].TableEnumOrID) +} + +func TestGetApexClass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 1, + Done: true, + Records: []Record{ + { + "Id": "01p000000000001", + "Name": "MyController", + "Body": "public class MyController { }", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + class, err := client.GetApexClass(context.Background(), "MyController") + require.NoError(t, err) + assert.Equal(t, "MyController", class.Name) + assert.Equal(t, "public class MyController { }", class.Body) +} + +func TestGetApexClassNotFound(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 0, + Done: true, + Records: []Record{}, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + _, err = client.GetApexClass(context.Background(), "NonExistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestExecuteAnonymous(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "/executeAnonymous") + assert.Contains(t, r.URL.RawQuery, "anonymousBody") + + response := ExecuteAnonymousResult{ + Compiled: true, + Success: true, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + result, err := client.ExecuteAnonymous(context.Background(), "System.debug('Hello');") + require.NoError(t, err) + assert.True(t, result.Compiled) + assert.True(t, result.Success) +} + +func TestExecuteAnonymousCompileError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := ExecuteAnonymousResult{ + Line: 1, + Column: 10, + Compiled: false, + Success: false, + CompileProblem: "Variable does not exist: foo", + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + result, err := client.ExecuteAnonymous(context.Background(), "System.debug(foo);") + require.NoError(t, err) + assert.False(t, result.Compiled) + assert.False(t, result.Success) + assert.Equal(t, "Variable does not exist: foo", result.CompileProblem) +} + +func TestRunTestsAsync(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Contains(t, r.URL.Path, "/runTestsAsynchronous") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`"7071x00000ABCDE"`)) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + jobID, err := client.RunTestsAsync(context.Background(), []string{"01p000000000001"}) + require.NoError(t, err) + assert.Equal(t, "7071x00000ABCDE", jobID) +} + +func TestGetTestResults(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 2, + Done: true, + Records: []Record{ + { + "Id": "07M000000000001", + "ApexClassId": "01p000000000001", + "ApexClass": map[string]interface{}{"Name": "MyTest"}, + "MethodName": "testSuccess", + "Outcome": "Pass", + "RunTime": float64(150), + "AsyncApexJobId": "7071x00000ABCDE", + }, + { + "Id": "07M000000000002", + "ApexClassId": "01p000000000001", + "ApexClass": map[string]interface{}{"Name": "MyTest"}, + "MethodName": "testFailure", + "Outcome": "Fail", + "Message": "Assertion failed", + "RunTime": float64(200), + "AsyncApexJobId": "7071x00000ABCDE", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + results, err := client.GetTestResults(context.Background(), "7071x00000ABCDE") + require.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, "testSuccess", results[0].MethodName) + assert.Equal(t, "Pass", results[0].Outcome) + assert.Equal(t, "testFailure", results[1].MethodName) + assert.Equal(t, "Fail", results[1].Outcome) +} + +func TestListApexLogs(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 1, + Done: true, + Records: []Record{ + { + "Id": "07L000000000001", + "LogUserId": "005000000000001", + "Operation": "/aura", + "Request": "API", + "Status": "Success", + "LogLength": float64(5000), + "DurationMilliseconds": float64(150), + "StartTime": "2024-01-15T10:30:00.000+0000", + "Location": "MonitoringService", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + logs, err := client.ListApexLogs(context.Background(), "", 10) + require.NoError(t, err) + assert.Len(t, logs, 1) + assert.Equal(t, "07L000000000001", logs[0].ID) + assert.Equal(t, "/aura", logs[0].Operation) + assert.Equal(t, 5000, logs[0].LogLength) +} + +func TestGetApexLogBody(t *testing.T) { + logContent := "DEBUG|Hello World\nUSER_DEBUG|Test message" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "/sobjects/ApexLog/") + assert.Contains(t, r.URL.Path, "/Body") + + w.Write([]byte(logContent)) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + body, err := client.GetApexLogBody(context.Background(), "07L000000000001") + require.NoError(t, err) + assert.Equal(t, logContent, body) +} + +func TestGetCodeCoverage(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 2, + Done: true, + Records: []Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(80), + "NumLinesUncovered": float64(20), + }, + { + "Id": "500000000000002", + "ApexClassOrTriggerId": "01p000000000002", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyHelper"}, + "NumLinesCovered": float64(50), + "NumLinesUncovered": float64(50), + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + coverage, err := client.GetCodeCoverage(context.Background()) + require.NoError(t, err) + assert.Len(t, coverage, 2) + assert.Equal(t, "MyController", coverage[0].ApexClassOrTrigger.Name) + assert.Equal(t, 80, coverage[0].NumLinesCovered) + assert.Equal(t, 20, coverage[0].NumLinesUncovered) +} + +func TestGetAsyncJobStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := QueryResult{ + TotalSize: 1, + Done: true, + Records: []Record{ + { + "Id": "7071x00000ABCDE", + "Status": "Completed", + "JobItemsProcessed": float64(5), + "TotalJobItems": float64(5), + "NumberOfErrors": float64(1), + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + job, err := client.GetAsyncJobStatus(context.Background(), "7071x00000ABCDE") + require.NoError(t, err) + assert.Equal(t, "Completed", job.Status) + assert.Equal(t, 5, job.TotalJobItems) + assert.Equal(t, 1, job.NumberOfErrors) +} + +func TestAPIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`[{"errorCode": "INVALID_SESSION_ID", "message": "Session expired"}]`)) + })) + defer server.Close() + + client, err := New(ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + _, err = client.ListApexClasses(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "401") +} diff --git a/api/tooling/types.go b/api/tooling/types.go new file mode 100644 index 0000000..7199851 --- /dev/null +++ b/api/tooling/types.go @@ -0,0 +1,138 @@ +// Package tooling provides a client for the Salesforce Tooling API. +package tooling + +import "time" + +// ApexClass represents an Apex class in Salesforce. +type ApexClass struct { + ID string `json:"Id"` + Name string `json:"Name"` + Body string `json:"Body,omitempty"` + Status string `json:"Status"` + IsValid bool `json:"IsValid"` + APIVersion float64 `json:"ApiVersion"` + LengthWithoutComments int `json:"LengthWithoutComments"` + NamespacePrefix string `json:"NamespacePrefix,omitempty"` +} + +// ApexTrigger represents an Apex trigger in Salesforce. +type ApexTrigger struct { + ID string `json:"Id"` + Name string `json:"Name"` + Body string `json:"Body,omitempty"` + Status string `json:"Status"` + IsValid bool `json:"IsValid"` + APIVersion float64 `json:"ApiVersion"` + TableEnumOrID string `json:"TableEnumOrId"` + NamespacePrefix string `json:"NamespacePrefix,omitempty"` +} + +// ApexLog represents a debug log entry. +type ApexLog struct { + ID string `json:"Id"` + LogUserID string `json:"LogUserId"` + LogUserName string `json:"LogUser.Name,omitempty"` + Operation string `json:"Operation"` + Request string `json:"Request"` + Status string `json:"Status"` + LogLength int `json:"LogLength"` + DurationMS int `json:"DurationMilliseconds"` + StartTime time.Time `json:"StartTime"` + Location string `json:"Location"` + Application string `json:"Application,omitempty"` + LastModified time.Time `json:"LastModifiedDate,omitempty"` + SystemModstamp time.Time `json:"SystemModstamp,omitempty"` +} + +// ApexTestQueueItem represents a test class in the test queue. +type ApexTestQueueItem struct { + ID string `json:"Id"` + ApexClassID string `json:"ApexClassId"` + Status string `json:"Status"` + ExtendedStatus string `json:"ExtendedStatus,omitempty"` + ParentJobID string `json:"ParentJobId,omitempty"` +} + +// ApexTestResult represents the result of running an Apex test method. +type ApexTestResult struct { + ID string `json:"Id"` + ApexClassID string `json:"ApexClassId"` + ClassName string `json:"ApexClass.Name,omitempty"` + MethodName string `json:"MethodName"` + Outcome string `json:"Outcome"` // Pass, Fail, CompileFail, Skip + Message string `json:"Message,omitempty"` + StackTrace string `json:"StackTrace,omitempty"` + RunTime int `json:"RunTime"` // milliseconds + AsyncApexJobID string `json:"AsyncApexJobId"` + TestTimestamp string `json:"TestTimestamp,omitempty"` +} + +// ApexCodeCoverage represents code coverage for an Apex class. +type ApexCodeCoverage struct { + ID string `json:"Id"` + ApexClassOrTriggerID string `json:"ApexClassOrTriggerId"` + ApexClassOrTrigger struct { + Name string `json:"Name"` + } `json:"ApexClassOrTrigger,omitempty"` + ApexTestClassID string `json:"ApexTestClassId"` + NumLinesCovered int `json:"NumLinesCovered"` + NumLinesUncovered int `json:"NumLinesUncovered"` +} + +// ApexCodeCoverageAggregate represents aggregate code coverage. +type ApexCodeCoverageAggregate struct { + ID string `json:"Id"` + ApexClassOrTriggerID string `json:"ApexClassOrTriggerId"` + ApexClassOrTrigger struct { + Name string `json:"Name"` + } `json:"ApexClassOrTrigger,omitempty"` + NumLinesCovered int `json:"NumLinesCovered"` + NumLinesUncovered int `json:"NumLinesUncovered"` +} + +// ExecuteAnonymousResult represents the result of executing anonymous Apex. +type ExecuteAnonymousResult struct { + Line int `json:"line"` + Column int `json:"column"` + Compiled bool `json:"compiled"` + Success bool `json:"success"` + CompileProblem string `json:"compileProblem,omitempty"` + ExceptionMessage string `json:"exceptionMessage,omitempty"` + ExceptionStackTrace string `json:"exceptionStackTrace,omitempty"` +} + +// AsyncApexJob represents an asynchronous Apex job (for test runs). +type AsyncApexJob struct { + ID string `json:"Id"` + Status string `json:"Status"` // Queued, Processing, Completed, Aborted, Failed + JobItemsProcessed int `json:"JobItemsProcessed"` + TotalJobItems int `json:"TotalJobItems"` + NumberOfErrors int `json:"NumberOfErrors"` + MethodName string `json:"MethodName,omitempty"` + ExtendedStatus string `json:"ExtendedStatus,omitempty"` + ParentJobID string `json:"ParentJobId,omitempty"` + ApexClassID string `json:"ApexClassId,omitempty"` + CompletedDate string `json:"CompletedDate,omitempty"` +} + +// QueryResult represents the result of a Tooling API query. +type QueryResult struct { + TotalSize int `json:"totalSize"` + Done bool `json:"done"` + Records []Record `json:"records"` + NextRecordsURL string `json:"nextRecordsUrl,omitempty"` +} + +// Record represents a generic record from a Tooling API query. +type Record map[string]interface{} + +// RunTestsRequest represents a request to run Apex tests. +type RunTestsRequest struct { + ClassIDs []string `json:"classids,omitempty"` + SuiteIDs []string `json:"suiteids,omitempty"` + MaxFailedTests int `json:"maxFailedTests,omitempty"` + TestLevel string `json:"testLevel,omitempty"` +} + +// RunTestsAsyncResult represents the result of enqueuing tests. +type RunTestsAsyncResult string diff --git a/cmd/sfdc/main.go b/cmd/sfdc/main.go index 1117e6f..a98baef 100644 --- a/cmd/sfdc/main.go +++ b/cmd/sfdc/main.go @@ -5,11 +5,14 @@ import ( "fmt" "os" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/apexcmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/bulkcmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/completion" "github.com/open-cli-collective/salesforce-cli/internal/cmd/configcmd" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/coveragecmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/initcmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/limitscmd" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/logcmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/objectcmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/querycmd" "github.com/open-cli-collective/salesforce-cli/internal/cmd/recordcmd" @@ -49,5 +52,10 @@ func run() error { // Bulk API commands bulkcmd.Register(rootCmd, opts) + // Tooling API commands + apexcmd.Register(rootCmd, opts) + logcmd.Register(rootCmd, opts) + coveragecmd.Register(rootCmd, opts) + return rootCmd.Execute() } diff --git a/internal/cmd/apexcmd/apex.go b/internal/cmd/apexcmd/apex.go new file mode 100644 index 0000000..c50500d --- /dev/null +++ b/internal/cmd/apexcmd/apex.go @@ -0,0 +1,36 @@ +// Package apexcmd provides commands for Apex class operations. +package apexcmd + +import ( + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +// Register registers the apex command with the root command. +func Register(parent *cobra.Command, opts *root.Options) { + parent.AddCommand(NewCommand(opts)) +} + +// NewCommand creates the apex command. +func NewCommand(opts *root.Options) *cobra.Command { + cmd := &cobra.Command{ + Use: "apex", + Short: "Apex class operations", + Long: `Manage Apex classes, triggers, and execute anonymous Apex. + +Examples: + sfdc apex list # List all Apex classes + sfdc apex list --triggers # List all Apex triggers + sfdc apex get MyController # Get class source code + sfdc apex execute "System.debug('Hi');" # Execute anonymous Apex + sfdc apex test --class MyTest # Run Apex tests`, + } + + cmd.AddCommand(newListCommand(opts)) + cmd.AddCommand(newGetCommand(opts)) + cmd.AddCommand(newExecuteCommand(opts)) + cmd.AddCommand(newTestCommand(opts)) + + return cmd +} diff --git a/internal/cmd/apexcmd/apex_test.go b/internal/cmd/apexcmd/apex_test.go new file mode 100644 index 0000000..d108ce1 --- /dev/null +++ b/internal/cmd/apexcmd/apex_test.go @@ -0,0 +1,379 @@ +package apexcmd + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/open-cli-collective/salesforce-cli/api/tooling" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func TestApexListClasses(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 2, + Done: true, + Records: []tooling.Record{ + { + "Id": "01p000000000001", + "Name": "MyController", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "LengthWithoutComments": float64(500), + }, + { + "Id": "01p000000000002", + "Name": "MyHelper", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "LengthWithoutComments": float64(300), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "MyController") + assert.Contains(t, output, "MyHelper") + assert.Contains(t, output, "2 class(es)") +} + +func TestApexListTriggers(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "01q000000000001", + "Name": "AccountTrigger", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "TableEnumOrId": "Account", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list", "--triggers"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "AccountTrigger") + assert.Contains(t, output, "Account") +} + +func TestApexGet(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "01p000000000001", + "Name": "MyController", + "Body": "public class MyController {\n public void doSomething() { }\n}", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "plain", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"get", "MyController"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "public class MyController") + assert.Contains(t, output, "doSomething") +} + +func TestApexExecuteSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.ExecuteAnonymousResult{ + Compiled: true, + Success: true, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"execute", "System.debug('Hello');"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "Executed successfully") +} + +func TestApexExecuteCompileError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.ExecuteAnonymousResult{ + Line: 1, + Column: 10, + Compiled: false, + Success: false, + CompileProblem: "Variable does not exist: foo", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: stderr, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"execute", "System.debug(foo);"}) + cmd.SetOut(stdout) + cmd.SetErr(stderr) + + err = cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "compilation failed") +} + +func TestApexExecuteFromStdin(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.ExecuteAnonymousResult{ + Compiled: true, + Success: true, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdin := strings.NewReader("System.debug('From stdin');") + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdin: stdin, + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"execute", "-"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "Executed successfully") +} + +func TestApexTestNoWait(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + if strings.Contains(r.URL.RawQuery, "ApexClass") && strings.Contains(r.URL.RawQuery, "Id") && strings.Contains(r.URL.RawQuery, "MyTest") { + // Get class ID + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + {"Id": "01p000000000001"}, + }, + } + _ = json.NewEncoder(w).Encode(response) + } else if strings.Contains(r.URL.Path, "runTestsAsynchronous") { + // Run tests + w.Write([]byte(`"7071x00000ABCDE"`)) + } + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"test", "--class", "MyTest"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "7071x00000ABCDE") + assert.Contains(t, output, "Tests enqueued") +} + +func TestApexListJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "01p000000000001", + "Name": "MyController", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "LengthWithoutComments": float64(500), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "json", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "MyController") + // Should be valid JSON + var result []tooling.ApexClass + err = json.Unmarshal([]byte(output), &result) + require.NoError(t, err) + assert.Len(t, result, 1) +} diff --git a/internal/cmd/apexcmd/execute.go b/internal/cmd/apexcmd/execute.go new file mode 100644 index 0000000..c4c7ca8 --- /dev/null +++ b/internal/cmd/apexcmd/execute.go @@ -0,0 +1,104 @@ +package apexcmd + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newExecuteCommand(opts *root.Options) *cobra.Command { + var file string + + cmd := &cobra.Command{ + Use: "execute [code]", + Short: "Execute anonymous Apex code", + Long: `Execute anonymous Apex code. + +The code can be provided as an argument, from a file, or via stdin. + +Examples: + sfdc apex execute "System.debug('Hello');" + sfdc apex execute --file script.apex + echo "System.debug(UserInfo.getUserName());" | sfdc apex execute - + sfdc apex execute - # Read from stdin`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + var code string + + if file != "" { + // Read from file + data, readErr := os.ReadFile(file) + if readErr != nil { + return fmt.Errorf("failed to read file: %w", readErr) + } + code = string(data) + } else if len(args) == 1 { + if args[0] == "-" { + // Read from stdin + data, err := io.ReadAll(opts.Stdin) + if err != nil { + return fmt.Errorf("failed to read stdin: %w", err) + } + code = string(data) + } else { + code = args[0] + } + } else { + return fmt.Errorf("code required: provide as argument, --file, or pipe to stdin") + } + + code = strings.TrimSpace(code) + if code == "" { + return fmt.Errorf("empty code provided") + } + + return runExecute(cmd.Context(), opts, code) + }, + } + + cmd.Flags().StringVarP(&file, "file", "f", "", "File containing Apex code") + + return cmd +} + +func runExecute(ctx context.Context, opts *root.Options, code string) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + result, err := client.ExecuteAnonymous(ctx, code) + if err != nil { + return fmt.Errorf("failed to execute anonymous apex: %w", err) + } + + v := opts.View() + + if opts.Output == "json" { + return v.JSON(result) + } + + if !result.Compiled { + v.Error("Compile error at line %d, column %d:", result.Line, result.Column) + fmt.Fprintln(opts.Stderr, result.CompileProblem) + return fmt.Errorf("compilation failed") + } + + if !result.Success { + v.Error("Runtime error:") + fmt.Fprintln(opts.Stderr, result.ExceptionMessage) + if result.ExceptionStackTrace != "" { + fmt.Fprintln(opts.Stderr, result.ExceptionStackTrace) + } + return fmt.Errorf("execution failed") + } + + v.Success("Executed successfully") + return nil +} diff --git a/internal/cmd/apexcmd/get.go b/internal/cmd/apexcmd/get.go new file mode 100644 index 0000000..d00daac --- /dev/null +++ b/internal/cmd/apexcmd/get.go @@ -0,0 +1,77 @@ +package apexcmd + +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newGetCommand(opts *root.Options) *cobra.Command { + var ( + outputFile string + trigger bool + ) + + cmd := &cobra.Command{ + Use: "get ", + Short: "Get Apex class or trigger source code", + Long: `Get the source code of an Apex class or trigger. + +Examples: + sfdc apex get MyController # Display class source + sfdc apex get MyController --output My.cls # Save to file + sfdc apex get MyTrigger --trigger # Get trigger source`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runGet(cmd.Context(), opts, args[0], outputFile, trigger) + }, + } + + cmd.Flags().StringVarP(&outputFile, "output", "f", "", "Output file path") + cmd.Flags().BoolVar(&trigger, "trigger", false, "Get trigger instead of class") + + return cmd +} + +func runGet(ctx context.Context, opts *root.Options, name, outputFile string, trigger bool) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + var body string + var typeName string + + if trigger { + typeName = "trigger" + t, err := client.GetApexTrigger(ctx, name) + if err != nil { + return fmt.Errorf("failed to get apex trigger: %w", err) + } + body = t.Body + } else { + typeName = "class" + c, err := client.GetApexClass(ctx, name) + if err != nil { + return fmt.Errorf("failed to get apex class: %w", err) + } + body = c.Body + } + + if outputFile != "" { + if err := os.WriteFile(outputFile, []byte(body), 0644); err != nil { + return fmt.Errorf("failed to write output file: %w", err) + } + v := opts.View() + v.Info("Saved %s %s to %s", typeName, name, outputFile) + return nil + } + + // Output directly to stdout + fmt.Fprintln(opts.Stdout, body) + return nil +} diff --git a/internal/cmd/apexcmd/list.go b/internal/cmd/apexcmd/list.go new file mode 100644 index 0000000..ebd0d53 --- /dev/null +++ b/internal/cmd/apexcmd/list.go @@ -0,0 +1,134 @@ +package apexcmd + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/api/tooling" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newListCommand(opts *root.Options) *cobra.Command { + var triggers bool + + cmd := &cobra.Command{ + Use: "list", + Short: "List Apex classes or triggers", + Long: `List all Apex classes or triggers in the org. + +Examples: + sfdc apex list # List all Apex classes + sfdc apex list --triggers # List all Apex triggers + sfdc apex list -o json # Output as JSON`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return runList(cmd.Context(), opts, triggers) + }, + } + + cmd.Flags().BoolVar(&triggers, "triggers", false, "List triggers instead of classes") + + return cmd +} + +func runList(ctx context.Context, opts *root.Options, triggers bool) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + v := opts.View() + + if triggers { + return listTriggers(ctx, client, v, opts) + } + return listClasses(ctx, client, v, opts) +} + +func listClasses(ctx context.Context, client *tooling.Client, v interface { + Table([]string, [][]string) error + JSON(interface{}) error + Info(string, ...interface{}) +}, opts *root.Options) error { + classes, err := client.ListApexClasses(ctx) + if err != nil { + return fmt.Errorf("failed to list apex classes: %w", err) + } + + if len(classes) == 0 { + v.Info("No Apex classes found") + return nil + } + + if opts.Output == "json" { + return v.JSON(classes) + } + + headers := []string{"ID", "Name", "Status", "Valid", "API Version", "Lines"} + rows := make([][]string, 0, len(classes)) + for _, c := range classes { + valid := "No" + if c.IsValid { + valid = "Yes" + } + rows = append(rows, []string{ + c.ID, + c.Name, + c.Status, + valid, + fmt.Sprintf("%.0f", c.APIVersion), + fmt.Sprintf("%d", c.LengthWithoutComments), + }) + } + + if err := v.Table(headers, rows); err != nil { + return err + } + v.Info("\n%d class(es)", len(classes)) + return nil +} + +func listTriggers(ctx context.Context, client *tooling.Client, v interface { + Table([]string, [][]string) error + JSON(interface{}) error + Info(string, ...interface{}) +}, opts *root.Options) error { + triggers, err := client.ListApexTriggers(ctx) + if err != nil { + return fmt.Errorf("failed to list apex triggers: %w", err) + } + + if len(triggers) == 0 { + v.Info("No Apex triggers found") + return nil + } + + if opts.Output == "json" { + return v.JSON(triggers) + } + + headers := []string{"ID", "Name", "Object", "Status", "Valid", "API Version"} + rows := make([][]string, 0, len(triggers)) + for _, t := range triggers { + valid := "No" + if t.IsValid { + valid = "Yes" + } + rows = append(rows, []string{ + t.ID, + t.Name, + t.TableEnumOrID, + t.Status, + valid, + fmt.Sprintf("%.0f", t.APIVersion), + }) + } + + if err := v.Table(headers, rows); err != nil { + return err + } + v.Info("\n%d trigger(s)", len(triggers)) + return nil +} diff --git a/internal/cmd/apexcmd/test.go b/internal/cmd/apexcmd/test.go new file mode 100644 index 0000000..f0058eb --- /dev/null +++ b/internal/cmd/apexcmd/test.go @@ -0,0 +1,188 @@ +package apexcmd + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/api/tooling" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newTestCommand(opts *root.Options) *cobra.Command { + var ( + className string + methodName string + wait bool + ) + + cmd := &cobra.Command{ + Use: "test", + Short: "Run Apex tests", + Long: `Run Apex tests asynchronously. + +Examples: + sfdc apex test --class MyControllerTest + sfdc apex test --class MyControllerTest --method testCreate + sfdc apex test --class MyTest --wait + sfdc apex test --class MyTest -o json`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + if className == "" { + return fmt.Errorf("--class is required") + } + return runTest(cmd.Context(), opts, className, methodName, wait) + }, + } + + cmd.Flags().StringVar(&className, "class", "", "Test class name (required)") + cmd.Flags().StringVar(&methodName, "method", "", "Specific test method to run") + cmd.Flags().BoolVar(&wait, "wait", false, "Wait for tests to complete") + + return cmd +} + +func runTest(ctx context.Context, opts *root.Options, className, methodName string, wait bool) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + v := opts.View() + + // Get the class ID for the test class + classID, err := client.GetApexClassID(ctx, className) + if err != nil { + return fmt.Errorf("failed to find test class: %w", err) + } + + v.Info("Running tests for %s...", className) + + // Enqueue the test run + jobID, err := client.RunTestsAsync(ctx, []string{classID}) + if err != nil { + return fmt.Errorf("failed to enqueue tests: %w", err) + } + + v.Info("Test job ID: %s", jobID) + + if !wait { + v.Info("Tests enqueued. Use 'sfdc apex test-status %s' to check results.", jobID) + return nil + } + + // Poll for completion + v.Info("Waiting for tests to complete...") + + for { + job, err := client.GetAsyncJobStatus(ctx, jobID) + if err != nil { + return fmt.Errorf("failed to get job status: %w", err) + } + + switch job.Status { + case "Completed", "Aborted", "Failed": + return displayTestResults(ctx, client, opts, jobID, methodName) + case "Queued", "Processing", "Preparing", "Holding": + time.Sleep(2 * time.Second) + default: + return fmt.Errorf("unexpected job status: %s", job.Status) + } + } +} + +func displayTestResults(ctx context.Context, client *tooling.Client, opts *root.Options, jobID, filterMethod string) error { + results, err := client.GetTestResults(ctx, jobID) + if err != nil { + return fmt.Errorf("failed to get test results: %w", err) + } + + // Filter by method if specified + if filterMethod != "" { + filtered := make([]tooling.ApexTestResult, 0) + for _, r := range results { + if r.MethodName == filterMethod { + filtered = append(filtered, r) + } + } + results = filtered + } + + v := opts.View() + + if len(results) == 0 { + v.Info("No test results found") + return nil + } + + if opts.Output == "json" { + return v.JSON(results) + } + + headers := []string{"Class", "Method", "Outcome", "Time (ms)", "Message"} + rows := make([][]string, 0, len(results)) + + passCount := 0 + failCount := 0 + totalTime := 0 + + for _, r := range results { + message := r.Message + if len(message) > 50 { + message = message[:50] + "..." + } + + rows = append(rows, []string{ + r.ClassName, + r.MethodName, + r.Outcome, + fmt.Sprintf("%d", r.RunTime), + message, + }) + + totalTime += r.RunTime + switch r.Outcome { + case "Pass": + passCount++ + case "Fail", "CompileFail": + failCount++ + } + } + + if err := v.Table(headers, rows); err != nil { + return err + } + + // Summary + var summaryParts []string + summaryParts = append(summaryParts, fmt.Sprintf("%d passed", passCount)) + if failCount > 0 { + summaryParts = append(summaryParts, fmt.Sprintf("%d failed", failCount)) + } + summaryParts = append(summaryParts, fmt.Sprintf("%dms total", totalTime)) + + v.Info("\n%s", strings.Join(summaryParts, ", ")) + + // Show failure details + for _, r := range results { + if r.Outcome == "Fail" || r.Outcome == "CompileFail" { + fmt.Fprintf(opts.Stderr, "\n%s.%s:\n", r.ClassName, r.MethodName) + fmt.Fprintf(opts.Stderr, " %s\n", r.Message) + if r.StackTrace != "" { + fmt.Fprintf(opts.Stderr, " Stack trace:\n") + for _, line := range strings.Split(r.StackTrace, "\n") { + fmt.Fprintf(opts.Stderr, " %s\n", line) + } + } + } + } + + if failCount > 0 { + return fmt.Errorf("%d test(s) failed", failCount) + } + + return nil +} diff --git a/internal/cmd/coveragecmd/coverage.go b/internal/cmd/coveragecmd/coverage.go new file mode 100644 index 0000000..00b8b96 --- /dev/null +++ b/internal/cmd/coveragecmd/coverage.go @@ -0,0 +1,150 @@ +// Package coveragecmd provides commands for code coverage operations. +package coveragecmd + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +// Register registers the coverage command with the root command. +func Register(parent *cobra.Command, opts *root.Options) { + parent.AddCommand(NewCommand(opts)) +} + +// NewCommand creates the coverage command. +func NewCommand(opts *root.Options) *cobra.Command { + var ( + className string + minCover int + ) + + cmd := &cobra.Command{ + Use: "coverage", + Short: "Show code coverage", + Long: `Show Apex code coverage for the org. + +Examples: + sfdc coverage # Show all coverage + sfdc coverage --class MyController # Show coverage for specific class + sfdc coverage --min 75 # Fail if overall coverage < 75% + sfdc coverage -o json # Output as JSON`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return runCoverage(cmd.Context(), opts, className, minCover) + }, + } + + cmd.Flags().StringVar(&className, "class", "", "Show coverage for specific class") + cmd.Flags().IntVar(&minCover, "min", 0, "Minimum coverage percentage (exit 1 if below)") + + return cmd +} + +func runCoverage(ctx context.Context, opts *root.Options, className string, minCover int) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + v := opts.View() + + // If specific class requested + if className != "" { + cov, err := client.GetCodeCoverageForClass(ctx, className) + if err != nil { + return fmt.Errorf("failed to get coverage: %w", err) + } + + if opts.Output == "json" { + return v.JSON(cov) + } + + total := cov.NumLinesCovered + cov.NumLinesUncovered + pct := 0.0 + if total > 0 { + pct = float64(cov.NumLinesCovered) / float64(total) * 100 + } + + headers := []string{"Class", "Lines Covered", "Lines Uncovered", "Coverage %"} + rows := [][]string{ + { + cov.ApexClassOrTrigger.Name, + fmt.Sprintf("%d", cov.NumLinesCovered), + fmt.Sprintf("%d", cov.NumLinesUncovered), + fmt.Sprintf("%.1f%%", pct), + }, + } + + if err := v.Table(headers, rows); err != nil { + return err + } + + if minCover > 0 && int(pct) < minCover { + return fmt.Errorf("coverage %.1f%% is below minimum %d%%", pct, minCover) + } + + return nil + } + + // Get all coverage + coverage, err := client.GetCodeCoverage(ctx) + if err != nil { + return fmt.Errorf("failed to get coverage: %w", err) + } + + if len(coverage) == 0 { + v.Info("No code coverage data found") + return nil + } + + if opts.Output == "json" { + return v.JSON(coverage) + } + + headers := []string{"Class/Trigger", "Lines Covered", "Lines Uncovered", "Coverage %"} + rows := make([][]string, 0, len(coverage)) + + totalCovered := 0 + totalUncovered := 0 + + for _, cov := range coverage { + total := cov.NumLinesCovered + cov.NumLinesUncovered + pct := 0.0 + if total > 0 { + pct = float64(cov.NumLinesCovered) / float64(total) * 100 + } + + totalCovered += cov.NumLinesCovered + totalUncovered += cov.NumLinesUncovered + + rows = append(rows, []string{ + cov.ApexClassOrTrigger.Name, + fmt.Sprintf("%d", cov.NumLinesCovered), + fmt.Sprintf("%d", cov.NumLinesUncovered), + fmt.Sprintf("%.1f%%", pct), + }) + } + + if err := v.Table(headers, rows); err != nil { + return err + } + + // Calculate and display overall coverage + overallTotal := totalCovered + totalUncovered + overallPct := 0.0 + if overallTotal > 0 { + overallPct = float64(totalCovered) / float64(overallTotal) * 100 + } + + v.Info("\nOverall: %d/%d lines covered (%.1f%%)", totalCovered, overallTotal, overallPct) + + if minCover > 0 && int(overallPct) < minCover { + return fmt.Errorf("overall coverage %.1f%% is below minimum %d%%", overallPct, minCover) + } + + return nil +} diff --git a/internal/cmd/coveragecmd/coverage_test.go b/internal/cmd/coveragecmd/coverage_test.go new file mode 100644 index 0000000..04e1e8c --- /dev/null +++ b/internal/cmd/coveragecmd/coverage_test.go @@ -0,0 +1,328 @@ +package coveragecmd + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/open-cli-collective/salesforce-cli/api/tooling" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func TestCoverageList(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 2, + Done: true, + Records: []tooling.Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(80), + "NumLinesUncovered": float64(20), + }, + { + "Id": "500000000000002", + "ApexClassOrTriggerId": "01p000000000002", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyHelper"}, + "NumLinesCovered": float64(50), + "NumLinesUncovered": float64(50), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "MyController") + assert.Contains(t, output, "MyHelper") + assert.Contains(t, output, "80.0%") + assert.Contains(t, output, "50.0%") + assert.Contains(t, output, "Overall") +} + +func TestCoverageForClass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.RawQuery, "MyController") + + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(80), + "NumLinesUncovered": float64(20), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"--class", "MyController"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "MyController") + assert.Contains(t, output, "80.0%") +} + +func TestCoverageMinimumPass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(80), + "NumLinesUncovered": float64(20), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"--min", "75"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) // 80% > 75%, should pass +} + +func TestCoverageMinimumFail(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(50), + "NumLinesUncovered": float64(50), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"--min", "75"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + assert.Error(t, err) // 50% < 75%, should fail + assert.Contains(t, err.Error(), "below minimum") +} + +func TestCoverageJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "500000000000001", + "ApexClassOrTriggerId": "01p000000000001", + "ApexClassOrTrigger": map[string]interface{}{"Name": "MyController"}, + "NumLinesCovered": float64(80), + "NumLinesUncovered": float64(20), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "json", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + // Should be valid JSON + var result []tooling.ApexCodeCoverageAggregate + err = json.Unmarshal([]byte(output), &result) + require.NoError(t, err) + assert.Len(t, result, 1) +} + +func TestCoverageEmpty(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 0, + Done: true, + Records: []tooling.Record{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "No code coverage data found") +} + +func TestCoverageClassNotFound(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.RawQuery, "NonExistent") { + response := tooling.QueryResult{ + TotalSize: 0, + Done: true, + Records: []tooling.Record{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + } + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"--class", "NonExistent"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "no coverage data found") +} diff --git a/internal/cmd/logcmd/get.go b/internal/cmd/logcmd/get.go new file mode 100644 index 0000000..44548d7 --- /dev/null +++ b/internal/cmd/logcmd/get.go @@ -0,0 +1,57 @@ +package logcmd + +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newGetCommand(opts *root.Options) *cobra.Command { + var outputFile string + + cmd := &cobra.Command{ + Use: "get ", + Short: "Get debug log content", + Long: `Get the content of a debug log. + +Examples: + sfdc log get 07L1x000000ABCD # Display log content + sfdc log get 07L1x000000ABCD --output debug.log # Save to file`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runLogGet(cmd.Context(), opts, args[0], outputFile) + }, + } + + cmd.Flags().StringVarP(&outputFile, "output", "f", "", "Output file path") + + return cmd +} + +func runLogGet(ctx context.Context, opts *root.Options, logID, outputFile string) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + body, err := client.GetApexLogBody(ctx, logID) + if err != nil { + return fmt.Errorf("failed to get log content: %w", err) + } + + if outputFile != "" { + if err := os.WriteFile(outputFile, []byte(body), 0644); err != nil { + return fmt.Errorf("failed to write output file: %w", err) + } + v := opts.View() + v.Info("Saved log to %s", outputFile) + return nil + } + + fmt.Fprintln(opts.Stdout, body) + return nil +} diff --git a/internal/cmd/logcmd/list.go b/internal/cmd/logcmd/list.go new file mode 100644 index 0000000..bf4a43f --- /dev/null +++ b/internal/cmd/logcmd/list.go @@ -0,0 +1,97 @@ +package logcmd + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newListCommand(opts *root.Options) *cobra.Command { + var ( + userID string + limit int + ) + + cmd := &cobra.Command{ + Use: "list", + Short: "List debug logs", + Long: `List debug logs from the org. + +Examples: + sfdc log list # List recent logs + sfdc log list --limit 20 # List last 20 logs + sfdc log list --user 005xxx # Filter by user ID + sfdc log list -o json # Output as JSON`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return runLogList(cmd.Context(), opts, userID, limit) + }, + } + + cmd.Flags().StringVar(&userID, "user", "", "Filter by user ID") + cmd.Flags().IntVar(&limit, "limit", 10, "Maximum number of logs to return") + + return cmd +} + +func runLogList(ctx context.Context, opts *root.Options, userID string, limit int) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + logs, err := client.ListApexLogs(ctx, userID, limit) + if err != nil { + return fmt.Errorf("failed to list logs: %w", err) + } + + v := opts.View() + + if len(logs) == 0 { + v.Info("No debug logs found") + return nil + } + + if opts.Output == "json" { + return v.JSON(logs) + } + + headers := []string{"ID", "Operation", "Status", "Size", "Duration", "Start Time"} + rows := make([][]string, 0, len(logs)) + for _, log := range logs { + rows = append(rows, []string{ + log.ID, + truncate(log.Operation, 30), + log.Status, + formatSize(log.LogLength), + fmt.Sprintf("%dms", log.DurationMS), + log.StartTime.Format("2006-01-02 15:04:05"), + }) + } + + if err := v.Table(headers, rows); err != nil { + return err + } + v.Info("\n%d log(s)", len(logs)) + return nil +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} + +func formatSize(bytes int) string { + if bytes < 1024 { + return fmt.Sprintf("%d B", bytes) + } + if bytes < 1024*1024 { + return fmt.Sprintf("%.1f KB", float64(bytes)/1024) + } + return fmt.Sprintf("%.1f MB", float64(bytes)/(1024*1024)) +} diff --git a/internal/cmd/logcmd/log.go b/internal/cmd/logcmd/log.go new file mode 100644 index 0000000..db6c13f --- /dev/null +++ b/internal/cmd/logcmd/log.go @@ -0,0 +1,34 @@ +// Package logcmd provides commands for debug log operations. +package logcmd + +import ( + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +// Register registers the log command with the root command. +func Register(parent *cobra.Command, opts *root.Options) { + parent.AddCommand(NewCommand(opts)) +} + +// NewCommand creates the log command. +func NewCommand(opts *root.Options) *cobra.Command { + cmd := &cobra.Command{ + Use: "log", + Short: "Debug log operations", + Long: `Manage and view Salesforce debug logs. + +Examples: + sfdc log list # List recent debug logs + sfdc log list --limit 20 # List last 20 logs + sfdc log get 07L1x000000ABCD # Get log content + sfdc log tail # Stream new logs`, + } + + cmd.AddCommand(newListCommand(opts)) + cmd.AddCommand(newGetCommand(opts)) + cmd.AddCommand(newTailCommand(opts)) + + return cmd +} diff --git a/internal/cmd/logcmd/log_test.go b/internal/cmd/logcmd/log_test.go new file mode 100644 index 0000000..44b0be4 --- /dev/null +++ b/internal/cmd/logcmd/log_test.go @@ -0,0 +1,294 @@ +package logcmd + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/open-cli-collective/salesforce-cli/api/tooling" + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func TestLogList(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 2, + Done: true, + Records: []tooling.Record{ + { + "Id": "07L000000000001", + "LogUserId": "005000000000001", + "Operation": "/aura", + "Request": "API", + "Status": "Success", + "LogLength": float64(5000), + "DurationMilliseconds": float64(150), + "StartTime": "2024-01-15T10:30:00.000+0000", + "Location": "MonitoringService", + }, + { + "Id": "07L000000000002", + "LogUserId": "005000000000001", + "Operation": "ApexTrigger", + "Request": "API", + "Status": "Success", + "LogLength": float64(2500), + "DurationMilliseconds": float64(75), + "StartTime": "2024-01-15T10:25:00.000+0000", + "Location": "MonitoringService", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "07L000000000001") + assert.Contains(t, output, "/aura") + assert.Contains(t, output, "2 log(s)") +} + +func TestLogListWithLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify limit is in query + assert.Contains(t, r.URL.RawQuery, "LIMIT+5") + + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "07L000000000001", + "LogUserId": "005000000000001", + "Operation": "/aura", + "Request": "API", + "Status": "Success", + "LogLength": float64(5000), + "DurationMilliseconds": float64(150), + "StartTime": "2024-01-15T10:30:00.000+0000", + "Location": "MonitoringService", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list", "--limit", "5"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) +} + +func TestLogGet(t *testing.T) { + logContent := "DEBUG|Hello World\nUSER_DEBUG|Test message" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/Body") { + w.Write([]byte(logContent)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "plain", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"get", "07L000000000001"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "DEBUG|Hello World") + assert.Contains(t, output, "USER_DEBUG|Test message") +} + +func TestLogListJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "07L000000000001", + "LogUserId": "005000000000001", + "Operation": "/aura", + "Request": "API", + "Status": "Success", + "LogLength": float64(5000), + "DurationMilliseconds": float64(150), + "StartTime": "2024-01-15T10:30:00.000+0000", + "Location": "MonitoringService", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "json", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + // Should be valid JSON + var result []tooling.ApexLog + err = json.Unmarshal([]byte(output), &result) + require.NoError(t, err) + assert.Len(t, result, 1) +} + +func TestLogListEmpty(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.QueryResult{ + TotalSize: 0, + Done: true, + Records: []tooling.Record{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"list"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "No debug logs found") +} + +func TestFormatSize(t *testing.T) { + tests := []struct { + bytes int + want string + }{ + {500, "500 B"}, + {1024, "1.0 KB"}, + {1536, "1.5 KB"}, + {1048576, "1.0 MB"}, + {2097152, "2.0 MB"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := formatSize(tt.bytes) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + input string + maxLen int + want string + }{ + {"short", 10, "short"}, + {"exactly10!", 10, "exactly10!"}, + {"this is a long string", 10, "this is..."}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/cmd/logcmd/tail.go b/internal/cmd/logcmd/tail.go new file mode 100644 index 0000000..4bc9455 --- /dev/null +++ b/internal/cmd/logcmd/tail.go @@ -0,0 +1,104 @@ +package logcmd + +import ( + "context" + "fmt" + "time" + + "github.com/spf13/cobra" + + "github.com/open-cli-collective/salesforce-cli/internal/cmd/root" +) + +func newTailCommand(opts *root.Options) *cobra.Command { + var ( + userID string + interval int + ) + + cmd := &cobra.Command{ + Use: "tail", + Short: "Stream new debug logs", + Long: `Continuously poll for new debug logs and display them. + +Press Ctrl+C to stop. + +Examples: + sfdc log tail # Stream all new logs + sfdc log tail --user 005xxx # Filter by user ID + sfdc log tail --interval 5 # Poll every 5 seconds`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return runLogTail(cmd.Context(), opts, userID, interval) + }, + } + + cmd.Flags().StringVar(&userID, "user", "", "Filter by user ID") + cmd.Flags().IntVar(&interval, "interval", 3, "Polling interval in seconds") + + return cmd +} + +func runLogTail(ctx context.Context, opts *root.Options, userID string, interval int) error { + client, err := opts.ToolingClient() + if err != nil { + return fmt.Errorf("failed to create tooling client: %w", err) + } + + v := opts.View() + v.Info("Tailing debug logs... (Ctrl+C to stop)") + + seenLogs := make(map[string]bool) + + // Get initial logs to establish baseline + initialLogs, err := client.ListApexLogs(ctx, userID, 10) + if err != nil { + return fmt.Errorf("failed to get initial logs: %w", err) + } + for _, log := range initialLogs { + seenLogs[log.ID] = true + } + + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + v.Info("\nStopped") + return nil + case <-ticker.C: + logs, err := client.ListApexLogs(ctx, userID, 10) + if err != nil { + v.Error("Failed to poll logs: %v", err) + continue + } + + // Process new logs (in reverse order to show oldest first) + for i := len(logs) - 1; i >= 0; i-- { + log := logs[i] + if seenLogs[log.ID] { + continue + } + seenLogs[log.ID] = true + + // Print log summary + fmt.Fprintf(opts.Stdout, "\n[%s] %s (%s, %s)\n", + log.StartTime.Format("15:04:05"), + log.Operation, + log.Status, + formatSize(log.LogLength), + ) + + // Fetch and print log body + body, err := client.GetApexLogBody(ctx, log.ID) + if err != nil { + v.Error("Failed to get log body: %v", err) + continue + } + fmt.Fprintln(opts.Stdout, body) + fmt.Fprintln(opts.Stdout, "---") + } + } + } +} diff --git a/internal/cmd/root/root.go b/internal/cmd/root/root.go index e538005..71ba5f3 100644 --- a/internal/cmd/root/root.go +++ b/internal/cmd/root/root.go @@ -11,6 +11,7 @@ import ( "github.com/open-cli-collective/salesforce-cli/api" "github.com/open-cli-collective/salesforce-cli/api/bulk" + "github.com/open-cli-collective/salesforce-cli/api/tooling" "github.com/open-cli-collective/salesforce-cli/internal/auth" "github.com/open-cli-collective/salesforce-cli/internal/config" "github.com/open-cli-collective/salesforce-cli/internal/version" @@ -31,6 +32,8 @@ type Options struct { testClient *api.Client // testBulkClient is used for testing; if set, BulkClient() returns this instead testBulkClient *bulk.Client + // testToolingClient is used for testing; if set, ToolingClient() returns this instead + testToolingClient *tooling.Client } // View returns a configured View instance @@ -102,6 +105,29 @@ func (o *Options) SetBulkClient(client *bulk.Client) { o.testBulkClient = client } +// ToolingClient creates a new Tooling API client from config +func (o *Options) ToolingClient() (*tooling.Client, error) { + if o.testToolingClient != nil { + return o.testToolingClient, nil + } + + instanceURL, httpClient, err := o.loadClientConfig() + if err != nil { + return nil, err + } + + return tooling.New(tooling.ClientConfig{ + InstanceURL: instanceURL, + HTTPClient: httpClient, + APIVersion: o.APIVersion, + }) +} + +// SetToolingClient sets a test tooling client (for testing only) +func (o *Options) SetToolingClient(client *tooling.Client) { + o.testToolingClient = client +} + // NewCmd creates the root command and returns the options struct func NewCmd() (*cobra.Command, *Options) { opts := &Options{ From fb70859d811a7be596c7f6b147af6df4b7480587 Mon Sep 17 00:00:00 2001 From: Rian Stockbower Date: Sun, 1 Feb 2026 06:51:44 -0500 Subject: [PATCH 2/2] test: add missing tests for trigger get, runtime errors, and tail --- internal/cmd/apexcmd/apex_test.go | 88 +++++++++++++++++++++++++++++++ internal/cmd/logcmd/log_test.go | 47 +++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/internal/cmd/apexcmd/apex_test.go b/internal/cmd/apexcmd/apex_test.go index d108ce1..57d4cbb 100644 --- a/internal/cmd/apexcmd/apex_test.go +++ b/internal/cmd/apexcmd/apex_test.go @@ -327,6 +327,94 @@ func TestApexTestNoWait(t *testing.T) { assert.Contains(t, output, "Tests enqueued") } +func TestApexGetTrigger(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.RawQuery, "ApexTrigger") + + response := tooling.QueryResult{ + TotalSize: 1, + Done: true, + Records: []tooling.Record{ + { + "Id": "01q000000000001", + "Name": "AccountTrigger", + "Body": "trigger AccountTrigger on Account (before insert) { }", + "Status": "Active", + "IsValid": true, + "ApiVersion": float64(62), + "TableEnumOrId": "Account", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "plain", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"get", "AccountTrigger", "--trigger"}) + cmd.SetOut(stdout) + + err = cmd.Execute() + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "trigger AccountTrigger on Account") +} + +func TestApexExecuteRuntimeError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := tooling.ExecuteAnonymousResult{ + Compiled: true, + Success: false, + ExceptionMessage: "System.NullPointerException: Attempt to de-reference a null object", + ExceptionStackTrace: "AnonymousBlock: line 1, column 1", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: stderr, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"execute", "String s; s.length();"}) + cmd.SetOut(stdout) + cmd.SetErr(stderr) + + err = cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "execution failed") + assert.Contains(t, stderr.String(), "NullPointerException") +} + func TestApexListJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := tooling.QueryResult{ diff --git a/internal/cmd/logcmd/log_test.go b/internal/cmd/logcmd/log_test.go index 44b0be4..d77b8ee 100644 --- a/internal/cmd/logcmd/log_test.go +++ b/internal/cmd/logcmd/log_test.go @@ -2,11 +2,13 @@ package logcmd import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -254,6 +256,51 @@ func TestLogListEmpty(t *testing.T) { assert.Contains(t, output, "No debug logs found") } +func TestLogTailContextCancellation(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + response := tooling.QueryResult{ + TotalSize: 0, + Done: true, + Records: []tooling.Record{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := tooling.New(tooling.ClientConfig{ + InstanceURL: server.URL, + HTTPClient: server.Client(), + }) + require.NoError(t, err) + + stdout := &bytes.Buffer{} + opts := &root.Options{ + Output: "table", + Stdout: stdout, + Stderr: &bytes.Buffer{}, + } + opts.SetToolingClient(client) + + cmd := NewCommand(opts) + cmd.SetArgs([]string{"tail", "--interval", "1"}) + cmd.SetOut(stdout) + + // Create a context that will be cancelled quickly + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = cmd.ExecuteContext(ctx) + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, "Tailing debug logs") + // Should have made at least one API call + assert.GreaterOrEqual(t, callCount, 1) +} + func TestFormatSize(t *testing.T) { tests := []struct { bytes int