Skip to content

Commit 061ad27

Browse files
fix a bug when loading a config file
The HDF5 file stores config values like log_level_global as a numeric (e.g., numpy.int64). When loading, that integer is passed into VariPEPS_Config. In setattr, the field type for log_level_global is the Enum LogLevel, so integers must be coerced to LogLevel. Your version’s coercion path doesn’t catch your value, so it falls through and raises: Type mismatch for option 'log_level_global', got '<class numpy.int64>', expected '<enum 'LogLevel'>'. Why it falls through The loader passes a numpy integer (or a 0-d/1-d array) instead of a Python int. The Enum branch in setattr is too strict about the numeric checks, so it doesn’t convert that value into LogLevel.
1 parent be9ed47 commit 061ad27

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

varipeps/config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,33 @@ def __setattr__(self, name: str, value: Any) -> NoReturn:
379379
elif (
380380
field.type is bool
381381
and hasattr(value, "dtype")
382-
and np.isdtype(value.dtype, np.bool)
382+
and np.issubdtype(value.dtype, np.bool_)
383383
and value.size == 1
384384
):
385385
if value.ndim > 0:
386386
value = value.reshape(-1)[0]
387387
value = bool(value)
388+
elif isinstance(field.type, type) and issubclass(field.type, Enum):
389+
# Accept ints/np.int64 or enum names for Enum fields
390+
if isinstance(value, field.type):
391+
pass
392+
elif isinstance(value, (int,)) or (
393+
hasattr(value, "dtype")
394+
and np.issubdtype(value.dtype, np.integer)
395+
and value.size == 1
396+
):
397+
if hasattr(value, "ndim") and value.ndim > 0:
398+
value = value.reshape(-1)[0]
399+
value = field.type(int(value))
400+
elif isinstance(value, str):
401+
try:
402+
value = field.type[value]
403+
except KeyError:
404+
value = field.type(int(value))
405+
else:
406+
raise TypeError(
407+
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
408+
)
388409
else:
389410
raise TypeError(
390411
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."

0 commit comments

Comments
 (0)