Skip to content

Commit 0f58880

Browse files
author
DvirDukhan
committed
added test for torch model with tuple output
1 parent 452bf1a commit 0f58880

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/flow/tests_pytorch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,24 @@ def test_parallelism():
962962
for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]}
963963
env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2")
964964
env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2")
965+
966+
def test_modelget_for_tuple_output(env):
967+
if not TEST_PT:
968+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
969+
return
970+
con = env.getConnection()
971+
972+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
973+
model_filename = os.path.join(test_data_path, 'pt-minimal-bb.pt')
974+
with open(model_filename, 'rb') as f:
975+
model_pb = f.read()
976+
ret = con.execute_command('AI.MODELSET', 'm{1}', 'TORCH', DEVICE, 'BLOB', model_pb)
977+
ensureSlaveSynced(con, env)
978+
env.assertEqual(b'OK', ret)
979+
ret = con.execute_command('AI.MODELGET', 'm{1}', 'META')
980+
env.assertEqual(ret[1], b'TORCH')
981+
env.assertEqual(ret[5], b'')
982+
env.assertEqual(ret[7], 0)
983+
env.assertEqual(ret[9], 0)
984+
env.assertEqual(len(ret[11]), 2)
985+
env.assertEqual(len(ret[13]), 2)

0 commit comments

Comments
 (0)