Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions physipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .integrate import quad, dblquad, tplquad
from .optimize import root, brentq
from .quantity import vectorize
from .quantity import uvectorize
from .quantity import m, kg, s, A, K, cd, mol, rad, sr
from .quantity import SI_units, SI_units_prefixed, SI_derived_units, other_units, units, all_units, SI_derived_units_prefixed

Expand Down
62 changes: 61 additions & 1 deletion physipy/quantity/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,52 @@
from .utils import array_to_Q_array, decorate_with_various_unit, asqarray


def umap(func, *args, **kwargs):
""" Extension of python's 'map' function that works with units
func is the function to map
args is the arrays on which the function should me mapped
kwargs must not be used
"""
if not callable(func):
raise ValueError("First argument must be a function")

if len(kwargs) > 0:
raise ValueError("Keywords arguments are not allowed in this function")

# Analyse whether arguments are arrays or scalars
args_len = []
for i, arg in enumerate(args):
if np.isscalar(arg):
args_len.append(1)
elif isinstance(arg, Quantity) and np.isscalar(arg.value):
args_len.append(1)
else:
shape = np.shape(arg)
if len(shape) > 1:
raise NotImplementedError("Only 1D arrays are supported")
else:
length = len(arg)
args_len.append(length)
ref_length = max(args_len)
args_modif = list(args)
for i, arg in enumerate(args):
if args_len[i] == 1:
args_modif[i] = np.repeat(args[i], ref_length)
elif args_len[i] < ref_length:
raise ValueError("When calling umap: All array/list arguments should have the same length")
else:
args_modif[i] = arg

args = tuple(args_modif)
print("Function arguments:", args)

out = np.empty((ref_length,), dtype=object) # Declare the output array

for i in range(ref_length): # iterating on index
arg = tuple( [arg[i] for arg in args] ) # building tuple of arguments for the current index
out[i] = func(*arg)

return out

def vectorize(func):
"""Allow vectorize a function of Quantity.
Expand All @@ -33,6 +79,18 @@ def func_Q_vec(*args, **kwargs):
return res
return func_Q_vec

def uvectorize(func, *args, **kwargs):
"""Allow vectorize a function of Quantity.

This function aims to extend numpy.vectorize to Quantity-function.

"""
def func_out(*a, **k):
return asqarray( umap(func, *a, **k))
return func_out




def xvectorize(func):
def vec_func(x):
Expand All @@ -46,6 +104,8 @@ def vec_func(x):


def ndvectorize(func):
""" Vectorize function for functions with one argument, this argument being an N-dimensional np array
"""
def vec_func(x):
res = []
for i in x.flat:
Expand Down Expand Up @@ -88,4 +148,4 @@ def main():


if __name__ == "__main__":
main()
main()
45 changes: 44 additions & 1 deletion test/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,50 @@ def thresh(x):
exp = np.array([[3, 3],[3, 3], [4, 5]])*m
self.assertTrue(np.all(res == exp))


def test_uvectorize_single_arg(self):
# 1D array
arr_m = np.arange(5)*m

def thresh(x):
if x >3*m:
return x
else:
return 3*m
vec_thresh = uvectorize(thresh)

res = vec_thresh(arr_m)
exp = np.array([3, 3, 3, 3, 4])*m
self.assertTrue(np.all(res == exp))

# nD array
#Will fail because not yet implemented
#arr_m = np.arange(6).reshape(3,2)*m
#res = vec_thresh(arr_m)
#exp = np.array([[3, 3],[3, 3], [4, 5]])*m
#self.assertTrue(np.all(res == exp))

def test_uvectorize_multiple_args(self):
# 1D array
x_arr = np.arange(5)*m
y_arr = 2*np.arange(5)*m
z_arr = 3*np.arange(5)*m

def thresh(x, y, z):
if x <=3*m:
return y+z
else:
return 3*m
vec_thresh = uvectorize(thresh)

res = vec_thresh(x_arr, y_arr, z_arr)
exp = np.array([0, 5, 10, 15, 3])*m
self.assertIsNone(np.testing.assert_array_equal(res, exp))

res2 = vec_thresh(x_arr, 0*m, z_arr)
exp2 = np.array([0, 3, 6, 9, 3])*m
self.assertIsNone(np.testing.assert_array_equal(res2, exp2))

def test_np_fft_fftshift(self):
exp = np.fft.fftshift(np.arange(10))*s
self.assertTrue(np.all(np.fft.fftshift(np.arange(10)*s)==exp))
Expand Down Expand Up @@ -2126,4 +2169,4 @@ def RHS_dydt(t, y):
if __name__ == "__main__":
unittest.main()