From b679fb076fa1f29aedee33780f49d163002046a6 Mon Sep 17 00:00:00 2001 From: Ashvitha Sridharan Date: Tue, 22 Apr 2025 21:11:07 -0400 Subject: [PATCH] Unwrap scim errors when possible for patch validation --- handlers_test.go | 19 +++++++++++++++++++ internal/patch/remove.go | 6 +----- resource_type.go | 27 ++++++++++++++++----------- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/handlers_test.go b/handlers_test.go index 9d7e6c2..cb632ef 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -257,6 +257,25 @@ func TestServerResourcePatchHandlerInvalidRemoveOp(t *testing.T) { assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) } +func TestServerResourcePatchHandlerInvalidRemoveOpNoTarget(t *testing.T) { + req := httptest.NewRequest(http.MethodPatch, "/Groups/0001", strings.NewReader(`{ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], + "Operations":[ + { + "op":"remove", + "path":"" + } + ] + }`)) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) + var scimErr *errors.ScimError + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &scimErr)) + assertEqualSCIMErrors(t, &errors.ScimErrorNoTarget, scimErr) +} + func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { recorder := httptest.NewRecorder() newTestServer(t).ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ diff --git a/internal/patch/remove.go b/internal/patch/remove.go index ebd88fe..e36870b 100644 --- a/internal/patch/remove.go +++ b/internal/patch/remove.go @@ -2,7 +2,6 @@ package patch import ( "github.com/elimity-com/scim/filter" - "net/http" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/schema" @@ -13,10 +12,7 @@ import ( func (v OperationValidator) validateRemove() (interface{}, error) { // If "path" is unspecified, the operation fails with HTTP status code 400 and a "scimType" error code of "noTarget". if v.Path == nil { - return nil, &errors.ScimError{ - ScimType: errors.ScimTypeNoTarget, - Status: http.StatusBadRequest, - } + return nil, &errors.ScimErrorNoTarget } refAttr, err := v.getRefAttribute(v.Path.AttributePath) diff --git a/resource_type.go b/resource_type.go index 8761d7d..2689f1e 100644 --- a/resource_type.go +++ b/resource_type.go @@ -3,9 +3,10 @@ package scim import ( "bytes" "encoding/json" + "errors" "net/http" - "github.com/elimity-com/scim/errors" + scimErrors "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/internal/patch" "github.com/elimity-com/scim/optional" "github.com/elimity-com/scim/schema" @@ -86,10 +87,10 @@ func (t ResourceType) schemaWithCommon() schema.Schema { return s } -func (t ResourceType) validate(raw []byte) (ResourceAttributes, *errors.ScimError) { +func (t ResourceType) validate(raw []byte) (ResourceAttributes, *scimErrors.ScimError) { var m map[string]interface{} if err := unmarshal(raw, &m); err != nil { - return ResourceAttributes{}, &errors.ScimErrorInvalidSyntax + return ResourceAttributes{}, &scimErrors.ScimErrorInvalidSyntax } attributes, scimErr := t.schemaWithCommon().Validate(m) @@ -101,7 +102,7 @@ func (t ResourceType) validate(raw []byte) (ResourceAttributes, *errors.ScimErro extensionField := m[extension.Schema.ID] if extensionField == nil { if extension.Required { - return ResourceAttributes{}, &errors.ScimErrorInvalidValue + return ResourceAttributes{}, &scimErrors.ScimErrorInvalidValue } continue } @@ -118,10 +119,10 @@ func (t ResourceType) validate(raw []byte) (ResourceAttributes, *errors.ScimErro } // validatePatch parse and validate PATCH request. -func (t ResourceType) validatePatch(r *http.Request) ([]PatchOperation, *errors.ScimError) { +func (t ResourceType) validatePatch(r *http.Request) ([]PatchOperation, *scimErrors.ScimError) { data, err := readBody(r) if err != nil { - return nil, &errors.ScimErrorInvalidSyntax + return nil, &scimErrors.ScimErrorInvalidSyntax } var req struct { @@ -129,19 +130,19 @@ func (t ResourceType) validatePatch(r *http.Request) ([]PatchOperation, *errors. Operations []json.RawMessage } if err := unmarshal(data, &req); err != nil { - return nil, &errors.ScimErrorInvalidSyntax + return nil, &scimErrors.ScimErrorInvalidSyntax } // The body of each request MUST contain the "schemas" attribute with the URI value of // "urn:ietf:params:scim:api:messages:2.0:PatchOp". if len(req.Schemas) != 1 || req.Schemas[0] != "urn:ietf:params:scim:api:messages:2.0:PatchOp" { - return nil, &errors.ScimErrorInvalidValue + return nil, &scimErrors.ScimErrorInvalidValue } // The body of an HTTP PATCH request MUST contain the attribute "Operations", // whose value is an array of one or more PATCH operations. if len(req.Operations) < 1 { - return nil, &errors.ScimErrorInvalidValue + return nil, &scimErrors.ScimErrorInvalidValue } // Evaluation continues until all operations are successfully applied or until an error condition is encountered. @@ -153,11 +154,15 @@ func (t ResourceType) validatePatch(r *http.Request) ([]PatchOperation, *errors. t.getSchemaExtensions()..., ) if err != nil { - return nil, &errors.ScimErrorInvalidPath + return nil, &scimErrors.ScimErrorInvalidPath } value, err := validator.Validate() if err != nil { - return nil, &errors.ScimErrorInvalidValue + var scimErr *scimErrors.ScimError + if errors.As(err, &scimErr) { + return nil, scimErr + } + return nil, &scimErrors.ScimErrorInvalidValue } operations = append(operations, PatchOperation{ Op: string(validator.Op),