Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 214 additions & 6 deletions protocol/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,11 @@ func (o ProviderOperation) Validate() error {
if spec.RetentionSeconds < 0 {
errs = append(errs, fmt.Errorf("artifact_specs[%d].retention_seconds must not be negative", i))
}
if spec.ProviderReturn != nil {
if err := spec.ProviderReturn.Validate(); err != nil {
errs = append(errs, fmt.Errorf("artifact_specs[%d].provider_return: %w", i, err))
}
}
}
return errors.Join(errs...)
}
Expand All @@ -2478,12 +2483,207 @@ func (o ProviderOperation) NormalizedArtifactSpecs() []ProviderArtifactSpec {
}

type ProviderArtifactSpec struct {
Name string `json:"name"`
Required bool `json:"required,omitempty"`
ContentType string `json:"content_type,omitempty"`
MaxBytes int64 `json:"max_bytes,omitempty"`
RetentionSeconds int `json:"retention_seconds,omitempty"`
Forwardable bool `json:"forwardable,omitempty"`
Name string `json:"name"`
Required bool `json:"required,omitempty"`
ContentType string `json:"content_type,omitempty"`
MaxBytes int64 `json:"max_bytes,omitempty"`
RetentionSeconds int `json:"retention_seconds,omitempty"`
Forwardable bool `json:"forwardable,omitempty"`
ProviderReturn *ProviderArtifactReturnSpec `json:"provider_return,omitempty"`
}

type ProviderArtifactReturnSpec struct {
StepType string `json:"step_type"`
Contract string `json:"contract"`
ContractVersion string `json:"contract_version,omitempty"`
SubmitEndpoint string `json:"submit_endpoint,omitempty"`
RequiredConfig []string `json:"required_config,omitempty"`
OutputHandling []string `json:"output_handling,omitempty"`
}

func (s ProviderArtifactReturnSpec) Enabled() bool {
return s.StepType != "" || s.Contract != "" || s.ContractVersion != "" || s.SubmitEndpoint != "" || len(s.RequiredConfig) > 0 || len(s.OutputHandling) > 0
}

func (s ProviderArtifactReturnSpec) Validate() error {
var errs []error
if err := validateIdentifier("step_type", s.StepType); err != nil {
errs = append(errs, err)
}
if strings.TrimSpace(s.Contract) == "" {
errs = append(errs, errors.New("contract is required"))
} else if strings.TrimSpace(s.Contract) != s.Contract || strings.ContainsAny(s.Contract, " \t\r\n\x00") || strings.Contains(s.Contract, "://") {
errs = append(errs, errors.New("contract must be a typed plugin contract without whitespace or URL scheme"))
} else if !strings.Contains(s.Contract, ":") {
errs = append(errs, errors.New("contract must include plugin and step type"))
}
if s.ContractVersion != "" && strings.ContainsAny(s.ContractVersion, " \t\r\n\x00") {
errs = append(errs, errors.New("contract_version must not contain whitespace"))
}
if s.SubmitEndpoint != "" && !validProviderReturnSubmitEndpoint(s.SubmitEndpoint) {
errs = append(errs, errors.New("submit_endpoint must be a server-relative path"))
}
for i, value := range s.RequiredConfig {
if err := validateIdentifier(fmt.Sprintf("required_config[%d]", i), value); err != nil {
errs = append(errs, err)
}
}
for i, value := range s.OutputHandling {
if err := validateIdentifier(fmt.Sprintf("output_handling[%d]", i), value); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

type ProviderArtifactDeliveryStatus string

const (
ProviderArtifactDeliveryPending ProviderArtifactDeliveryStatus = "pending"
ProviderArtifactDeliveryDelivered ProviderArtifactDeliveryStatus = "delivered"
ProviderArtifactDeliveryFailed ProviderArtifactDeliveryStatus = "failed"
ProviderArtifactDeliveryDeadLetter ProviderArtifactDeliveryStatus = "dead_letter"
)

type ProviderArtifactDeliveryArtifact struct {
Name string `json:"name"`
Ref string `json:"ref"`
ContentType string `json:"content_type,omitempty"`
SHA256 string `json:"sha256"`
SizeBytes int64 `json:"size_bytes"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}

func (a ProviderArtifactDeliveryArtifact) Validate() error {
var errs []error
if !validProviderArtifactName(a.Name) {
errs = append(errs, errors.New("artifact.name is invalid"))
}
if strings.TrimSpace(a.Ref) == "" {
errs = append(errs, errors.New("artifact.ref is required"))
} else if err := validateScopedRef("artifact.ref", a.Ref, "artifact://"); err != nil {
errs = append(errs, err)
}
if a.ContentType != "" {
if strings.TrimSpace(a.ContentType) != a.ContentType || strings.ContainsAny(a.ContentType, "\x00\r\n\t") {
errs = append(errs, errors.New("artifact.content_type is invalid"))
} else if _, _, err := mime.ParseMediaType(a.ContentType); err != nil {
errs = append(errs, errors.New("artifact.content_type is invalid"))
}
}
if !validSHA256Ref(a.SHA256) {
errs = append(errs, errors.New("artifact.sha256 must be sha256:<64 hex chars>"))
}
if a.SizeBytes < 0 {
errs = append(errs, errors.New("artifact.size_bytes must not be negative"))
}
return errors.Join(errs...)
}

type ProviderArtifactDelivery struct {
ProtocolVersion string `json:"protocol_version,omitempty"`
ID string `json:"id"`
OrgID string `json:"org_id"`
PoolID string `json:"pool_id"`
ProductID string `json:"product_id,omitempty"`
TaskID string `json:"task_id"`
ProofID string `json:"proof_id"`
WorkerID string `json:"worker_id"`
ProviderConfig ProviderConfig `json:"provider_config"`
Operation string `json:"operation"`
ReturnSpec ProviderArtifactReturnSpec `json:"provider_return"`
Artifact ProviderArtifactDeliveryArtifact `json:"artifact"`
Status ProviderArtifactDeliveryStatus `json:"status"`
Attempts int `json:"attempts,omitempty"`
LastErrorHash string `json:"last_error_hash,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

func (d ProviderArtifactDelivery) Validate() error {
var errs []error
if d.ProtocolVersion != "" && d.ProtocolVersion != Version {
errs = append(errs, fmt.Errorf("protocol_version must be %q", Version))
}
for _, field := range []struct {
name string
value string
}{
{name: "id", value: d.ID},
{name: "org_id", value: d.OrgID},
{name: "pool_id", value: d.PoolID},
{name: "task_id", value: d.TaskID},
{name: "proof_id", value: d.ProofID},
{name: "worker_id", value: d.WorkerID},
{name: "operation", value: d.Operation},
} {
if err := validateIdentifier(field.name, field.value); err != nil {
errs = append(errs, err)
}
}
if d.ProductID != "" {
if err := validateIdentifier("product_id", d.ProductID); err != nil {
errs = append(errs, err)
}
}
if err := d.ProviderConfig.Validate(); err != nil {
errs = append(errs, fmt.Errorf("provider_config: %w", err))
}
if err := d.ReturnSpec.Validate(); err != nil {
errs = append(errs, fmt.Errorf("provider_return: %w", err))
}
if err := d.Artifact.Validate(); err != nil {
errs = append(errs, err)
}
if !validProviderArtifactDeliveryStatus(d.Status) {
errs = append(errs, fmt.Errorf("status %q is unsupported", d.Status))
}
if d.Attempts < 0 {
errs = append(errs, errors.New("attempts must not be negative"))
}
if d.LastErrorHash != "" && !validSHA256Ref(d.LastErrorHash) {
errs = append(errs, errors.New("last_error_hash must be sha256:<64 hex chars>"))
}
if d.CreatedAt.IsZero() {
errs = append(errs, errors.New("created_at is required"))
}
if d.UpdatedAt.IsZero() {
errs = append(errs, errors.New("updated_at is required"))
}
return errors.Join(errs...)
}

type ProviderArtifactDeliveryStatusUpdate struct {
ProtocolVersion string `json:"protocol_version,omitempty"`
DeliveryID string `json:"delivery_id"`
Status ProviderArtifactDeliveryStatus `json:"status"`
ProviderRef string `json:"provider_ref,omitempty"`
ErrorHash string `json:"error_hash,omitempty"`
ObservedAt time.Time `json:"observed_at,omitempty"`
}

type ProviderArtifactDeliveryAction struct {
ProtocolVersion string `json:"protocol_version,omitempty"`
DeliveryID string `json:"delivery_id"`
PluginID string `json:"plugin_id"`
ProviderID string `json:"provider_id"`
ContractID string `json:"contract_id"`
ContractVersion string `json:"contract_version,omitempty"`
StepType string `json:"step_type"`
Operation string `json:"operation"`
SubmitEndpoint string `json:"submit_endpoint,omitempty"`
Artifact ProviderArtifactDeliveryArtifact `json:"artifact"`
TaskID string `json:"task_id"`
ProofID string `json:"proof_id"`
}

func validProviderArtifactDeliveryStatus(status ProviderArtifactDeliveryStatus) bool {
switch status {
case ProviderArtifactDeliveryPending, ProviderArtifactDeliveryDelivered, ProviderArtifactDeliveryFailed, ProviderArtifactDeliveryDeadLetter:
return true
default:
return false
}
}

type ProviderRuntimeContract struct {
Expand Down Expand Up @@ -3928,6 +4128,14 @@ func validProviderArtifactName(name string) bool {
name != ".."
}

func validProviderReturnSubmitEndpoint(value string) bool {
return strings.HasPrefix(value, "/") &&
!strings.Contains(value, "://") &&
!strings.ContainsAny(value, " \t\r\n\x00?#") &&
path.Clean(value) == value &&
!strings.Contains(value, "..")
}

func ProviderPluginRequiresUpstreamClientConformance(pluginID string) bool {
switch pluginID {
case "workflow-plugin-volunteer-science", "workflow-plugin-crypto":
Expand Down
117 changes: 117 additions & 0 deletions protocol/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,123 @@ func TestProviderContractAcceptsAccessScopedProviderOperations(t *testing.T) {
}
}

func TestProviderOperationAcceptsProviderReturnArtifactIntent(t *testing.T) {
operation := protocol.ProviderOperation{
ID: "build",
InputSchemaRef: "schema://providers/example/operations/build/input/v1",
InputSchemaDigest: protocol.CanonicalHash("input"),
OutputSchemaRef: "schema://providers/example/operations/build/output/v1",
OutputSchemaDigest: protocol.CanonicalHash("output"),
ArtifactSpecs: []protocol.ProviderArtifactSpec{{
Name: "provenance",
Required: true,
ContentType: "application/json",
ProviderReturn: &protocol.ProviderArtifactReturnSpec{
StepType: "step.provider_artifact_return",
Contract: "workflow-plugin-ci:step.provider_artifact_return",
ContractVersion: "v1",
SubmitEndpoint: "/v1/provider-return/artifact-deliveries",
},
}},
}

if err := operation.Validate(); err != nil {
t.Fatalf("operation invalid: %v", err)
}
specs := operation.NormalizedArtifactSpecs()
if len(specs) != 1 || specs[0].ProviderReturn == nil || !specs[0].ProviderReturn.Enabled() {
t.Fatalf("provider return intent not preserved: %#v", specs)
}
}

func TestProviderOperationRejectsMalformedProviderReturnArtifactIntent(t *testing.T) {
valid := protocol.ProviderOperation{
ID: "build",
InputSchemaRef: "schema://providers/example/operations/build/input/v1",
InputSchemaDigest: protocol.CanonicalHash("input"),
OutputSchemaRef: "schema://providers/example/operations/build/output/v1",
OutputSchemaDigest: protocol.CanonicalHash("output"),
ArtifactSpecs: []protocol.ProviderArtifactSpec{{
Name: "provenance",
ProviderReturn: &protocol.ProviderArtifactReturnSpec{
StepType: "step.provider_artifact_return",
Contract: "workflow-plugin-ci:step.provider_artifact_return",
ContractVersion: "v1",
SubmitEndpoint: "/v1/provider-return/artifact-deliveries",
},
}},
}
cases := map[string]func(*protocol.ProviderArtifactReturnSpec){
"missing step type": func(spec *protocol.ProviderArtifactReturnSpec) {
spec.StepType = ""
},
"missing contract": func(spec *protocol.ProviderArtifactReturnSpec) {
spec.Contract = ""
},
"absolute submit endpoint": func(spec *protocol.ProviderArtifactReturnSpec) {
spec.SubmitEndpoint = "https://provider.example/upload"
},
"control whitespace": func(spec *protocol.ProviderArtifactReturnSpec) {
spec.Contract = "workflow-plugin-ci:step.provider_artifact_return\n"
},
}
for name, mutate := range cases {
t.Run(name, func(t *testing.T) {
operation := valid
returnSpec := *valid.ArtifactSpecs[0].ProviderReturn
mutate(&returnSpec)
operation.ArtifactSpecs = []protocol.ProviderArtifactSpec{{Name: "provenance", ProviderReturn: &returnSpec}}
if err := operation.Validate(); err == nil {
t.Fatalf("expected invalid provider return intent")
}
})
}
}

func TestProviderArtifactDeliveryValidatesStatusAndArtifactRef(t *testing.T) {
now := time.Date(2026, 5, 30, 12, 0, 0, 0, time.UTC)
delivery := protocol.ProviderArtifactDelivery{
ProtocolVersion: protocol.Version,
ID: "provider-artifact-delivery-1",
OrgID: "gocodealone",
PoolID: "ci",
TaskID: "task-1",
ProofID: "proof-1",
WorkerID: "worker-1",
ProviderConfig: protocol.ProviderConfig{
PluginID: "workflow-plugin-ci",
ProviderID: "ci",
ContractID: "ci.v1",
Version: "v1",
ConfigRef: "config://providers/ci",
},
Operation: "build",
ReturnSpec: protocol.ProviderArtifactReturnSpec{
StepType: "step.provider_artifact_return",
Contract: "workflow-plugin-ci:step.provider_artifact_return",
ContractVersion: "v1",
SubmitEndpoint: "/v1/provider-return/artifact-deliveries",
},
Artifact: protocol.ProviderArtifactDeliveryArtifact{
Name: "provenance",
Ref: "artifact://ci/tasks/task-1/proofs/proof-1/provenance",
ContentType: "application/json",
SHA256: protocol.CanonicalHash("artifact"),
SizeBytes: 42,
},
Status: protocol.ProviderArtifactDeliveryPending,
CreatedAt: now,
UpdatedAt: now,
}
if err := delivery.Validate(); err != nil {
t.Fatalf("delivery invalid: %v", err)
}
delivery.Status = protocol.ProviderArtifactDeliveryStatus("unknown")
if err := delivery.Validate(); err == nil || !strings.Contains(err.Error(), "status") {
t.Fatalf("expected status validation error, got %v", err)
}
}

func TestProviderContractRejectsPoolWithoutOrg(t *testing.T) {
contract := validBatchProviderContract()
contract.PoolID = "ci-runners"
Expand Down
Loading