diff --git a/src/nagini_translation/lib/constants.py b/src/nagini_translation/lib/constants.py index 731a7203..4021f44a 100644 --- a/src/nagini_translation/lib/constants.py +++ b/src/nagini_translation/lib/constants.py @@ -311,6 +311,8 @@ BOOL_TYPE = 'bool' +TYPE_TYPE = 'type' + PRIMITIVE_PREFIX = '__prim__' PRIMITIVE_INT_TYPE = PRIMITIVE_PREFIX + INT_TYPE @@ -325,13 +327,15 @@ PRIMITIVE_MSET_TYPE = PRIMITIVE_PREFIX + 'Multiset' +PRIMITIVE_TYPE_TYPE = 'PyType' + OBJECT_TYPE = 'object' CALLABLE_TYPE = 'Callable' PRIMITIVES = {PRIMITIVE_INT_TYPE, PRIMITIVE_BOOL_TYPE, PRIMITIVE_SEQ_TYPE, - PRIMITIVE_SET_TYPE, CALLABLE_TYPE, PRIMITIVE_MSET_TYPE} + PRIMITIVE_SET_TYPE, CALLABLE_TYPE, PRIMITIVE_MSET_TYPE, PRIMITIVE_TYPE_TYPE} BOXED_PRIMITIVES = {INT_TYPE, BOOL_TYPE} diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index 78aa9c1e..26204cba 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2019 ETH Zurich +Copyright (c) 2025 ETH Zurich This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. @@ -24,6 +24,7 @@ PRIMITIVE_PREFIX, PRIMITIVE_SEQ_TYPE, PRIMITIVE_SET_TYPE, + PRIMITIVE_TYPE_TYPE, PRIMITIVES, PSEQ_TYPE, PSET_TYPE, @@ -707,10 +708,12 @@ def try_box(self) -> 'PythonClass': boxed_name = self.name[len(PRIMITIVE_PREFIX):] if boxed_name == 'Set': boxed_name = PSET_TYPE - if boxed_name == 'Multiset': + elif boxed_name == 'Multiset': boxed_name = PMSET_TYPE - if boxed_name == 'Seq': + elif boxed_name == 'Seq': boxed_name = PSEQ_TYPE + elif self.name == PRIMITIVE_TYPE_TYPE: + boxed_name = 'type' return self.module.classes[boxed_name] return self diff --git a/src/nagini_translation/lib/resolver.py b/src/nagini_translation/lib/resolver.py index 59053af9..56a5e811 100644 --- a/src/nagini_translation/lib/resolver.py +++ b/src/nagini_translation/lib/resolver.py @@ -26,6 +26,7 @@ SET_TYPE, STRING_TYPE, TUPLE_TYPE, + TYPE_TYPE, ) from nagini_translation.lib.program_nodes import ( ContainerInterface, @@ -76,15 +77,19 @@ def get_target(node: ast.AST, if (container and func_name == 'Result' and isinstance(container, PythonMethod)): # In this case the immediate container must be a method, and we - # return its result type - return container.type + return None elif (container and func_name == 'super' and isinstance(container, PythonMethod)): # Return the type of the current method's superclass return container.cls.superclass elif func_name == 'cast': return None - return get_target(node.func, containers, container) + func_target = get_target(node.func, containers, container) + if isinstance(func_target, PythonType): + # this is a constructor call, it's not pointing at an actual PythonNode + return None + else: + return func_target elif isinstance(node, ast.Attribute): # Find the type of the LHS, so that we can look through its members. lhs = get_type(node.value, containers, container) @@ -99,6 +104,8 @@ def get_target(node: ast.AST, # defined in the class. So instead of type[C], we want to look in # class C directly here. lhs = lhs.type_args[0] + if isinstance(lhs, PythonClass) and lhs.name == TYPE_TYPE: + lhs = get_target(node.value, containers, container) if isinstance(lhs, GenericType): # Use the class, since we want to look for members. lhs = lhs.cls @@ -258,7 +265,7 @@ def _do_get_type(node: ast.AST, containers: List[ContainerInterface], else: error = 'generic.constructor.without.type' raise InvalidProgramException(node, error) - return target + return module.global_module.classes[TYPE_TYPE] if isinstance(node, (ast.Attribute, ast.Name)): if isinstance(node, ast.Attribute): lhs = _do_get_type(node.value, containers, container) @@ -422,6 +429,45 @@ def _get_call_type(node: ast.Call, module: PythonModule, current_function: PythonMethod, containers: List[ContainerInterface], container: PythonNode) -> PythonType: + call_target = get_target(node.func, containers, container) + if isinstance(call_target, PythonMethod): + if isinstance(node.func, ast.Attribute): + rec_target = get_target(node.func.value, containers, container) + if not isinstance(rec_target, PythonModule): + rectype = get_type(node.func.value, containers, container) + if call_target.generic_type != -1: + if call_target.generic_type == -2: + return rectype + return rectype.type_args[call_target.generic_type] + if isinstance(call_target.type, TypeVar): + while rectype.python_class is not call_target.cls: + rectype = rectype.superclass + name_list = list(rectype.python_class.type_vars.keys()) + index = name_list.index(call_target.type.name) + return rectype.type_args[index] + if (isinstance(call_target, PythonClass) and + call_target.type_vars): + # This is a call to a constructor of a generic class; it's not + # enough to just return the class, we need the entire type with + # type arguments. We only support that if we can get it directly + # from mypy, i.e., when the result is assigned to a variable + # and we can get the variable type. + if hasattr(node, '_parent') and node._parent and isinstance(node._parent, (ast.Assign, ast.AnnAssign)): + trgt = node._parent.targets[0] if isinstance(node._parent, ast.Assign) else node._parent.target + ann_type = get_type(trgt, containers, container) + if isinstance(ann_type, GenericType) and ann_type.python_class == call_target: + return ann_type + if (call_target.name in (PSEQ_TYPE, PSET_TYPE, PMSET_TYPE) and + isinstance(node, ast.Call) and node.args): + arg_types = [get_type(arg, containers, container) + for arg in node.args] + return GenericType(call_target, [common_supertype(arg_types)]) + else: + error = 'generic.constructor.without.type' + raise InvalidProgramException(node, error) + if isinstance(call_target, PythonType): + # constructor call + return call_target func_name = get_func_name(node) if func_name == 'super': if len(node.args) == 2: @@ -511,7 +557,7 @@ def _get_call_type(node: ast.Call, module: PythonModule, else: raise UnsupportedException(node) if node.func.id in module.classes: - return module.global_module.classes[node.func.id] + return module.classes[node.func.id] elif module.get_func_or_method(node.func.id) is not None: target = module.get_func_or_method(node.func.id) return target.type diff --git a/src/nagini_translation/resources/bool.sil b/src/nagini_translation/resources/bool.sil index aae5c7f1..c85fbae5 100644 --- a/src/nagini_translation/resources/bool.sil +++ b/src/nagini_translation/resources/bool.sil @@ -391,3 +391,19 @@ domain __SumHelper[T$] { forall __ss1: Seq[Int], __ss2: Seq[Int] :: { __sum(__ss1), __sum(__ss2) } __toMS(__ss1) == __toMS(__ss2) ==> __sum(__ss1) == __sum(__ss2) } } + +function PyType___box__(prim: PyType): Ref + decreases _ + ensures typeof(result) == type() + ensures type___unbox__(result) == prim + +function type___unbox__(box: Ref): PyType + decreases _ + requires issubtype(typeof(box), type()) + ensures PyType___box__(result) == box + +function type___eq__(self: Ref, other: Ref): Bool + decreases _ + requires issubtype(typeof(self), type()) + requires issubtype(typeof(other), type()) + ensures result == (type___unbox__(self) == type___unbox__(other)) diff --git a/src/nagini_translation/resources/builtins.json b/src/nagini_translation/resources/builtins.json index ec17a306..f59a5b5c 100644 --- a/src/nagini_translation/resources/builtins.json +++ b/src/nagini_translation/resources/builtins.json @@ -14,7 +14,7 @@ "type": "str" }, "__cast__": { - "args": ["type", "object"], + "args": ["PyType", "object"], "type": "object" } } @@ -559,27 +559,27 @@ "type": "tuple" }, "__create1__": { - "args": ["object", "type", "__prim__int"], + "args": ["object", "PyType", "__prim__int"], "type": "tuple" }, "__create2__": { - "args": ["object", "object", "type", "type", "__prim__int"], + "args": ["object", "object", "PyType", "PyType", "__prim__int"], "type": "tuple" }, "__create3__": { - "args": ["object", "object", "object", "type", "type", "type", "__prim__int"], + "args": ["object", "object", "object", "PyType", "PyType", "PyType", "__prim__int"], "type": "tuple" }, "__create4__": { - "args": ["object", "object", "object", "object", "type", "type", "type", "type", "__prim__int"], + "args": ["object", "object", "object", "object", "PyType", "PyType", "PyType", "PyType", "__prim__int"], "type": "tuple" }, "__create5__": { - "args": ["object", "object", "object", "object", "object", "type", "type", "type", "type", "type", "__prim__int"], + "args": ["object", "object", "object", "object", "object", "PyType", "PyType", "PyType", "PyType", "PyType", "__prim__int"], "type": "tuple" }, "__create6__": { - "args": ["object", "object", "object", "object", "object", "object", "type", "type", "type", "type", "type", "type", "__prim__int"], + "args": ["object", "object", "object", "object", "object", "object", "PyType", "PyType", "PyType", "PyType", "PyType", "PyType", "__prim__int"], "type": "tuple" }, "__getitem__": { @@ -642,10 +642,34 @@ "__prim__Seq": {}, "__prim__Set": {}, "__prim__Multiset": {}, +"PyType": { + "functions": { + "__box__": { + "args": ["PyType"], + "type": "type", + "requires": ["type___unbox__"] + } + } +}, +"type": { + "functions": { + "__unbox__": { + "args": ["type"], + "type": "PyType", + "requires": ["PyType___box__"] + }, + "__eq__": { + "args": ["type", "object"], + "type": "__prim__bool", + "requires": ["__unbox__"] + } + }, + "extends": "object" +}, "PSeq": { "functions": { "__create__": { - "args": ["__prim__Seq", "type"], + "args": ["__prim__Seq", "PyType"], "type": "PSeq" }, "__unbox__": { @@ -697,7 +721,7 @@ "PSet": { "functions": { "__create__": { - "args": ["__prim__Set", "type"], + "args": ["__prim__Set", "PyType"], "type": "PSet" }, "__unbox__": { @@ -735,7 +759,7 @@ "PMultiset": { "functions": { "__create__": { - "args": ["__prim__Multiset", "type"], + "args": ["__prim__Multiset", "PyType"], "type": "PMultiset" }, "__unbox__": { @@ -873,9 +897,6 @@ }, "LevelType": { }, -"type": { - "extends": "object" -}, "Callable": { "extends": "object" }, diff --git a/src/nagini_translation/resources/pytype.sil b/src/nagini_translation/resources/pytype.sil index 55ae05d1..be93d8b3 100644 --- a/src/nagini_translation/resources/pytype.sil +++ b/src/nagini_translation/resources/pytype.sil @@ -40,4 +40,5 @@ domain PyType { unique function str(): PyType unique function object(): PyType unique function NoneType(): PyType + unique function type(): PyType } diff --git a/src/nagini_translation/translators/abstract.py b/src/nagini_translation/translators/abstract.py index 7003ac21..3c8f691a 100644 --- a/src/nagini_translation/translators/abstract.py +++ b/src/nagini_translation/translators/abstract.py @@ -351,6 +351,9 @@ def type_check(self, lhs: Expr, type: PythonType, return self.config.type_translator.type_check( lhs, type, position, ctx, inhale_exhale=inhale_exhale) + def subtype_check(self, obj: Expr, type_expr: Expr, position: 'silver.ast.Position', ctx: Context) -> Expr: + return self.config.type_translator.subtype_check(obj, type_expr, position, ctx) + def bind_type_vars(self, method: PythonMethod, ctx: Context) -> None: return self.config.method_translator.bind_type_vars(method, ctx) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index 08cb3e84..3e941a7d 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -42,6 +42,7 @@ THREAD_POST_PRED, THREAD_START_PRED, TUPLE_TYPE, + TYPE_TYPE, ) from nagini_translation.lib.errors import rules from nagini_translation.lib.program_nodes import ( @@ -87,15 +88,46 @@ def _translate_isinstance(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: assert len(node.args) == 2 target = self.get_target(node.args[1], ctx) - assert isinstance(target, (PythonType, PythonVar)) + type_arg_type = self.get_type(node.args[1], ctx) stmt, obj = self.translate_expr(node.args[0], ctx) pos = self.to_position(node, ctx) + info = self.no_info(ctx) + type_stmt = [] if isinstance(target, PythonType): check = self.type_check(obj, target, pos, ctx, inhale_exhale=False) - else: + elif False and isinstance(target, PythonVar): check = self.type_factory.dynamic_type_check(obj, target.ref(), pos, ctx) - return stmt, check + elif type_arg_type.name == TYPE_TYPE: + type_stmt, type_obj = self.translate_expr(node.args[1], ctx) + check = self.subtype_check(obj, type_obj, pos, ctx) + elif type_arg_type.name == TUPLE_TYPE: + if isinstance(node.args[1], ast.Tuple): + options = [] + for e in node.args[1].elts: + el_target = self.get_target(e, ctx) + if isinstance(el_target, PythonType): + options.append(self.type_check(obj, el_target, pos, ctx, inhale_exhale=False)) + else: + el_stmt, el_obj = self.translate_expr(e, ctx) + type_stmt.extend(el_stmt) + options.append(self.subtype_check(obj, el_obj, pos, ctx)) + check = self._disjoin(options, pos, info) + else: + type_stmt, type_obj = self.translate_expr(node.args[1], ctx) + if type_arg_type.exact_length: + options = [] + for index, ta in enumerate(type_arg_type.type_args): + el_obj = self.get_function_call(type_arg_type, '__getitem__', + [type_obj, self.viper.IntLit(index, pos, info)], [None, None], + node, ctx) + options.append(self.subtype_check(obj, el_obj, pos, ctx)) + check = self._disjoin(options, pos, info) + else: + raise UnsupportedException(node, "isinstance with unknown-length tuple argument is currently not supported") + else: + print("++") + return stmt + type_stmt, check def _translate_type_func(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: @@ -194,7 +226,7 @@ def translate_adt_cons(self, cons: PythonClass, args: List[FuncApp], else: if arg_type.type.name in PRIMITIVES: v_type = self.translate_type(arg_type.type, ctx) - args[index] = self.to_type(translated_arg, v_type, ctx) + args[index] = self.convert_to_type(translated_arg, v_type, ctx) # Translate constructor call cons_call = self.viper.DomainFuncApp(cons.fresh(cons.adt_prefix + @@ -232,7 +264,10 @@ def translate_constructor_call(self, target_class: PythonClass, '_res', target_class, self.translator) - result_type = self.get_type(node, ctx) + if isinstance(node, ast.Call): + result_type = self.get_type(node, ctx) + else: + result_type = self.get_target(node, ctx) info = self.no_info(ctx) # Temporarily bind the type variables of the constructed class to @@ -690,6 +725,8 @@ def _has_implicit_receiver_arg(self, node: ast.Call, ctx: Context) -> bool: # constructor return True # If normal + if not isinstance(called_func, PythonMethod): + called_func = self.get_target(node.func, ctx) assert isinstance(called_func, PythonMethod) if (isinstance(node.func, ast.Attribute) and get_func_name(node.func.value) == 'super'): @@ -1669,8 +1706,8 @@ def _translate_cls_call(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: type_stmt, dynamic_type = self.translate_expr(node.func, ctx) assert not type_stmt - result_has_type = self.type_factory.dynamic_type_check(res_var.ref(), - dynamic_type, self.to_position(node, ctx), ctx) + result_has_type = self.viper.EqCmp(self.type_factory.typeof(res_var.ref(), ctx), self.to_type(dynamic_type, ctx), + self.to_position(node, ctx), self.no_info(ctx)) # Inhale the type information about the newly created object # so that it's already present when calling __init__. type_inhale = self.viper.Inhale(result_has_type, pos, diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 4949d9e3..3ab87b50 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -23,11 +23,13 @@ OBJECT_TYPE, PRIMITIVE_BOOL_TYPE, PRIMITIVE_INT_TYPE, + PRIMITIVE_TYPE_TYPE, RANGE_TYPE, PSEQ_TYPE, PSET_TYPE, SET_TYPE, SINGLE_NAME, + TYPE_TYPE, UNION_TYPE, ) from nagini_translation.lib.context import Context @@ -103,6 +105,8 @@ def convert_to_type(self, e: Expr, target_type, ctx: Context, result = self.to_bool(e, ctx, node) elif target_type == self.viper.Int: result = self.to_int(e, ctx) + elif target_type == self.type_factory.type_type(): + result = self.to_type(e, ctx) return result def _is_pure(self, e: Expr) -> bool: @@ -149,6 +153,15 @@ def to_ref(self, e: Expr, ctx: Context) -> Expr: result = self.get_function_call(prim_bool, '__box__', [result], [None], None, ctx, position=e.pos()) + elif e.typ() == self.type_factory.type_type(): + if (isinstance(e, self.viper.ast.FuncApp) and + e.funcname() == 'type___unbox__'): + result = e.args().head() + else: + prim_type = ctx.module.global_module.classes[PRIMITIVE_TYPE_TYPE] + result = self.get_function_call(prim_type, '__box__', + [result], [None], None, ctx, + position=e.pos()) return result def to_bool(self, e: Expr, ctx: Context, node: ast.AST = None) -> Expr: @@ -209,6 +222,30 @@ def to_int(self, e: Expr, ctx: Context) -> Expr: position=e.pos()) return result + def to_type(self, e: Expr, ctx: Context) -> Expr: + """ + Converts the given expression to an expression of the Silver type PyType + if it isn't already, either by unboxing a reference or undoing a + previous boxing operation. + """ + # Avoid wrapping non-pure expressions (leads to errors within Silver's + # Consistency object) + if not self._is_pure(e): + return e + if e.typ() == self.type_factory.type_type(): + return e + if e.typ() != self.viper.Ref: + e = self.to_ref(e, ctx) + if (isinstance(e, self.viper.ast.FuncApp) and + e.funcname() == 'PyType__box__'): + return e.args().head() + result = e + type_type = ctx.module.global_module.classes[TYPE_TYPE] + result = self.get_function_call(type_type, '__unbox__', + [result], [None], None, ctx, + position=e.pos()) + return result + def unwrap(self, e: Expr) -> Expr: if isinstance(e, self.viper.ast.FuncApp): if (e.funcname().endswith('__box__') or @@ -663,10 +700,12 @@ def _get_function_call(self, receiver: PythonType, assert len(args) == len(func.get_args()) for arg, param, type in zip(args, func.get_args(), arg_types): formal_args.append(param.decl) - if param.type.name == '__prim__bool': + if param.type.name == PRIMITIVE_BOOL_TYPE: actual_arg = self.to_bool(arg, ctx) - elif param.type.name == '__prim__int': + elif param.type.name == PRIMITIVE_INT_TYPE: actual_arg = self.to_int(arg, ctx) + elif param.type.name == PRIMITIVE_TYPE_TYPE: + actual_arg = self.to_type(arg, ctx) else: actual_arg = self.to_ref(arg, ctx) actual_args.append(actual_arg) diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index 9f1a6b58..131dec4c 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -9,7 +9,6 @@ import inspect from nagini_translation.lib.constants import ( - BOOL_TYPE, BOXED_PRIMITIVES, BYTES_TYPE, CHECK_DEFINED_FUNC, @@ -31,6 +30,7 @@ STRING_TYPE, THREAD_DOMAIN, TUPLE_TYPE, + TYPE_TYPE, ) from nagini_translation.lib.errors import rules from nagini_translation.lib.program_nodes import ( @@ -666,7 +666,7 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr: else: if isinstance(target, PythonType): return [], self.type_factory.translate_type_literal(target, - self.to_position(node, ctx), ctx) + self.to_position(node, ctx), ctx, node=node) if node.id == '_': object_type = ctx.module.global_module.classes[OBJECT_TYPE] temp_var = ctx.current_function.create_variable('wildcard', object_type, self.translator) @@ -743,12 +743,12 @@ def translate_static_field_access(self, field: PythonGlobalVar, type_arg = self.type_factory.translate_type_literal(receiver, position, ctx) else: - if receiver.typ() != self.type_factory.type_type(): - # Normal object, get its type. - type_arg = self.type_factory.typeof(receiver, ctx) - else: - # Type expression, use it directly. - type_arg = receiver + type_type = ctx.module.global_module.classes[TYPE_TYPE] + receiver_is_type = self.type_check(receiver, type_type, position, ctx) + type_arg = self.to_type(self.viper.CondExp(receiver_is_type, receiver, + self.to_ref(self.type_factory.typeof(receiver, ctx), ctx), + position, self.no_info(ctx)), + ctx) info = self.no_info(ctx) param = self.viper.LocalVarDecl('receiver', self.type_factory.type_type(), position, @@ -1040,48 +1040,6 @@ def is_thread_method_definition(self, node: ast.Compare, ctx: Context) -> bool: return True return False - def is_type_equality(self, node: ast.Compare, ctx: Context) -> bool: - """ - Checks if a comparison checks the equality of the type of an object with - something else (e.g., ``type(e1) == e2``), since these comparisons need special - treatment. - """ - if len(node.ops) != 1 or len(node.comparators) != 1: - return False - if not isinstance(node.ops[0], (ast.Eq, ast.Is, ast.NotEq, ast.IsNot)): - return False - for arg in (node.left, node.comparators[0]): - if (isinstance(arg, ast.Call) and isinstance(arg.func, ast.Name) and - arg.func.id == 'type'): - for other in (node.left, node.comparators[0]): - if other is not arg: - other_target = self.get_target(other, ctx) - if isinstance(other_target, PythonType): - return True - return False - - def translate_type_equality(self, node: ast.Compare, ctx: Context) -> StmtsAndExpr: - if (isinstance(node.left, ast.Call) and isinstance(node.left.func, ast.Name) and - node.left.func.id == 'type'): - type_call = node.left - type_literal = node.comparators[0] - else: - type_call = node.comparators[0] - type_literal = node.left - target = self.get_target(type_literal, ctx) - assert isinstance(target, PythonType) - call_stmt, call = self.translate_expr(type_call, ctx) - pos = self.to_position(node, ctx) - info = self.no_info(ctx) - type_literal = self.type_factory.translate_type_literal(target.python_class, - pos, ctx, alias = call) - if isinstance(node.ops[0], (ast.Is, ast.Eq)): - func = self.viper.EqCmp - else: - func = self.viper.NeCmp - comp = func(type_literal, call, pos, info) - return [], comp - def translate_thread_method_definition(self, node: ast.Compare, ctx: Context) -> StmtsAndExpr: ctx.are_threading_constants_used = True @@ -1117,8 +1075,6 @@ def translate_Compare(self, node: ast.Compare, return self.translate_wait_level_comparison(node, ctx) if self.is_thread_method_definition(node, ctx): return self.translate_thread_method_definition(node, ctx) - if self.is_type_equality(node, ctx): - return self.translate_type_equality(node, ctx) if len(node.ops) != 1 or len(node.comparators) != 1: raise UnsupportedException(node) left_stmt, left = self.translate_expr(node.left, ctx) diff --git a/src/nagini_translation/translators/method.py b/src/nagini_translation/translators/method.py index f700f4a1..afc9e3a0 100644 --- a/src/nagini_translation/translators/method.py +++ b/src/nagini_translation/translators/method.py @@ -10,7 +10,6 @@ from nagini_translation.lib.constants import ( END_LABEL, ERROR_NAME, - FILE_VAR, GLOBAL_VAR_FIELD, MAIN_METHOD_NAME, MODULE_VARS, @@ -18,12 +17,12 @@ OBJECT_TYPE, PRIMITIVES, STRING_TYPE, + TYPE_TYPE, ) from nagini_translation.lib.program_nodes import ( GenericType, MethodType, PythonExceptionHandler, - PythonField, PythonMethod, PythonModule, PythonTryBlock, @@ -221,8 +220,12 @@ def _create_typeof_pres(self, func: PythonMethod, is_constructor: bool, continue if func.method_type == MethodType.class_method: cls_arg = arg.ref() + type_type = ctx.module.global_module.classes[TYPE_TYPE] + type_check = self.type_factory.type_check( + cls_arg, type_type, self.no_position(ctx), ctx) + pres.append(type_check) type_check = self.type_factory.subtype_check( - cls_arg, func.cls, self.no_position(ctx), ctx) + self.to_type(cls_arg, ctx), func.cls, self.no_position(ctx), ctx) pres.append(type_check) continue type_check = self.get_parameter_typeof(arg, ctx) @@ -624,7 +627,7 @@ def _assign_exit_vars(self, block: PythonTryBlock, type_var: PythonVar, block.error_var.ref(), pos, info) error_case.append(value_assign) - error_type = self.type_factory.typeof(block.error_var.ref(), ctx) + error_type = self.to_ref(self.type_factory.typeof(block.error_var.ref(), ctx), ctx) type_assign = self.viper.LocalVarAssign(type_var.ref(), error_type, pos, info) error_case.append(type_assign) diff --git a/src/nagini_translation/translators/program.py b/src/nagini_translation/translators/program.py index ab5fa4fd..d8d543a2 100644 --- a/src/nagini_translation/translators/program.py +++ b/src/nagini_translation/translators/program.py @@ -29,6 +29,7 @@ THREAD_DOMAIN, THREAD_POST_PRED, THREAD_START_PRED, + TYPE_TYPE, ) from nagini_translation.lib.jvmaccess import getobject from nagini_translation.lib.program_nodes import ( @@ -321,8 +322,12 @@ def create_override_check(self, method: PythonMethod, ctx, inhale_exhale=False) elif method.method_type == MethodType.class_method: cls_arg = next(iter(method.overrides.args.values())).ref() - has_subtype = self.type_factory.subtype_check(cls_arg, method.cls, - pos, ctx) + type_type = ctx.module.global_module.classes[TYPE_TYPE] + has_type_type = self.type_factory.type_check(cls_arg, type_type, + pos, ctx) + type_has_type = self.type_factory.subtype_check(self.to_type(cls_arg, ctx), method.cls, + pos, ctx) + has_subtype = self.viper.And(has_type_type, type_has_type, pos, self.no_info(ctx)) if method.name == '__init__': fields = method.cls.all_fields pres.extend([self.get_may_set_predicate(self_arg.ref(), f, ctx) diff --git a/src/nagini_translation/translators/type.py b/src/nagini_translation/translators/type.py index 431c2712..e8d41b75 100644 --- a/src/nagini_translation/translators/type.py +++ b/src/nagini_translation/translators/type.py @@ -46,6 +46,7 @@ def builtins(self): 'builtins.PSeq': self.viper.SeqType(self.viper.Ref), 'builtins.PSet': self.viper.SetType(self.viper.Ref), 'builtins.PMultiset': self.viper.MultisetType(self.viper.Ref), + 'builtins.type': self.type_factory.type_type(), } def translate_type(self, cls: PythonClass, @@ -61,8 +62,6 @@ def translate_type(self, cls: PythonClass, elif cls.name in PRIMITIVES: cls = cls.try_box() return self.builtins['builtins.' + cls.name] - elif cls.name == 'type': - return self.type_factory.type_type() else: return self.viper.Ref @@ -95,8 +94,11 @@ def type_check(self, lhs: Expr, type: PythonType, if type is None: none_type = ctx.module.global_module.classes['NoneType'] return self.type_factory.type_check(lhs, none_type, position, ctx) - elif type.name == 'type': - return self.viper.TrueLit(position, self.no_info(ctx)) else: result = self.type_factory.type_check(lhs, type, position, ctx) return result + + def subtype_check(self, obj: Expr, type_expr: Expr, position: 'silver.ast.Position', ctx: Context) -> Expr: + obj_type = self.type_factory.typeof(self.to_ref(obj, ctx), ctx) + rhs_type = self.to_type(type_expr, ctx) + return self.type_factory._issubtype(obj_type, rhs_type, ctx, position=position) diff --git a/src/nagini_translation/translators/type_domain_factory.py b/src/nagini_translation/translators/type_domain_factory.py index 6082137a..513a431b 100644 --- a/src/nagini_translation/translators/type_domain_factory.py +++ b/src/nagini_translation/translators/type_domain_factory.py @@ -7,15 +7,19 @@ import ast -from nagini_translation.lib.constants import OBJECT_TYPE, TUPLE_TYPE +from nagini_translation.lib.constants import ( + OBJECT_TYPE, + TUPLE_TYPE, + TYPE_TYPE, +) from nagini_translation.lib.program_nodes import ( GenericType, OptionalType, PythonClass, PythonType, TypeVar, - UnionType, ) +from nagini_translation.lib.util import UnsupportedException from nagini_translation.lib.viper_ast import ViperAST from nagini_translation.translators.abstract import Context, Expr from typing import List, Tuple @@ -833,7 +837,7 @@ def type_check(self, lhs: 'Expr', type: 'PythonType', concrete=concrete) def translate_type_literal(self, type: 'PythonType', position: 'Position', - ctx: Context, alias: Expr = None) -> Expr: + ctx: Context, alias: Expr = None, node: ast.AST = None) -> Expr: """ Translates the given type to a type literal. If the given type is a generic type with missing type argument information, the type @@ -853,11 +857,14 @@ def translate_type_literal(self, type: 'PythonType', position: 'Position', if type is None: type = ctx.module.global_module.classes['NoneType'] args = [] + if isinstance(type, GenericType) and type.python_class.name == TYPE_TYPE and type.python_class.interface: + type = type.python_class if isinstance(type, GenericType): for arg in type.type_args: args.append(self.translate_type_literal(arg, position, ctx)) elif isinstance(type, PythonClass) and type.type_vars: - assert alias + if not alias: + raise UnsupportedException(node, "Unsupported reference to generic type without type arguments.") for index, arg in enumerate(type.type_vars): args.append(self.get_type_arg(alias, type, index, ctx)) if type.python_class.name == TUPLE_TYPE: diff --git a/tests/functional/verification/test_types.py b/tests/functional/verification/test_types.py new file mode 100644 index 00000000..404c0d6f --- /dev/null +++ b/tests/functional/verification/test_types.py @@ -0,0 +1,97 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from typing import Type, cast, List, Tuple +from nagini_contracts.contracts import * + +class MyClass: + pass + +class MyOtherClass(MyClass): + pass + +class MyThirdClass(MyClass): + pass + +def tester1(o: object) -> None: + if type(o) == MyClass: + mc = cast(MyClass, o) + elif type(o) in (MyClass, MyOtherClass): + moc = cast(MyOtherClass, o) + if isinstance(o, MyClass): + mc = cast(MyClass, o) + ls = [int, MyClass] + if type(o) in ls: + if not isinstance(o, MyClass): + a = cast(int, o) + +def tester1f1(o: object) -> None: + if type(o) == MyClass: + #:: ExpectedOutput(application.precondition:assertion.false) + mc = cast(MyOtherClass, o) + +def tester1f2(o: object) -> None: + if type(o) == MyClass: + mc = cast(MyClass, o) + elif type(o) in (MyThirdClass, MyOtherClass): + #:: ExpectedOutput(application.precondition:assertion.false) + moc = cast(MyOtherClass, o) + +def tester1f3(o: object) -> None: + if type(o) == MyClass: + mc = cast(MyClass, o) + elif type(o) in (MyClass, MyOtherClass): + moc = cast(MyOtherClass, o) + ls = [int, MyClass] + if type(o) in ls: + #:: ExpectedOutput(application.precondition:assertion.false) + a = cast(int, o) + +def tester2(o: object, t: type) -> None: + Requires(type(o) == int) + if isinstance(o, MyClass): + Assert(False) + Assert(type(o) != bool) + +def tester2f1(o: object, t: type) -> None: + Requires(type(o) == int) + if isinstance(o, object): + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(False) + + +def tester3(o: object, t: Type[int]) -> None: + pass + +def tester4(o: object, t: type, b: bool) -> None: + Requires(Implies(b, t == MyClass)) + ii = isinstance(o, t) + if b and isinstance(o, MyOtherClass): + Assert(ii) + + +def tester5(o: object, t: type) -> None: + if isinstance(o, (int, t)): + if t == bool: + a = cast(int, o) + +def tester5f(o: object, t: type) -> None: + if isinstance(o, (int, t)): + if t == str: + #:: ExpectedOutput(assert.failed:assertion.false) + a = cast(int, o) + + +def tester6(o: object, t: type) -> None: + tps: Tuple[type, type] = (int, t) + if isinstance(o, tps): + if t == bool: + a = cast(int, o) + + +def tester6f(o: object, t: type) -> None: + tps: Tuple[type, type] = (int, t) + if isinstance(o, tps): + if t == str: + #:: ExpectedOutput(assert.failed:assertion.false) + a = cast(int, o) \ No newline at end of file