diff --git a/sdks/go/container/boot.go b/sdks/go/container/boot.go index b75201520f39..4e93f01b9d15 100644 --- a/sdks/go/container/boot.go +++ b/sdks/go/container/boot.go @@ -158,6 +158,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject pipeline options into context + ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions()) + // (2) Retrieve the staged files. // // The Go SDK harness downloads the worker binary and invokes diff --git a/sdks/go/pkg/beam/artifact/materialize.go b/sdks/go/pkg/beam/artifact/materialize.go index 624e30efcd2b..edd59a26da17 100644 --- a/sdks/go/pkg/beam/artifact/materialize.go +++ b/sdks/go/pkg/beam/artifact/materialize.go @@ -39,6 +39,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/util/errorx" "github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" ) // TODO(lostluck): 2018/05/28 Extract these from their enum descriptors in the pipeline_v1 proto @@ -131,6 +132,7 @@ func newMaterializeWithClient(ctx context.Context, client jobpb.ArtifactRetrieva RoleUrn: URNStagingTo, RolePayload: rolePayload, }, + expectedSha256: filePayload.Sha256, }) } @@ -183,8 +185,9 @@ func MustExtractFilePayload(artifact *pipepb.ArtifactInformation) (string, strin } type artifact struct { - client jobpb.ArtifactRetrievalServiceClient - dep *pipepb.ArtifactInformation + client jobpb.ArtifactRetrievalServiceClient + dep *pipepb.ArtifactInformation + expectedSha256 string } func (a artifact) retrieve(ctx context.Context, dest string) error { @@ -231,7 +234,15 @@ func (a artifact) retrieve(ctx context.Context, dest string) error { stat, _ := fd.Stat() log.Printf("Downloaded: %v (sha256: %v, size: %v)", filename, sha256Hash, stat.Size()) - return fd.Close() + if err := fd.Close(); err != nil { + return err + } + + if isArtifactValidationEnabled(ctx) && a.expectedSha256 != "" && sha256Hash != a.expectedSha256 { + return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.expectedSha256) + } + + return nil } func writeChunks(stream jobpb.ArtifactRetrievalService_GetArtifactClient, w io.Writer) (string, error) { @@ -442,7 +453,7 @@ func retrieve(ctx context.Context, client jobpb.LegacyArtifactRetrievalServiceCl } // Artifact Sha256 hash is an optional field in metadata so we should only validate when its present. - if a.Sha256 != "" && sha256Hash != a.Sha256 { + if isArtifactValidationEnabled(ctx) && a.Sha256 != "" && sha256Hash != a.Sha256 { return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256) } return nil @@ -511,3 +522,25 @@ func queue2slice(q chan *jobpb.ArtifactMetadata) []*jobpb.ArtifactMetadata { } return ret } + +type contextKey string + +const pipelineOptionsKey contextKey = "pipeline_options" + +// WithPipelineOptions returns a new context carrying the full pipeline options struct. +func WithPipelineOptions(ctx context.Context, options *structpb.Struct) context.Context { + return context.WithValue(ctx, pipelineOptionsKey, options) +} + +// isArtifactValidationEnabled parses pipeline options to check if "disable_integrity_checks" is enabled. +func isArtifactValidationEnabled(ctx context.Context) bool { + options, _ := ctx.Value(pipelineOptionsKey).(*structpb.Struct) + if options != nil { + for _, v := range options.GetFields()["options"].GetStructValue().GetFields()["experiments"].GetListValue().GetValues() { + if v.GetStringValue() == "disable_integrity_checks" { + return false + } + } + } + return true +} diff --git a/sdks/go/pkg/beam/artifact/materialize_test.go b/sdks/go/pkg/beam/artifact/materialize_test.go index 31890ed045cc..9d527569d1bc 100644 --- a/sdks/go/pkg/beam/artifact/materialize_test.go +++ b/sdks/go/pkg/beam/artifact/materialize_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" ) // TestRetrieve tests that we can successfully retrieve fresh files. @@ -82,6 +83,57 @@ func TestMultiRetrieve(t *testing.T) { } } +func TestRetrieveWithBadShaFails(t *testing.T) { + cc := startServer(t) + defer cc.Close() + + ctx := grpcx.WriteWorkerID(context.Background(), "idA") + keys := []string{"foo"} + st := "whatever" + rt, artifacts := populate(ctx, cc, t, keys, 300, st) + + dst := makeTempDir(t) + defer os.RemoveAll(dst) + + client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc) + for _, a := range artifacts { + a.Sha256 = "badhash" // mutate hash + if err := Retrieve(ctx, client, a, rt, dst); err == nil { + t.Errorf("expected materialization to fail due to bad sha256 mismatch") + } + } +} + +func TestRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) { + cc := startServer(t) + defer cc.Close() + + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"disable_integrity_checks"}, + }, + }) + ctx := WithPipelineOptions(grpcx.WriteWorkerID(context.Background(), "idA"), options) + keys := []string{"foo"} + st := "whatever" + rt, artifacts := populate(ctx, cc, t, keys, 300, st) + + dst := makeTempDir(t) + defer os.RemoveAll(dst) + + client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc) + for _, a := range artifacts { + originalHash := a.Sha256 + a.Sha256 = "badhash" // mutate hash + filename := makeFilename(dst, a.Name) + if err := Retrieve(ctx, client, a, rt, dst); err != nil { + t.Errorf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err) + continue + } + verifySHA256(t, filename, originalHash) + } +} + // populate stages a set of artifacts with the given keys, each with // slightly different sizes and chucksizes. func populate(ctx context.Context, cc *grpc.ClientConn, t *testing.T, keys []string, size int, st string) (string, []*jobpb.ArtifactMetadata) { @@ -266,6 +318,65 @@ func TestNewRetrieveWithResolution(t *testing.T) { checkStagedFiles(mds, dest, expected, t) } +func TestIsArtifactValidationEnabled(t *testing.T) { + ctx := context.Background() + if !isArtifactValidationEnabled(ctx) { + t.Errorf("empty context should have validation enabled") + } + + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"disable_integrity_checks"}, + }, + }) + ctx2 := WithPipelineOptions(ctx, options) + if isArtifactValidationEnabled(ctx2) { + t.Errorf("populated context should have validation disabled") + } +} + +func TestNewRetrieveWithBadShaFails(t *testing.T) { + expected := map[string]string{"a.txt": "a"} + client := &fakeRetrievalService{artifacts: expected} + dest := makeTempDir(t) + defer os.RemoveAll(dest) + ctx := grpcx.WriteWorkerID(context.Background(), "worker") + + _, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest) + if err == nil { + t.Fatalf("expected materialization to fail due to bad sha256 mismatch") + } +} + +func TestNewRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) { + expected := map[string]string{"a.txt": "a"} + client := &fakeRetrievalService{artifacts: expected} + dest := makeTempDir(t) + defer os.RemoveAll(dest) + + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"disable_integrity_checks"}, + }, + }) + ctx := WithPipelineOptions(grpcx.WriteWorkerID(context.Background(), "worker"), options) + + mds, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest) + if err != nil { + t.Fatalf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err) + } + + generated := make(map[string]string) + for _, md := range mds { + name, _ := MustExtractFilePayload(md) + payload, _ := proto.Marshal(&pipepb.ArtifactStagingToRolePayload{ + StagedName: name}) + generated[name] = string(payload) + } + + checkStagedFiles(mds, dest, generated, t) +} + func checkStagedFiles(mds []*pipepb.ArtifactInformation, dest string, expected map[string]string, t *testing.T) { if len(mds) != len(expected) { t.Errorf("wrong number of artifacts staged %v vs %v", len(mds), len(expected)) @@ -323,6 +434,21 @@ func (fake *fakeRetrievalService) fileArtifactsWithoutStagingTo() []*pipepb.Arti return artifacts } +func (fake *fakeRetrievalService) fileArtifactsWithBadSha() []*pipepb.ArtifactInformation { + var artifacts []*pipepb.ArtifactInformation + for name := range fake.artifacts { + payload, _ := proto.Marshal(&pipepb.ArtifactFilePayload{ + Path: filepath.Join("/tmp", name), + Sha256: "badhash", + }) + artifacts = append(artifacts, &pipepb.ArtifactInformation{ + TypeUrn: URNFileArtifact, + TypePayload: payload, + }) + } + return artifacts +} + func (fake *fakeRetrievalService) urlArtifactsWithoutStagingTo() []*pipepb.ArtifactInformation { var artifacts []*pipepb.ArtifactInformation for name := range fake.artifacts { diff --git a/sdks/java/container/boot.go b/sdks/java/container/boot.go index f6c33b635d3c..a5005e4dc320 100644 --- a/sdks/java/container/boot.go +++ b/sdks/java/container/boot.go @@ -105,6 +105,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject pipeline options into context + ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions()) + // (2) Retrieve the staged user jars. We ignore any disk limit, // because the staged jars are mandatory. diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 164ace532b23..fe36f56df427 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -595,8 +595,9 @@ def _stage_resources(self, pipeline, options): else: remote_name = os.path.basename(type_payload.path) is_staged_role = False - - if self._enable_caching and not type_payload.sha256: + # compute sha256 even if caching is disabled. + # This is used to check the payload integrity along with caching. + if not type_payload.sha256: type_payload.sha256 = self._compute_sha256(type_payload.path) if type_payload.sha256 and type_payload.sha256 in staged_hashes: diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index b767cef86b2e..51f4264d3e45 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1340,13 +1340,19 @@ def test_stage_resources(self): ]) })) client = apiclient.DataflowApplicationClient(pipeline_options) - with mock.patch.object(apiclient._LegacyDataflowStager, - 'stage_job_resources') as mock_stager: - client._stage_resources(pipeline, pipeline_options) + with mock.patch.object(apiclient.DataflowApplicationClient, + '_compute_sha256', + side_effect=lambda path: 'hash' + path): + with mock.patch.object(apiclient._LegacyDataflowStager, + 'stage_job_resources') as mock_stager: + client._stage_resources(pipeline, pipeline_options) mock_stager.assert_called_once_with( - [('/tmp/foo1', 'foo1', ''), ('/tmp/bar1', 'bar1', ''), - ('/tmp/baz', 'baz1', ''), ('/tmp/renamed1', 'renamed1', 'abcdefg'), - ('/tmp/foo2', 'foo2', ''), ('/tmp/bar2', 'bar2', '')], + [('/tmp/foo1', 'foo1', 'hash/tmp/foo1'), + ('/tmp/bar1', 'bar1', 'hash/tmp/bar1'), + ('/tmp/baz', 'baz1', 'hash/tmp/baz'), + ('/tmp/renamed1', 'renamed1', 'abcdefg'), + ('/tmp/foo2', 'foo2', 'hash/tmp/foo2'), + ('/tmp/bar2', 'bar2', 'hash/tmp/bar2')], staging_location='gs://test-location/staging') pipeline_expected = beam_runner_api_pb2.Pipeline( @@ -1357,8 +1363,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/foo1' - ).SerializeToString(), + url='gs://test-location/staging/foo1', + sha256='hash/tmp/foo1').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1366,8 +1372,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/bar1'). - SerializeToString(), + url='gs://test-location/staging/bar1', + sha256='hash/tmp/bar1').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1375,8 +1381,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/baz1'). - SerializeToString(), + url='gs://test-location/staging/baz1', + sha256='hash/tmp/baz').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1396,8 +1402,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/foo2'). - SerializeToString(), + url='gs://test-location/staging/foo2', + sha256='hash/tmp/foo2').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1405,8 +1411,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/bar2'). - SerializeToString(), + url='gs://test-location/staging/bar2', + sha256='hash/tmp/bar2').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1414,8 +1420,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/baz1'). - SerializeToString(), + url='gs://test-location/staging/baz1', + sha256='hash/tmp/baz').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index 7c0f22675daf..e0d2fd3d0937 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -184,6 +184,9 @@ func launchSDKProcess() error { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject pipeline options into context + ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions()) + experiments := getExperiments(options) pipNoBuildIsolation = false if slices.Contains(experiments, "pip_no_build_isolation") { diff --git a/sdks/typescript/container/boot.go b/sdks/typescript/container/boot.go index 44f94f804330..70c512d62b04 100644 --- a/sdks/typescript/container/boot.go +++ b/sdks/typescript/container/boot.go @@ -91,6 +91,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject pipeline options into context + ctx = artifact.WithPipelineOptions(ctx, info.GetPipelineOptions()) + // (2) Retrieve and install the staged packages. dir := filepath.Join(*semiPersistDir, *id, "staged")