Skip to content

Commit b0c32c0

Browse files
bcallerBen Caller
authored andcommitted
Handle assignment unpacking a, b, c = d
We already handle a, b, c = d, *e, f a, b, c = d() But `a, b, c = d` prints 'Assignment not properly handled.' This can be handled exactly like `a, b, c = (*d,)`, where taint in value `d` is propagated to all targets.
1 parent 70257a7 commit b0c32c0

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

pyt/cfg/stmt_visitor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,11 @@ def visit_Try(self, node):
326326

327327
return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements)
328328

329-
def assign_tuple_target(self, node, right_hand_side_variables):
329+
def assign_tuple_target(self, target_nodes, value_nodes, right_hand_side_variables):
330330
new_assignment_nodes = []
331331
remaining_variables = list(right_hand_side_variables)
332-
remaining_targets = list(node.targets[0].elts)
333-
remaining_values = list(node.value.elts) # May contain duplicates
332+
remaining_targets = list(target_nodes)
333+
remaining_values = list(value_nodes) # May contain duplicates
334334

335335
def visit(target, value):
336336
label = LabelVisitor()
@@ -339,7 +339,7 @@ def visit(target, value):
339339
rhs_visitor.visit(value)
340340
if isinstance(value, ast.Call):
341341
new_ast_node = ast.Assign(target, value)
342-
ast.copy_location(new_ast_node, node)
342+
ast.copy_location(new_ast_node, target)
343343
new_assignment_nodes.append(self.assignment_call_node(label.result, new_ast_node))
344344
else:
345345
label.result += ' = '
@@ -349,7 +349,7 @@ def visit(target, value):
349349
extract_left_hand_side(target),
350350
ast.Assign(target, value),
351351
rhs_visitor.result,
352-
line_number=node.lineno,
352+
line_number=target.lineno,
353353
path=self.filenames[-1]
354354
)))
355355
remaining_targets.remove(target)
@@ -358,7 +358,7 @@ def visit(target, value):
358358
remaining_variables.remove(var)
359359

360360
# Pair targets and values until a Starred node is reached
361-
for target, value in zip(node.targets[0].elts, node.value.elts):
361+
for target, value in zip(target_nodes, value_nodes):
362362
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
363363
break
364364
visit(target, value)
@@ -380,7 +380,7 @@ def visit(target, value):
380380
extract_left_hand_side(target),
381381
ast.Assign(target, remaining_values[0]),
382382
remaining_variables,
383-
line_number=node.lineno,
383+
line_number=target.lineno,
384384
path=self.filenames[-1]
385385
)))
386386

@@ -413,14 +413,18 @@ def visit_Assign(self, node):
413413
rhs_visitor.visit(node.value)
414414
if isinstance(node.targets[0], (ast.Tuple, ast.List)): # x,y = [1,2]
415415
if isinstance(node.value, (ast.Tuple, ast.List)):
416-
return self.assign_tuple_target(node, rhs_visitor.result)
416+
return self.assign_tuple_target(node.targets[0].elts, node.value.elts, rhs_visitor.result)
417417
elif isinstance(node.value, ast.Call):
418418
call = None
419419
for element in node.targets[0].elts:
420420
label = LabelVisitor()
421421
label.visit(element)
422422
call = self.assignment_call_node(label.result, node)
423423
return call
424+
elif isinstance(node.value, ast.Name): # Treat `x, y = z` like `x, y = (*z,)`
425+
value_node = ast.Starred(node.value, ast.Load())
426+
ast.copy_location(value_node, node)
427+
return self.assign_tuple_target(node.targets[0].elts, [value_node], rhs_visitor.result)
424428
else:
425429
label = LabelVisitor()
426430
label.visit(node)

tests/cfg/cfg_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,21 @@ def test_assignment_starred_list(self):
820820
[('a', ['d']), ('b', ['d']), ('c', ['e'])],
821821
)
822822

823+
def test_unpacking_to_tuple(self):
824+
self.cfg_create_from_ast(ast.parse('a, b, c = d'))
825+
826+
middle_nodes = self.cfg.nodes[1:-1]
827+
self.assert_length(middle_nodes, expected_length=3)
828+
829+
self.assertCountEqual(
830+
[n.label for n in middle_nodes],
831+
['a, b, c = *d'] * 3,
832+
)
833+
self.assertCountEqual(
834+
[(n.left_hand_side, n.right_hand_side_variables) for n in middle_nodes],
835+
[('a', ['d']), ('b', ['d']), ('c', ['d'])],
836+
)
837+
823838
def test_augmented_assignment(self):
824839
self.cfg_create_from_ast(ast.parse('a+=f(b,c)'))
825840

0 commit comments

Comments
 (0)