diff --git a/src/libpython_clj2/python/with.clj b/src/libpython_clj2/python/with.clj index 6b45dce9..54788336 100644 --- a/src/libpython_clj2/python/with.clj +++ b/src/libpython_clj2/python/with.clj @@ -49,15 +49,16 @@ [bind-vec & body] (when-not (= 2 (count bind-vec)) (throw (Exception. "Bind vector must have 2 items"))) - (let [varname (first bind-vec)] + (let [varname (first bind-vec) + mgr (gensym "mgr")] `(py-ffi/with-gil - (let [~@bind-vec] + (let [~mgr ~(second bind-vec)] (with-bindings {#'py-ffi/*python-error-handler* python-pyerr-fetch-error-handler} - (py-fn/call-attr ~varname "__enter__" nil) - (try - (let [retval# (do ~@body)] - (py-fn/call-attr ~varname "__exit__" [nil nil nil]) - retval#) - (catch Throwable e# - (with-exit-error-handler ~varname e#)))))))) + (let [~varname (py-fn/call-attr ~mgr "__enter__" nil)] + (try + (let [retval# (do ~@body)] + (py-fn/call-attr ~mgr "__exit__" [nil nil nil]) + retval#) + (catch Throwable e# + (with-exit-error-handler ~mgr e#))))))))) diff --git a/test/libpython_clj2/python_test.clj b/test/libpython_clj2/python_test.clj index 6f46a394..71d99c55 100644 --- a/test/libpython_clj2/python_test.clj +++ b/test/libpython_clj2/python_test.clj @@ -192,6 +192,13 @@ (is (= ["enter" "exit: None"] (py/->jvm fn-list)))))) +(deftest with-enter-returns-different-object + (testing "py/with should bind the return value of __enter__, not the context manager" + (let [testcode (py/import-module "testcode")] + (py/with [f (py/call-attr testcode "FileWrapper" "test content")] + ;; f should be the StringIO object returned by __enter__, not FileWrapper + (is (= "test content" (py/call-attr f "read"))))))) + (deftest arrow-as-fns-with-nil (is (= nil (py/->jvm nil))) (is (= nil (py/as-jvm nil)))) @@ -449,6 +456,4 @@ class Foo: (let [data (doto (pd/DataFrame {:index [1 2] :value [2 3] :variable [1 1]}) (py. melt :id_vars "index"))] - ((py.- px line) :data_frame data :x "index" :y "value" :color "variable")) - - ) + ((py.- px line) :data_frame data :x "index" :y "value" :color "variable"))) diff --git a/testcode/__init__.py b/testcode/__init__.py index d2fa00b7..52727dd3 100644 --- a/testcode/__init__.py +++ b/testcode/__init__.py @@ -2,17 +2,38 @@ class WithObjClass: def __init__(self, suppress, fn_list): self.suppress = suppress self.fn_list = fn_list + def __enter__(self): self.fn_list.append("enter") + return self # Return self so methods can be called on the bound variable + def doit_noerr(self): return 1 + def doit_err(self): raise Exception("Spam", "Eggs") + def __exit__(self, ex_type, ex_val, ex_traceback): self.fn_list.append("exit: " + str(ex_val)) return self.suppress +class FileWrapper: + """Context manager where __enter__ returns a different object""" + + def __init__(self, content): + self.content = content + + def __enter__(self): + # Return a different object with the content + import io + + return io.StringIO(self.content) + + def __exit__(self, *args): + return False + + def for_iter(arg): retval = [] for item in arg: @@ -24,15 +45,13 @@ def calling_custom_clojure_fn(arg): return arg.clojure_fn() - -def complex_fn(a, b, c: str=5, *args, d=10, **kwargs): - return {"a" : a, - "b" : b, - "c" : c, - "args" : args, - "d": d, - "kwargs": kwargs} +def complex_fn(a, b, c: str = 5, *args, d=10, **kwargs): + return {"a": a, "b": b, "c": c, "args": args, "d": d, "kwargs": kwargs} -complex_fn_testcases = {"complex_fn(1, 2, c=10, d=10, e=10)":complex_fn(1, 2, c=10, d=10, e=10), - "complex_fn(1, 2, 10, 11, 12, d=10, e=10)":complex_fn(1, 2, 10, 11, 12, d=10, e=10)} +complex_fn_testcases = { + "complex_fn(1, 2, c=10, d=10, e=10)": complex_fn(1, 2, c=10, d=10, e=10), + "complex_fn(1, 2, 10, 11, 12, d=10, e=10)": complex_fn( + 1, 2, 10, 11, 12, d=10, e=10 + ), +}