1+ import pytest
2+ import numpy as np
3+
14from matplotlib .sankey import Sankey
5+ from matplotlib .testing .decorators import check_figures_equal
26
37
48def test_sankey ():
@@ -22,3 +26,80 @@ def show_three_decimal_places(value):
2226 format = show_three_decimal_places )
2327
2428 assert s .diagrams [0 ].texts [0 ].get_text () == 'First\n 0.250'
29+
30+
31+ @pytest .mark .parametrize ('kwargs, msg' , (
32+ ({'gap' : - 1 }, "'gap' is negative" ),
33+ ({'gap' : 1 , 'radius' : 2 }, "'radius' is greater than 'gap'" ),
34+ ({'head_angle' : - 1 }, "'head_angle' is negative" ),
35+ ({'tolerance' : - 1 }, "'tolerance' is negative" ),
36+ ({'flows' : [1 , - 1 ], 'orientations' : [- 1 , 0 , 1 ]},
37+ r"The shapes of 'flows' \(2,\) and 'orientations'" ),
38+ ({'flows' : [1 , - 1 ], 'labels' : ['a' , 'b' , 'c' ]},
39+ r"The shapes of 'flows' \(2,\) and 'labels'" ),
40+ ))
41+ def test_sankey_errors (kwargs , msg ):
42+ with pytest .raises (ValueError , match = msg ):
43+ Sankey (** kwargs )
44+
45+
46+ @pytest .mark .parametrize ('kwargs, msg' , (
47+ ({'trunklength' : - 1 }, "'trunklength' is negative" ),
48+ ({'flows' : [0.2 , 0.3 ], 'prior' : 0 }, "The scaled sum of the connected" ),
49+ ({'prior' : - 1 }, "The index of the prior diagram is negative" ),
50+ ({'prior' : 1 }, "The index of the prior diagram is 1" ),
51+ ({'connect' : (- 1 , 1 ), 'prior' : 0 }, "At least one of the connection" ),
52+ ({'connect' : (2 , 1 ), 'prior' : 0 }, "The connection index to the source" ),
53+ ({'connect' : (1 , 3 ), 'prior' : 0 }, "The connection index to this dia" ),
54+ ({'connect' : (1 , 1 ), 'prior' : 0 , 'flows' : [- 0.2 , 0.2 ],
55+ 'orientations' : [2 ]}, "The value of orientations" ),
56+ ({'connect' : (1 , 1 ), 'prior' : 0 , 'flows' : [- 0.2 , 0.2 ],
57+ 'pathlengths' : [2 ]}, "The lengths of 'flows'" ),
58+ ))
59+ def test_sankey_add_errors (kwargs , msg ):
60+ sankey = Sankey ()
61+ with pytest .raises (ValueError , match = msg ):
62+ sankey .add (flows = [0.2 , - 0.2 ])
63+ sankey .add (** kwargs )
64+
65+
66+ def test_sankey2 ():
67+ s = Sankey (flows = [0.25 , - 0.25 , 0.5 , - 0.5 ], labels = ['Foo' ],
68+ orientations = [- 1 ], unit = 'Bar' )
69+ sf = s .finish ()
70+ assert np .all (np .equal (np .array ((0.25 , - 0.25 , 0.5 , - 0.5 )), sf [0 ].flows ))
71+ assert sf [0 ].angles == [1 , 3 , 1 , 3 ]
72+ assert all ([text .get_text ()[0 :3 ] == 'Foo' for text in sf [0 ].texts ])
73+ assert all ([text .get_text ()[- 3 :] == 'Bar' for text in sf [0 ].texts ])
74+ assert sf [0 ].text .get_text () == ''
75+ assert np .allclose (np .array (((- 1.375 , - 0.52011255 ),
76+ (1.375 , - 0.75506044 ),
77+ (- 0.75 , - 0.41522509 ),
78+ (0.75 , - 0.8599479 ))),
79+ sf [0 ].tips )
80+
81+ s = Sankey (flows = [0.25 , - 0.25 , 0 , 0.5 , - 0.5 ], labels = ['Foo' ],
82+ orientations = [- 1 ], unit = 'Bar' )
83+ sf = s .finish ()
84+ assert np .all (np .equal (np .array ((0.25 , - 0.25 , 0 , 0.5 , - 0.5 )), sf [0 ].flows ))
85+ assert sf [0 ].angles == [1 , 3 , None , 1 , 3 ]
86+ assert np .allclose (np .array (((- 1.375 , - 0.52011255 ),
87+ (1.375 , - 0.75506044 ),
88+ (0 , 0 ),
89+ (- 0.75 , - 0.41522509 ),
90+ (0.75 , - 0.8599479 ))),
91+ sf [0 ].tips )
92+
93+
94+ @check_figures_equal (extensions = ['png' ])
95+ def test_sankey3 (fig_test , fig_ref ):
96+ ax_test = fig_test .gca ()
97+ s_test = Sankey (ax = ax_test , flows = [0.25 , - 0.25 , - 0.25 , 0.25 , 0.5 , - 0.5 ],
98+ orientations = [1 , - 1 , 1 , - 1 , 0 , 0 ])
99+ s_test .finish ()
100+
101+ ax_ref = fig_ref .gca ()
102+ s_ref = Sankey (ax = ax_ref )
103+ s_ref .add (flows = [0.25 , - 0.25 , - 0.25 , 0.25 , 0.5 , - 0.5 ],
104+ orientations = [1 , - 1 , 1 , - 1 , 0 , 0 ])
105+ s_ref .finish ()
0 commit comments