Skip to content

Commit 953271f

Browse files
committed
add a load_string function
1 parent ca41146 commit 953271f

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

rbms/bernoulli_gaussian/classes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
from botocore.vendored.six import u
32

43
import numpy as np
54
import torch

rbms/custom_fn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import h5py
2+
import numpy as np
13
import torch
24
from torch import Tensor
35

@@ -47,3 +49,13 @@ def check_keys_dict(d: dict, names: list[str]):
4749
raise ValueError(
4850
f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}"""
4951
)
52+
53+
54+
def load_string(f: h5py.Dataset, k: str | bytes) -> str:
55+
# Fix 1: Ensure key is a string
56+
# key = k.decode("utf-8") if isinstance(k, bytes) else k
57+
val = np.asarray(f[k])
58+
# Fix 2: Ensure string values (like 'Reservoir') are strings, not bytes
59+
if val.dtype.kind in ["S", "V", "O"]: # Bytes, Void, or Object (StringDType)
60+
val = val.astype(str)
61+
return str(val)

0 commit comments

Comments
 (0)