Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
626 changes: 626 additions & 0 deletions compyle/autodiff_tapenade.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion compyle/c_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def __call__(self, *args, **kwargs):
}
delete[] stage1_res;
delete[] stage2_res;
py::print(ary);
}
}
'''
Expand Down
119 changes: 111 additions & 8 deletions compyle/cimport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import pybind11
from distutils.extension import Extension
from distutils.command import build_ext
from distutils.core import setup
from setuptools import setup
from distutils.errors import CompileError, LinkError
from distutils.sysconfig import customize_compiler

from .ext_module import get_platform_dir, get_ext_extension, get_openmp_flags
from .capture_stream import CaptureMultipleStreams # noqa: 402
from distutils.ccompiler import new_compiler


class Cmodule:
Expand All @@ -26,7 +28,12 @@ def __init__(self, src, hash_fn, root=None, verbose=False, openmp=False,
self.name = f'm_{self.hash}'
self.verbose = verbose
self.openmp = openmp
self.extra_inc_dir = extra_inc_dir
# self.extra_inc_dir = extra_inc_dir + ['/usr/include']
self.extra_inc_dir = extra_inc_dir + [
'/usr/include',
'/usr/include/c++/{}'.format(self._get_gcc_version()),
'/usr/lib/gcc/{}-linux-gnu/{}'.format(self._get_machine(), self._get_gcc_version())
]
self.extra_link_args = extra_link_args
self.extra_compile_args = extra_compile_args
self._use_cpp11()
Expand All @@ -35,6 +42,18 @@ def __init__(self, src, hash_fn, root=None, verbose=False, openmp=False,
self._setup_filenames()
self.lock = FileLock(self.lock_path, timeout=120)

def _get_gcc_version(self):
import subprocess
try:
gcc_version = subprocess.check_output(['gcc', '-dumpversion']).decode().strip()
return gcc_version
except:
return '11' # fallback to a common version

def _get_machine(self):
import platform
return platform.machine()

def _setup_root(self, root):
if root is None:
plat_dir = get_platform_dir()
Expand Down Expand Up @@ -65,12 +84,21 @@ def is_build_needed(self):

def build(self):
self._include_openmp()
ext = Extension(name=self.name,
sources=[self.src_path],
language='c++',
include_dirs=self.extra_inc_dir,
extra_link_args=self.extra_link_args,
extra_compile_args=self.extra_compile_args)
ext = Extension(
name=self.name,
sources=[self.src_path],
language='c++', # Change to c++ since we're using pybind11
include_dirs=self.extra_inc_dir,
extra_link_args=self.extra_link_args,
extra_compile_args=self.extra_compile_args + [
'-D_COMPLEX_H',
'-D_GNU_SOURCE',
'-fpermissive', # Allow mixing C and C++ code
'-include', 'complex.h',
'-include', 'complex',
'-x', 'c++',
]
)
args = [
"build_ext",
"--build-lib=" + self.build_dir,
Expand Down Expand Up @@ -127,3 +155,78 @@ def _message(self, *args):
msg = ' '.join(args)
if self.verbose:
print(msg)


# def wget_tpnd_headers():
# import requests
# baseurl = 'https://gitlab.inria.fr/tapenade/tapenade/-/raw/master/ADFirstAidKit/{}?ref_type=heads&inline=false'
# files = ['adComplex.h', 'adStack.c', 'adStack.h']
# reqs = [requests.get(baseurl.format(file)) for file in files]
# saveloc = get_tpnd_obj_dir()
# if not os.path.exists(saveloc):
# os.mkdir(saveloc)

# for file, r in zip(files, reqs):
# with open(join(saveloc, file), 'wb') as f:
# f.write(r.content)

def wget_tpnd_headers():
import requests
"https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.c?inline=false"
"https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/adBuffer.h?inline=false"
baseurl = 'https://gitlab.inria.fr/tapenade/tapenade/-/raw/3.16-v2/ADFirstAidKit/{}?inline=false'
files = ['adBuffer.c', 'adBuffer.h', 'adStack.c', 'adStack.h']
saveloc = get_tpnd_obj_dir()
os.makedirs(saveloc, exist_ok=True)

print(f"Downloading Tapenade source code to {saveloc}")
for file in files:
r = requests.get(baseurl.format(file), timeout=30)
r.raise_for_status()
with open(join(saveloc, file), 'wb') as f:
f.write(r.content)

def get_tpnd_obj_dir():
plat_dir = get_platform_dir()
root = expanduser(join('~', '.compyle', 'source', plat_dir))
tpnd_dir = join(root, 'tapenade_src')
return tpnd_dir


def compile_tapenade_source(verbose=0):
print("Setting up Tapenade source code...")
try:
obj_dir_tpnd = get_tpnd_obj_dir()
with CaptureMultipleStreams() as stream:
wget_tpnd_headers()
os.environ["CC"] = 'gcc' # Use gcc instead of g++
compiler = new_compiler(verbose=1)
customize_compiler(compiler)
compiler.compile(
[join(obj_dir_tpnd, 'adStack.c'),
join(obj_dir_tpnd, 'adBuffer.c')],
output_dir=obj_dir_tpnd,
extra_preargs=[
'-c',
'-fPIC',
],
include_dirs=[obj_dir_tpnd, '/usr/include'],
)
s_out = stream.get_output()
print(s_out[0])
print(s_out[1])
objdir = join(obj_dir_tpnd, obj_dir_tpnd[1:])
shutil.move(join(objdir, 'adStack.o'),
join(obj_dir_tpnd, 'adStack.o'))
shutil.move(join(objdir, 'adBuffer.o'),
join(obj_dir_tpnd, 'adBuffer.o'))
except (CompileError, LinkError):
hline = "*"*80
print(hline + "\nERROR")
s_out = stream.get_output()
print(s_out[0])
print(s_out[1])
msg = "Compilation of tapenade source failed, please check "\
"error messages above."
print(hline + "\n" + msg)
sys.exit(1)
1 change: 0 additions & 1 deletion compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ast
import importlib
import warnings
import json
from pytools import memoize
from .config import get_config
from .cython_generator import CythonGenerator
Expand Down
4 changes: 4 additions & 0 deletions compyle/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"""

from compyle import c_backend
from compyle import c_backend
from functools import wraps
from textwrap import wrap
Expand All @@ -14,6 +15,7 @@
import numpy as np
import pybind11

from .cimport import Cmodule
from .cimport import Cmodule
from .config import get_config
from .profile import profile
Expand Down Expand Up @@ -1062,6 +1064,8 @@ def _generate(self, declarations=None):
return self._generate_cython_code(declarations=declarations)
elif self.backend == 'c':
return self._generate_c_code(declarations=declarations)
elif self.backend == 'c':
return self._generate_c_code(declarations=declarations)

def _default_cython_input_function(self):
py_data = (['int i', '{type}[:] input'.format(type=self.type)],
Expand Down
99 changes: 99 additions & 0 deletions compyle/tests/test_autodiff_tapenade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import numpy as np
from math import exp, log

from ..autodiff_tapenade import ForwardGrad, ReverseGrad, ElementwiseGrad
from ..types import annotate
from ..array import wrap, declare



@annotate(n_1='int', doublep='x, y')
def simple_pow(x, y, n_1):
y[0] = 1
for i in range(n_1):
y[0] *= x[0]

@annotate(doublep='x, y', n_1='int')
def ifelse(x, y, n_1):
if n_1 < 0:
x[0] = 1 / x[0]
n_1 = -n_1
y[0] = 1
for i in range(n_1):
y[0] *= x[0]


@annotate(doublep='ip, W, loss', double='b, label', n_1='int')
def log_reg(ip, W, b, loss, label, n_1):
h = declare('double')
pred = declare('double')
h = 0
for i in range(n_1):
h += ip[i] * W[i]
h += b
pred = 1 / (1 + exp(-h))
los_prob = pred * label + (1 - pred) * (1 - label)
loss[0] = -log(los_prob)


def grad_log_reg(ip, W, label):
h = np.dot(ip, W)
pred = 1 / (1 + exp(-h))
return ip * (pred - label)


def test_simple_pow():
grad_pow = ForwardGrad(simple_pow, ['x'], ['y'])

x = np.array([2])
y = np.empty([1])
[[yfd]] = grad_pow(x, y, 5)
assert yfd == 80

grad_pow = ReverseGrad(simple_pow, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)
xd = np.zeros(1)
yd = np.array([1])

grad_pow(x, xd, y, yd, 5)
assert xd[0] == 80


def test_if_else():
grad_pow = ForwardGrad(ifelse, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)

[[yfd]] = grad_pow(x, y, -5)
assert yfd == (-5 * (2 ** -6))

grad_pow = ReverseGrad(ifelse, ['x'], ['y'])

x = np.array([2])
y = np.empty(1)
xd = np.zeros(1)
yd = np.array([1])

grad_pow(x, xd, y, yd, -5)
assert xd[0] == (-5 * (2 ** -6))


def t_log_red():
g_log_reg = ReverseGrad(log_reg, ['W'], ['loss'])
n = 5
ip = np.linspace(0, 1, n)
W = np.random.randn(n)
b = 1
loss = np.array([0])
label = 1

loss_d = np.ones(1)
W_d = np.empty_like(W)

g_log_reg(ip, W, W_d, b, loss, loss_d, label, n)

assert np.allclose(W_d, grad_log_reg(ip, W, label))
59 changes: 59 additions & 0 deletions compyle/tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,65 @@ def test_repeated_scans_with_different_settings_c(self):
self._test_unique_scan(backend='c')


class ParallelUtilsBaseC(object):
def test_elementwise_works_with_c(self):
self._check_simple_elementwise(backend='c')

def test_elementwise_works_with_global_constant_c(self):
self._check_elementwise_with_constant(backend='c')

def test_reduction_works_without_map_c(self):
self._check_simple_reduction(backend='c')

def test_reduction_works_with_map_c(self):
self._check_reduction_with_map(backend='c')

def test_reduction_works_with_external_func_c(self):
self._check_reduction_with_external_func(backend='c')

def test_reduction_works_neutral_c(self):
self._check_reduction_min(backend='c')

def test_scan_works_c(self):
self._test_scan(backend='c')

def test_scan_works_c_parallel(self):
with use_config(use_openmp=True):
self._test_scan(backend='c')

def test_large_scan_works_c_parallel(self):
with use_config(use_openmp=True):
self._test_large_scan(backend='c')

def test_scan_works_with_external_func_c(self):
self._test_scan_with_external_func(backend='c')

def test_scan_works_with_external_func_c_parallel(self):
with use_config(use_openmp=True):
self._test_scan_with_external_func(backend='c')

def test_scan_last_item_c_parallel(self):
with use_config(use_openmp=True):
self._test_scan_last_item(backend='c')

def test_scan_last_item_c_serial(self):
self._test_scan_last_item(backend='c')

def test_unique_scan_c(self):
self._test_unique_scan(backend='c')

def test_unique_scan_c_parallel(self):
with use_config(use_openmp=True):
self._test_unique_scan(backend='c')

def test_repeated_scans_with_different_settings_c(self):
with use_config(use_openmp=False):
self._test_unique_scan(backend='c')

with use_config(use_openmp=True):
self._test_unique_scan(backend='c')


class ParallelUtilsBase(object):
def test_elementwise_works_with_cython(self):
self._check_simple_elementwise(backend='cython')
Expand Down
9 changes: 6 additions & 3 deletions compyle/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import ast
import re
from subprocess import call
import sys
from textwrap import dedent, wrap
import types
Expand Down Expand Up @@ -173,7 +174,7 @@ def _get_self_type(self):
def _get_local_arg(self, arg, type):
return arg, type

def _get_function_args(self, node):
def _get_function_args(self, node, convert_array_args=False):
node_args = node.args.args
if PY_VER == 2:
args = [x.id for x in node_args]
Expand Down Expand Up @@ -206,8 +207,10 @@ def _get_function_args(self, node):
type = self._detect_type(arg, value)
if 'LOCAL_MEM' in type:
arg, type = self._get_local_arg(arg, type)
call_sig.append('{type} {arg}'.format(type=type, arg=arg))

if convert_array_args and type.endswith('*'):
call_sig.append('{type} {arg}[]'.format(type=type[:-1], arg=arg))
else:
call_sig.append('{type} {arg}'.format(type=type, arg=arg))
return ', '.join(call_sig)

def _get_variable_declaration(self, type_str, names):
Expand Down
Loading