Skip to content
This repository was archived by the owner on Jun 3, 2018. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
language: python
python:
- "2.7"
- "3.4"
install:
- "pip install -r requirements.txt --use-mirrors"
- "pip install coveralls"
Expand Down
16 changes: 16 additions & 0 deletions py14/analysis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import ast


Expand All @@ -11,6 +12,21 @@ def is_void_function(fun):
finder.visit(fun)
return not finder.returns

if sys.version_info[0] >= 3:
def get_id(var):
if isinstance(var, ast.alias):
return var.name
elif isinstance(var, ast.Name):
return var.id
elif isinstance(var, ast.arg):
return var.arg
else:
def get_id(var):
if isinstance(var, ast.alias):
return var.name
elif isinstance(var, ast.Name):
return var.id


class ReturnFinder(ast.NodeVisitor):
returns = False
Expand Down
8 changes: 8 additions & 0 deletions py14/clike.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def visit_Name(self, node):
if node.id in self.builtin_constants:
return node.id.lower()
return node.id

def visit_NameConstant(self, node):
if node.value is True:
return "true"
elif node.value is False:
return "false"
else:
return node.value

def visit_Num(self, node):
return str(node.n)
Expand Down
6 changes: 3 additions & 3 deletions py14/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from scope import ScopeMixin
from .scope import ScopeMixin


def add_list_calls(node):
Expand Down Expand Up @@ -55,11 +55,11 @@ def visit_Import(self, node):

def visit_If(self, node):
node.vars = []
map(self.visit, node.body)
list(map(self.visit, node.body))
node.body_vars = node.vars

node.vars = []
map(self.visit, node.orelse)
list(map(self.visit, node.orelse))
node.orelse_vars = node.vars

node.vars = []
Expand Down
7 changes: 2 additions & 5 deletions py14/scope.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
from contextlib import contextmanager

from .analysis import get_id

def add_scope_context(node):
"""Provide to scope context to all nodes"""
Expand Down Expand Up @@ -42,13 +43,9 @@ class ScopeList(list):
"""
def find(self, lookup):
"""Find definition of variable lookup."""
def is_match(var):
return ((isinstance(var, ast.alias) and var.name == lookup) or
(isinstance(var, ast.Name) and var.id == lookup))

def find_definition(scope, var_attr="vars"):
for var in getattr(scope, var_attr):
if is_match(var):
if get_id(var) == lookup:
return var

for scope in self:
Expand Down
16 changes: 11 additions & 5 deletions py14/tests/test_transpiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from py14.transpiler import transpile
import sys
import pytest
from py14.transpiler import transpile


def parse(*args):
Expand All @@ -26,7 +27,10 @@ def test_empty_return():


def test_print_multiple_vars():
source = parse('print("hi", "there" )')
if sys.version_info[0] >= 3:
source = parse('print(("hi", "there" ))')
else:
source = parse('print("hi", "there" )')
cpp = transpile(source)
assert cpp == ('std::cout << std::string {"hi"} '
'<< std::string {"there"} << std::endl;')
Expand Down Expand Up @@ -228,12 +232,13 @@ def test_bubble_sort():
" return seq",
)
cpp = transpile(source)
range_f = "range" if sys.version_info[0] < 3 else "xrange"
assert cpp == parse(
"template <typename T1>",
"auto sort(T1 seq) {",
"auto L = seq.size();",
"for(auto _ : rangepp::range(L)) {",
"for(auto n : rangepp::range(1, L)) {",
"for(auto _ : rangepp::{0}(L)) {{".format(range_f),
"for(auto n : rangepp::{0}(1, L)) {{".format(range_f),
"if(seq[n] < seq[n - 1]) {",
"std::tie(seq[n - 1], seq[n]) = "
"std::make_tuple(seq[n], seq[n - 1]);",
Expand Down Expand Up @@ -287,6 +292,7 @@ def test_comb_sort():
" return seq",
)
cpp = transpile(source)
range_f = "range" if sys.version_info[0] < 3 else "xrange"
assert cpp == parse(
"template <typename T1>",
"auto sort(T1 seq) {",
Expand All @@ -295,7 +301,7 @@ def test_comb_sort():
"while (gap > 1||swap) {",
"gap = std::max(1, py14::to_int(gap / 1.25));",
"swap = false;",
"for(auto i : rangepp::range(seq.size() - gap)) {",
"for(auto i : rangepp::{0}(seq.size() - gap)) {{".format(range_f),
"if(seq[i] > seq[i + gap]) {",
"std::tie(seq[i], seq[i + gap]) = "
"std::make_tuple(seq[i + gap], seq[i]);",
Expand Down
12 changes: 8 additions & 4 deletions py14/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Trace object types that are inserted into Python list.
"""
import ast
from clike import CLikeTranspiler
from .analysis import get_id
from .clike import CLikeTranspiler


def decltype(node):
Expand All @@ -24,7 +25,7 @@ def is_list(node):
elif isinstance(node, ast.Assign):
return is_list(node.value)
elif isinstance(node, ast.Name):
var = node.scopes.find(node.id)
var = node.scopes.find(get_id(node))
return (hasattr(var, "assigned_from") and not
isinstance(var.assigned_from, ast.FunctionDef) and
is_list(var.assigned_from.value))
Expand Down Expand Up @@ -59,13 +60,13 @@ def visit_Str(self, node):
return node.s

def visit_Name(self, node):
var = node.scopes.find(node.id)
var = node.scopes.find(get_id(node))
if isinstance(var.assigned_from, ast.For):
it = var.assigned_from.iter
return "std::declval<typename decltype({0})::value_type>()".format(
self.visit(it))
elif isinstance(var.assigned_from, ast.FunctionDef):
return var.id
return get_id(var)
else:
return self.visit(var.assigned_from.value)

Expand Down Expand Up @@ -99,6 +100,9 @@ def visit_Name(self, node):
else:
return self.visit(var.assigned_from.value)

def visit_NameConstant(self, node):
return CLikeTranspiler().visit(node)

def visit_Call(self, node):
params = ",".join([self.visit(arg) for arg in node.args])
return "{0}({1})".format(node.func.id, params)
Expand Down
61 changes: 47 additions & 14 deletions py14/transpiler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
import ast
from clike import CLikeTranspiler
from scope import add_scope_context
from analysis import add_imports, is_void_function
from context import add_variable_context, add_list_calls
from tracer import decltype, is_list, is_builtin_import, defined_before
from .clike import CLikeTranspiler
from .scope import add_scope_context
from .context import add_variable_context, add_list_calls
from .analysis import add_imports, is_void_function, get_id
from .tracer import decltype, is_list, is_builtin_import, defined_before


def transpile(source, headers=False, testing=False):
Expand Down Expand Up @@ -43,7 +44,7 @@ def generate_catch_test_case(node, body):
def generate_template_fun(node, body):
params = []
for idx, arg in enumerate(node.args.args):
params.append(("T" + str(idx + 1), arg.id))
params.append(("T" + str(idx + 1), get_id(arg)))
typenames = ["typename " + arg[0] for arg in params]

template = "inline "
Expand Down Expand Up @@ -90,9 +91,10 @@ def visit_FunctionDef(self, node):

def visit_Attribute(self, node):
attr = node.attr
if is_builtin_import(node.value.id):
return "py14::" + node.value.id + "::" + attr
elif node.value.id == "math":
value_id = get_id(node.value)
if is_builtin_import(value_id):
return "py14::" + value_id + "::" + attr
elif value_id == "math":
if node.attr == "asin":
return "std::asin"
elif node.attr == "atan":
Expand All @@ -103,7 +105,7 @@ def visit_Attribute(self, node):
if is_list(node.value):
if node.attr == "append":
attr = "push_back"
return node.value.id + "." + attr
return value_id + "." + attr

def visit_Call(self, node):
fname = self.visit(node.func)
Expand All @@ -120,11 +122,24 @@ def visit_Call(self, node):
elif fname == "max":
return "std::max({0})".format(args)
elif fname == "range":
return "rangepp::range({0})".format(args)
if sys.version_info[0] >= 3:
return "rangepp::xrange({0})".format(args)
else:
return "rangepp::range({0})".format(args)
elif fname == "xrange":
return "rangepp::xrange({0})".format(args)
elif fname == "len":
return "{0}.size()".format(self.visit(node.args[0]))
elif fname == "print":
buf = []
for n in node.args:
value = self.visit(n)
if isinstance(n, ast.List) or isinstance(n, ast.Tuple):
buf.append("std::cout << {0} << std::endl;".format(
" << ".join([self.visit(el) for el in n.elts])))
else:
buf.append('std::cout << {0} << std::endl;'.format(value))
return '\n'.join(buf)

return '{0}({1})'.format(fname, args)

Expand Down Expand Up @@ -157,9 +172,17 @@ def visit_Name(self, node):
else:
return super(CppTranspiler, self).visit_Name(node)

def visit_NameConstant(self, node):
if node.value is True:
return "true"
elif node.value is False:
return "false"
else:
return super(CppTranspiler, self).visit_NameConstant(node)

def visit_If(self, node):
body_vars = set([v.id for v in node.scopes[-1].body_vars])
orelse_vars = set([v.id for v in node.scopes[-1].orelse_vars])
body_vars = set([get_id(v) for v in node.scopes[-1].body_vars])
orelse_vars = set([get_id(v) for v in node.scopes[-1].orelse_vars])
node.common_vars = body_vars.intersection(orelse_vars)

var_definitions = []
Expand All @@ -179,6 +202,16 @@ def visit_If(self, node):
return ("".join(var_definitions) +
super(CppTranspiler, self).visit_If(node))

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.USub):
if isinstance(node.operand, (ast.Call, ast.Num)):
# Shortcut if parenthesis are not needed
return "-{0}".format(self.visit(node.operand))
else:
return "-({0})".format(self.visit(node.operand))
else:
return super(CppTranspiler, self).visit_UnaryOp(node)

def visit_BinOp(self, node):
if (isinstance(node.left, ast.List)
and isinstance(node.op, ast.Mult)
Expand All @@ -197,7 +230,7 @@ def visit_alias(self, node):

def visit_Import(self, node):
imports = [self.visit(n) for n in node.names]
return "\n".join(filter(None, imports))
return "\n".join(i for i in imports if i)

def visit_List(self, node):
if len(node.elts) > 0:
Expand Down
6 changes: 3 additions & 3 deletions regtests/test_range.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
def test_range():
simple = range(0, 10)
simple = list(range(0, 10))
results = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert simple == results


def test_range_with_steps():
with_steps = range(0, 10, 2)
with_steps = list(range(0, 10, 2))
results = [0, 2, 4, 6, 8]
assert with_steps == results


def test_range_with_negative_steps():
with_steps = range(10, 0, -2)
with_steps = list(range(10, 0, -2))
results = [10, 8, 6, 4, 2]
assert with_steps == results