From c76306ba35d9781d4e2bbe89228ff71418704cd8 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sat, 30 May 2026 21:12:15 -0400 Subject: [PATCH] feat: add provider artifact return contracts --- protocol/types.go | 220 +++++++++++++++++++++++++++++++++++++++-- protocol/types_test.go | 117 ++++++++++++++++++++++ 2 files changed, 331 insertions(+), 6 deletions(-) diff --git a/protocol/types.go b/protocol/types.go index 9b88cc9..26449ea 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -2457,6 +2457,11 @@ func (o ProviderOperation) Validate() error { if spec.RetentionSeconds < 0 { errs = append(errs, fmt.Errorf("artifact_specs[%d].retention_seconds must not be negative", i)) } + if spec.ProviderReturn != nil { + if err := spec.ProviderReturn.Validate(); err != nil { + errs = append(errs, fmt.Errorf("artifact_specs[%d].provider_return: %w", i, err)) + } + } } return errors.Join(errs...) } @@ -2478,12 +2483,207 @@ func (o ProviderOperation) NormalizedArtifactSpecs() []ProviderArtifactSpec { } type ProviderArtifactSpec struct { - Name string `json:"name"` - Required bool `json:"required,omitempty"` - ContentType string `json:"content_type,omitempty"` - MaxBytes int64 `json:"max_bytes,omitempty"` - RetentionSeconds int `json:"retention_seconds,omitempty"` - Forwardable bool `json:"forwardable,omitempty"` + Name string `json:"name"` + Required bool `json:"required,omitempty"` + ContentType string `json:"content_type,omitempty"` + MaxBytes int64 `json:"max_bytes,omitempty"` + RetentionSeconds int `json:"retention_seconds,omitempty"` + Forwardable bool `json:"forwardable,omitempty"` + ProviderReturn *ProviderArtifactReturnSpec `json:"provider_return,omitempty"` +} + +type ProviderArtifactReturnSpec struct { + StepType string `json:"step_type"` + Contract string `json:"contract"` + ContractVersion string `json:"contract_version,omitempty"` + SubmitEndpoint string `json:"submit_endpoint,omitempty"` + RequiredConfig []string `json:"required_config,omitempty"` + OutputHandling []string `json:"output_handling,omitempty"` +} + +func (s ProviderArtifactReturnSpec) Enabled() bool { + return s.StepType != "" || s.Contract != "" || s.ContractVersion != "" || s.SubmitEndpoint != "" || len(s.RequiredConfig) > 0 || len(s.OutputHandling) > 0 +} + +func (s ProviderArtifactReturnSpec) Validate() error { + var errs []error + if err := validateIdentifier("step_type", s.StepType); err != nil { + errs = append(errs, err) + } + if strings.TrimSpace(s.Contract) == "" { + errs = append(errs, errors.New("contract is required")) + } else if strings.TrimSpace(s.Contract) != s.Contract || strings.ContainsAny(s.Contract, " \t\r\n\x00") || strings.Contains(s.Contract, "://") { + errs = append(errs, errors.New("contract must be a typed plugin contract without whitespace or URL scheme")) + } else if !strings.Contains(s.Contract, ":") { + errs = append(errs, errors.New("contract must include plugin and step type")) + } + if s.ContractVersion != "" && strings.ContainsAny(s.ContractVersion, " \t\r\n\x00") { + errs = append(errs, errors.New("contract_version must not contain whitespace")) + } + if s.SubmitEndpoint != "" && !validProviderReturnSubmitEndpoint(s.SubmitEndpoint) { + errs = append(errs, errors.New("submit_endpoint must be a server-relative path")) + } + for i, value := range s.RequiredConfig { + if err := validateIdentifier(fmt.Sprintf("required_config[%d]", i), value); err != nil { + errs = append(errs, err) + } + } + for i, value := range s.OutputHandling { + if err := validateIdentifier(fmt.Sprintf("output_handling[%d]", i), value); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +type ProviderArtifactDeliveryStatus string + +const ( + ProviderArtifactDeliveryPending ProviderArtifactDeliveryStatus = "pending" + ProviderArtifactDeliveryDelivered ProviderArtifactDeliveryStatus = "delivered" + ProviderArtifactDeliveryFailed ProviderArtifactDeliveryStatus = "failed" + ProviderArtifactDeliveryDeadLetter ProviderArtifactDeliveryStatus = "dead_letter" +) + +type ProviderArtifactDeliveryArtifact struct { + Name string `json:"name"` + Ref string `json:"ref"` + ContentType string `json:"content_type,omitempty"` + SHA256 string `json:"sha256"` + SizeBytes int64 `json:"size_bytes"` + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +func (a ProviderArtifactDeliveryArtifact) Validate() error { + var errs []error + if !validProviderArtifactName(a.Name) { + errs = append(errs, errors.New("artifact.name is invalid")) + } + if strings.TrimSpace(a.Ref) == "" { + errs = append(errs, errors.New("artifact.ref is required")) + } else if err := validateScopedRef("artifact.ref", a.Ref, "artifact://"); err != nil { + errs = append(errs, err) + } + if a.ContentType != "" { + if strings.TrimSpace(a.ContentType) != a.ContentType || strings.ContainsAny(a.ContentType, "\x00\r\n\t") { + errs = append(errs, errors.New("artifact.content_type is invalid")) + } else if _, _, err := mime.ParseMediaType(a.ContentType); err != nil { + errs = append(errs, errors.New("artifact.content_type is invalid")) + } + } + if !validSHA256Ref(a.SHA256) { + errs = append(errs, errors.New("artifact.sha256 must be sha256:<64 hex chars>")) + } + if a.SizeBytes < 0 { + errs = append(errs, errors.New("artifact.size_bytes must not be negative")) + } + return errors.Join(errs...) +} + +type ProviderArtifactDelivery struct { + ProtocolVersion string `json:"protocol_version,omitempty"` + ID string `json:"id"` + OrgID string `json:"org_id"` + PoolID string `json:"pool_id"` + ProductID string `json:"product_id,omitempty"` + TaskID string `json:"task_id"` + ProofID string `json:"proof_id"` + WorkerID string `json:"worker_id"` + ProviderConfig ProviderConfig `json:"provider_config"` + Operation string `json:"operation"` + ReturnSpec ProviderArtifactReturnSpec `json:"provider_return"` + Artifact ProviderArtifactDeliveryArtifact `json:"artifact"` + Status ProviderArtifactDeliveryStatus `json:"status"` + Attempts int `json:"attempts,omitempty"` + LastErrorHash string `json:"last_error_hash,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (d ProviderArtifactDelivery) Validate() error { + var errs []error + if d.ProtocolVersion != "" && d.ProtocolVersion != Version { + errs = append(errs, fmt.Errorf("protocol_version must be %q", Version)) + } + for _, field := range []struct { + name string + value string + }{ + {name: "id", value: d.ID}, + {name: "org_id", value: d.OrgID}, + {name: "pool_id", value: d.PoolID}, + {name: "task_id", value: d.TaskID}, + {name: "proof_id", value: d.ProofID}, + {name: "worker_id", value: d.WorkerID}, + {name: "operation", value: d.Operation}, + } { + if err := validateIdentifier(field.name, field.value); err != nil { + errs = append(errs, err) + } + } + if d.ProductID != "" { + if err := validateIdentifier("product_id", d.ProductID); err != nil { + errs = append(errs, err) + } + } + if err := d.ProviderConfig.Validate(); err != nil { + errs = append(errs, fmt.Errorf("provider_config: %w", err)) + } + if err := d.ReturnSpec.Validate(); err != nil { + errs = append(errs, fmt.Errorf("provider_return: %w", err)) + } + if err := d.Artifact.Validate(); err != nil { + errs = append(errs, err) + } + if !validProviderArtifactDeliveryStatus(d.Status) { + errs = append(errs, fmt.Errorf("status %q is unsupported", d.Status)) + } + if d.Attempts < 0 { + errs = append(errs, errors.New("attempts must not be negative")) + } + if d.LastErrorHash != "" && !validSHA256Ref(d.LastErrorHash) { + errs = append(errs, errors.New("last_error_hash must be sha256:<64 hex chars>")) + } + if d.CreatedAt.IsZero() { + errs = append(errs, errors.New("created_at is required")) + } + if d.UpdatedAt.IsZero() { + errs = append(errs, errors.New("updated_at is required")) + } + return errors.Join(errs...) +} + +type ProviderArtifactDeliveryStatusUpdate struct { + ProtocolVersion string `json:"protocol_version,omitempty"` + DeliveryID string `json:"delivery_id"` + Status ProviderArtifactDeliveryStatus `json:"status"` + ProviderRef string `json:"provider_ref,omitempty"` + ErrorHash string `json:"error_hash,omitempty"` + ObservedAt time.Time `json:"observed_at,omitempty"` +} + +type ProviderArtifactDeliveryAction struct { + ProtocolVersion string `json:"protocol_version,omitempty"` + DeliveryID string `json:"delivery_id"` + PluginID string `json:"plugin_id"` + ProviderID string `json:"provider_id"` + ContractID string `json:"contract_id"` + ContractVersion string `json:"contract_version,omitempty"` + StepType string `json:"step_type"` + Operation string `json:"operation"` + SubmitEndpoint string `json:"submit_endpoint,omitempty"` + Artifact ProviderArtifactDeliveryArtifact `json:"artifact"` + TaskID string `json:"task_id"` + ProofID string `json:"proof_id"` +} + +func validProviderArtifactDeliveryStatus(status ProviderArtifactDeliveryStatus) bool { + switch status { + case ProviderArtifactDeliveryPending, ProviderArtifactDeliveryDelivered, ProviderArtifactDeliveryFailed, ProviderArtifactDeliveryDeadLetter: + return true + default: + return false + } } type ProviderRuntimeContract struct { @@ -3928,6 +4128,14 @@ func validProviderArtifactName(name string) bool { name != ".." } +func validProviderReturnSubmitEndpoint(value string) bool { + return strings.HasPrefix(value, "/") && + !strings.Contains(value, "://") && + !strings.ContainsAny(value, " \t\r\n\x00?#") && + path.Clean(value) == value && + !strings.Contains(value, "..") +} + func ProviderPluginRequiresUpstreamClientConformance(pluginID string) bool { switch pluginID { case "workflow-plugin-volunteer-science", "workflow-plugin-crypto": diff --git a/protocol/types_test.go b/protocol/types_test.go index e76dd92..bb2417b 100644 --- a/protocol/types_test.go +++ b/protocol/types_test.go @@ -1175,6 +1175,123 @@ func TestProviderContractAcceptsAccessScopedProviderOperations(t *testing.T) { } } +func TestProviderOperationAcceptsProviderReturnArtifactIntent(t *testing.T) { + operation := protocol.ProviderOperation{ + ID: "build", + InputSchemaRef: "schema://providers/example/operations/build/input/v1", + InputSchemaDigest: protocol.CanonicalHash("input"), + OutputSchemaRef: "schema://providers/example/operations/build/output/v1", + OutputSchemaDigest: protocol.CanonicalHash("output"), + ArtifactSpecs: []protocol.ProviderArtifactSpec{{ + Name: "provenance", + Required: true, + ContentType: "application/json", + ProviderReturn: &protocol.ProviderArtifactReturnSpec{ + StepType: "step.provider_artifact_return", + Contract: "workflow-plugin-ci:step.provider_artifact_return", + ContractVersion: "v1", + SubmitEndpoint: "/v1/provider-return/artifact-deliveries", + }, + }}, + } + + if err := operation.Validate(); err != nil { + t.Fatalf("operation invalid: %v", err) + } + specs := operation.NormalizedArtifactSpecs() + if len(specs) != 1 || specs[0].ProviderReturn == nil || !specs[0].ProviderReturn.Enabled() { + t.Fatalf("provider return intent not preserved: %#v", specs) + } +} + +func TestProviderOperationRejectsMalformedProviderReturnArtifactIntent(t *testing.T) { + valid := protocol.ProviderOperation{ + ID: "build", + InputSchemaRef: "schema://providers/example/operations/build/input/v1", + InputSchemaDigest: protocol.CanonicalHash("input"), + OutputSchemaRef: "schema://providers/example/operations/build/output/v1", + OutputSchemaDigest: protocol.CanonicalHash("output"), + ArtifactSpecs: []protocol.ProviderArtifactSpec{{ + Name: "provenance", + ProviderReturn: &protocol.ProviderArtifactReturnSpec{ + StepType: "step.provider_artifact_return", + Contract: "workflow-plugin-ci:step.provider_artifact_return", + ContractVersion: "v1", + SubmitEndpoint: "/v1/provider-return/artifact-deliveries", + }, + }}, + } + cases := map[string]func(*protocol.ProviderArtifactReturnSpec){ + "missing step type": func(spec *protocol.ProviderArtifactReturnSpec) { + spec.StepType = "" + }, + "missing contract": func(spec *protocol.ProviderArtifactReturnSpec) { + spec.Contract = "" + }, + "absolute submit endpoint": func(spec *protocol.ProviderArtifactReturnSpec) { + spec.SubmitEndpoint = "https://provider.example/upload" + }, + "control whitespace": func(spec *protocol.ProviderArtifactReturnSpec) { + spec.Contract = "workflow-plugin-ci:step.provider_artifact_return\n" + }, + } + for name, mutate := range cases { + t.Run(name, func(t *testing.T) { + operation := valid + returnSpec := *valid.ArtifactSpecs[0].ProviderReturn + mutate(&returnSpec) + operation.ArtifactSpecs = []protocol.ProviderArtifactSpec{{Name: "provenance", ProviderReturn: &returnSpec}} + if err := operation.Validate(); err == nil { + t.Fatalf("expected invalid provider return intent") + } + }) + } +} + +func TestProviderArtifactDeliveryValidatesStatusAndArtifactRef(t *testing.T) { + now := time.Date(2026, 5, 30, 12, 0, 0, 0, time.UTC) + delivery := protocol.ProviderArtifactDelivery{ + ProtocolVersion: protocol.Version, + ID: "provider-artifact-delivery-1", + OrgID: "gocodealone", + PoolID: "ci", + TaskID: "task-1", + ProofID: "proof-1", + WorkerID: "worker-1", + ProviderConfig: protocol.ProviderConfig{ + PluginID: "workflow-plugin-ci", + ProviderID: "ci", + ContractID: "ci.v1", + Version: "v1", + ConfigRef: "config://providers/ci", + }, + Operation: "build", + ReturnSpec: protocol.ProviderArtifactReturnSpec{ + StepType: "step.provider_artifact_return", + Contract: "workflow-plugin-ci:step.provider_artifact_return", + ContractVersion: "v1", + SubmitEndpoint: "/v1/provider-return/artifact-deliveries", + }, + Artifact: protocol.ProviderArtifactDeliveryArtifact{ + Name: "provenance", + Ref: "artifact://ci/tasks/task-1/proofs/proof-1/provenance", + ContentType: "application/json", + SHA256: protocol.CanonicalHash("artifact"), + SizeBytes: 42, + }, + Status: protocol.ProviderArtifactDeliveryPending, + CreatedAt: now, + UpdatedAt: now, + } + if err := delivery.Validate(); err != nil { + t.Fatalf("delivery invalid: %v", err) + } + delivery.Status = protocol.ProviderArtifactDeliveryStatus("unknown") + if err := delivery.Validate(); err == nil || !strings.Contains(err.Error(), "status") { + t.Fatalf("expected status validation error, got %v", err) + } +} + func TestProviderContractRejectsPoolWithoutOrg(t *testing.T) { contract := validBatchProviderContract() contract.PoolID = "ci-runners"