diff --git a/tensorflow_gnn/experimental/sampler/beam/sampler.py b/tensorflow_gnn/experimental/sampler/beam/sampler.py index 3715351a..a9e61840 100644 --- a/tensorflow_gnn/experimental/sampler/beam/sampler.py +++ b/tensorflow_gnn/experimental/sampler/beam/sampler.py @@ -28,6 +28,7 @@ import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.portability import fn_api_runner import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.data import unigraph @@ -191,7 +192,10 @@ def _create_beam_runner( ) -> beam.runners.PipelineRunner: """Creates appropriate runner.""" if runner_name == _DIRECT_RUNNER: - runner = beam.runners.DirectRunner() + # b/490166623: Changed to FnApiRunner due to the new prism implementation + # in DirectRunner since beam 2.68.0 which breaks the sampler. FnApiRunner + # enables the old implementation of DirectRunner. + runner = fn_api_runner.FnApiRunner() elif runner_name == _DATAFLOW_RUNNER: runner = beam.runners.DataflowRunner() # Placeholder for Google-internal runner option creation