diff --git a/csa/elementary.py b/csa/elementary.py index 4d848b8..f3c5d2a 100644 --- a/csa/elementary.py +++ b/csa/elementary.py @@ -71,6 +71,22 @@ def ival (beg, end): # Cartesian product # def cross (set0, set1): + """ + Compute the Cartesian product of two sets. + + Note + ---- + The Cartesian product returned by `cross()` is not automatically + restricted to a finite mask. When visualized using `show()`, + connections may appear in the default (infinite) mask unless + explicitly multiplied with a finite mask. + + Example + ------- + # To obtain a bounded all-to-all connectivity pattern: + cross(A, B) * full(len(A), len(B)) + """ + return _cs.intervalSetMask (set0, set1) # Elementary masks diff --git a/csa/plot.py b/csa/plot.py index dd2fecc..c6a6ced 100644 --- a/csa/plot.py +++ b/csa/plot.py @@ -19,6 +19,14 @@ import numpy as _numpy import matplotlib import matplotlib.pyplot as _plt +from matplotlib.colors import ListedColormap + +# Color mapping for masks in show() +MASK_COLOR = { + 'full': 'purple', + 'oneToOne': 'purple', + 'randomMask': 'green', +} from . import elementary @@ -40,9 +48,15 @@ def show (cset, N0 = 30, N1 = None): N1 = N0 if N1 == None else N1 _plt.clf () _plt.axis ('equal') + # Determine color from mask type + mask_tag = getattr(cset, 'tag', None) + color = MASK_COLOR.get(mask_tag, 'gray') + # Create a colormap with only this color + cmap = ListedColormap([color_name]) a = _numpy.zeros ((N0, N1)) for (i, j) in elementary.cross (range (N0), range (N1)) * cset: a[i,j] += 1.0 + # Show with consistent color _plt.imshow (a, interpolation='nearest', vmin = 0.0, vmax = 1.0) _plt.show ()