Skip to content

Commit f49d375

Browse files
committed
Fix the argument list error and enable the compile of program when cinn is enable through environ variables.
1 parent a292eb7 commit f49d375

2 files changed

Lines changed: 22 additions & 2 deletions

File tree

api/common/paddle_op_benchmark.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import abc
1617
import six
1718
import sys
@@ -109,6 +110,22 @@ def generate_gradients(self, targets, inputs):
109110
gradients = paddle.static.gradients(targets, inputs)
110111
return gradients
111112

113+
def compile(self, program):
114+
use_cinn = os.environ.get("FLAGS_use_cinn", False)
115+
if use_cinn:
116+
# To enable CINN, we need to use CompiledProgram to compile the program.
117+
# Only forward ops are enabled because loss_name should not be none when
118+
# backward ops are contained in the origin program.
119+
build_strategy = paddle.static.BuildStrategy()
120+
exec_strategy = paddle.static.ExecutionStrategy()
121+
exec_strategy.num_threads = 1
122+
compiled_program = paddle.static.CompiledProgram(
123+
program).with_data_parallel(
124+
build_strategy=build_strategy, exec_strategy=exec_strategy)
125+
return compiled_program
126+
else:
127+
return program
128+
112129
def init_feed_tensor(self, use_gpu, feed_vars, feed_dict, scope):
113130
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
114131
for var in feed_vars:
@@ -426,10 +443,12 @@ def _run_static_impl(self,
426443
executor = paddle.static.Executor(place)
427444
executor.run(self.startup_program)
428445

446+
main_program = self._helper.compile(self.main_program)
447+
429448
def _run_main_iter():
430449
feed_dict = feed if self._need_feed else None
431450
fetch_vars = self.fetch_list if self._need_fetch else None
432-
outputs = executor.run(program=self.main_program,
451+
outputs = executor.run(program=main_program,
433452
feed=feed_dict,
434453
fetch_list=fetch_vars,
435454
use_program_cache=True,

api/common/tensorflow_op_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _run_null_graph(self, use_gpu, repeat):
198198
sess.close()
199199
return walltimes
200200

201-
def run_impl(self, use_gpu, feed, repeat=1, profiler="none"):
201+
def run_impl(self, use_gpu, config, feed, repeat=1, profiler="none"):
202202
sess = self._init_session(use_gpu)
203203

204204
#tf.debugging.set_log_device_placement(True)
@@ -301,6 +301,7 @@ def run(self, config, args, use_feed_fetch=True, feeder_adapter=None):
301301
self.allow_growth = False if args.task == "speed" else True
302302
outputs, stats = self.run_impl(
303303
use_gpu=args.use_gpu,
304+
config=config,
304305
feed=feed,
305306
repeat=args.repeat,
306307
profiler=args.profiler)

0 commit comments

Comments
 (0)