|
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import os |
8 | | -from typing import Sequence, Tuple |
| 8 | +from typing import Optional, Sequence, Tuple |
9 | 9 |
|
10 | 10 | import numpy as np |
| 11 | +import orbax.checkpoint as ocp |
11 | 12 | import torch |
12 | 13 | from absl import logging |
13 | 14 | from flax import jax_utils |
14 | 15 | from flax.training import checkpoints as flax_checkpoints |
15 | 16 | from flax.training.checkpoints import latest_checkpoint |
| 17 | +from orbax.checkpoint.type_handlers import NumpyHandler |
16 | 18 | from tensorflow.io import gfile # pytype: disable=import-error |
17 | 19 |
|
18 | 20 | from algoperf import spec |
|
30 | 32 | ] |
31 | 33 |
|
32 | 34 |
|
| 35 | +class BoolHandler(NumpyHandler): |
| 36 | + """ |
| 37 | + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. |
| 38 | + It works by treating the scalar as a 0-dimensional array. |
| 39 | + """ |
| 40 | + |
| 41 | + def typestr(self) -> str: |
| 42 | + """Unique string identifier for this handler.""" |
| 43 | + return 'np.bool_' |
| 44 | + |
| 45 | + async def serialize( |
| 46 | + self, |
| 47 | + values: Sequence[np.bool_], |
| 48 | + infos: Sequence, |
| 49 | + args: Optional[Sequence[ocp.SaveArgs]] = None, |
| 50 | + ): |
| 51 | + """ |
| 52 | + Serializes a sequence of np.bool_ scalars by first converting them |
| 53 | + to 0-dim numpy arrays and then calling the parent NumpyHandler. |
| 54 | + """ |
| 55 | + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray |
| 56 | + array_values = [np.asarray(v, dtype=np.bool_) for v in values] |
| 57 | + # Use the parent class's robust serialization logic |
| 58 | + return await super().serialize(array_values, infos, args) |
| 59 | + |
| 60 | + async def deserialize( |
| 61 | + self, |
| 62 | + infos: Sequence, |
| 63 | + args: Optional[Sequence[ocp.RestoreArgs]] = None, |
| 64 | + ) -> Sequence[np.bool_]: |
| 65 | + """ |
| 66 | + Deserializes into a sequence of np.bool_ scalars by calling the |
| 67 | + parent handler and then converting the resulting 0-dim arrays. |
| 68 | + """ |
| 69 | + # Parent deserialize will return a sequence of 0-dimensional np.ndarray |
| 70 | + results = await super().deserialize(infos, args) |
| 71 | + |
| 72 | + # Convert each 0-d array back to an np.bool_ scalar using .item() |
| 73 | + scalar_results = [np.bool_(r.item()) for r in results] |
| 74 | + return scalar_results |
| 75 | + |
| 76 | + |
| 77 | +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) |
| 78 | + |
| 79 | + |
33 | 80 | def maybe_restore_checkpoint( |
34 | 81 | framework: str, |
35 | 82 | optimizer_state: spec.OptimizerState, |
|
0 commit comments