Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

objax.variable.VarCollection.update fails when passing Dict[str, Any] #253

@alvarobartt

Description

@alvarobartt

Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).

objax/objax/variable.py

Lines 311 to 318 in 53b391b

def update(self, other: Union['VarCollection', Iterable[Tuple[str, BaseVar]]]):
"""Overload dict.update method to catch potential conflicts during assignment."""
if not isinstance(other, self.__class__):
other = list(other)
else:
other = other.items()
conflicts = set()
for k, v in other:

Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions