Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,189 @@ def _gru(self, node: fx.Node) -> relax.Var:

return output

def _rnn_tanh_cell_unroll(
self,
input_reshaped,
weight_ih,
weight_hh,
bias_ih,
bias_hh,
h_prev,
seq_len,
reverse=False,
):
"""Unroll vanilla tanh-RNN cells for a single direction."""
# Transpose weights for matmul: (hidden_size, in) -> (in, hidden_size)
weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0]))
weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0]))

outputs = []
time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)

for t in time_steps:
# Input at time t: (batch_size, input_size)
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
)

# h_t = tanh(W_ih @ x_t + b_ih + W_hh @ h_{t-1} + b_hh)
ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
if bias_ih is not None and bias_hh is not None:
ih = self.block_builder.emit(relax.op.add(ih, bias_ih))
hh = self.block_builder.emit(relax.op.add(hh, bias_hh))
h_t = self.block_builder.emit(relax.op.tanh(relax.op.add(ih, hh)))

outputs.append(h_t)
h_prev = h_t

if reverse:
outputs = outputs[::-1]

output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
return output

def _rnn_tanh(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
hx = args[1] if len(args) > 1 else None
params = args[2] if len(args) > 2 else None
has_biases = args[3] if len(args) > 3 else True
num_layers = args[4] if len(args) > 4 else 1
_dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
_train = args[6] if len(args) > 6 else False # Not used in inference
bidirectional = args[7] if len(args) > 7 else False
batch_first = args[8] if len(args) > 8 else False

if num_layers > 1:
raise NotImplementedError("Multi-layer RNN is not yet supported")

input_shape = self.shape_of(input_tensor)
if batch_first:
batch_size, seq_len, input_size = input_shape
else:
seq_len, batch_size, input_size = input_shape

seq_len = int(seq_len) if isinstance(seq_len, tvm.tirx.IntImm) else seq_len
batch_size = int(batch_size) if isinstance(batch_size, tvm.tirx.IntImm) else batch_size
input_size = int(input_size) if isinstance(input_size, tvm.tirx.IntImm) else input_size

# params per direction: weight_ih, weight_hh, [bias_ih, bias_hh]
params_per_direction = 4 if has_biases else 2

# A vanilla RNN has a single gate, so weight_ih has shape (hidden_size, input_size)
if params and len(params) >= 2:
hidden_size = self.shape_of(params[0])[0]
else:
hidden_size = 16
hidden_size = int(hidden_size) if isinstance(hidden_size, tvm.tirx.IntImm) else hidden_size

dtype = input_tensor.struct_info.dtype

# Forward direction weights
if params and len(params) >= params_per_direction:
weight_ih_fwd = params[0]
weight_hh_fwd = params[1]
bias_ih_fwd = params[2] if has_biases else None
bias_hh_fwd = params[3] if has_biases else None
else:
weight_ih_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), dtype)
)
weight_hh_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)), dtype)
)
bias_ih_fwd = None
bias_hh_fwd = None

# Backward direction weights if bidirectional
if bidirectional:
if params and len(params) >= params_per_direction * 2:
weight_ih_bwd = params[params_per_direction]
weight_hh_bwd = params[params_per_direction + 1]
bias_ih_bwd = params[params_per_direction + 2] if has_biases else None
bias_hh_bwd = params[params_per_direction + 3] if has_biases else None
else:
weight_ih_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((hidden_size, input_size)), dtype)
)
weight_hh_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((hidden_size, hidden_size)), dtype)
)
bias_ih_bwd = None
bias_hh_bwd = None
else:
weight_ih_bwd = None
weight_hh_bwd = None
bias_ih_bwd = None
bias_hh_bwd = None

# Initial hidden states
if hx is not None:
h_prev_fwd = self.block_builder.emit(
relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip")
)
Comment thread
cchung100m marked this conversation as resolved.
h_prev_bwd = (
self.block_builder.emit(
relax.op.take(hx, relax.const(1, "int64"), axis=0, mode="clip")
)
if bidirectional
else None
)
else:
h_prev_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
h_prev_bwd = (
self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
if bidirectional
else None
)

# Reshape input to (seq_len, batch_size, input_size)
input_reshaped = (
self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2]))
if batch_first
else input_tensor
)

# Process forward direction
output_fwd = self._rnn_tanh_cell_unroll(
input_reshaped,
weight_ih_fwd,
weight_hh_fwd,
bias_ih_fwd,
bias_hh_fwd,
h_prev_fwd,
seq_len,
reverse=False,
)

# Process backward direction if bidirectional
if bidirectional:
output_bwd = self._rnn_tanh_cell_unroll(
input_reshaped,
weight_ih_bwd,
weight_hh_bwd,
bias_ih_bwd,
bias_hh_bwd,
h_prev_bwd,
seq_len,
reverse=True,
)
# Concatenate forward and backward outputs along feature dimension
output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2))
else:
output = output_fwd

# Reshape back to batch_first if needed
if batch_first:
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))

return output

########## Manipulation ##########

def _narrow(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -1704,6 +1887,7 @@ def create_convert_map(
"linear.default": self._linear,
"lstm.input": self._lstm,
"gru.input": self._gru,
"rnn_tanh.input": self._rnn_tanh,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool2d_with_indices.default": self._max_pool2d_with_indices,
Expand Down
Loading