diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6c9e3e3f5ef5..b187af7cd665 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -918,6 +918,189 @@ def _gru(self, node: fx.Node) -> relax.Var: return output + def _rnn_tanh_cell_unroll( + self, + input_reshaped, + weight_ih, + weight_hh, + bias_ih, + bias_hh, + h_prev, + seq_len, + reverse=False, + ): + """Unroll vanilla tanh-RNN cells for a single direction.""" + # Transpose weights for matmul: (hidden_size, in) -> (in, hidden_size) + weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) + weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) + + outputs = [] + time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len) + + for t in time_steps: + # Input at time t: (batch_size, input_size) + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + + # h_t = tanh(W_ih @ x_t + b_ih + W_hh @ h_{t-1} + b_hh) + ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) + hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) + if bias_ih is not None and bias_hh is not None: + ih = self.block_builder.emit(relax.op.add(ih, bias_ih)) + hh = self.block_builder.emit(relax.op.add(hh, bias_hh)) + h_t = self.block_builder.emit(relax.op.tanh(relax.op.add(ih, hh))) + + outputs.append(h_t) + h_prev = h_t + + if reverse: + outputs = outputs[::-1] + + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + return output + + def _rnn_tanh(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + + if num_layers > 1: + raise NotImplementedError("Multi-layer RNN is not yet supported") + + input_shape = self.shape_of(input_tensor) + if batch_first: + batch_size, seq_len, input_size = input_shape + else: + seq_len, batch_size, input_size = input_shape + + seq_len = int(seq_len) if isinstance(seq_len, tvm.tirx.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tirx.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tirx.IntImm) else input_size + + # params per direction: weight_ih, weight_hh, [bias_ih, bias_hh] + params_per_direction = 4 if has_biases else 2 + + # A vanilla RNN has a single gate, so weight_ih has shape (hidden_size, input_size) + if params and len(params) >= 2: + hidden_size = self.shape_of(params[0])[0] + else: + hidden_size = 16 + hidden_size = int(hidden_size) if isinstance(hidden_size, tvm.tirx.IntImm) else hidden_size + + dtype = input_tensor.struct_info.dtype + + # Forward direction weights + if params and len(params) >= params_per_direction: + weight_ih_fwd = params[0] + weight_hh_fwd = params[1] + bias_ih_fwd = params[2] if has_biases else None + bias_hh_fwd = params[3] if has_biases else None + else: + weight_ih_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), dtype) + ) + weight_hh_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)), dtype) + ) + bias_ih_fwd = None + bias_hh_fwd = None + + # Backward direction weights if bidirectional + if bidirectional: + if params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + else: + weight_ih_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), dtype) + ) + weight_hh_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)), dtype) + ) + bias_ih_bwd = None + bias_hh_bwd = None + else: + weight_ih_bwd = None + weight_hh_bwd = None + bias_ih_bwd = None + bias_hh_bwd = None + + # Initial hidden states + if hx is not None: + h_prev_fwd = self.block_builder.emit( + relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip") + ) + h_prev_bwd = ( + self.block_builder.emit( + relax.op.take(hx, relax.const(1, "int64"), axis=0, mode="clip") + ) + if bidirectional + else None + ) + else: + h_prev_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + h_prev_bwd = ( + self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + if bidirectional + else None + ) + + # Reshape input to (seq_len, batch_size, input_size) + input_reshaped = ( + self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2])) + if batch_first + else input_tensor + ) + + # Process forward direction + output_fwd = self._rnn_tanh_cell_unroll( + input_reshaped, + weight_ih_fwd, + weight_hh_fwd, + bias_ih_fwd, + bias_hh_fwd, + h_prev_fwd, + seq_len, + reverse=False, + ) + + # Process backward direction if bidirectional + if bidirectional: + output_bwd = self._rnn_tanh_cell_unroll( + input_reshaped, + weight_ih_bwd, + weight_hh_bwd, + bias_ih_bwd, + bias_hh_bwd, + h_prev_bwd, + seq_len, + reverse=True, + ) + # Concatenate forward and backward outputs along feature dimension + output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2)) + else: + output = output_fwd + + # Reshape back to batch_first if needed + if batch_first: + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + + return output + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -1704,6 +1887,7 @@ def create_convert_map( "linear.default": self._linear, "lstm.input": self._lstm, "gru.input": self._gru, + "rnn_tanh.input": self._rnn_tanh, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, "max_pool2d_with_indices.default": self._max_pool2d_with_indices,