Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions protocol/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package protocol

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
Expand Down Expand Up @@ -1693,6 +1694,201 @@ func (w ContainerBuildWorkload) Validate() error {
return errors.Join(errs...)
}

type ProductCaptureMode string

const (
ProductCaptureModeBrowser ProductCaptureMode = "browser"
ProductCaptureModeMetadata ProductCaptureMode = "metadata"
)

const (
MaxProductCaptureTimeoutSeconds = 300
MaxProductCaptureHTMLBytes = 10 << 20
MaxProductCaptureImageCount = 32
)

type ProductCaptureWorkload struct {
URL string `json:"url"`
AllowedHosts []string `json:"allowed_hosts"`
CaptureMode ProductCaptureMode `json:"capture_mode,omitempty"`
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
MaxHTMLBytes int64 `json:"max_html_bytes,omitempty"`
MaxImageCount int `json:"max_image_count,omitempty"`
MetadataOnly bool `json:"metadata_only,omitempty"`
}

func (w ProductCaptureWorkload) Validate() error {
var errs []error
parsed, err := url.Parse(strings.TrimSpace(w.URL))
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
errs = append(errs, errors.New("url must be absolute http(s) URL"))
} else {
if parsed.Scheme != "http" && parsed.Scheme != "https" {
errs = append(errs, errors.New("url must use http or https"))
}
if !productCaptureHostAllowed(parsed.Hostname(), w.AllowedHosts) {
errs = append(errs, fmt.Errorf("url host %q is not listed in allowed_hosts", parsed.Hostname()))
}
}
if len(w.AllowedHosts) == 0 {
errs = append(errs, errors.New("allowed_hosts is required"))
}
for i, host := range w.AllowedHosts {
if err := validateNetworkHostname(host); err != nil {
errs = append(errs, fmt.Errorf("allowed_hosts[%d]: %w", i, err))
}
}
switch w.CaptureMode {
case "", ProductCaptureModeBrowser, ProductCaptureModeMetadata:
default:
errs = append(errs, fmt.Errorf("capture_mode %q is unsupported", w.CaptureMode))
}
if w.TimeoutSeconds < 0 || w.TimeoutSeconds > MaxProductCaptureTimeoutSeconds {
errs = append(errs, fmt.Errorf("timeout_seconds must be between 0 and %d", MaxProductCaptureTimeoutSeconds))
}
if w.MaxHTMLBytes < 0 || w.MaxHTMLBytes > MaxProductCaptureHTMLBytes {
errs = append(errs, fmt.Errorf("max_html_bytes must be between 0 and %d", MaxProductCaptureHTMLBytes))
}
if w.MaxImageCount < 0 || w.MaxImageCount > MaxProductCaptureImageCount {
errs = append(errs, fmt.Errorf("max_image_count must be between 0 and %d", MaxProductCaptureImageCount))
}
return errors.Join(errs...)
}

func productCaptureHostAllowed(host string, allowed []string) bool {
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
return false
}
for _, candidate := range allowed {
if strings.EqualFold(host, strings.TrimSpace(candidate)) {
return true
}
}
return false
}

type ProviderWorkload struct {
ProviderConfig ProviderConfig `json:"provider_config"`
Operation string `json:"operation"`
ImageRef string `json:"image_ref,omitempty"`
ComponentRef string `json:"component_ref,omitempty"`
ComponentDigest string `json:"component_digest,omitempty"`
ABI string `json:"abi,omitempty"`
Input json.RawMessage `json:"input"`
}

func (w ProviderWorkload) Validate() error {
var errs []error
if w.ProviderConfig == (ProviderConfig{}) {
errs = append(errs, errors.New("provider_config is required"))
} else if err := w.ProviderConfig.Validate(); err != nil {
errs = append(errs, fmt.Errorf("provider_config: %w", err))
}
if err := validateIdentifier("operation", w.Operation); err != nil {
errs = append(errs, err)
}
imageRefSet := strings.TrimSpace(w.ImageRef) != ""
componentRefSet := strings.TrimSpace(w.ComponentRef) != ""
if imageRefSet && componentRefSet {
errs = append(errs, errors.New("image_ref and component_ref are mutually exclusive"))
}
if !imageRefSet && !componentRefSet {
errs = append(errs, errors.New("image_ref or component_ref is required"))
} else if imageRefSet && (strings.TrimSpace(w.ImageRef) != w.ImageRef || strings.ContainsAny(w.ImageRef, "\t\r\n \x00")) {
errs = append(errs, errors.New("image_ref must not contain whitespace or NUL"))
} else if imageRefSet {
if _, digest, ok := strings.Cut(w.ImageRef, "@"); !ok || !validSHA256Ref(digest) {
errs = append(errs, errors.New("image_ref must be digest-pinned with @sha256:<64 hex>"))
}
}
if componentRefSet {
if err := validateProviderComponentRef("component_ref", w.ComponentRef); err != nil {
errs = append(errs, err)
}
if strings.HasPrefix(w.ComponentRef, "provider://") {
ref := strings.TrimPrefix(w.ComponentRef, "provider://")
pluginID, _, _ := strings.Cut(ref, "/")
if pluginID != "" && w.ProviderConfig.PluginID != "" && pluginID != w.ProviderConfig.PluginID {
errs = append(errs, fmt.Errorf("component_ref provider plugin %q must match provider_config.plugin_id %q", pluginID, w.ProviderConfig.PluginID))
}
}
if !validSHA256Ref(w.ComponentDigest) {
errs = append(errs, errors.New("component_digest must be sha256:<64 hex chars>"))
}
if w.ABI != "" {
if err := validateIdentifier("abi", w.ABI); err != nil {
errs = append(errs, err)
}
}
}
if err := validateJSONInput(w.Input); err != nil {
errs = append(errs, fmt.Errorf("input: %w", err))
}
return errors.Join(errs...)
}

type WASMWorkload struct {
ComponentRef string `json:"component_ref"`
ComponentDigest string `json:"component_digest"`
ABI string `json:"abi"`
Operation string `json:"operation"`
Input json.RawMessage `json:"input"`
}

func (w WASMWorkload) Validate() error {
var errs []error
if err := validateComponentRef("component_ref", w.ComponentRef); err != nil {
errs = append(errs, err)
}
if !validSHA256Ref(w.ComponentDigest) {
errs = append(errs, errors.New("component_digest must be sha256:<64 hex chars>"))
}
if err := validateIdentifier("abi", w.ABI); err != nil {
errs = append(errs, err)
}
if err := validateIdentifier("operation", w.Operation); err != nil {
errs = append(errs, err)
}
if err := validateJSONInput(w.Input); err != nil {
errs = append(errs, fmt.Errorf("input: %w", err))
}
return errors.Join(errs...)
}

func validateProviderComponentRef(name, value string) error {
value = strings.TrimSpace(value)
if value == "" {
return fmt.Errorf("%s is required", name)
}
if strings.ContainsAny(value, " \t\r\n\x00") {
return fmt.Errorf("%s must not contain whitespace or NUL", name)
}
if strings.HasPrefix(value, "provider://") {
ref := strings.TrimPrefix(value, "provider://")
pluginID, componentPath, ok := strings.Cut(ref, "/")
if !ok || strings.TrimSpace(pluginID) == "" || strings.TrimSpace(componentPath) == "" {
return fmt.Errorf("%s must include provider plugin and component path", name)
}
return nil
}
return fmt.Errorf("%s must use provider:// plugin component ref", name)
}

func validateJSONInput(input json.RawMessage) error {
trimmed := bytes.TrimSpace(input)
if len(trimmed) == 0 {
return errors.New("is required")
}
if !json.Valid(trimmed) {
return errors.New("must be valid JSON")
}
if trimmed[0] != '{' && trimmed[0] != '[' {
return errors.New("must be a JSON object or array")
}
return nil
}

func (r RuntimeExecutionRequest) Validate() error {
var errs []error
if r.ProtocolVersion != Version {
Expand Down
62 changes: 62 additions & 0 deletions protocol/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,68 @@ func TestContainerBuildWorkloadContractUsesRegistryRefs(t *testing.T) {
}
}

func TestWASMRuntimePayloadContracts(t *testing.T) {
digest := "sha256:" + strings.Repeat("a", 64)
wasm := protocol.WASMWorkload{
ComponentRef: "artifact://edge/echo.wasm",
ComponentDigest: digest,
ABI: "wasm-export-i32-v1",
Operation: "handle_request",
Input: json.RawMessage(`{"path":"/index.html"}`),
}
if err := wasm.Validate(); err != nil {
t.Fatalf("valid wasm workload rejected: %v", err)
}
wasm.ComponentRef = "file:///tmp/evil.wasm"
if err := wasm.Validate(); err == nil || !strings.Contains(err.Error(), "component_ref") {
t.Fatalf("host ref accepted: %v", err)
}

provider := protocol.ProviderWorkload{
ProviderConfig: protocol.ProviderConfig{
PluginID: "workflow-plugin-product-capture",
ProviderID: "browser",
ContractID: "product-capture.browser.v1",
Version: "v0.1.0",
ConfigRef: "config://browser",
},
Operation: "capture",
ComponentRef: "provider://workflow-plugin-product-capture/browser.wasm",
ComponentDigest: digest,
ABI: "wasm-export-i32-v1",
Input: json.RawMessage(`{"url":"https://example.invalid"}`),
}
if err := provider.Validate(); err != nil {
t.Fatalf("valid provider workload rejected: %v", err)
}
provider.ComponentRef = "provider://other-plugin/browser.wasm"
if err := provider.Validate(); err == nil || !strings.Contains(err.Error(), "provider plugin") {
t.Fatalf("mismatched provider component accepted: %v", err)
}
}

func TestProductCaptureWorkloadValidation(t *testing.T) {
workload := protocol.ProductCaptureWorkload{
URL: "https://www.amazon.com/dp/B08H75RTZ8",
AllowedHosts: []string{"www.amazon.com", "amazon.com"},
CaptureMode: protocol.ProductCaptureModeBrowser,
TimeoutSeconds: 30,
MaxHTMLBytes: 1 << 20,
MaxImageCount: 8,
}
if err := workload.Validate(); err != nil {
t.Fatalf("valid product-capture workload rejected: %v", err)
}
workload.URL = "https://evil.example/item"
if err := workload.Validate(); err == nil || !strings.Contains(err.Error(), "allowed_hosts") {
t.Fatalf("disallowed host accepted: %v", err)
}
workload.URL = "file:///tmp/item"
if err := workload.Validate(); err == nil || !strings.Contains(err.Error(), "http") {
t.Fatalf("non-http URL accepted: %v", err)
}
}

func TestExecutorRefValidateForProofRequiresDigestsForNonNativeExecutors(t *testing.T) {
ref := protocol.ExecutorRef{
Provider: "sandboxed-command",
Expand Down
Loading