From 5f2fe07d01eec27a11e640e4fcdf63be51f62a04 Mon Sep 17 00:00:00 2001 From: Rohit Tembhare Date: Tue, 17 Mar 2026 00:11:10 +0530 Subject: [PATCH] add support for automatic differentiation using Tapenade --- compyle/autodiff_tapenade.py | 626 ++++++++++++++++++++++++ compyle/c_backend.py | 1 - compyle/cimport.py | 119 ++++- compyle/jit.py | 1 - compyle/parallel.py | 4 + compyle/tests/test_autodiff_tapenade.py | 99 ++++ compyle/tests/test_parallel.py | 59 +++ compyle/translator.py | 9 +- compyle/transpiler.py | 9 +- examples/autodiff/billiards.py | 193 ++++++++ examples/autodiff/nn_mnist.py | 197 ++++++++ 11 files changed, 1301 insertions(+), 16 deletions(-) create mode 100644 compyle/autodiff_tapenade.py create mode 100644 compyle/tests/test_autodiff_tapenade.py create mode 100644 examples/autodiff/billiards.py create mode 100644 examples/autodiff/nn_mnist.py diff --git a/compyle/autodiff_tapenade.py b/compyle/autodiff_tapenade.py new file mode 100644 index 0000000..c705d61 --- /dev/null +++ b/compyle/autodiff_tapenade.py @@ -0,0 +1,626 @@ +import pybind11 +import inspect +import os +import subprocess +import functools +import re +import tempfile +import shutil +import shlex +import numpy as np +from mako.template import Template +from distutils.errors import CompileError + +from .profile import profile +from .translator import CConverter +from .transpiler import convert_to_float_if_needed +from . import array +from compyle.api import get_config + +from .ext_module import get_md5 +from .cimport import Cmodule, get_tpnd_obj_dir, compile_tapenade_source +from .transpiler import Transpiler + + +pyb11_setup_header = ''' + +// c code for with PyBind11 binding +#include +#include +namespace py = pybind11; + + +''' + +pyb11_setup_header_rev = ''' + +// c code for with PyBind11 binding +#include +#include +#include +namespace py = pybind11; +extern "C" { +#include +} +''' + + +# // Include complex before adStack.h and define the type +# extern "C" { +# #include +# #define double_complex double complex +# } +pyb11_wrap_template = ''' +PYBIND11_MODULE(m_${hash_fn}, m) { + m.def("m_${hash_fn}", [](${pyb_call}){ + return ${name}${FUNC_SUFFIX_F}(${c_call}); + }); +} +''' + +pyb11_wrap_template_rev = ''' +PYBIND11_MODULE(${modname}, m) { + m.def("${modname}", [](${pyb_call}){ + return ${name}${FUNC_SUFFIX_F}(${c_call}); + }); +} +''' + +c_backend_template = ''' +${c_kernel_defn} + +void elementwise_${fn_name}(long int SIZE, ${fn_args}) +{ + %if openmp: + #pragma omp parallel for + %endif + for(long int iter = (SIZE - 1); iter >= 0; iter--) + { + ${fn_name}(iter, ${fn_call}) ; + } +} +''' +VAR_SUFFIX_F = '__d' +FUNC_SUFFIX_F = '_d' + +VAR_SUFFIX_R = '__b' +FUNC_SUFFIX_R = '_b' + + +def _get_tapenade_cmd(): + cmd = os.environ.get("COMPYLE_TAPENADE_CMD") + if cmd: + return shlex.split(cmd) + + tap = os.environ.get("TAPENADE") + if tap: + return shlex.split(tap) + + exe = shutil.which("tapenade") + if exe: + return [exe] + + raise RuntimeError( + "Tapenade executable not found. Install Tapenade and ensure it is on PATH " + "as 'tapenade', or set COMPYLE_TAPENADE_CMD (e.g. 'java -jar /path/tapenade.jar') " + "or TAPENADE (path to tapenade)." + ) + + +def get_source(f): + c = CConverter() + source = c.parse(f) + return source + + +def sig_to_pyb_call(par, typ): + if typ[-1] == "*": + call = "py::array_t<{}> {}".format(typ[:-1], par) + else: + call = typ + " " + str(par) + return call + + +def sig_to_c_call(par, typ): + if typ[-1] == '*': + call = "({ctype}) {arg}.request().ptr".format(ctype=typ, arg=par) + else: + call = "{arg}".format(arg=par) + return call + + +def get_diff_signature(f, active, mode='forward'): + if mode == 'forward': + VAR_SUFFIX = VAR_SUFFIX_F + elif mode == 'reverse': + VAR_SUFFIX = VAR_SUFFIX_R + + sig = inspect.signature(f) + pyb_c = [] + pyb_py = [] + pure_c = [] + pure_py = [] + for s in sig.parameters: + typ = str(sig.parameters[s].annotation.type) + if s not in active: + pyb_py.append([sig_to_pyb_call(s, typ)]) + pyb_c.append([sig_to_c_call(s, typ)]) + pure_c.append(["{typ} {i}".format(typ=typ, i=s)]) + pure_py.append([s]) + else: + typstar = typ if typ[-1] == '*' else typ + '*' + pyb_py.append([ + sig_to_pyb_call(s, typ), + sig_to_pyb_call(s + VAR_SUFFIX, typstar) + ]) + pyb_c.append([ + sig_to_c_call(s, typ), + sig_to_c_call(s + VAR_SUFFIX, typstar) + ]) + pure_c.append([ + "{typ} {i}".format(typ=typ, i=s), + "{typ} {i}".format(typ=typstar, i=s + VAR_SUFFIX) + ]) + pure_py.append([s, s + VAR_SUFFIX]) + + pyb_py_all = functools.reduce(lambda x, y: x + y, pyb_py) + pyb_c_all = functools.reduce(lambda x, y: x + y, pyb_c) + pure_c = functools.reduce(lambda x, y: x + y, pure_c) + pure_py = functools.reduce(lambda x, y: x + y, pure_py) + + return pyb_py_all, pyb_c_all, pure_py, pure_c + + +class GradBase: + def __init__(self, func, wrt, gradof, mode='forward', backend='tapenade'): + self.backend = backend + self.func = func + self.args = list(inspect.signature(self.func).parameters.keys()) + self.wrt = wrt + self.gradof = gradof + self.active = [] + self.mode = mode + self.name = func.__name__ + self._config = get_config() + self.source = 'Not yet generated' + self.grad_source = 'Not yet generated' + self.grad_all_source = 'Not yet generated' + self.tapenade_op = 'Not yet generated' + self.message = "" + self.c_func = self.c_gen_error + self.tp = Transpiler(backend='c') + self._get_sources() + self.grad_args, self.grad_types = self._get_grad_def(self.grad_source) + self._get_active_vars() + + def _get_sources(self): + self.tp.add(self.func) + self.source = self.tp.get_code(incl_header=False) + + temp_dir = tempfile.mkdtemp() + + with open(os.path.join(temp_dir, self.name + '.c'), 'w') as f: + f.write(self.source) + + if self.mode == 'forward': + command = _get_tapenade_cmd() + [ + f"{self.name}.c", "-d", "-o", + f"{self.name}_forward_diff", "-tgtvarname", + f"{VAR_SUFFIX_F}", "-tgtfuncname", + f"{FUNC_SUFFIX_F}", + "-head", + f'{self.name}({" ".join(self.wrt)})({" ".join(self.gradof)})', + "-nooptim", "recomputeintermediates", + "-nooptim", "spareinit", + ] + elif self.mode == 'reverse': + command = _get_tapenade_cmd() + [ + f"{self.name}.c", "-b", "-o", + f"{self.name}_reverse_diff", "-adjvarname", + f"{VAR_SUFFIX_R}", "-adjfuncname", + f"{FUNC_SUFFIX_R}", + "-head", + f'{self.name}({" ".join(self.wrt)})({" ".join(self.gradof)})', + "-nooptim", "adjointliveness", + ] + # "-nooptim", "diffliveness", + # "-fixinterface" + # "-nooptim", "recomputeintermediates", + else: + raise ValueError(f"supported modes are 'forward' and 'reverse', got {self.mode}") + + if self.mode == 'forward': + f_extn = "_forward_diff_d.c" + elif self.mode == 'reverse': + f_extn = "_reverse_diff_b.c" + + op_tapenade = "" + try: + proc = subprocess.run( + command, capture_output=True, text=True, cwd=temp_dir, check=True + ) + op_tapenade += proc.stdout + except FileNotFoundError as e: + raise RuntimeError( + "Tapenade command could not be executed. " + "Install Tapenade and ensure it is on PATH, or set COMPYLE_TAPENADE_CMD/TAPENADE. " + f"Original error: {e}" + ) from e + except subprocess.CalledProcessError as e: + self.read_msg(temp_dir, f_extn) + print(e) + print("*"*80) + print(self.message) + print("*"*80) + raise RuntimeError( + "Encountered errors while differentiating through Tapenade.") + self.tapenade_op = op_tapenade + self.read_msg(temp_dir, f_extn) + + with open(os.path.join(temp_dir, self.name + f_extn), 'r') as f: + self.grad_source = f.read() + + def read_msg(self, temp_dir, f_extn): + try: + with open(os.path.join(temp_dir, self.name + f_extn[:-1] + "msg"), 'r') as f: + self.message = f.read() + except: + try: + with open(os.path.join(temp_dir, self.name + f_extn[:-1] + "msg~"), 'r') as f: + self.message = f.read() + except: + pass + + def c_gen_error(*args): + raise RuntimeError("Differentiated function not yet generated") + + def _massage_arg(self, x): + if isinstance(x, array.Array): + return x.dev + elif self.backend != 'cuda' or isinstance(x, np.ndarray): + return x + else: + return np.asarray(x) + + def _get_grad_def(self, src): + lines_src = src.split("\n") + n_lines = len(lines_src) + i = 0 + start = 0 + while i < n_lines: + # if lines_src[i].strip().startswith(f'void {self.name}'): + if f'void {self.name}' in lines_src[i].strip(): + start = i + break + i += 1 + if i == n_lines: + raise CompileError('could not find fn definition for derivative') + + end = 0 + + while i < n_lines: + if lines_src[i].strip().endswith("{"): + end = i + break + i += 1 + if i == n_lines: + raise CompileError('could not find fn definition for derivative') + + src_def = " ".join([i.strip() for i in lines_src[start:end + 1]]) + args_type = re.search(r"\((.*?)\)", src_def).group(1).split(",") + args = [] + types = [] + for val in args_type: + temp = val.split() + if len(temp) == 2: + t1 = temp[0] + t2 = temp[1] + elif len(temp) == 3: + t1 = temp[0] + t2 = temp[1] + temp[2] + else: + raise CompileError('could not get arguments from generated fn') + if t2.startswith("*"): + t1 += "*" + t2 = t2[1:] + types.append(t1) + args.append(t2) + return args, types + + def _get_active_vars(self): + if self.mode == 'forward': + suff = VAR_SUFFIX_F + elif self.mode == 'reverse': + suff = VAR_SUFFIX_R + + for i, var in enumerate(self.grad_args): + if var.endswith(suff): + self.active.append(self.grad_args[i-1]) + + @profile + def __call__(self, *args, **kw): + c_args = [self._massage_arg(x) for x in args] + + if self.backend == 'cuda': + import pycuda.driver as drv + event = drv.Event() + self.c_func(*c_args, **kw) + event.record() + event.synchronize() + + elif self.backend == 'c': + self.c_func(*c_args) + + else: + raise RuntimeError("Given backend not supported, got '{}'".format( + self.backend)) + + +class ForwardGrad(GradBase): + def __init__(self, func, wrt, gradof): + super(ForwardGrad, self).__init__(func, + wrt, + gradof, + mode='forward', + backend='tapenade') + self.c_func = self.get_c_forward_diff() + + def get_c_forward_diff(self): + self.grad_source = pyb11_setup_header + self.grad_source + hash_fn = get_md5(self.grad_source) + modname = f'm_{hash_fn}' + + pyb_all, c_all, _, _ = get_diff_signature(self.func, + self.active, + mode='forward') + pyb_call = ", ".join(pyb_all) + c_call = ", ".join(c_all) + + pyb_temp = Template(pyb11_wrap_template) + pyb_bind = pyb_temp.render(name=self.name, + hash_fn=hash_fn, + FUNC_SUFFIX_F=FUNC_SUFFIX_F, + pyb_call=pyb_call, + c_call=c_call) + + self.grad_all_source = self.grad_source + pyb_bind + mod = Cmodule(self.grad_all_source, hash_fn) + module = mod.load() + return getattr(module, modname) + + def _get_len_wrt_args(self, args): + len_wrt_args = [] + for i in range(len(self.args)): + if self.args[i] in self.wrt: + len_wrt_args.append(len(args[i])) + return len_wrt_args + + def _add_wrt_args_fwd(self, args, wrt_var, len_wrt_arg): + final_args = [] + gradof_args = [] + wrt_arg = None + is_grad_var = [] + for i in range(len(args)): + final_args.append(args[i]) + is_grad_var.append(0) + if self.args[i] in self.active: + temp = np.zeros((len(args[i]), len_wrt_arg)) + final_args.append(temp) + is_grad_var.append(1) + if self.args[i] == wrt_var: + wrt_arg = temp + if self.args[i] in self.gradof: + gradof_args.append(temp) + return final_args, gradof_args, wrt_arg, is_grad_var + + def _call_multi_fwd(self, final_args, wrt_arg, is_grad_var): + for i in range(len(wrt_arg)): + wrt_arg[i][i] = 1.0 + + for i in range(len(wrt_arg)): + temp_args = [] + for j, arg in enumerate(final_args): + if not is_grad_var[j]: + temp_args.append(arg) + else: + temp_args.append(arg[:, i]) + self.c_func(*temp_args) + + @profile + def __call__(self, *args): + c_args = [self._massage_arg(x) for x in args] + + len_g_args = self._get_len_wrt_args(c_args) + + ans = [] + for i, grad_var in enumerate(self.wrt): + final_args, gradof_args, wrt_arg, is_grad_var = self._add_wrt_args_fwd(c_args, grad_var, len_g_args[i]) + + self._call_multi_fwd(final_args, wrt_arg, is_grad_var) + ans.append(gradof_args) + return ans + + +class ElementwiseGrad(GradBase): + def __init__(self, func, wrt, gradof, backend='c'): + super(ElementwiseGrad, self).__init__(func, + wrt, + gradof, + mode='forward', + backend=backend) + self._config = get_config() + self.c_func = self._generate() + + def _generate(self): + if self.backend == 'c': + return self._c_gen() + elif self.backend == 'cuda': + return self._cuda_gen() + + def correct_initialization(self): + for var in self.gradof: + grad_var = var + VAR_SUFFIX_F + prev = f"*{grad_var} = 0" + after = f"{grad_var}[i] = 0" + self.grad_all_source = self.grad_all_source.replace(prev, after) + + def _c_gen(self): + pyb_args, pyb_c_args, py_args, c_args = get_diff_signature( + self.func, self.active) + + c_templt = Template(c_backend_template) + c_code = c_templt.render(c_kernel_defn=self.grad_source, + fn_name='{fname}{suff}'.format( + fname=self.name, suff=FUNC_SUFFIX_F), + fn_args=", ".join(c_args[1:]), + fn_call=", ".join(py_args[1:]), + openmp=self._config.use_openmp) + + self.grad_all_source = pyb11_setup_header + c_code + + hash_fn = get_md5(self.grad_all_source) + modname = f'm_{hash_fn}' + + pyb_templt = Template(pyb11_wrap_template) + elwise_name = 'elementwise_' + self.name + size = "{}.request().size".format(py_args[1]) + pyb_code = pyb_templt.render(name=elwise_name, + hash_fn=hash_fn, + FUNC_SUFFIX_F=FUNC_SUFFIX_F, + pyb_call=", ".join(pyb_args[1:]), + c_call=", ".join([size] + pyb_c_args[1:])) + self.grad_all_source += pyb_code + mod = Cmodule(self.grad_all_source, hash_fn) + module = mod.load() + return getattr(module, modname) + + def _cuda_gen(self): + from .cuda import set_context + set_context() + from pycuda.elementwise import ElementwiseKernel + from pycuda._cluda import CLUDA_PREAMBLE + + _, _, py_args, c_args = get_diff_signature(self.func, self.active) + + self.grad_source = self.convert_to_device_code(self.grad_source) + expr = '{func}({args})'.format(func=self.name + FUNC_SUFFIX_F, + args=", ".join(py_args)) + + arguments = convert_to_float_if_needed(", ".join(c_args[1:])) + preamble = convert_to_float_if_needed(self.grad_source) + + cluda_preamble = Template(text=CLUDA_PREAMBLE).render( + double_support=True) + self.grad_all_source = cluda_preamble + preamble + self.correct_initialization() + knl = ElementwiseKernel(name=self.name, + arguments=arguments, + operation=expr, + preamble="\n".join([cluda_preamble, preamble])) + return knl + + def convert_to_device_code(self, code): + code = re.sub(r'\bvoid\b', 'WITHIN_KERNEL void', code) + code = re.sub(r'\bfloat\b', 'GLOBAL_MEM float ', code) + code = re.sub(r'\bdouble\b', 'GLOBAL_MEM double ', code) + return code + + +class ReverseGrad(GradBase): + def __init__(self, func, wrt, gradof, backend='tapenade'): + if self.req_recomp_tpnd(): + compile_tapenade_source() + super().__init__(func, wrt, gradof, mode='reverse', backend=backend) + self.c_func = self._c_reverse_diff() + + def _c_reverse_diff(self): + self.grad_source = pyb11_setup_header_rev + self.grad_source + hash_fn = get_md5(self.grad_source) + modname = f'm_{hash_fn}' + + pyb_all, c_all, _, _ = get_diff_signature(self.func, + self.active, + mode=self.mode) + pyb_call = ", ".join(pyb_all) + c_call = ", ".join(c_all) + + pyb_temp = Template(pyb11_wrap_template_rev) + pyb_bind = pyb_temp.render(name=self.name, + modname=modname, + FUNC_SUFFIX_F=FUNC_SUFFIX_R, + pyb_call=pyb_call, + c_call=c_call) + + self.grad_all_source = self.grad_source + pyb_bind + tpnd_obj_dir = get_tpnd_obj_dir() + + extra_inc_dir = [pybind11.get_include(), tpnd_obj_dir] + extra_link_args = [os.path.join(tpnd_obj_dir, 'adBuffer.o'), + os.path.join(tpnd_obj_dir, 'adStack.o')] + mod = Cmodule(self.grad_all_source, hash_fn, + extra_inc_dir=extra_inc_dir, + extra_link_args=extra_link_args, + extra_compile_args=['-ftemplate-depth=1024', "-O2"]) + module = mod.load() + return getattr(module, modname) + + def req_recomp_tpnd(self): + tpnd_obj_dir = get_tpnd_obj_dir() + cond1 = not os.path.exists(os.path.join(tpnd_obj_dir, 'adBuffer.o')) + cond2 = not os.path.exists(os.path.join(tpnd_obj_dir, 'adStack.o')) + return cond1 or cond2 + + def _get_len_gradof_args(self, args): + len_gradof_args = [] + for i in range(len(self.args)): + if self.args[i] in self.gradof: + len_gradof_args.append(len(args[i])) + return len_gradof_args + + def _add_grad_args_rev(self, args, gradof_var, len_gradof_arg): + final_args = [] + wrt_args = [] + gradof_arg = None + is_grad_var = [] + for i in range(len(args)): + final_args.append(args[i]) + is_grad_var.append(0) + if self.args[i] in self.active: + temp = np.zeros((len_gradof_arg, len(args[i]))) + final_args.append(temp) + is_grad_var.append(1) + if self.args[i] == gradof_var: + gradof_arg = temp + if self.args[i] in self.wrt: + wrt_args.append(temp) + return final_args, wrt_args, gradof_arg, is_grad_var + + def _call_multi_rev(self, final_args, gradof_arg, is_grad_var, len_gradof_args): + for i in range(len(gradof_arg)): + gradof_arg[i][i] = 1.0 + + for i in range(len(gradof_arg)): + temp_args = [] + for j, arg in enumerate(final_args): + if not is_grad_var[j]: + temp_args.append(arg) + else: + temp_args.append(arg[i]) + self.c_func(*temp_args) + + def _get_grad_dict(self, wrt_args, len_gradof_args): + grads = {} + for varname, arg in zip(self.wrt, wrt_args): + if len_gradof_args > 1: + grads[varname] = arg + else: + grads[varname] = arg[0] + return grads + + def _update_cache(self, wrt_args, gradof_arg, gradof_var): + self.cache[gradof_var] = [wrt_args, gradof_arg] + + @profile + def __call__(self, *args): + c_args = [self._massage_arg(x) for x in args] + self.c_func(*c_args) diff --git a/compyle/c_backend.py b/compyle/c_backend.py index bdabff6..1fbe42d 100644 --- a/compyle/c_backend.py +++ b/compyle/c_backend.py @@ -307,7 +307,6 @@ def __call__(self, *args, **kwargs): } delete[] stage1_res; delete[] stage2_res; - py::print(ary); } } ''' diff --git a/compyle/cimport.py b/compyle/cimport.py index 16b21f3..47b6b91 100644 --- a/compyle/cimport.py +++ b/compyle/cimport.py @@ -10,11 +10,13 @@ import pybind11 from distutils.extension import Extension from distutils.command import build_ext -from distutils.core import setup +from setuptools import setup from distutils.errors import CompileError, LinkError +from distutils.sysconfig import customize_compiler from .ext_module import get_platform_dir, get_ext_extension, get_openmp_flags from .capture_stream import CaptureMultipleStreams # noqa: 402 +from distutils.ccompiler import new_compiler class Cmodule: @@ -26,7 +28,12 @@ def __init__(self, src, hash_fn, root=None, verbose=False, openmp=False, self.name = f'm_{self.hash}' self.verbose = verbose self.openmp = openmp - self.extra_inc_dir = extra_inc_dir + # self.extra_inc_dir = extra_inc_dir + ['/usr/include'] + self.extra_inc_dir = extra_inc_dir + [ + '/usr/include', + '/usr/include/c++/{}'.format(self._get_gcc_version()), + '/usr/lib/gcc/{}-linux-gnu/{}'.format(self._get_machine(), self._get_gcc_version()) + ] self.extra_link_args = extra_link_args self.extra_compile_args = extra_compile_args self._use_cpp11() @@ -35,6 +42,18 @@ def __init__(self, src, hash_fn, root=None, verbose=False, openmp=False, self._setup_filenames() self.lock = FileLock(self.lock_path, timeout=120) + def _get_gcc_version(self): + import subprocess + try: + gcc_version = subprocess.check_output(['gcc', '-dumpversion']).decode().strip() + return gcc_version + except: + return '11' # fallback to a common version + + def _get_machine(self): + import platform + return platform.machine() + def _setup_root(self, root): if root is None: plat_dir = get_platform_dir() @@ -65,12 +84,21 @@ def is_build_needed(self): def build(self): self._include_openmp() - ext = Extension(name=self.name, - sources=[self.src_path], - language='c++', - include_dirs=self.extra_inc_dir, - extra_link_args=self.extra_link_args, - extra_compile_args=self.extra_compile_args) + ext = Extension( + name=self.name, + sources=[self.src_path], + language='c++', # Change to c++ since we're using pybind11 + include_dirs=self.extra_inc_dir, + extra_link_args=self.extra_link_args, + extra_compile_args=self.extra_compile_args + [ + '-D_COMPLEX_H', + '-D_GNU_SOURCE', + '-fpermissive', # Allow mixing C and C++ code + '-include', 'complex.h', + '-include', 'complex', + '-x', 'c++', + ] + ) args = [ "build_ext", "--build-lib=" + self.build_dir, @@ -127,3 +155,78 @@ def _message(self, *args): msg = ' '.join(args) if self.verbose: print(msg) + + +# def wget_tpnd_headers(): +# import requests +# baseurl = 'https://gitlab.inria.fr/tapenade/tapenade/-/raw/master/ADFirstAidKit/{}?ref_type=heads&inline=false' +# files = ['adComplex.h', 'adStack.c', 'adStack.h'] +# reqs = [requests.get(baseurl.format(file)) for file in files] +# saveloc = get_tpnd_obj_dir() +# if not os.path.exists(saveloc): +# os.mkdir(saveloc) + +# for file, r in zip(files, reqs): +# with open(join(saveloc, file), 'wb') as f: +# f.write(r.content) + +def wget_tpnd_headers(): + import requests + "https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.c?inline=false" + "https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.h?inline=false" + baseurl = 'https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/{}?inline=false' + files = ['adBuffer.c', 'adBuffer.h', 'adStack.c', 'adStack.h'] + saveloc = get_tpnd_obj_dir() + os.makedirs(saveloc, exist_ok=True) + + print(f"Downloading Tapenade source code to {saveloc}") + for file in files: + r = requests.get(baseurl.format(file), timeout=30) + r.raise_for_status() + with open(join(saveloc, file), 'wb') as f: + f.write(r.content) + +def get_tpnd_obj_dir(): + plat_dir = get_platform_dir() + root = expanduser(join('~', '.compyle', 'source', plat_dir)) + tpnd_dir = join(root, 'tapenade_src') + return tpnd_dir + + +def compile_tapenade_source(verbose=0): + print("Setting up Tapenade source code...") + try: + obj_dir_tpnd = get_tpnd_obj_dir() + with CaptureMultipleStreams() as stream: + wget_tpnd_headers() + os.environ["CC"] = 'gcc' # Use gcc instead of g++ + compiler = new_compiler(verbose=1) + customize_compiler(compiler) + compiler.compile( + [join(obj_dir_tpnd, 'adStack.c'), + join(obj_dir_tpnd, 'adBuffer.c')], + output_dir=obj_dir_tpnd, + extra_preargs=[ + '-c', + '-fPIC', + ], + include_dirs=[obj_dir_tpnd, '/usr/include'], + ) + s_out = stream.get_output() + print(s_out[0]) + print(s_out[1]) + objdir = join(obj_dir_tpnd, obj_dir_tpnd[1:]) + shutil.move(join(objdir, 'adStack.o'), + join(obj_dir_tpnd, 'adStack.o')) + shutil.move(join(objdir, 'adBuffer.o'), + join(obj_dir_tpnd, 'adBuffer.o')) + except (CompileError, LinkError): + hline = "*"*80 + print(hline + "\nERROR") + s_out = stream.get_output() + print(s_out[0]) + print(s_out[1]) + msg = "Compilation of tapenade source failed, please check "\ + "error messages above." + print(hline + "\n" + msg) + sys.exit(1) diff --git a/compyle/jit.py b/compyle/jit.py index bf0ee4f..7c5a739 100644 --- a/compyle/jit.py +++ b/compyle/jit.py @@ -5,7 +5,6 @@ import ast import importlib import warnings -import json from pytools import memoize from .config import get_config from .cython_generator import CythonGenerator diff --git a/compyle/parallel.py b/compyle/parallel.py index e78ee0f..5004b81 100644 --- a/compyle/parallel.py +++ b/compyle/parallel.py @@ -6,6 +6,7 @@ """ +from compyle import c_backend from compyle import c_backend from functools import wraps from textwrap import wrap @@ -14,6 +15,7 @@ import numpy as np import pybind11 +from .cimport import Cmodule from .cimport import Cmodule from .config import get_config from .profile import profile @@ -1062,6 +1064,8 @@ def _generate(self, declarations=None): return self._generate_cython_code(declarations=declarations) elif self.backend == 'c': return self._generate_c_code(declarations=declarations) + elif self.backend == 'c': + return self._generate_c_code(declarations=declarations) def _default_cython_input_function(self): py_data = (['int i', '{type}[:] input'.format(type=self.type)], diff --git a/compyle/tests/test_autodiff_tapenade.py b/compyle/tests/test_autodiff_tapenade.py new file mode 100644 index 0000000..44021e5 --- /dev/null +++ b/compyle/tests/test_autodiff_tapenade.py @@ -0,0 +1,99 @@ +import pytest +import numpy as np +from math import exp, log + +from ..autodiff_tapenade import ForwardGrad, ReverseGrad, ElementwiseGrad +from ..types import annotate +from ..array import wrap, declare + + + +@annotate(n_1='int', doublep='x, y') +def simple_pow(x, y, n_1): + y[0] = 1 + for i in range(n_1): + y[0] *= x[0] + +@annotate(doublep='x, y', n_1='int') +def ifelse(x, y, n_1): + if n_1 < 0: + x[0] = 1 / x[0] + n_1 = -n_1 + y[0] = 1 + for i in range(n_1): + y[0] *= x[0] + + +@annotate(doublep='ip, W, loss', double='b, label', n_1='int') +def log_reg(ip, W, b, loss, label, n_1): + h = declare('double') + pred = declare('double') + h = 0 + for i in range(n_1): + h += ip[i] * W[i] + h += b + pred = 1 / (1 + exp(-h)) + los_prob = pred * label + (1 - pred) * (1 - label) + loss[0] = -log(los_prob) + + +def grad_log_reg(ip, W, label): + h = np.dot(ip, W) + pred = 1 / (1 + exp(-h)) + return ip * (pred - label) + + +def test_simple_pow(): + grad_pow = ForwardGrad(simple_pow, ['x'], ['y']) + + x = np.array([2]) + y = np.empty([1]) + [[yfd]] = grad_pow(x, y, 5) + assert yfd == 80 + + grad_pow = ReverseGrad(simple_pow, ['x'], ['y']) + + x = np.array([2]) + y = np.empty(1) + xd = np.zeros(1) + yd = np.array([1]) + + grad_pow(x, xd, y, yd, 5) + assert xd[0] == 80 + + +def test_if_else(): + grad_pow = ForwardGrad(ifelse, ['x'], ['y']) + + x = np.array([2]) + y = np.empty(1) + + [[yfd]] = grad_pow(x, y, -5) + assert yfd == (-5 * (2 ** -6)) + + grad_pow = ReverseGrad(ifelse, ['x'], ['y']) + + x = np.array([2]) + y = np.empty(1) + xd = np.zeros(1) + yd = np.array([1]) + + grad_pow(x, xd, y, yd, -5) + assert xd[0] == (-5 * (2 ** -6)) + + +def t_log_red(): + g_log_reg = ReverseGrad(log_reg, ['W'], ['loss']) + n = 5 + ip = np.linspace(0, 1, n) + W = np.random.randn(n) + b = 1 + loss = np.array([0]) + label = 1 + + loss_d = np.ones(1) + W_d = np.empty_like(W) + + g_log_reg(ip, W, W_d, b, loss, loss_d, label, n) + + assert np.allclose(W_d, grad_log_reg(ip, W, label)) \ No newline at end of file diff --git a/compyle/tests/test_parallel.py b/compyle/tests/test_parallel.py index ce1667a..7e55701 100644 --- a/compyle/tests/test_parallel.py +++ b/compyle/tests/test_parallel.py @@ -78,6 +78,65 @@ def test_repeated_scans_with_different_settings_c(self): self._test_unique_scan(backend='c') +class ParallelUtilsBaseC(object): + def test_elementwise_works_with_c(self): + self._check_simple_elementwise(backend='c') + + def test_elementwise_works_with_global_constant_c(self): + self._check_elementwise_with_constant(backend='c') + + def test_reduction_works_without_map_c(self): + self._check_simple_reduction(backend='c') + + def test_reduction_works_with_map_c(self): + self._check_reduction_with_map(backend='c') + + def test_reduction_works_with_external_func_c(self): + self._check_reduction_with_external_func(backend='c') + + def test_reduction_works_neutral_c(self): + self._check_reduction_min(backend='c') + + def test_scan_works_c(self): + self._test_scan(backend='c') + + def test_scan_works_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan(backend='c') + + def test_large_scan_works_c_parallel(self): + with use_config(use_openmp=True): + self._test_large_scan(backend='c') + + def test_scan_works_with_external_func_c(self): + self._test_scan_with_external_func(backend='c') + + def test_scan_works_with_external_func_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan_with_external_func(backend='c') + + def test_scan_last_item_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan_last_item(backend='c') + + def test_scan_last_item_c_serial(self): + self._test_scan_last_item(backend='c') + + def test_unique_scan_c(self): + self._test_unique_scan(backend='c') + + def test_unique_scan_c_parallel(self): + with use_config(use_openmp=True): + self._test_unique_scan(backend='c') + + def test_repeated_scans_with_different_settings_c(self): + with use_config(use_openmp=False): + self._test_unique_scan(backend='c') + + with use_config(use_openmp=True): + self._test_unique_scan(backend='c') + + class ParallelUtilsBase(object): def test_elementwise_works_with_cython(self): self._check_simple_elementwise(backend='cython') diff --git a/compyle/translator.py b/compyle/translator.py index 7a10a92..5841e44 100644 --- a/compyle/translator.py +++ b/compyle/translator.py @@ -13,6 +13,7 @@ import ast import re +from subprocess import call import sys from textwrap import dedent, wrap import types @@ -173,7 +174,7 @@ def _get_self_type(self): def _get_local_arg(self, arg, type): return arg, type - def _get_function_args(self, node): + def _get_function_args(self, node, convert_array_args=False): node_args = node.args.args if PY_VER == 2: args = [x.id for x in node_args] @@ -206,8 +207,10 @@ def _get_function_args(self, node): type = self._detect_type(arg, value) if 'LOCAL_MEM' in type: arg, type = self._get_local_arg(arg, type) - call_sig.append('{type} {arg}'.format(type=type, arg=arg)) - + if convert_array_args and type.endswith('*'): + call_sig.append('{type} {arg}[]'.format(type=type[:-1], arg=arg)) + else: + call_sig.append('{type} {arg}'.format(type=type, arg=arg)) return ', '.join(call_sig) def _get_variable_declaration(self, type_str, names): diff --git a/compyle/transpiler.py b/compyle/transpiler.py index 926e347..eac8fca 100644 --- a/compyle/transpiler.py +++ b/compyle/transpiler.py @@ -9,7 +9,7 @@ from .config import get_config from .ast_utils import get_unknown_names_and_calls from .cython_generator import CythonGenerator, CodeGenerationError -from .translator import OpenCLConverter, CUDAConverter, CConverter, literal_to_float +from .translator import OpenCLConverter, CUDAConverter, CConverter from .ext_module import ExtModule from .extern import Extern, get_extern_code from .utils import getsourcelines @@ -308,8 +308,11 @@ def add_code(self, code): cb = CodeBlock(code, code) self.blocks.append(cb) - def get_code(self): - code = [self.header] + [x.code for x in self.blocks] + def get_code(self, incl_header=True): + if incl_header: + code = [self.header] + [x.code for x in self.blocks] + else: + code = [x.code for x in self.blocks] return '\n'.join(code) def compile(self): diff --git a/examples/autodiff/billiards.py b/examples/autodiff/billiards.py new file mode 100644 index 0000000..b1e7aad --- /dev/null +++ b/examples/autodiff/billiards.py @@ -0,0 +1,193 @@ +import numpy as np +from compyle.api import annotate, Elementwise, get_config, declare +from compyle.autodiff_tapenade import ReverseGrad +import taichi as ti + +from time import time + +# This problem aims at solving the billiards problem using autodiff. +# Aim is to determine initial location and velocity of a cue ball to hit the +# target location with one of the decided balls on the table. + +get_config().use_openmp = True +gui = ti.GUI("Billiards", (1024, 1024), background_color=0x3C733F) + +def visualise(x, y, n_balls, goalx, goaly, pixel_radius, steps): + gui.clear() + for t in range(1, steps): + gui.circle((goalx, goaly), 0x00000, pixel_radius // 2) + for i in range(n_balls): + idxi = t * n_balls + i + if i == 0: + color = 0xCCCCCC + elif i == n_balls - 1: + color = 0x3344cc + else: + color = 0xF20530 + + gui.circle((x[idxi], y[idxi]), color, pixel_radius) + + gui.show() + +# forward kernel to simulate billiards for given initial conditions and steps +@annotate(int='n_balls, target_ball, billiard_layers, steps', + float='dt, elasticity, goalx, goaly, radius', + floatp='x, y, vx, vy, init_x, init_y, init_vx, init_vy, impulse_x, impulse_y, x_inc, y_inc, loss') +def forward(x, y, vx, vy, init_x, init_y, init_vx, init_vy, + impulse_x, impulse_y, x_inc, y_inc, n_balls, dt, + elasticity, target_ball, goalx, goaly, loss, + billiard_layers, radius, steps): + #initialize + x[0] = init_x[0] + y[0] = init_y[0] + vx[0] = init_vx[0] + vy[0] = init_vy[0] + count = declare('int') + idxi = declare('int') + idxj = declare('int') + idxip = declare('int') + idxtg = declare('int') + i, j, t = declare('int') + count = 0 + for i in range(billiard_layers): + for j in range(i + 1): + count += 1 + x[count] = i * 2 * radius + 0.5 + y[count] = j * 2 * radius + 0.5 - i * radius * 0.7 + + for t in range(1, steps): + # collide balls + for i in range(n_balls): + x_inc[i] = 0 + y_inc[i] = 0 + impulse_x[i] = 0 + impulse_y[i] = 0 + for j in range(n_balls): + if i != j: + x_inc_contrib = 0.0 + y_inc_contrib = 0.0 + impx = 0 + impy = 0 + + idxi = (t - 1) * n_balls + i + idxj = (t - 1) * n_balls + j + distx = (x[idxi] + dt * vx[idxi]) - (x[idxj] + dt * vx[idxj]) + disty = (y[idxi] + dt * vy[idxi]) - (y[idxj] + dt * vy[idxj]) + dist_norm = ((distx * distx) + (disty * disty)) ** 0.5 + rela_vx = vx[idxi] - vx[idxj] + rela_vy = vy[idxi] - vy[idxj] + + if dist_norm < 2 * radius: + dirx = distx / dist_norm + diry = disty / dist_norm + projected_v = (dirx * rela_vx) + (diry * rela_vy) + + if projected_v < 0: + impx = -(1 + elasticity) * 0.5 * projected_v * dirx + impy = -(1 + elasticity) * 0.5 * projected_v * diry + + toi = (dist_norm - 2 * radius) / min( + -1e-3, projected_v) # Time of impact + x_inc_contrib = min(toi - dt, 0.0) * impx + y_inc_contrib = min(toi - dt, 0.0) * impy + + x_inc[i] += x_inc_contrib + y_inc[i] += y_inc_contrib + impulse_x[i] += impx + impulse_y[i] += impy + + # end collide balls + + # update speed and position + for i in range(n_balls): + idxi = t * n_balls + i + idxip = (t - 1) * n_balls + i + vx[idxi] = vx[idxip] + impulse_x[i] + vy[idxi] = vy[idxip] + impulse_y[i] + x[idxi] = x[idxip] + dt * vx[idxi] + x_inc[i] + y[idxi] = y[idxip] + dt * vy[idxi] + y_inc[i] + + # compute loss + idxtg = (steps - 1) * n_balls + target_ball + loss[0] = (x[idxtg] - goalx) ** 2 + (y[idxtg] - goaly) ** 2 + + +# generate a gradient function for the forward function +grad_forward = ReverseGrad(forward, wrt=['init_x', 'init_y', 'init_vx', 'init_vy', 'y_inc', 'x', 'y', 'vx', 'x_inc', 'vy', 'impulse_x', 'impulse_y'], gradof=['loss']) + + +def optimize(): + for iter in range(200): + lossb[0] = 1.0 + + grad_forward(x, xb, y, yb, vx, vxb, vy, vyb, init_x, init_xb, init_y, init_yb, + init_vx, init_vxb, init_vy, init_vyb, impulse_x, impulse_xb, + impulse_y, impulse_yb, x_inc, x_incb, y_inc, y_incb, + n_balls, dt, elasticity, target_ball, goalx, goaly, loss, lossb, + billiard_layers, radius, steps) + init_x[0] -= learning_rate * init_xb[0] + init_y[0] -= learning_rate * init_yb[0] + init_vx[0] -= learning_rate * init_vxb[0] + init_vy[0] -= learning_rate * init_vyb[0] + if iter % 20 == 0: + print(f"iter: {iter} \t loss: {loss[0]}") +if __name__ == '__main__': + # setup parameters + dtype = np.float32 + billiard_layers = 4 + n_balls = 1 + (1 + billiard_layers) * billiard_layers // 2 + + vis_interval = 64 + output_vis_interval = 16 + steps = 1024 + max_steps = 1024 + + vis_resolution = 1024 + + loss = np.zeros(1, dtype=dtype) + + x = np.zeros(max_steps * n_balls, dtype=dtype) + y = np.zeros(max_steps * n_balls, dtype=dtype) + x_inc = np.zeros(n_balls, dtype=dtype) + y_inc = np.zeros(n_balls, dtype=dtype) + vx = np.zeros(max_steps * n_balls, dtype=dtype) + vy = np.zeros(max_steps * n_balls, dtype=dtype) + impulse_x = np.zeros(n_balls, dtype=dtype) + impulse_y = np.zeros(n_balls, dtype=dtype) + + init_x = np.array([0.1], dtype=dtype) + init_y = np.array([0.5], dtype=dtype) + init_vx = np.array([0.3], dtype=dtype) + init_vy = np.array([0.0], dtype=dtype) + + xb = np.zeros_like(x, dtype=dtype) + yb = np.zeros_like(y, dtype=dtype) + vxb = np.zeros_like(vx, dtype=dtype) + vyb = np.zeros_like(vy, dtype=dtype) + init_xb = np.zeros_like(init_x, dtype=dtype) + init_yb = np.zeros_like(init_y, dtype=dtype) + init_vxb = np.zeros_like(init_vx, dtype=dtype) + init_vyb = np.zeros_like(init_vy, dtype=dtype) + impulse_xb = np.zeros_like(impulse_x, dtype=dtype) + impulse_yb = np.zeros_like(impulse_y, dtype=dtype) + x_incb = np.zeros_like(x_inc, dtype=dtype) + y_incb = np.zeros_like(y_inc, dtype=dtype) + lossb = np.ones_like(loss, dtype=dtype) + + target_ball = n_balls - 1 + goalx = 0.9 + goaly = 0.75 + radius = 0.03 + elasticity = 0.8 + + + dt = 0.003 + learning_rate = 0.01 + + begin = time() + optimize() + end = time() + print(f"took: {end - begin} seconds to simulate {billiard_layers} layers of billiard balls") + + pixel_radius = (int(radius * 1024) + 1) + visualise(x, y, n_balls, goalx, goaly, pixel_radius, steps) \ No newline at end of file diff --git a/examples/autodiff/nn_mnist.py b/examples/autodiff/nn_mnist.py new file mode 100644 index 0000000..9325ecf --- /dev/null +++ b/examples/autodiff/nn_mnist.py @@ -0,0 +1,197 @@ +import optax +import jax.numpy as jnp +import jax +import numpy as np +from optax import softmax_cross_entropy + +try: + from keras.datasets import mnist +except ImportError: # pragma: no cover + from tensorflow.keras.datasets import mnist + +from math import exp, log + +from compyle.autodiff_tapenade import ReverseGrad +from compyle.api import annotate, Elementwise, declare, get_config +from compyle.array import empty, ones +from pytools import argmax + +np.random.seed(2) +# get_config().use_openmp = True + +BATCH_SIZE = 200 + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +TRAIN = (x_train.reshape(x_train.shape[0], -1).astype(np.float32)) / 255.0 +TRAIN_LABELS = y_train.astype(np.int32) +TEST = (x_test.reshape(x_test.shape[0], -1).astype(np.float32)) / 255.0 +TEST_LABELS = y_test.astype(np.int32) + +def initialise(n_0, n_1, n_2): + w_01 = np.random.random(n_0 * n_1).astype(np.float32) * 0.01 + w_12 = np.random.random(n_1 * n_2).astype(np.float32) * 0.01 + b_1 = np.random.random(n_1).astype(np.float32) * 0.01 + b_2 = np.random.random(n_2).astype(np.float32) * 0.01 + v_1 = np.zeros(n_1, dtype=np.float32) + v_2 = np.zeros(n_2, dtype=np.float32) + + g_w_01 = np.zeros_like(w_01, dtype=np.float32) + g_w_12 = np.zeros_like(w_12, dtype=np.float32) + + g_b_1 = np.zeros_like(b_1, dtype=np.float32) + g_b_2 = np.zeros_like(b_2, dtype=np.float32) + + return w_01, b_1, v_1, w_12, b_2, v_2, g_w_01, g_w_12, g_b_1, g_b_2 + + + +n_0 = 784 +n_1 = 128 +n_2 = 10 +alpha = 0.001 +n_train = TRAIN.shape[0] + +w_01, b_1, v_1, w_12, b_2, v_2, g_w_01, g_w_12, g_b_1, g_b_2 = initialise(n_0, n_1,n_2) +g_v_1 = np.zeros_like(v_1) +g_v_2 = np.zeros_like(v_2) +g_ip = np.zeros(n_0) +loss = np.zeros(1, dtype=np.float32) +loss_b = np.ones(1, dtype=np.float32) + +initial_params = {} +initial_params['w01'] = w_01 +initial_params['w12'] = w_12 +initial_params['b1'] = b_1 +initial_params['b2'] = b_2 + +grads = {} +grads['w01'] = g_w_01 +grads['w12'] = g_w_12 +grads['b1'] = g_b_1 +grads['b2'] = g_b_2 + +############################################################################### +############################################################################### +############################################################################### + +@annotate(i='int', v='floatp') +def reset(i, v): + v[i] = 0.0 +reset_all = Elementwise(reset, backend='c') + +@annotate(int='i, batch_size', floatp='g, gsum') +def addgrad(i, g, gsum): + gsum[i] += g[i] +addgrad_elwise = Elementwise(addgrad, backend='c') + +def reset_grads(grads): + for key in grads: + reset_all(grads[key]) + +def add_grads(grads, gradsums): + for key in grads: + addgrad_elwise(grads[key], gradsums[key]) +def avg_grads(gradsums, batch_size): + for key in gradsums: + gradsums[key] /= batch_size + +@annotate(int='n_0, n_1, n_2, n_3, expected', floatp='input, w_01, b_1, v_1, w_12, b_2, v_2, loss') +def fwd_pass_final(n_0, input, w_01, n_1, b_1, v_1, w_12, n_2, b_2, v_2, loss, expected): + i, j = declare('int') + + for i in range(n_1): + v_1[i] = b_1[i] + for j in range(n_0): + v_1[i] += w_01[i * n_0 + j] * input[j] + + for i in range(n_1): + if v_1[i] < 0: + v_1[i] = 0 + + for i in range(n_2): + v_2[i] = b_2[i] + for j in range(n_1): + v_2[i] += w_12[i * n_1 + j] * v_1[j] + + den = declare('float') + den = 0 + + for i in range(n_2): + v_2[i] = exp(v_2[i]) + den += v_2[i] + for i in range(n_2): + v_2[i] = v_2[i] / den + + loss[0] = -log(v_2[expected]) + +grad_forward = ReverseGrad(fwd_pass_final, + ['w_01', 'b_1', 'w_12', 'b_2'], ['loss']) + +def fit(optimizer, params, grads, n_train, ip_ar, op_ar): + opt_state = optimizer.init(params) + + gradsums = {} + gradsums['w01'] = np.zeros_like(w_01) + gradsums['w12'] = np.zeros_like(w_12) + gradsums['b1'] = np.zeros_like(b_1) + gradsums['b2'] = np.zeros_like(b_2) + + losssum = 0.0 + + + def step(opt_state, params, grads): + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return opt_state, params + + for _ in range(1): + reset_grads(grads) + reset_grads(gradsums) + for i in range(n_train): + loss_b[0] = 1.0 + grad_forward(n_0, ip_ar[i, :], params['w01'], grads['w01'], n_1, params['b1'], grads['b1'], v_1, g_v_1, + params['w12'], grads['w12'], n_2, params['b2'], grads['b2'], v_2, g_v_2, + loss, loss_b, op_ar[i]) + add_grads(grads, gradsums) + losssum += loss[0] + reset_grads(grads) + + if i!= 0 and i % BATCH_SIZE == 0: + print(i, losssum / BATCH_SIZE) + avg_grads(gradsums, BATCH_SIZE) + opt_state, params = step(opt_state, params, gradsums) + reset_grads(gradsums) + losssum = 0.0 + + return params + + +def net(params, x): + # l1 + x = jnp.dot(x, params['w01']) + params['b1'] + x = jax.nn.relu(x) + #l2 + x = jnp.dot(x, params['w12']) + params['b2'] + x = jax.nn.softmax(x) + return x + +def test(params, test_images, test_labels): + correct = 0 + for i, (batch, labels) in enumerate(zip(TEST, TEST_LABELS)): + y_hat = net(params, batch) + if argmax(y_hat) == labels: + correct += 1 + + return correct / len(TEST) +############################################################################### +############################################################################### +############################################################################### + +optimizer = optax.adam(learning_rate=1e-3) +params = fit(optimizer, initial_params, grads, n_train, TRAIN, TRAIN_LABELS) + +params['w01'] = params['w01'].reshape(n_1, n_0).T +params['w12'] = params['w12'].reshape(n_2, n_1).T + +op = test(params, TEST, TEST_LABELS) +print("Accuracy: ", op) \ No newline at end of file