Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
590 changes: 590 additions & 0 deletions compyle/autodiff_tapenade.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion compyle/c_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def __call__(self, *args, **kwargs):
}
delete[] stage1_res;
delete[] stage2_res;
py::print(ary);
}
}
'''
Expand Down
57 changes: 57 additions & 0 deletions compyle/cimport.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from distutils.command import build_ext
from distutils.core import setup
from distutils.errors import CompileError, LinkError
from distutils.sysconfig import customize_compiler

from .ext_module import get_platform_dir, get_ext_extension, get_openmp_flags
from .capture_stream import CaptureMultipleStreams # noqa: 402
from distutils.ccompiler import new_compiler


class Cmodule:
Expand Down Expand Up @@ -127,3 +129,58 @@ def _message(self, *args):
msg = ' '.join(args)
if self.verbose:
print(msg)


def wget_tpnd_headers():
import requests
"https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.c?inline=false"
"https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.h?inline=false"
baseurl = 'https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/{}?inline=false'
files = ['adBuffer.c', 'adBuffer.h', 'adStack.c', 'adStack.h']
reqs = [requests.get(baseurl.format(file)) for file in files]
saveloc = get_tpnd_obj_dir()
if not os.path.exists(saveloc):
os.mkdir(saveloc)

for file, r in zip(files, reqs):
with open(join(saveloc, file), 'wb') as f:
f.write(r.content)


def get_tpnd_obj_dir():
plat_dir = get_platform_dir()
root = expanduser(join('~', '.compyle', 'source', plat_dir))
tpnd_dir = join(root, 'tapenade_src')
return tpnd_dir


def compile_tapenade_source(verbose=0):
print("Setting up Tapenade source code...")
try:
obj_dir_tpnd = get_tpnd_obj_dir()
with CaptureMultipleStreams() as stream:
wget_tpnd_headers()
os.environ["CC"] = 'g++'
compiler = new_compiler(verbose=1)
customize_compiler(compiler)
compiler.compile([join(obj_dir_tpnd, 'adBuffer.c')],
output_dir=obj_dir_tpnd,
extra_preargs=['-c', '-fPIC'])
compiler.compile([join(obj_dir_tpnd, 'adStack.c')],
output_dir=obj_dir_tpnd,
extra_preargs=['-c', '-fPIC'])
objdir = join(obj_dir_tpnd, obj_dir_tpnd[1:])
shutil.move(join(objdir, 'adBuffer.o'),
join(obj_dir_tpnd, 'adBuffer.o'))
shutil.move(join(objdir, 'adStack.o'),
join(obj_dir_tpnd, 'adStack.o'))
except (CompileError, LinkError):
hline = "*"*80
print(hline + "\nERROR")
s_out = stream.get_output()
print(s_out[0])
print(s_out[1])
msg = "Compilation of tapenade source failed, please check "\
"error messages above."
print(hline + "\n" + msg)
sys.exit(1)
1 change: 0 additions & 1 deletion compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ast
import importlib
import warnings
import json
from pytools import memoize
from .config import get_config
from .cython_generator import CythonGenerator
Expand Down
99 changes: 99 additions & 0 deletions compyle/tests/test_autodiff_tapenade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import numpy as np
from math import exp, log

from ..autodiff_tapenade import ForwardGrad, ReverseGrad, ElementwiseGrad
from ..types import annotate
from ..array import wrap, declare



@annotate(n_1='int', doublep='x, y')
def simple_pow(x, y, n_1):
y[0] = 1
for i in range(n_1):
y[0] *= x[0]

@annotate(doublep='x, y', n_1='int')
def ifelse(x, y, n_1):
if n_1 < 0:
x[0] = 1 / x[0]
n_1 = -n_1
y[0] = 1
for i in range(n_1):
y[0] *= x[0]


@annotate(doublep='ip, W, loss', double='b, label', n_1='int')
def log_reg(ip, W, b, loss, label, n_1):
h = declare('double')
pred = declare('double')
h = 0
for i in range(n_1):
h += ip[i] * W[i]
h += b
pred = 1 / (1 + exp(-h))
los_prob = pred * label + (1 - pred) * (1 - label)
loss[0] = -log(los_prob)


def grad_log_reg(ip, W, label):
h = np.dot(ip, W)
pred = 1 / (1 + exp(-h))
return ip * (pred - label)


def test_simple_pow():
grad_pow = ForwardGrad(simple_pow, ['x'], ['y'])

x = np.array([2])
y = np.empty([1])
[[yfd]] = grad_pow(x, y, 5)
assert yfd == 80

grad_pow = ReverseGrad(simple_pow, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)
xd = np.zeros(1)
yd = np.array([1])

grad_pow(x, xd, y, yd, 5)
assert xd[0] == 80


def test_if_else():
grad_pow = ForwardGrad(ifelse, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)

[[yfd]] = grad_pow(x, y, -5)
assert yfd == (-5 * (2 ** -6))

grad_pow = ReverseGrad(ifelse, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)
xd = np.zeros(1)
yd = np.array([1])

grad_pow(x, xd, y, yd, -5)
assert xd[0] == (-5 * (2 ** -6))


def t_log_red():
g_log_reg = ReverseGrad(log_reg, ['W'], ['loss'])
n = 5
ip = np.linspace(0, 1, n)
W = np.random.randn(n)
b = 1
loss = np.array([0])
label = 1

loss_d = np.ones(1)
W_d = np.empty_like(W)

g_log_reg(ip, W, W_d, b, loss, loss_d, label, n)

assert np.allclose(W_d, grad_log_reg(ip, W, label))
9 changes: 6 additions & 3 deletions compyle/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import ast
import re
from subprocess import call
import sys
from textwrap import dedent, wrap
import types
Expand Down Expand Up @@ -163,7 +164,7 @@ def _get_self_type(self):
def _get_local_arg(self, arg, type):
return arg, type

def _get_function_args(self, node):
def _get_function_args(self, node, convert_array_args=False):
node_args = node.args.args
if PY_VER == 2:
args = [x.id for x in node_args]
Expand Down Expand Up @@ -196,8 +197,10 @@ def _get_function_args(self, node):
type = self._detect_type(arg, value)
if 'LOCAL_MEM' in type:
arg, type = self._get_local_arg(arg, type)
call_sig.append('{type} {arg}'.format(type=type, arg=arg))

if convert_array_args and type.endswith('*'):
call_sig.append('{type} {arg}[]'.format(type=type[:-1], arg=arg))
else:
call_sig.append('{type} {arg}'.format(type=type, arg=arg))
return ', '.join(call_sig)

def _get_variable_declaration(self, type_str, names):
Expand Down
7 changes: 5 additions & 2 deletions compyle/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,11 @@ def add_code(self, code):
cb = CodeBlock(code, code)
self.blocks.append(cb)

def get_code(self):
code = [self.header] + [x.code for x in self.blocks]
def get_code(self, incl_header=True):
if incl_header:
code = [self.header] + [x.code for x in self.blocks]
else:
code = [x.code for x in self.blocks]
return '\n'.join(code)

def compile(self):
Expand Down
Loading