diff --git a/protocol/types.go b/protocol/types.go index 71e9434..4b29681 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -1,6 +1,7 @@ package protocol import ( + "bytes" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -1693,6 +1694,201 @@ func (w ContainerBuildWorkload) Validate() error { return errors.Join(errs...) } +type ProductCaptureMode string + +const ( + ProductCaptureModeBrowser ProductCaptureMode = "browser" + ProductCaptureModeMetadata ProductCaptureMode = "metadata" +) + +const ( + MaxProductCaptureTimeoutSeconds = 300 + MaxProductCaptureHTMLBytes = 10 << 20 + MaxProductCaptureImageCount = 32 +) + +type ProductCaptureWorkload struct { + URL string `json:"url"` + AllowedHosts []string `json:"allowed_hosts"` + CaptureMode ProductCaptureMode `json:"capture_mode,omitempty"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` + MaxHTMLBytes int64 `json:"max_html_bytes,omitempty"` + MaxImageCount int `json:"max_image_count,omitempty"` + MetadataOnly bool `json:"metadata_only,omitempty"` +} + +func (w ProductCaptureWorkload) Validate() error { + var errs []error + parsed, err := url.Parse(strings.TrimSpace(w.URL)) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + errs = append(errs, errors.New("url must be absolute http(s) URL")) + } else { + if parsed.Scheme != "http" && parsed.Scheme != "https" { + errs = append(errs, errors.New("url must use http or https")) + } + if !productCaptureHostAllowed(parsed.Hostname(), w.AllowedHosts) { + errs = append(errs, fmt.Errorf("url host %q is not listed in allowed_hosts", parsed.Hostname())) + } + } + if len(w.AllowedHosts) == 0 { + errs = append(errs, errors.New("allowed_hosts is required")) + } + for i, host := range w.AllowedHosts { + if err := validateNetworkHostname(host); err != nil { + errs = append(errs, fmt.Errorf("allowed_hosts[%d]: %w", i, err)) + } + } + switch w.CaptureMode { + case "", ProductCaptureModeBrowser, ProductCaptureModeMetadata: + default: + errs = append(errs, fmt.Errorf("capture_mode %q is unsupported", w.CaptureMode)) + } + if w.TimeoutSeconds < 0 || w.TimeoutSeconds > MaxProductCaptureTimeoutSeconds { + errs = append(errs, fmt.Errorf("timeout_seconds must be between 0 and %d", MaxProductCaptureTimeoutSeconds)) + } + if w.MaxHTMLBytes < 0 || w.MaxHTMLBytes > MaxProductCaptureHTMLBytes { + errs = append(errs, fmt.Errorf("max_html_bytes must be between 0 and %d", MaxProductCaptureHTMLBytes)) + } + if w.MaxImageCount < 0 || w.MaxImageCount > MaxProductCaptureImageCount { + errs = append(errs, fmt.Errorf("max_image_count must be between 0 and %d", MaxProductCaptureImageCount)) + } + return errors.Join(errs...) +} + +func productCaptureHostAllowed(host string, allowed []string) bool { + host = strings.ToLower(strings.TrimSpace(host)) + if host == "" { + return false + } + for _, candidate := range allowed { + if strings.EqualFold(host, strings.TrimSpace(candidate)) { + return true + } + } + return false +} + +type ProviderWorkload struct { + ProviderConfig ProviderConfig `json:"provider_config"` + Operation string `json:"operation"` + ImageRef string `json:"image_ref,omitempty"` + ComponentRef string `json:"component_ref,omitempty"` + ComponentDigest string `json:"component_digest,omitempty"` + ABI string `json:"abi,omitempty"` + Input json.RawMessage `json:"input"` +} + +func (w ProviderWorkload) Validate() error { + var errs []error + if w.ProviderConfig == (ProviderConfig{}) { + errs = append(errs, errors.New("provider_config is required")) + } else if err := w.ProviderConfig.Validate(); err != nil { + errs = append(errs, fmt.Errorf("provider_config: %w", err)) + } + if err := validateIdentifier("operation", w.Operation); err != nil { + errs = append(errs, err) + } + imageRefSet := strings.TrimSpace(w.ImageRef) != "" + componentRefSet := strings.TrimSpace(w.ComponentRef) != "" + if imageRefSet && componentRefSet { + errs = append(errs, errors.New("image_ref and component_ref are mutually exclusive")) + } + if !imageRefSet && !componentRefSet { + errs = append(errs, errors.New("image_ref or component_ref is required")) + } else if imageRefSet && (strings.TrimSpace(w.ImageRef) != w.ImageRef || strings.ContainsAny(w.ImageRef, "\t\r\n \x00")) { + errs = append(errs, errors.New("image_ref must not contain whitespace or NUL")) + } else if imageRefSet { + if _, digest, ok := strings.Cut(w.ImageRef, "@"); !ok || !validSHA256Ref(digest) { + errs = append(errs, errors.New("image_ref must be digest-pinned with @sha256:<64 hex>")) + } + } + if componentRefSet { + if err := validateProviderComponentRef("component_ref", w.ComponentRef); err != nil { + errs = append(errs, err) + } + if strings.HasPrefix(w.ComponentRef, "provider://") { + ref := strings.TrimPrefix(w.ComponentRef, "provider://") + pluginID, _, _ := strings.Cut(ref, "/") + if pluginID != "" && w.ProviderConfig.PluginID != "" && pluginID != w.ProviderConfig.PluginID { + errs = append(errs, fmt.Errorf("component_ref provider plugin %q must match provider_config.plugin_id %q", pluginID, w.ProviderConfig.PluginID)) + } + } + if !validSHA256Ref(w.ComponentDigest) { + errs = append(errs, errors.New("component_digest must be sha256:<64 hex chars>")) + } + if w.ABI != "" { + if err := validateIdentifier("abi", w.ABI); err != nil { + errs = append(errs, err) + } + } + } + if err := validateJSONInput(w.Input); err != nil { + errs = append(errs, fmt.Errorf("input: %w", err)) + } + return errors.Join(errs...) +} + +type WASMWorkload struct { + ComponentRef string `json:"component_ref"` + ComponentDigest string `json:"component_digest"` + ABI string `json:"abi"` + Operation string `json:"operation"` + Input json.RawMessage `json:"input"` +} + +func (w WASMWorkload) Validate() error { + var errs []error + if err := validateComponentRef("component_ref", w.ComponentRef); err != nil { + errs = append(errs, err) + } + if !validSHA256Ref(w.ComponentDigest) { + errs = append(errs, errors.New("component_digest must be sha256:<64 hex chars>")) + } + if err := validateIdentifier("abi", w.ABI); err != nil { + errs = append(errs, err) + } + if err := validateIdentifier("operation", w.Operation); err != nil { + errs = append(errs, err) + } + if err := validateJSONInput(w.Input); err != nil { + errs = append(errs, fmt.Errorf("input: %w", err)) + } + return errors.Join(errs...) +} + +func validateProviderComponentRef(name, value string) error { + value = strings.TrimSpace(value) + if value == "" { + return fmt.Errorf("%s is required", name) + } + if strings.ContainsAny(value, " \t\r\n\x00") { + return fmt.Errorf("%s must not contain whitespace or NUL", name) + } + if strings.HasPrefix(value, "provider://") { + ref := strings.TrimPrefix(value, "provider://") + pluginID, componentPath, ok := strings.Cut(ref, "/") + if !ok || strings.TrimSpace(pluginID) == "" || strings.TrimSpace(componentPath) == "" { + return fmt.Errorf("%s must include provider plugin and component path", name) + } + return nil + } + return fmt.Errorf("%s must use provider:// plugin component ref", name) +} + +func validateJSONInput(input json.RawMessage) error { + trimmed := bytes.TrimSpace(input) + if len(trimmed) == 0 { + return errors.New("is required") + } + if !json.Valid(trimmed) { + return errors.New("must be valid JSON") + } + if trimmed[0] != '{' && trimmed[0] != '[' { + return errors.New("must be a JSON object or array") + } + return nil +} + func (r RuntimeExecutionRequest) Validate() error { var errs []error if r.ProtocolVersion != Version { diff --git a/protocol/types_test.go b/protocol/types_test.go index 8b76d56..cac55f6 100644 --- a/protocol/types_test.go +++ b/protocol/types_test.go @@ -443,6 +443,68 @@ func TestContainerBuildWorkloadContractUsesRegistryRefs(t *testing.T) { } } +func TestWASMRuntimePayloadContracts(t *testing.T) { + digest := "sha256:" + strings.Repeat("a", 64) + wasm := protocol.WASMWorkload{ + ComponentRef: "artifact://edge/echo.wasm", + ComponentDigest: digest, + ABI: "wasm-export-i32-v1", + Operation: "handle_request", + Input: json.RawMessage(`{"path":"/index.html"}`), + } + if err := wasm.Validate(); err != nil { + t.Fatalf("valid wasm workload rejected: %v", err) + } + wasm.ComponentRef = "file:///tmp/evil.wasm" + if err := wasm.Validate(); err == nil || !strings.Contains(err.Error(), "component_ref") { + t.Fatalf("host ref accepted: %v", err) + } + + provider := protocol.ProviderWorkload{ + ProviderConfig: protocol.ProviderConfig{ + PluginID: "workflow-plugin-product-capture", + ProviderID: "browser", + ContractID: "product-capture.browser.v1", + Version: "v0.1.0", + ConfigRef: "config://browser", + }, + Operation: "capture", + ComponentRef: "provider://workflow-plugin-product-capture/browser.wasm", + ComponentDigest: digest, + ABI: "wasm-export-i32-v1", + Input: json.RawMessage(`{"url":"https://example.invalid"}`), + } + if err := provider.Validate(); err != nil { + t.Fatalf("valid provider workload rejected: %v", err) + } + provider.ComponentRef = "provider://other-plugin/browser.wasm" + if err := provider.Validate(); err == nil || !strings.Contains(err.Error(), "provider plugin") { + t.Fatalf("mismatched provider component accepted: %v", err) + } +} + +func TestProductCaptureWorkloadValidation(t *testing.T) { + workload := protocol.ProductCaptureWorkload{ + URL: "https://www.amazon.com/dp/B08H75RTZ8", + AllowedHosts: []string{"www.amazon.com", "amazon.com"}, + CaptureMode: protocol.ProductCaptureModeBrowser, + TimeoutSeconds: 30, + MaxHTMLBytes: 1 << 20, + MaxImageCount: 8, + } + if err := workload.Validate(); err != nil { + t.Fatalf("valid product-capture workload rejected: %v", err) + } + workload.URL = "https://evil.example/item" + if err := workload.Validate(); err == nil || !strings.Contains(err.Error(), "allowed_hosts") { + t.Fatalf("disallowed host accepted: %v", err) + } + workload.URL = "file:///tmp/item" + if err := workload.Validate(); err == nil || !strings.Contains(err.Error(), "http") { + t.Fatalf("non-http URL accepted: %v", err) + } +} + func TestExecutorRefValidateForProofRequiresDigestsForNonNativeExecutors(t *testing.T) { ref := protocol.ExecutorRef{ Provider: "sandboxed-command",