diff --git a/src/cmd_parser.cpp b/src/cmd_parser.cpp index 044db608..6b8163cd 100644 --- a/src/cmd_parser.cpp +++ b/src/cmd_parser.cpp @@ -804,8 +804,10 @@ bool CmdParser::parseNextCommand() { d_lex.parseError("Cannot define program more than once"); } - // it should be a program with the same type - d_eparser.typeCheck(pprev, progType); + // ensure that its type is compatible with the previous definition + d_eparser.typeCheckProgramFwdDecl(pprev, progType, name); + // we use the forward declaration of the program that was globally bound + // already pvar = pprev; } else diff --git a/src/expr.cpp b/src/expr.cpp index 37a49f4e..0fc1c875 100644 --- a/src/expr.cpp +++ b/src/expr.cpp @@ -370,7 +370,7 @@ void Expr::printDebugInternal(const Expr& e, if (k == Kind::APPLY_OPAQUE) { // ambiguous functions must use "as" - Attr attr = ExprValue::d_state->getConstructorKind((*cur.first)[0]); + Attr attr = ExprValue::d_state->getAttributeKind((*cur.first)[0]); if (attr == Attr::AMB || attr == Attr::AMB_DATATYPE_CONSTRUCTOR) { os << "as "; diff --git a/src/expr_parser.cpp b/src/expr_parser.cpp index 2d5710e6..c073cf91 100644 --- a/src/expr_parser.cpp +++ b/src/expr_parser.cpp @@ -200,7 +200,7 @@ Expr ExprParser::parseExpr() args.push_back(v); size_t nscopes = 0; // if a binder, read a variable list and push a scope - Attr ck = d_state.getConstructorKind(v.getValue()); + Attr ck = d_state.getAttributeKind(v.getValue()); if (ck==Attr::BINDER || ck==Attr::LET_BINDER) { // If it is a binder, immediately read the bound variable list. @@ -1272,6 +1272,37 @@ void ExprParser::typeCheckProgramPair(Expr& pat, } } +void ExprParser::typeCheckProgramFwdDecl(Expr& prevProg, + Expr& newType, + const std::string& progName) +{ + const Expr& prevType = d_state.getTypeChecker().getType(prevProg); + if (prevType == newType) + { + // already ok, return + return; + } + // rename the old variables to the new ones + std::vector vso = Expr::getVariables(prevType); + std::vector vsn = Expr::getVariables(newType); + if (vso.size() == vsn.size()) + { + Ctx ctx; + for (size_t i = 0, nvars = vso.size(); i < nvars; i++) + { + ctx[vso[i].getValue()] = vsn[i].getValue(); + } + if (d_state.getTypeChecker().evaluate(prevType.getValue(), ctx) == newType) + { + // same after renaming, return + return; + } + } + std::stringstream ss; + ss << "Forward declaration of program " << progName << " had different type."; + d_lex.parseError(ss.str()); +} + Expr ExprParser::findFreeVar(const Expr& e, const std::vector& bvs) { std::vector efv = Expr::getVariables(e); diff --git a/src/expr_parser.h b/src/expr_parser.h index 961af872..d8646f95 100644 --- a/src/expr_parser.h +++ b/src/expr_parser.h @@ -166,6 +166,16 @@ class ExprParser * terms with free parameters that are not bound during pattern matching. */ void typeCheckProgramPair(Expr& pat, Expr& ret, bool checkPreservation); + /** + * Type check program with forward declaration. Ensure that the forward + * declaration in prevProg has a type that is compatible with newType. + * In particular, note that forward declared programs may have non-ground + * type. Since parameters are not normalized, we need to check whether + * newType is alpha equivalent to the previously declared type. + */ + void typeCheckProgramFwdDecl(Expr& prevProg, + Expr& newType, + const std::string& progName); /** get variable, else error */ Expr getVar(const std::string& name); /** get variable, else error */ diff --git a/src/literal.cpp b/src/literal.cpp index f5473140..0e053568 100644 --- a/src/literal.cpp +++ b/src/literal.cpp @@ -137,8 +137,7 @@ Literal Literal::evaluate(Kind k, const std::vector& args) { Assert (k!=Kind::EVAL_IS_EQ && k!=Kind::EVAL_IF_THEN_ELSE && k!=Kind::EVAL_REQUIRES); Kind ka = Kind::NONE; - if (k != Kind::EVAL_EXTRACT && k != Kind::EVAL_TO_BIN - && k != Kind::EVAL_POW) + if (k != Kind::EVAL_EXTRACT && k != Kind::EVAL_TO_BIN && k != Kind::EVAL_POW) { ka = allSameKind(args); if (ka==Kind::NONE) diff --git a/src/state.cpp b/src/state.cpp index 5d954ed7..ea449d0f 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -722,7 +722,7 @@ Expr State::mkExpr(Kind k, const std::vector& children) // have to check whether we have marked the constructor kind, which is // not the case i.e. if we are constructing applications corresponding to // the cases in the program definition itself. - if (getConstructorKind(hd)!=Attr::NONE) + if (getAttributeKind(hd) != Attr::NONE) { Expr hdt = Expr(hd); const Expr& t = d_tc.getType(hdt); @@ -789,7 +789,7 @@ Expr State::mkExpr(Kind k, const std::vector& children) else if (k == Kind::AS_RETURN) { // (as nil (List Int)) --> (_ nil (List Int)) - Attr ck = getConstructorKind(vchildren[0]); + Attr ck = getAttributeKind(vchildren[0]); if ((ck == Attr::AMB_DATATYPE_CONSTRUCTOR || ck == Attr::AMB) && children.size() == 2) { @@ -841,6 +841,16 @@ Expr State::mkExpr(Kind k, const std::vector& children) return Expr(mkExprInternal(k, vchildren)); } +Expr State::mkRawExpr(Kind k, const std::vector& children) +{ + std::vector vchildren; + for (const Expr& c : children) + { + vchildren.push_back(c.getValue()); + } + return Expr(mkExprInternal(k, vchildren)); +} + Expr State::mkTrue() const { return d_true; } Expr State::mkFalse() const { return d_false; } @@ -1030,7 +1040,7 @@ Expr State::mkApplyAttr(AppInfo* ai, size_t nlistTerms = 0; if (isNil) { - if (getConstructorKind(curr) != Attr::LIST) + if (getAttributeKind(curr) != Attr::LIST) { // if the last term is not marked as a list variable and // we have a null terminator, then we insert the null terminator @@ -1062,7 +1072,7 @@ Expr State::mkApplyAttr(AppInfo* ai, cc[prevIndex] = curr; cc[nextIndex] = vchildren[isLeft ? i : nchild - i]; // if the "head" child is marked as list, we construct concatenation - if (isNil && getConstructorKind(cc[nextIndex]) == Attr::LIST) + if (isNil && getAttributeKind(cc[nextIndex]) == Attr::LIST) { curr = mkExprInternal(Kind::EVAL_LIST_CONCAT, cc); } @@ -1152,8 +1162,7 @@ Expr State::mkApplyAttr(AppInfo* ai, Expr argList; // If there is only one argument, and it was marked :list, then it is // not desugared. - if (vchildren.size() == 2 - && getConstructorKind(vchildren[1]) == Attr::LIST) + if (vchildren.size() == 2 && getAttributeKind(vchildren[1]) == Attr::LIST) { argList = Expr(vchildren[1]); } @@ -1332,7 +1341,7 @@ Expr State::mkLetBinderList(const ExprValue* ev, const std::vectord_attrConsTerm; + } + return d_null; +} Expr State::getVar(const std::string& name) const { diff --git a/src/state.h b/src/state.h index 5500ff79..8b4557f1 100644 --- a/src/state.h +++ b/src/state.h @@ -135,8 +135,14 @@ class State Expr mkSelf() const; /** Make pair */ Expr mkPair(const Expr& t1, const Expr& t2); - /** */ + /** + * Makes expression with given kind and childen. This method will apply + * desugaring based on the attributes of the operator head, i.e. the first + * expression in children. + */ Expr mkExpr(Kind k, const std::vector& children); + /** Same as above, without desugaring */ + Expr mkRawExpr(Kind k, const std::vector& children); /** make true */ Expr mkTrue() const; /** make false */ @@ -180,8 +186,16 @@ class State const Expr& ret, const std::string& name); //-------------------------------------- - /** Get the constructor kind for symbol v */ - Attr getConstructorKind(const ExprValue* v) const; + /** + * Get the constructor kind for symbol v. This is one of the types listed in + * attr.h which impact how the symbol v is parsed on interpreted. + */ + Attr getAttributeKind(const ExprValue* v) const; + /** + * Get the attribute term for symbol v. Along with getAttributeKind, this + * term impacts how the symbol v is parsed on interpreted. + */ + Expr getAttributeTerm(const ExprValue* v) const; /** make binder list */ Expr mkBinderList(const ExprValue* ev, const std::vector& vs); /** */ diff --git a/src/type_checker.cpp b/src/type_checker.cpp index 1ee3271b..fb9ff8d3 100644 --- a/src/type_checker.cpp +++ b/src/type_checker.cpp @@ -57,14 +57,14 @@ void TypeChecker::setLiteralTypeRule(Kind k, const Expr& t) it->second = t; } -Expr TypeChecker::getOrSetLiteralTypeRule(Kind k, ExprValue* self) +Expr TypeChecker::getLiteralTypeRuleMaybeInit(Kind k, ExprValue* self) { std::map::iterator it = d_literalTypeRules.find(k); if (it==d_literalTypeRules.end()) { std::stringstream ss; - EO_FATAL() << "TypeChecker::getOrSetLiteralTypeRule: cannot get type rule for kind " - << k; + EO_FATAL() << "TypeChecker::getLiteralTypeRuleMaybeInit: cannot get type " + << "rule for kind " << k; } Expr tp; if (it->second.isNull()) @@ -284,7 +284,7 @@ Expr TypeChecker::getTypeInternal(ExprValue* e, std::ostream* out) case Kind::STRING: { // use the literal type rule - return getOrSetLiteralTypeRule(k, e); + return getLiteralTypeRuleMaybeInit(k, e); } case Kind::AS: case Kind::AS_RETURN: @@ -305,7 +305,7 @@ Expr TypeChecker::getTypeInternal(ExprValue* e, std::ostream* out) { Expr ctype1 = Expr(d_state.lookupType(e->d_children[0])); Expr ctype2 = Expr(d_state.lookupType(e->d_children[1])); - if (ctype1 != getOrSetLiteralTypeRule(Kind::STRING, e->d_children[0])) + if (ctype1 != getLiteralTypeRuleMaybeInit(Kind::STRING, e->d_children[0])) { if (out) { @@ -1313,10 +1313,10 @@ Expr TypeChecker::evaluateLiteralOpInternal( for (ExprValue* c : cargs) { Expr ce(c); - if (d_state.getConstructorKind(c) == Attr::AMB_DATATYPE_CONSTRUCTOR) + if (d_state.getAttributeKind(c) == Attr::AMB_DATATYPE_CONSTRUCTOR) { - Expr dt(args[0]); - ce = d_state.mkExpr(Kind::APPLY_OPAQUE, {ce, dt}); + ce = Expr( + d_state.mkExprInternal(Kind::APPLY_OPAQUE, {c, args[0]})); } cargsp.push_back(ce); } @@ -1854,15 +1854,12 @@ Expr TypeChecker::getLiteralOpType(Kind k, case Kind::EVAL_FIND: case Kind::EVAL_LIST_LENGTH: case Kind::EVAL_LIST_FIND: - return getOrSetLiteralTypeRule(Kind::NUMERAL); + return getLiteralTypeRuleMaybeInit(Kind::NUMERAL); case Kind::EVAL_RAT_DIV: - case Kind::EVAL_TO_RAT: - return getOrSetLiteralTypeRule(Kind::RATIONAL); + case Kind::EVAL_TO_RAT: return getLiteralTypeRuleMaybeInit(Kind::RATIONAL); case Kind::EVAL_NAME_OF: - case Kind::EVAL_TO_STRING: - return getOrSetLiteralTypeRule(Kind::STRING); - case Kind::EVAL_TO_BIN: - return getOrSetLiteralTypeRule(Kind::BINARY); + case Kind::EVAL_TO_STRING: return getLiteralTypeRuleMaybeInit(Kind::STRING); + case Kind::EVAL_TO_BIN: return getLiteralTypeRuleMaybeInit(Kind::BINARY); case Kind::EVAL_DT_CONSTRUCTORS: case Kind::EVAL_DT_SELECTORS: return d_state.mkListType(); default:break; diff --git a/src/type_checker.h b/src/type_checker.h index c5843700..977ea493 100644 --- a/src/type_checker.h +++ b/src/type_checker.h @@ -45,6 +45,13 @@ class TypeChecker static bool checkArity(Kind k, size_t nargs, std::ostream* out = nullptr); /** Set type rule for literal kind k to t */ void setLiteralTypeRule(Kind k, const Expr& t); + /** + * Get type rule for literal kind k. The argument self is the expression to + * instantiate eo::self with, if applicable, otherwise eo::? is used. + * If no type rule has been set yet for k, the type rule for k is initialized + * to a default, given by State::mkBuiltinType(k). + */ + Expr getLiteralTypeRuleMaybeInit(Kind k, ExprValue* self = nullptr); /** * Evaluate the expression e in the given context. */ @@ -89,12 +96,6 @@ class TypeChecker Ctx& newCtx); /** Return its type */ Expr getTypeInternal(ExprValue* e, std::ostream* out); - /** - * Get or set type rule (to default) for literal kind k. The argument - * self is the expression to instantiate eo::self with, if applicable, - * otherwise eo::? is used. - */ - Expr getOrSetLiteralTypeRule(Kind k, ExprValue* self = nullptr); /** Evaluate literal op */ Expr evaluateLiteralOpInternal(Kind k, const std::vector& args); /** Evaluate nil diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bd19142c..3f155da5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -173,6 +173,7 @@ set(ethos_test_file_list code-point-oob.eo homogeneous-list-ops.eo quant-sk-type.eo + fwd-decl-nground.eo ) macro(ethos_test file) diff --git a/tests/fwd-decl-nground.eo b/tests/fwd-decl-nground.eo new file mode 100644 index 00000000..22fe676d --- /dev/null +++ b/tests/fwd-decl-nground.eo @@ -0,0 +1,11 @@ +(declare-parameterized-const = ((T Type :implicit)) (-> T T Bool)) + +(program flip ((U Type)) :signature (U U) Bool) + +(program flip + ((T Type) (x T) (y T)) + :signature (T T) Bool + ( + ((flip x y) (= y x)) + ) +)