diff --git a/protocol/types.go b/protocol/types.go index 3c2b5ea..f62cfe0 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "mime" + "net/netip" "path" "strings" "time" @@ -1149,6 +1150,203 @@ func (r RuntimeServiceResult) Validate() error { return errors.Join(errs...) } +type ServiceIngressTerminalStatus string + +const ( + ServiceIngressTerminalCompleted ServiceIngressTerminalStatus = "completed" + ServiceIngressTerminalFailed ServiceIngressTerminalStatus = "failed" +) + +type ServiceIngressEvidence struct { + ID string `json:"id"` + OrgID string `json:"org_id"` + PoolID string `json:"pool_id"` + ProductID string `json:"product_id"` + Hostname string `json:"hostname"` + RouteTarget string `json:"route_target"` + ServiceLeaseID string `json:"service_lease_id"` + TaskID string `json:"task_id"` + WorkerID string `json:"worker_id"` + LeaseLeasedAt time.Time `json:"lease_leased_at"` + LeaseRenewBy time.Time `json:"lease_renew_by"` + SelectedAt time.Time `json:"selected_at"` + LastHealthAt time.Time `json:"last_health_at"` + HealthValidUntil time.Time `json:"health_valid_until"` + LastHealthResponseHash string `json:"last_health_response_hash"` + LastHealthSLOEvidenceHash string `json:"last_health_slo_evidence_hash"` + AuthDecisionHash string `json:"auth_decision_hash"` + IdempotencyKey string `json:"idempotency_key"` + RequestMethod string `json:"request_method"` + RequestPath string `json:"request_path"` + RequestBodyHash string `json:"request_body_hash"` + RequestHeaderNames []string `json:"request_header_names,omitempty"` + HelperImage string `json:"helper_image"` + HelperScheme string `json:"helper_scheme"` + HelperHost string `json:"helper_host"` + HelperPort int `json:"helper_port"` + HelperPortName string `json:"helper_port_name,omitempty"` + HelperTimeoutMS int `json:"helper_timeout_ms"` + HelperContainerNetNSTarget string `json:"helper_container_netns_target"` + HelperOutputHash string `json:"helper_output_hash,omitempty"` + HelperErrorHash string `json:"helper_error_hash,omitempty"` + FailureClass string `json:"failure_class,omitempty"` + FailureMessage string `json:"failure_message,omitempty"` + ResponseStatus int `json:"response_status,omitempty"` + ResponseHeaderNames []string `json:"response_header_names,omitempty"` + ResponseHash string `json:"response_hash,omitempty"` + ResponseBytes int64 `json:"response_bytes,omitempty"` + TerminalStatus ServiceIngressTerminalStatus `json:"terminal_status"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` +} + +func (e ServiceIngressEvidence) Validate() error { + var errs []error + require := func(name, value string) { + if strings.TrimSpace(value) == "" { + errs = append(errs, fmt.Errorf("%s is required", name)) + } + } + require("id", e.ID) + require("org_id", e.OrgID) + require("pool_id", e.PoolID) + require("product_id", e.ProductID) + require("hostname", e.Hostname) + require("route_target", e.RouteTarget) + require("service_lease_id", e.ServiceLeaseID) + require("task_id", e.TaskID) + require("worker_id", e.WorkerID) + require("last_health_response_hash", e.LastHealthResponseHash) + require("last_health_slo_evidence_hash", e.LastHealthSLOEvidenceHash) + require("auth_decision_hash", e.AuthDecisionHash) + require("idempotency_key", e.IdempotencyKey) + require("request_method", e.RequestMethod) + require("request_path", e.RequestPath) + require("helper_image", e.HelperImage) + require("helper_scheme", e.HelperScheme) + require("helper_host", e.HelperHost) + require("helper_container_netns_target", e.HelperContainerNetNSTarget) + for _, field := range []struct { + name string + value string + }{ + {name: "id", value: e.ID}, + {name: "org_id", value: e.OrgID}, + {name: "pool_id", value: e.PoolID}, + {name: "product_id", value: e.ProductID}, + {name: "service_lease_id", value: e.ServiceLeaseID}, + {name: "task_id", value: e.TaskID}, + {name: "worker_id", value: e.WorkerID}, + {name: "idempotency_key", value: e.IdempotencyKey}, + } { + if field.value != "" { + if err := validateIdentifier(field.name, field.value); err != nil { + errs = append(errs, err) + } + } + } + if e.Hostname != "" { + if err := validateNetworkHostname(e.Hostname); err != nil { + errs = append(errs, fmt.Errorf("hostname: %w", err)) + } + } + if e.RouteTarget != "" && !strings.HasPrefix(e.RouteTarget, "service-route:") { + errs = append(errs, errors.New("route_target must be opaque service-route target")) + } + switch e.RequestMethod { + case "GET", "POST": + case "": + default: + errs = append(errs, fmt.Errorf("request_method %q is unsupported", e.RequestMethod)) + } + if e.RequestPath != "" { + if !strings.HasPrefix(e.RequestPath, "/") || strings.Contains(e.RequestPath, "://") { + errs = append(errs, errors.New("request_path must be queryless origin-form path")) + } + if strings.ContainsAny(e.RequestPath, "?#") { + errs = append(errs, errors.New("request_path must not contain query or fragment")) + } + } + if e.HelperScheme != "" && e.HelperScheme != "http" { + errs = append(errs, errors.New("helper_scheme must be http")) + } + if e.HelperHost != "" && e.HelperHost != "127.0.0.1" { + errs = append(errs, errors.New("helper_host must be 127.0.0.1")) + } + if e.HelperPort <= 0 || e.HelperPort > 65535 { + errs = append(errs, errors.New("helper_port must be between 1 and 65535")) + } + if e.HelperTimeoutMS <= 0 { + errs = append(errs, errors.New("helper_timeout_ms is required")) + } + for _, field := range []struct { + name string + value string + }{ + {name: "last_health_response_hash", value: e.LastHealthResponseHash}, + {name: "last_health_slo_evidence_hash", value: e.LastHealthSLOEvidenceHash}, + {name: "auth_decision_hash", value: e.AuthDecisionHash}, + {name: "request_body_hash", value: e.RequestBodyHash}, + {name: "helper_output_hash", value: e.HelperOutputHash}, + {name: "helper_error_hash", value: e.HelperErrorHash}, + {name: "response_hash", value: e.ResponseHash}, + } { + if field.value != "" && !validSHA256Digest(field.value) { + errs = append(errs, fmt.Errorf("%s must use sha256 digest", field.name)) + } + } + switch e.TerminalStatus { + case ServiceIngressTerminalCompleted: + if e.ResponseStatus <= 0 { + errs = append(errs, errors.New("response_status is required for completed ingress")) + } + if e.ResponseHash == "" { + errs = append(errs, errors.New("response_hash is required for completed ingress")) + } + case ServiceIngressTerminalFailed: + if strings.TrimSpace(e.FailureClass) == "" { + errs = append(errs, errors.New("failure_class is required for failed ingress")) + } + case "": + errs = append(errs, errors.New("terminal_status is required")) + default: + errs = append(errs, fmt.Errorf("terminal_status %q is unsupported", e.TerminalStatus)) + } + if e.ResponseBytes < 0 { + errs = append(errs, errors.New("response_bytes must be non-negative")) + } + for i, name := range e.RequestHeaderNames { + if err := validateIngressHeaderName(name); err != nil { + errs = append(errs, fmt.Errorf("request_header_names[%d]: %w", i, err)) + } + } + for i, name := range e.ResponseHeaderNames { + if err := validateIngressHeaderName(name); err != nil { + errs = append(errs, fmt.Errorf("response_header_names[%d]: %w", i, err)) + } + } + for _, field := range []struct { + name string + ts time.Time + }{ + {name: "lease_leased_at", ts: e.LeaseLeasedAt}, + {name: "lease_renew_by", ts: e.LeaseRenewBy}, + {name: "selected_at", ts: e.SelectedAt}, + {name: "last_health_at", ts: e.LastHealthAt}, + {name: "health_valid_until", ts: e.HealthValidUntil}, + {name: "started_at", ts: e.StartedAt}, + {name: "finished_at", ts: e.FinishedAt}, + } { + if field.ts.IsZero() { + errs = append(errs, fmt.Errorf("%s is required", field.name)) + } + } + if !e.StartedAt.IsZero() && !e.FinishedAt.IsZero() && e.FinishedAt.Before(e.StartedAt) { + errs = append(errs, errors.New("finished_at must be after started_at")) + } + return errors.Join(errs...) +} + type NetworkMode string const ( @@ -2935,6 +3133,48 @@ func validSHA256Ref(value string) bool { return err == nil } +func validSHA256Digest(value string) bool { + if !strings.HasPrefix(value, "sha256:") { + return false + } + hexPart := strings.TrimPrefix(value, "sha256:") + if len(hexPart) != 64 { + return false + } + _, err := hex.DecodeString(hexPart) + return err == nil && strings.ToLower(hexPart) == hexPart +} + +func validateNetworkHostname(hostname string) error { + hostname = strings.TrimSpace(hostname) + if hostname == "" { + return errors.New("hostname is required") + } + if strings.ContainsAny(hostname, " \t\r\n/:?&#") { + return errors.New("hostname must not contain whitespace, scheme, path, query, or fragment") + } + if _, err := netip.ParseAddr(hostname); err == nil { + return errors.New("hostname must not be an IP address") + } + return nil +} + +func validateIngressHeaderName(name string) error { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "" { + return errors.New("header name is required") + } + if strings.ContainsAny(name, " \t\r\n:") { + return fmt.Errorf("header %q is invalid", name) + } + switch name { + case "authorization", "cookie", "proxy-authorization", "x-api-key", "x-auth-token", "x-forwarded-authorization": + return fmt.Errorf("header %q is not allowed", name) + default: + return nil + } +} + func validateIdentifier(name, id string) error { trimmed := strings.TrimSpace(id) if trimmed == "" { diff --git a/protocol/types_test.go b/protocol/types_test.go index ebce48f..e76dd92 100644 --- a/protocol/types_test.go +++ b/protocol/types_test.go @@ -92,6 +92,88 @@ func TestRuntimeExecutionRequestRejectsMalformedInvocation(t *testing.T) { } } +func TestServiceIngressEvidenceValidation(t *testing.T) { + evidence := validServiceIngressEvidence() + if err := evidence.Validate(); err != nil { + t.Fatalf("valid ingress evidence rejected: %v", err) + } + + cases := []struct { + name string + mut func(*protocol.ServiceIngressEvidence) + want string + }{ + {"missing auth decision", func(e *protocol.ServiceIngressEvidence) { e.AuthDecisionHash = "" }, "auth_decision_hash"}, + {"missing health response", func(e *protocol.ServiceIngressEvidence) { e.LastHealthResponseHash = "" }, "last_health_response_hash"}, + {"missing helper target", func(e *protocol.ServiceIngressEvidence) { e.HelperContainerNetNSTarget = "" }, "helper_container_netns_target"}, + {"unsafe helper host", func(e *protocol.ServiceIngressEvidence) { e.HelperHost = "10.0.0.5" }, "helper_host"}, + {"queryful request path", func(e *protocol.ServiceIngressEvidence) { e.RequestPath = "/compile?token=secret" }, "query"}, + {"unsupported method", func(e *protocol.ServiceIngressEvidence) { e.RequestMethod = "PUT" }, "request_method"}, + {"failed without class", func(e *protocol.ServiceIngressEvidence) { + e.TerminalStatus = protocol.ServiceIngressTerminalFailed + e.ResponseStatus = 0 + e.ResponseHash = "" + e.FailureClass = "" + }, "failure_class"}, + {"secret request header", func(e *protocol.ServiceIngressEvidence) { e.RequestHeaderNames = []string{"authorization"} }, "header"}, + {"secret response header", func(e *protocol.ServiceIngressEvidence) { e.ResponseHeaderNames = []string{"cookie"} }, "header"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := evidence + got.RequestHeaderNames = slices.Clone(evidence.RequestHeaderNames) + got.ResponseHeaderNames = slices.Clone(evidence.ResponseHeaderNames) + tc.mut(&got) + if err := got.Validate(); err == nil || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("Validate() error = %v, want containing %q", err, tc.want) + } + }) + } +} + +func validServiceIngressEvidence() protocol.ServiceIngressEvidence { + return protocol.ServiceIngressEvidence{ + ID: "ing-evidence-1", + OrgID: "org-1", + PoolID: "pool-1", + ProductID: "edge-product", + Hostname: "edge.example.invalid", + RouteTarget: "service-route:abc123", + ServiceLeaseID: "svc-1", + TaskID: "task-1", + WorkerID: "worker-1", + LeaseLeasedAt: time.Unix(100, 0).UTC(), + LeaseRenewBy: time.Unix(200, 0).UTC(), + SelectedAt: time.Unix(101, 0).UTC(), + LastHealthAt: time.Unix(102, 0).UTC(), + HealthValidUntil: time.Unix(132, 0).UTC(), + LastHealthResponseHash: "sha256:" + strings.Repeat("b", 64), + LastHealthSLOEvidenceHash: "sha256:" + strings.Repeat("c", 64), + AuthDecisionHash: "sha256:" + strings.Repeat("d", 64), + IdempotencyKey: "idem-1", + RequestMethod: "POST", + RequestPath: "/compile", + RequestBodyHash: "sha256:" + strings.Repeat("e", 64), + RequestHeaderNames: []string{"content-type"}, + HelperImage: "ingress-helper@sha256:" + strings.Repeat("f", 64), + HelperScheme: "http", + HelperHost: "127.0.0.1", + HelperPort: 8080, + HelperPortName: "http", + HelperTimeoutMS: 1000, + HelperContainerNetNSTarget: "container:svc-container", + HelperOutputHash: "sha256:" + strings.Repeat("1", 64), + HelperErrorHash: "sha256:" + strings.Repeat("2", 64), + ResponseStatus: 200, + ResponseHeaderNames: []string{"content-type"}, + ResponseHash: "sha256:" + strings.Repeat("3", 64), + ResponseBytes: 42, + TerminalStatus: protocol.ServiceIngressTerminalCompleted, + StartedAt: time.Unix(103, 0).UTC(), + FinishedAt: time.Unix(104, 0).UTC(), + } +} + func TestRuntimeExecutionResultValidatesTimingAndPreview(t *testing.T) { result := protocol.RuntimeExecutionResult{ StartedAt: time.Unix(10, 0).UTC(),