diff --git a/protocol/types.go b/protocol/types.go index 4b29681..7a2846e 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -9,9 +9,11 @@ import ( "errors" "fmt" "mime" + "net" "net/netip" "net/url" "path" + "strconv" "strings" "time" @@ -1609,6 +1611,454 @@ func (r EnvRef) Validate() error { return errors.Join(errs...) } +type WorkloadSpec struct { + Kind WorkloadKind `json:"kind"` + Command *CommandWorkload `json:"command,omitempty"` + ContainerBuild *ContainerBuildWorkload `json:"container_build,omitempty"` + Service *ServiceWorkload `json:"service,omitempty"` + NodeService *NodeServiceWorkload `json:"node_service,omitempty"` + ProductCapture *ProductCaptureWorkload `json:"product_capture,omitempty"` + Provider *ProviderWorkload `json:"provider,omitempty"` + WASM *WASMWorkload `json:"wasm,omitempty"` + Params map[string]map[string]string `json:"params,omitempty"` +} + +func (w WorkloadSpec) Validate() error { + switch w.Kind { + case WorkloadCommand: + if w.Command == nil { + return errors.New("command workload is required") + } + return w.Command.Validate() + case WorkloadContainerBuild: + if w.ContainerBuild == nil { + return errors.New("container_build workload is required") + } + return w.ContainerBuild.Validate() + case WorkloadService: + if w.Service == nil { + return errors.New("service workload is required") + } + return w.Service.Validate() + case WorkloadNodeService: + if w.NodeService == nil { + return errors.New("node_service workload is required") + } + return w.NodeService.Validate() + case WorkloadProductCapture: + if w.ProductCapture == nil { + return errors.New("product_capture workload is required") + } + return w.ProductCapture.Validate() + case WorkloadProvider: + if w.Provider == nil { + return errors.New("provider workload is required") + } + return w.Provider.Validate() + case WorkloadWASMComponent: + if w.WASM == nil { + return errors.New("wasm workload is required") + } + return w.WASM.Validate() + case WorkloadDockerComposeBuild, WorkloadBenchmark, WorkloadTraining, WorkloadContentCache, WorkloadSupervisor: + return nil + default: + return fmt.Errorf("workload kind %q is unknown", w.Kind) + } +} + +type WorkloadFileRef struct { + Path string `json:"path"` + ValueRef string `json:"value_ref,omitempty"` + SecretRef string `json:"secret_ref,omitempty"` + Template string `json:"template,omitempty"` + Refs []EnvRef `json:"refs,omitempty"` + Mode uint32 `json:"mode,omitempty"` +} + +func (r WorkloadFileRef) Validate() error { + var errs []error + trimmedPath := strings.TrimSpace(r.Path) + cleanPath := path.Clean(trimmedPath) + if trimmedPath == "" { + errs = append(errs, errors.New("path is required")) + } else if trimmedPath != r.Path || strings.HasPrefix(trimmedPath, "/") || strings.Contains(trimmedPath, "\\") || cleanPath == "." || strings.HasPrefix(cleanPath, "../") || cleanPath == ".." { + errs = append(errs, errors.New("path must be a relative workspace path")) + } + directRefs := 0 + if r.ValueRef != "" { + directRefs++ + } + if r.SecretRef != "" { + directRefs++ + } + if r.Template == "" { + if directRefs == 0 { + errs = append(errs, errors.New("requires value_ref, secret_ref, or template")) + } + if len(r.Refs) > 0 { + errs = append(errs, errors.New("refs require template")) + } + } else if directRefs > 0 { + errs = append(errs, errors.New("template cannot be combined with value_ref or secret_ref")) + } + if directRefs > 1 { + errs = append(errs, errors.New("cannot set both value_ref and secret_ref")) + } + if r.Mode > 0o777 { + errs = append(errs, errors.New("mode must be an octal file permission <= 0777")) + } + for i, ref := range r.Refs { + if err := ref.Validate(); err != nil { + errs = append(errs, fmt.Errorf("refs[%d]: %w", i, err)) + } + } + return errors.Join(errs...) +} + +type ServiceStatusProbe struct { + Command []string `json:"command,omitempty"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` + MaxPreviewBytes int `json:"max_preview_bytes,omitempty"` +} + +func (p ServiceStatusProbe) Validate() error { + var errs []error + for i, arg := range p.Command { + if strings.TrimSpace(arg) == "" { + errs = append(errs, fmt.Errorf("command[%d] is required", i)) + } + } + if p.TimeoutSeconds < 0 { + errs = append(errs, errors.New("timeout_seconds must not be negative")) + } + if p.MaxPreviewBytes < 0 || p.MaxPreviewBytes > 4096 { + errs = append(errs, errors.New("max_preview_bytes must be between 0 and 4096")) + } + return errors.Join(errs...) +} + +type ServiceWorkload struct { + ImageRef string `json:"image_ref"` + ComponentRef string `json:"component_ref,omitempty"` + ComponentDigest string `json:"component_digest,omitempty"` + Command []string `json:"command,omitempty"` + Ports []ServicePort `json:"ports"` + HealthCheck HealthCheck `json:"health_check"` + Ingress IngressPolicy `json:"ingress"` + DataDirRef string `json:"data_dir_ref,omitempty"` + DataMountPath string `json:"data_mount_path,omitempty"` + Env []EnvRef `json:"env,omitempty"` + Files []WorkloadFileRef `json:"files,omitempty"` + StatusProbe ServiceStatusProbe `json:"status_probe,omitzero"` +} + +func (w ServiceWorkload) Validate() error { + var errs []error + errs = append(errs, validateImageOrComponentRef("service", w.ImageRef, w.ComponentRef, w.ComponentDigest)...) + headless := w.Ingress.Mode == "none" + if headless { + if len(w.Ports) != 0 { + errs = append(errs, errors.New("ports must be empty for headless service")) + } + if w.HealthCheck.Kind != "command" { + errs = append(errs, errors.New("headless service requires command health check")) + } + } else if len(w.Ports) == 0 { + errs = append(errs, errors.New("ports is required")) + } + for i, port := range w.Ports { + if err := port.Validate(); err != nil { + errs = append(errs, fmt.Errorf("ports[%d]: %w", i, err)) + } + } + if err := w.HealthCheck.Validate(); err != nil { + errs = append(errs, fmt.Errorf("health_check: %w", err)) + } + if err := w.Ingress.Validate(); err != nil { + errs = append(errs, fmt.Errorf("ingress: %w", err)) + } + if w.DataDirRef != "" { + if err := validateScopedRef("data_dir_ref", w.DataDirRef, "volume://"); err != nil { + errs = append(errs, err) + } + if strings.TrimSpace(w.DataMountPath) == "" { + errs = append(errs, errors.New("data_mount_path is required with data_dir_ref")) + } else if !strings.HasPrefix(w.DataMountPath, "/") || strings.Contains(w.DataMountPath, "..") { + errs = append(errs, errors.New("data_mount_path must be an absolute container path")) + } + } else if w.DataMountPath != "" { + errs = append(errs, errors.New("data_mount_path requires data_dir_ref")) + } + for i, ref := range w.Env { + if err := ref.Validate(); err != nil { + errs = append(errs, fmt.Errorf("env[%d]: %w", i, err)) + } + } + for i, ref := range w.Files { + if err := ref.Validate(); err != nil { + errs = append(errs, fmt.Errorf("files[%d]: %w", i, err)) + } + } + if err := w.StatusProbe.Validate(); err != nil { + errs = append(errs, fmt.Errorf("status_probe: %w", err)) + } + return errors.Join(errs...) +} + +type NodeServiceWorkload struct { + ImageRef string `json:"image_ref"` + ComponentRef string `json:"component_ref,omitempty"` + ComponentDigest string `json:"component_digest,omitempty"` + Chain string `json:"chain"` + Network string `json:"network"` + DataDirRef string `json:"data_dir_ref"` + RPCSecretRef string `json:"rpc_secret_ref"` + PeerPolicy PeerPolicy `json:"peer_policy"` + ArtifactRefs []string `json:"artifact_refs,omitempty"` + Command []string `json:"command,omitempty"` + HealthCheck HealthCheck `json:"health_check,omitzero"` + Env []EnvRef `json:"env,omitempty"` +} + +func (w NodeServiceWorkload) Validate() error { + var errs []error + errs = append(errs, validateImageOrComponentRef("node_service", w.ImageRef, w.ComponentRef, w.ComponentDigest)...) + if w.Chain == "" { + errs = append(errs, errors.New("chain is required")) + } + if w.Network == "" { + errs = append(errs, errors.New("network is required")) + } + if w.DataDirRef == "" { + errs = append(errs, errors.New("data_dir_ref is required")) + } else if err := validateScopedRef("data_dir_ref", w.DataDirRef, "volume://"); err != nil { + errs = append(errs, err) + } + if w.RPCSecretRef == "" { + errs = append(errs, errors.New("rpc_secret_ref is required")) + } else if err := validateScopedRef("rpc_secret_ref", w.RPCSecretRef, "secret://"); err != nil { + errs = append(errs, err) + } + if err := w.PeerPolicy.Validate(); err != nil { + errs = append(errs, fmt.Errorf("peer_policy: %w", err)) + } + if w.HealthCheck.IsZero() { + errs = append(errs, errors.New("health_check is required")) + } else if err := w.HealthCheck.Validate(); err != nil { + errs = append(errs, fmt.Errorf("health_check: %w", err)) + } + for i, ref := range w.Env { + if err := ref.Validate(); err != nil { + errs = append(errs, fmt.Errorf("env[%d]: %w", i, err)) + } + } + for i, ref := range w.ArtifactRefs { + if err := validateScopedRef(fmt.Sprintf("artifact_refs[%d]", i), ref, "artifact://"); err != nil { + errs = append(errs, err) + } + } + for i, arg := range w.Command { + if strings.TrimSpace(arg) == "" { + errs = append(errs, fmt.Errorf("command[%d] is required", i)) + } + } + return errors.Join(errs...) +} + +type ServicePort struct { + Name string `json:"name,omitempty"` + Port int `json:"port"` + Protocol string `json:"protocol"` +} + +func (p ServicePort) Validate() error { + var errs []error + if p.Port <= 0 || p.Port > 65535 { + errs = append(errs, errors.New("port must be between 1 and 65535")) + } + switch p.Protocol { + case "http", "https", "tcp": + default: + errs = append(errs, fmt.Errorf("protocol %q is unsupported", p.Protocol)) + } + return errors.Join(errs...) +} + +type HealthCheck struct { + Kind string `json:"kind"` + Path string `json:"path,omitempty"` + Command []string `json:"command,omitempty"` + IntervalSeconds int `json:"interval_seconds,omitempty"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` +} + +func (h HealthCheck) IsZero() bool { + return h.Kind == "" && h.Path == "" && len(h.Command) == 0 && h.IntervalSeconds == 0 && h.TimeoutSeconds == 0 +} + +func (h HealthCheck) Validate() error { + var errs []error + switch h.Kind { + case "http": + if h.Path == "" { + errs = append(errs, errors.New("path is required for http health check")) + } + case "command": + if len(h.Command) == 0 { + errs = append(errs, errors.New("command is required for command health check")) + } + default: + errs = append(errs, fmt.Errorf("kind %q is unsupported", h.Kind)) + } + if h.IntervalSeconds <= 0 { + errs = append(errs, errors.New("interval_seconds must be positive")) + } + if h.TimeoutSeconds <= 0 { + errs = append(errs, errors.New("timeout_seconds must be positive")) + } + if h.IntervalSeconds > 0 && h.TimeoutSeconds > 0 && h.TimeoutSeconds > h.IntervalSeconds { + errs = append(errs, errors.New("timeout_seconds must not exceed interval_seconds")) + } + return errors.Join(errs...) +} + +type IngressPolicy struct { + Mode string `json:"mode"` + AllowedCIDRs []string `json:"allowed_cidrs,omitempty"` + AuthRequired bool `json:"auth_required"` +} + +func (p IngressPolicy) Validate() error { + var errs []error + switch p.Mode { + case "none": + if p.AuthRequired { + errs = append(errs, errors.New("auth_required must be false for headless ingress")) + } + case "internal", "private", "public": + if !p.AuthRequired { + errs = append(errs, errors.New("auth_required must be true")) + } + default: + errs = append(errs, fmt.Errorf("mode %q is unsupported", p.Mode)) + } + for i, cidr := range p.AllowedCIDRs { + if _, err := netip.ParsePrefix(strings.TrimSpace(cidr)); err != nil { + errs = append(errs, fmt.Errorf("allowed_cidrs[%d] must be a valid CIDR prefix: %w", i, err)) + } + } + return errors.Join(errs...) +} + +type PeerPolicy struct { + Mode string `json:"mode"` + AllowedPeers []string `json:"allowed_peers,omitempty"` + EgressAllowlist []string `json:"egress_allowlist,omitempty"` +} + +func (p PeerPolicy) Validate() error { + switch p.Mode { + case "isolated": + return nil + case "public-chain-peer": + var errs []error + for i, endpoint := range p.EgressAllowlist { + if _, err := validatePeerEndpoint("egress_allowlist", i, endpoint); err != nil { + errs = append(errs, err) + } + } + for i, peer := range p.AllowedPeers { + if _, err := validatePeerEndpoint("allowed_peers", i, peer); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + case "allowlist": + var errs []error + if len(p.AllowedPeers) == 0 { + errs = append(errs, errors.New("allowed_peers is required for allowlist mode")) + } + if len(p.EgressAllowlist) == 0 { + errs = append(errs, errors.New("egress_allowlist is required for allowlist mode")) + } + egress := map[string]struct{}{} + for i, endpoint := range p.EgressAllowlist { + normalized, err := validatePeerEndpoint("egress_allowlist", i, endpoint) + if err != nil { + errs = append(errs, err) + continue + } + egress[normalized] = struct{}{} + } + for i, peer := range p.AllowedPeers { + normalized, err := validatePeerEndpoint("allowed_peers", i, peer) + if err != nil { + errs = append(errs, err) + continue + } + if _, ok := egress[normalized]; !ok { + errs = append(errs, fmt.Errorf("allowed_peers[%d] must be present in egress_allowlist", i)) + } + } + return errors.Join(errs...) + default: + return fmt.Errorf("mode %q is unsupported", p.Mode) + } +} + +func validateImageOrComponentRef(kind, imageRef, componentRef, componentDigest string) []error { + var errs []error + imageRefSet := strings.TrimSpace(imageRef) != "" + componentRefSet := strings.TrimSpace(componentRef) != "" + if imageRefSet && componentRefSet { + errs = append(errs, errors.New("image_ref and component_ref are mutually exclusive")) + } + if !imageRefSet && !componentRefSet { + errs = append(errs, errors.New("image_ref or component_ref is required")) + } + if imageRefSet && (strings.TrimSpace(imageRef) != imageRef || strings.ContainsAny(imageRef, "\t\r\n \x00")) { + errs = append(errs, fmt.Errorf("%s image_ref must not contain whitespace or NUL", kind)) + } + if componentRefSet { + if err := validateProviderComponentRef("component_ref", componentRef); err != nil { + errs = append(errs, err) + } + if !validSHA256Digest(componentDigest) { + errs = append(errs, errors.New("component_digest must be sha256:<64 hex chars>")) + } + } else if strings.TrimSpace(componentDigest) != "" { + errs = append(errs, errors.New("component_digest requires component_ref")) + } + return errs +} + +func validatePeerEndpoint(field string, index int, endpoint string) (string, error) { + endpoint = strings.TrimSpace(strings.ToLower(endpoint)) + if endpoint == "" { + return "", fmt.Errorf("%s[%d] is required", field, index) + } + if strings.Contains(endpoint, "://") || strings.ContainsAny(endpoint, "/?#@") || strings.Contains(endpoint, "*") { + return "", fmt.Errorf("%s[%d] must be host:port without scheme, path, userinfo, query, fragment, or wildcard", field, index) + } + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + return "", fmt.Errorf("%s[%d] must be host:port: %w", field, index, err) + } + host = strings.TrimSpace(strings.Trim(host, "[]")) + if host == "" || port == "" { + return "", fmt.Errorf("%s[%d] must include host and port", field, index) + } + portNumber, err := strconv.Atoi(port) + if err != nil || portNumber <= 0 || portNumber > 65535 { + return "", fmt.Errorf("%s[%d] port must be numeric and between 1 and 65535", field, index) + } + if strings.ContainsAny(host, " \t\r\n:") && net.ParseIP(host) == nil { + return "", fmt.Errorf("%s[%d] host is invalid", field, index) + } + return net.JoinHostPort(host, strconv.Itoa(portNumber)), nil +} + type ConfidentialPayloadRef struct { CiphertextRef string `json:"ciphertext_ref"` CiphertextHash string `json:"ciphertext_hash"` diff --git a/protocol/types_test.go b/protocol/types_test.go index cac55f6..b293f40 100644 --- a/protocol/types_test.go +++ b/protocol/types_test.go @@ -92,6 +92,155 @@ func TestRuntimeExecutionRequestRejectsMalformedInvocation(t *testing.T) { } } +func TestWorkloadSpecValidatesServiceWorkload(t *testing.T) { + workload := protocol.WorkloadSpec{ + Kind: protocol.WorkloadService, + Service: &protocol.ServiceWorkload{ + ComponentRef: "provider://workflow-plugin-compute-service/service-runtime", + ComponentDigest: "sha256:" + strings.Repeat("a", 64), + Command: []string{"serve", "--port", "8080"}, + Ports: []protocol.ServicePort{{Name: "http", Port: 8080, Protocol: "http"}}, + HealthCheck: protocol.HealthCheck{ + Kind: "http", + Path: "/healthz", + IntervalSeconds: 5, + TimeoutSeconds: 2, + }, + Ingress: protocol.IngressPolicy{Mode: "private", AuthRequired: true}, + Env: []protocol.EnvRef{{Name: "PORT", ValueRef: "config://service/port"}}, + Files: []protocol.WorkloadFileRef{{ + Path: "config/app.toml", + Template: "port={{ .PORT }}", + Refs: []protocol.EnvRef{{Name: "PORT", ValueRef: "config://service/port"}}, + Mode: 0o640, + }}, + }, + } + + if err := workload.Validate(); err != nil { + t.Fatalf("service workload invalid: %v", err) + } +} + +func TestServiceWorkloadRejectsUnsafeShape(t *testing.T) { + workload := protocol.ServiceWorkload{ + ImageRef: "repo/service:latest bad", + ComponentRef: "provider://workflow-plugin-compute-service/service-runtime", + Ports: []protocol.ServicePort{{Port: 8080, Protocol: "http"}}, + HealthCheck: protocol.HealthCheck{Kind: "http", Path: "/healthz", IntervalSeconds: 5, TimeoutSeconds: 2}, + Ingress: protocol.IngressPolicy{Mode: "none", AllowedCIDRs: []string{"not-a-cidr"}}, + DataDirRef: "volume://service-data", + DataMountPath: "../data", + } + + err := workload.Validate() + if err == nil { + t.Fatal("expected unsafe service workload to fail") + } + for _, want := range []string{ + "mutually exclusive", + "whitespace or NUL", + "component_digest", + "ports must be empty", + "headless service requires command health check", + "allowed_cidrs", + "data_mount_path", + } { + if !strings.Contains(err.Error(), want) { + t.Fatalf("Validate() = %v, want %q", err, want) + } + } +} + +func TestWorkloadSpecValidatesNodeServiceWorkload(t *testing.T) { + endpoint := "node.example.invalid:30303" + workload := protocol.WorkloadSpec{ + Kind: protocol.WorkloadNodeService, + NodeService: &protocol.NodeServiceWorkload{ + ImageRef: "ghcr.io/gocodealone/node@sha256:" + strings.Repeat("b", 64), + Chain: "ethereum", + Network: "sepolia", + DataDirRef: "volume://nodes/sepolia", + RPCSecretRef: "secret://nodes/sepolia/rpc", + PeerPolicy: protocol.PeerPolicy{ + Mode: "allowlist", + AllowedPeers: []string{endpoint}, + EgressAllowlist: []string{endpoint}, + }, + ArtifactRefs: []string{"artifact://snapshots/sepolia"}, + Command: []string{"node", "--network", "sepolia"}, + HealthCheck: protocol.HealthCheck{ + Kind: "command", + Command: []string{"node", "status"}, + IntervalSeconds: 10, + TimeoutSeconds: 3, + }, + Env: []protocol.EnvRef{{Name: "RPC_TOKEN", SecretRef: "secret://nodes/sepolia/rpc"}}, + }, + } + + if err := workload.Validate(); err != nil { + t.Fatalf("node-service workload invalid: %v", err) + } +} + +func TestNodeServiceWorkloadRejectsMissingHealthAndUnsafePeers(t *testing.T) { + workload := protocol.NodeServiceWorkload{ + ImageRef: "ghcr.io/gocodealone/node@sha256:" + strings.Repeat("c", 64), + Chain: "ethereum", + Network: "mainnet", + DataDirRef: "volume://nodes/mainnet", + RPCSecretRef: "secret://nodes/mainnet/rpc", + PeerPolicy: protocol.PeerPolicy{ + Mode: "allowlist", + AllowedPeers: []string{"https://peer.example.invalid:30303/path", "peer.example.invalid:notaport"}, + EgressAllowlist: []string{"other.example.invalid:30303"}, + }, + } + + err := workload.Validate() + if err == nil { + t.Fatal("expected unsafe node-service workload to fail") + } + for _, want := range []string{"peer_policy", "allowed_peers", "port", "health_check"} { + if !strings.Contains(err.Error(), want) { + t.Fatalf("Validate() = %v, want %q", err, want) + } + } +} + +func TestWorkloadSpecAllowsReservedHostWorkloadKinds(t *testing.T) { + for _, kind := range []protocol.WorkloadKind{protocol.WorkloadContentCache, protocol.WorkloadSupervisor} { + if err := (protocol.WorkloadSpec{Kind: kind}).Validate(); err != nil { + t.Fatalf("%s workload spec rejected: %v", kind, err) + } + } +} + +func TestWorkloadFileRefRejectsTrimmedAbsolutePath(t *testing.T) { + ref := protocol.WorkloadFileRef{ + Path: " /etc/passwd", + ValueRef: "config://service/file", + } + + err := ref.Validate() + if err == nil || !strings.Contains(err.Error(), "relative workspace path") { + t.Fatalf("absolute path with leading whitespace accepted: %v", err) + } +} + +func TestPeerPolicyCanonicalizesNumericPort(t *testing.T) { + policy := protocol.PeerPolicy{ + Mode: "allowlist", + AllowedPeers: []string{"node.example.invalid:030303"}, + EgressAllowlist: []string{"node.example.invalid:30303"}, + } + + if err := policy.Validate(); err != nil { + t.Fatalf("equivalent numeric peer ports rejected: %v", err) + } +} + func TestServiceIngressEvidenceValidation(t *testing.T) { evidence := validServiceIngressEvidence() if err := evidence.Validate(); err != nil {