|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
15 | 16 | import abc |
16 | 17 | import six |
17 | 18 | import sys |
@@ -109,6 +110,22 @@ def generate_gradients(self, targets, inputs): |
109 | 110 | gradients = paddle.static.gradients(targets, inputs) |
110 | 111 | return gradients |
111 | 112 |
|
| 113 | + def compile(self, program): |
| 114 | + use_cinn = os.environ.get("FLAGS_use_cinn", False) |
| 115 | + if use_cinn: |
| 116 | + # To enable CINN, we need to use CompiledProgram to compile the program. |
| 117 | + # Only forward ops are enabled because loss_name should not be none when |
| 118 | + # backward ops are contained in the origin program. |
| 119 | + build_strategy = paddle.static.BuildStrategy() |
| 120 | + exec_strategy = paddle.static.ExecutionStrategy() |
| 121 | + exec_strategy.num_threads = 1 |
| 122 | + compiled_program = paddle.static.CompiledProgram( |
| 123 | + program).with_data_parallel( |
| 124 | + build_strategy=build_strategy, exec_strategy=exec_strategy) |
| 125 | + return compiled_program |
| 126 | + else: |
| 127 | + return program |
| 128 | + |
112 | 129 | def init_feed_tensor(self, use_gpu, feed_vars, feed_dict, scope): |
113 | 130 | place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() |
114 | 131 | for var in feed_vars: |
@@ -426,10 +443,12 @@ def _run_static_impl(self, |
426 | 443 | executor = paddle.static.Executor(place) |
427 | 444 | executor.run(self.startup_program) |
428 | 445 |
|
| 446 | + main_program = self._helper.compile(self.main_program) |
| 447 | + |
429 | 448 | def _run_main_iter(): |
430 | 449 | feed_dict = feed if self._need_feed else None |
431 | 450 | fetch_vars = self.fetch_list if self._need_fetch else None |
432 | | - outputs = executor.run(program=self.main_program, |
| 451 | + outputs = executor.run(program=main_program, |
433 | 452 | feed=feed_dict, |
434 | 453 | fetch_list=fetch_vars, |
435 | 454 | use_program_cache=True, |
|
0 commit comments