diff --git a/physipy/__init__.py b/physipy/__init__.py index cb546d4..ff31440 100644 --- a/physipy/__init__.py +++ b/physipy/__init__.py @@ -10,6 +10,7 @@ from .integrate import quad, dblquad, tplquad from .optimize import root, brentq from .quantity import vectorize +from .quantity import uvectorize from .quantity import m, kg, s, A, K, cd, mol, rad, sr from .quantity import SI_units, SI_units_prefixed, SI_derived_units, other_units, units, all_units, SI_derived_units_prefixed diff --git a/physipy/quantity/calculus.py b/physipy/quantity/calculus.py index a65d93b..7c1ab76 100644 --- a/physipy/quantity/calculus.py +++ b/physipy/quantity/calculus.py @@ -19,6 +19,52 @@ from .utils import array_to_Q_array, decorate_with_various_unit, asqarray +def umap(func, *args, **kwargs): + """ Extension of python's 'map' function that works with units + func is the function to map + args is the arrays on which the function should me mapped + kwargs must not be used + """ + if not callable(func): + raise ValueError("First argument must be a function") + + if len(kwargs) > 0: + raise ValueError("Keywords arguments are not allowed in this function") + + # Analyse whether arguments are arrays or scalars + args_len = [] + for i, arg in enumerate(args): + if np.isscalar(arg): + args_len.append(1) + elif isinstance(arg, Quantity) and np.isscalar(arg.value): + args_len.append(1) + else: + shape = np.shape(arg) + if len(shape) > 1: + raise NotImplementedError("Only 1D arrays are supported") + else: + length = len(arg) + args_len.append(length) + ref_length = max(args_len) + args_modif = list(args) + for i, arg in enumerate(args): + if args_len[i] == 1: + args_modif[i] = np.repeat(args[i], ref_length) + elif args_len[i] < ref_length: + raise ValueError("When calling umap: All array/list arguments should have the same length") + else: + args_modif[i] = arg + + args = tuple(args_modif) + print("Function arguments:", args) + + out = np.empty((ref_length,), dtype=object) # Declare the output array + + for i in range(ref_length): # iterating on index + arg = tuple( [arg[i] for arg in args] ) # building tuple of arguments for the current index + out[i] = func(*arg) + + return out def vectorize(func): """Allow vectorize a function of Quantity. @@ -33,6 +79,18 @@ def func_Q_vec(*args, **kwargs): return res return func_Q_vec +def uvectorize(func, *args, **kwargs): + """Allow vectorize a function of Quantity. + + This function aims to extend numpy.vectorize to Quantity-function. + + """ + def func_out(*a, **k): + return asqarray( umap(func, *a, **k)) + return func_out + + + def xvectorize(func): def vec_func(x): @@ -46,6 +104,8 @@ def vec_func(x): def ndvectorize(func): + """ Vectorize function for functions with one argument, this argument being an N-dimensional np array + """ def vec_func(x): res = [] for i in x.flat: @@ -88,4 +148,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test/test_quantity.py b/test/test_quantity.py index 8ec56af..1625bce 100644 --- a/test/test_quantity.py +++ b/test/test_quantity.py @@ -1854,7 +1854,50 @@ def thresh(x): exp = np.array([[3, 3],[3, 3], [4, 5]])*m self.assertTrue(np.all(res == exp)) + + def test_uvectorize_single_arg(self): + # 1D array + arr_m = np.arange(5)*m + + def thresh(x): + if x >3*m: + return x + else: + return 3*m + vec_thresh = uvectorize(thresh) + + res = vec_thresh(arr_m) + exp = np.array([3, 3, 3, 3, 4])*m + self.assertTrue(np.all(res == exp)) + + # nD array + #Will fail because not yet implemented + #arr_m = np.arange(6).reshape(3,2)*m + #res = vec_thresh(arr_m) + #exp = np.array([[3, 3],[3, 3], [4, 5]])*m + #self.assertTrue(np.all(res == exp)) + + def test_uvectorize_multiple_args(self): + # 1D array + x_arr = np.arange(5)*m + y_arr = 2*np.arange(5)*m + z_arr = 3*np.arange(5)*m + + def thresh(x, y, z): + if x <=3*m: + return y+z + else: + return 3*m + vec_thresh = uvectorize(thresh) + + res = vec_thresh(x_arr, y_arr, z_arr) + exp = np.array([0, 5, 10, 15, 3])*m + self.assertIsNone(np.testing.assert_array_equal(res, exp)) + res2 = vec_thresh(x_arr, 0*m, z_arr) + exp2 = np.array([0, 3, 6, 9, 3])*m + self.assertIsNone(np.testing.assert_array_equal(res2, exp2)) + def test_np_fft_fftshift(self): exp = np.fft.fftshift(np.arange(10))*s self.assertTrue(np.all(np.fft.fftshift(np.arange(10)*s)==exp)) @@ -2126,4 +2169,4 @@ def RHS_dydt(t, y): if __name__ == "__main__": unittest.main() - \ No newline at end of file +