Skip to content

Commit a81bd70

Browse files
support LSTM quantization for PackedSequence input (#1384)
1 parent 01ef647 commit a81bd70

File tree

4 files changed

+86
-7
lines changed

4 files changed

+86
-7
lines changed

intel_extension_for_pytorch/quantization/_quantize_utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.fx.node import map_aggregate
77
from torch.ao.quantization import PlaceholderObserver
88
from torch.quantization.qconfig import QConfig
9+
from torch.nn.utils.rnn import PackedSequence
910

1011
from ._utils import get_torch_function_hook_type, HookType, get_module_hook_type, OpQuantizeabilityType, \
1112
attach_op_convert_info_to_model, save_quant_state, attach_scale_zp_values_to_model, convert_quant_state_map_to_nodes, \
@@ -36,6 +37,25 @@ def _check_add_has_scalar_input(args):
3637
return True
3738
return False
3839

40+
def _convert_PackedSequence_to_tuple_lstm(args):
41+
if isinstance(args, tuple) and len(args) == 2: # (PackedSequence, hx)
42+
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
43+
args = (input, batch_sizes, sorted_indices, unsorted_indices, args[-1])
44+
elif isinstance(args, tuple) and len(args) == 1: # (PackedSequence, )
45+
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
46+
args = (input, batch_sizes, sorted_indices, unsorted_indices)
47+
else:
48+
assert False, "_convert_PackedSequence_to_tuple args should be a tuple with size 2 or PackedSequence"
49+
return args
50+
51+
def _convert_tuple_to_PackedSequence_lstm(args):
52+
assert isinstance(args, tuple) and len(args) >= 4 and len(args) <=5, "_convert_tuple_to_PackedSequence input should be a tuple(5=<size >=4)"
53+
if len(args) == 4:
54+
return (PackedSequence(*args),)
55+
else:
56+
return (PackedSequence(*args[:-1]), args[-1])
57+
58+
3959
def auto_prepare(
4060
model : torch.nn.Module,
4161
configure: QConfig,
@@ -212,7 +232,9 @@ def _patched_module_call(self, *args, **kwargs):
212232
old_global_disable_torch_function_override = \
213233
global_disable_torch_function_override
214234
global_disable_torch_function_override = True
215-
235+
is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
236+
if is_lstm_packed_input:
237+
args = _convert_PackedSequence_to_tuple_lstm(args)
216238
if first_call:
217239
# mypy ignore is used instead of assert because this
218240
# runs on every forward and assert has a performance cost
@@ -226,19 +248,28 @@ def _patched_module_call(self, *args, **kwargs):
226248
args, kwargs = parent_qstate.op_prepare_before_hook(
227249
cur_module, args, kwargs) # type: ignore[arg-type]
228250

251+
if is_lstm_packed_input:
252+
args = _convert_tuple_to_PackedSequence_lstm(args)
253+
229254
# original forward
230255
output = orig_module_call(self, *args, **kwargs)
231256
# Re-enable the overrides.
232257
global_disable_torch_function_override = \
233258
old_global_disable_torch_function_override
234259

235260
# after hooks
261+
if is_lstm_packed_input:
262+
output = _convert_PackedSequence_to_tuple_lstm(output)
236263
if first_call:
237264
output = parent_qstate.first_call_op_prepare_after_hook(
238265
cur_module, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE)
239266
else:
240267
output = parent_qstate.op_prepare_after_hook(
241268
cur_module, output, args, global_op_idx)
269+
270+
if is_lstm_packed_input:
271+
output = _convert_tuple_to_PackedSequence_lstm(output)
272+
242273
parent_qstate.mark_cur_op_complete(cur_module)
243274
elif hook_type is HookType.MODULE_IO_HOOKS:
244275
cur_qstate = cur_module._auto_quant_state
@@ -500,17 +531,25 @@ def _patched_module_call(self, *args, **kwargs):
500531
old_global_disable_torch_function_override = \
501532
global_disable_torch_function_override
502533
global_disable_torch_function_override = True
534+
is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
535+
if is_lstm_packed_input:
536+
args = _convert_PackedSequence_to_tuple_lstm(args)
503537
_, args, kwargs = qstate.op_convert_before_hook(
504538
cur_module, args, kwargs, cur_module)
539+
if is_lstm_packed_input:
540+
args = _convert_tuple_to_PackedSequence_lstm(args)
505541
if type(cur_module) in quantized_modules_has_weights:
506542
weights = qstate.op_weight_convert_before_hook(cur_module)
507543
output = module_call_to_function_call(self, args, weights)
508544
else:
509545
output = orig_module_call(self, *args, **kwargs)
510546
# after hooks
547+
if is_lstm_packed_input:
548+
output = _convert_PackedSequence_to_tuple_lstm(output)
511549
output = qstate.op_convert_after_hook(
512550
cur_module, output)
513-
551+
if is_lstm_packed_input:
552+
output = _convert_tuple_to_PackedSequence_lstm(output)
514553
# Re-enable the override.
515554
global_disable_torch_function_override = \
516555
old_global_disable_torch_function_override

intel_extension_for_pytorch/quantization/_recipe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def _default_recipe_init(nodes):
6363
tensor_info.inf_dtype = tensor_info.orig_dtype
6464
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype
6565

66+
# For LSTM, if it's input is a PackedSequence, we don't support ot now.
67+
# TODO: support PackedSequence input for quantization LSTM.
68+
if node.type in rnn_ops and len(node.input_tensor_infos) > 2:
69+
for idx, tensor_info in enumerate(node.input_tensor_infos):
70+
if tensor_info is not None:
71+
tensor_info.inf_dtype = tensor_info.orig_dtype
72+
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype
73+
for idx, tensor_info in enumerate(node.weight_tensor_infos):
74+
if tensor_info is not None:
75+
tensor_info.inf_dtype = tensor_info.orig_dtype
76+
6677
#TODO: making fusion pattern check more general.
6778
def _find_fused_node_with_cur_elt_wise(node, ops):
6879
r"""

intel_extension_for_pytorch/quantization/_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def set_node_output_quantized(nodes):
403403
# output's infe dtype is not int8, set it and also set insert_fake_quant_after_output to True.
404404
"""
405405
def _reset_post_node_input_infos(node):
406-
# make sure the post node will node insert fake quant if we add fake quant by cur node' output
406+
# make sure the post node will insert fake quant if we add fake quant by cur node' output
407407
if len(node.post_nodes) > 0:
408408
for post_node in node.post_nodes:
409409
if post_node.qconfig is not None:
@@ -434,10 +434,12 @@ def _reset_post_node_input_infos(node):
434434
node.insert_fake_quant_after_outputs[0] = True
435435
_reset_post_node_input_infos(node)
436436
else:
437-
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
438-
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
439-
node.insert_fake_quant_after_outputs[0] = True
440-
_reset_post_node_input_infos(node)
437+
# TODO: enable PackedSequence input for LSTM.
438+
if not (node.type in [nn.LSTM] and len(node.input_tensor_infos) > 2):
439+
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
440+
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
441+
node.insert_fake_quant_after_outputs[0] = True
442+
_reset_post_node_input_infos(node)
441443

442444
qscheme_dict = {
443445
str(torch.per_tensor_affine): torch.per_tensor_affine,

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
9+
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
910
from torch.testing import FileCheck
1011
import copy
1112
import json
@@ -262,6 +263,32 @@ def _lstm_params_list():
262263
graph = self.checkQuantizeTrace(m, [x], atol=3e-2, rtol=1e-1)
263264
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)
264265

266+
def test_lstm_PackedSequence(self):
267+
class M(nn.Module):
268+
def __init__(self):
269+
super(M, self).__init__()
270+
self.lstm = nn.LSTM(input_size=288, hidden_size=1024, num_layers=6, batch_first=True, bidirectional=True, bias=True, dropout=0.2)
271+
272+
def forward(self, input, hid, mask=None):
273+
if mask is not None:
274+
lengths = mask.sum(-1)
275+
seq = pack_padded_sequence(input, lengths.cpu(), batch_first=True)
276+
seq, hid = self.lstm(seq, hid)
277+
seq = pad_packed_sequence(seq, batch_first=True)[0]
278+
return seq, hid
279+
else:
280+
return self.lstm(input, hid)
281+
282+
model = M().eval()
283+
seq = torch.randn(size=(1, 211, 288), dtype=torch.float32)
284+
# initialize hidden states
285+
h0 = torch.zeros((12, 1, 1024), dtype=seq.dtype)
286+
hid = (h0, h0)
287+
mask = torch.ones(size=(1, 211), dtype=torch.uint8)
288+
289+
graph = self.checkQuantizeTrace(model, [seq, hid, mask])
290+
self.assertGraphContainsExactly(graph, 'aten::lstm', 1)
291+
265292
class TestIpexQuantizationConvertAPI(JitLlgaTestCase):
266293
def test_inplace_preapre(self):
267294
class M(nn.Module):

0 commit comments

Comments
 (0)