diff --git a/flask_cache/__init__.py b/flask_cache/__init__.py index 85d90e3..8460ad0 100644 --- a/flask_cache/__init__.py +++ b/flask_cache/__init__.py @@ -38,11 +38,53 @@ else: null_control = (dict((k,None) for k in delchars),) +def get_arg_names(f): + """ + Return arguments of function + + :param f: + :return: String list of arguments + """ + try: + # Python >= 3.3 + sig = inspect.signature(f) + return [parameter.name + for parameter + in sig.parameters.values() + if parameter.kind == parameter.POSITIONAL_OR_KEYWORD] + except AttributeError: + try: + # Python >= 3.0 + return inspect.getfullargspec(f).args + except AttributeError: + return inspect.getargspec(f).args + +def get_arg_default(f, position): + try: + # Python >= 3.3 + sig = inspect.signature(f) + arg = list(sig.parameters.values())[position] + arg_def = arg.default + return arg_def if arg_def != inspect.Parameter.empty else None + except AttributeError: + try: + spec = inspect.getfullargspec(f) + except AttributeError: + spec = inspect.getargspec(f) + + args_len = len(spec.args) + if spec.defaults and abs(position - args_len) <= len(spec.defaults): + return spec.defaults[position - args_len] + else: + return None + def function_namespace(f, args=None): """ Attempts to returns unique namespace for function """ - m_args = inspect.getargspec(f)[0] + + m_args = get_arg_names(f) + instance_token = None instance_self = getattr(f, '__self__', None) @@ -415,24 +457,27 @@ def _memoize_kwargs_to_args(self, f, *args, **kwargs): #: 1, b=2 is equivilant to a=1, b=2, etc. new_args = [] arg_num = 0 - argspec = inspect.getargspec(f) + argspec = get_arg_names(f) + + arg_names = get_arg_names(f) + args_len = len(arg_names) - args_len = len(argspec.args) for i in range(args_len): - if i == 0 and argspec.args[i] in ('self', 'cls'): + arg_default = get_arg_default(f, i) + if i == 0 and arg_names[i] in ('self', 'cls'): #: use the repr of the class instance #: this supports instance methods for #: the memoized functions, giving more #: flexibility to developers arg = repr(args[0]) arg_num += 1 - elif argspec.args[i] in kwargs: - arg = kwargs[argspec.args[i]] + elif arg_names[i] in kwargs: + arg = kwargs[arg_names[i]] elif arg_num < len(args): arg = args[arg_num] arg_num += 1 - elif abs(i-args_len) <= len(argspec.defaults): - arg = argspec.defaults[i-args_len] + elif arg_default: + arg = arg_default arg_num += 1 else: arg = None diff --git a/test_cache.py b/test_cache.py index 9811572..caaeed7 100644 --- a/test_cache.py +++ b/test_cache.py @@ -164,6 +164,21 @@ def big_foo(a, b): assert big_foo(5, 2) == result + def test_06b_memoize_annotated(self): + if sys.version_info >= (3, 0): + with self.app.test_request_context(): + @self.cache.memoize(50) + def big_foo_annotated(a, b): + return a+b+random.randrange(0, 100000) + big_foo_annotated.__annotations__ = {'a': int, 'b': int, 'return': int} + + result = big_foo_annotated(5, 2) + + time.sleep(2) + + assert big_foo_annotated(5, 2) == result + + def test_07_delete_memoize(self): with self.app.test_request_context(): @@ -213,6 +228,37 @@ def big_foo(a, b): assert self.cache.get(version_key) is not None + + + def test_07c_delete_memoized_annotated(self): + with self.app.test_request_context(): + @self.cache.memoize(5) + def big_foo_annotated(a, b): + return a+b+random.randrange(0, 100000) + + big_foo_annotated.__annotations__ = {'a': int, 'b': int, 'return': int} + + result = big_foo_annotated(5, 2) + result2 = big_foo_annotated(5, 3) + + time.sleep(1) + + assert big_foo_annotated(5, 2) == result + assert big_foo_annotated(5, 2) == result + assert big_foo_annotated(5, 3) != result + assert big_foo_annotated(5, 3) == result2 + + self.cache.delete_memoized_verhash(big_foo_annotated) + + _fname, _fname_instance = function_namespace(big_foo_annotated) + version_key = self.cache._memvname(_fname) + assert self.cache.get(version_key) is None + + assert big_foo_annotated(5, 2) != result + assert big_foo_annotated(5, 3) != result2 + + assert self.cache.get(version_key) is not None + def test_08_delete_memoize(self): with self.app.test_request_context():