@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454
5555def create_joint_forward_backward (fn ):
5656 def joint_forward_backward (
57- primals : List [Any ], tangents : List [Any ]
57+ primals : List [Any ], cotangents : List [Any ]
5858 ) -> Tuple [List [Any ], List [Any ]]:
5959 # Call the forward pass
6060 outs = fn (* primals )
@@ -68,21 +68,21 @@ def joint_forward_backward(
6868 grad_primals .append (p )
6969
7070 # Get the outputs that need gradients
71- assert len (tangents ) == len (outs )
71+ assert len (cotangents ) == len (outs )
7272 needed_outs = []
73- needed_tangents = []
74- for out , tangent in zip (outs , tangents ):
73+ needed_cotangents = []
74+ for out , cotangent in zip (outs , cotangents ):
7575 if isinstance (out , Tensor ) and out .requires_grad :
7676 needed_outs .append (out )
77- needed_tangents .append (tangent )
77+ needed_cotangents .append (cotangent )
7878 backward_out = []
7979 # Call the backwards pass
8080 if grad_primals :
8181 backward_out = torch .autograd .grad (
8282 needed_outs ,
8383 grad_primals ,
84- grad_outputs = needed_tangents ,
85- allow_unused = True ,
84+ grad_outputs = needed_cotangents ,
85+ allow_unused = True
8686 )
8787 backward_out_iter = iter (backward_out )
8888 return outs , [
@@ -140,12 +140,13 @@ def create_aot_autograd_function(
140140 compiled_fw = None
141141 compiled_bw = None
142142 num_outs = None
143+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
143144
144145 class CompiledFunction (torch .autograd .Function ):
145146 @staticmethod
146147 @disable_torchdynamo
147148 def forward (ctx , * flat_tensor_args ):
148- nonlocal compiled_fw , compiled_bw , num_outs
149+ nonlocal compiled_fw , num_outs
149150 if compiled_fw is None :
150151 with torch .set_grad_enabled (grad_state ):
151152 out = flat_fn (* flat_tensor_args )
@@ -159,31 +160,67 @@ def forward(ctx, *flat_tensor_args):
159160 num_outs = 1
160161
161162 joint_inputs = (flat_tensor_args , out )
162- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
163+ # Need it because autograd.Function disables grad in forward
163164 with torch .set_grad_enabled (grad_state ):
164165 fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165166 * joint_inputs
166167 )
167168 fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168- # print(fw_module.code, bw_module.code)
169169
170170 compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171171 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172-
173- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174- compiled_bw = bw_compiler (bw_module , bw_args )
172+ if partition_fn is default_partition :
173+ ctx .num_intermediate = len (fw_outs [num_outs :])
174+ ctx .num_inputs = len (flat_tensor_args )
175+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + out
176+ ctx .fx_g = fx_g
177+ ctx .save_for_backward (* to_be_saved )
178+ ctx .fwd_graph = fw_module .code
179+ ctx .bw_graph = bw_module .code
180+ else :
181+ nonlocal compiled_bw
182+ bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
183+ compiled_bw = bw_compiler (bw_module , bw_args )
184+ ctx .save_for_backward (* fw_outs [num_outs :])
175185 else :
176186 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177- ctx .save_for_backward (* fw_outs [num_outs :])
187+ if partition_fn is default_partition :
188+ with torch .set_grad_enabled (grad_state ):
189+ out = flat_fn (* flat_tensor_args )
190+ out = pytree .tree_map (
191+ lambda x : x .detach ().contiguous () if isinstance (x , Tensor ) else x , out
192+ )
193+ ctx .num_intermediate = len (fw_outs [num_outs :])
194+ ctx .num_inputs = len (flat_tensor_args )
195+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + out
196+ ctx .save_for_backward (* to_be_saved )
197+ else :
198+ ctx .save_for_backward (* fw_outs [num_outs :])
178199 return tuple (fw_outs [0 :num_outs ])
179200
180201 @staticmethod
181202 @disable_torchdynamo
182- def backward (ctx , * flat_args ):
183- contiguous_args = [t .contiguous () for t in flat_args ]
184- # contiguous_args = [t for t in flat_args]
185- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186- return tuple (out )
203+ def backward (ctx , * flat_grad_outs ):
204+ contiguous_args = [t .contiguous () for t in flat_grad_outs ]
205+ if compiled_bw is None :
206+ assert partition_fn is default_partition
207+ with torch .set_grad_enabled (grad_state ):
208+ inputs = ctx .saved_tensors [ctx .num_intermediate :ctx .num_intermediate + ctx .num_inputs ]
209+ fx_g = make_fx (joint_forward_backward , aot_decompositions )(inputs , contiguous_args )
210+ fw_module , bw_module = partition_fn (fx_g , ctx .saved_tensors [ctx .num_intermediate :])
211+ assert fx_g .code == ctx .fx_g .code
212+ f = aot_function (bw_module , bw_compiler , bw_compiler , partition_fn , aot_decompositions )
213+ print ("INPUTS----->" , * ctx .saved_tensors [:ctx .num_intermediate ], * contiguous_args )
214+ print (bw_module .code )
215+ out = f (* ctx .saved_tensors [:ctx .num_intermediate ], * contiguous_args )
216+ return out
217+ else :
218+ if partition_fn is default_partition :
219+ out = normalize_as_list (compiled_bw (* ctx .saved_tensors [:ctx .num_intermediate ], * contiguous_args ))
220+ else :
221+ assert not torch .is_grad_enabled ()
222+ out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
223+ return tuple (out )
187224
188225 return CompiledFunction
189226
0 commit comments