This repository was archived by the owner on Mar 31, 2025. It is now read-only.

Description
Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!
I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).
|
def update(self, other: Union['VarCollection', Iterable[Tuple[str, BaseVar]]]): |
|
"""Overload dict.update method to catch potential conflicts during assignment.""" |
|
if not isinstance(other, self.__class__): |
|
other = list(other) |
|
else: |
|
other = other.items() |
|
conflicts = set() |
|
for k, v in other: |
Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?