diff --git a/pytools/__init__.py b/pytools/__init__.py index 08307a82..2ecaca47 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -395,12 +395,18 @@ class RecordWithoutPickling: """ __slots__: ClassVar[list[str]] = [] - fields: ClassVar[set[str]] + + # A dict, not a set, to maintain a deterministic iteration order + fields: ClassVar[Dict[str, None]] def __init__(self, valuedict: Mapping[str, Any] | None = None, exclude: Sequence[str] | None = None, **kwargs: Any) -> None: + from warnings import warn + warn(f"{self.__class__.__bases__[0]} is deprecated and will be " + "removed in 2025. Use dataclasses instead.") + assert self.__class__ is not Record if exclude is None: @@ -409,42 +415,45 @@ def __init__(self, try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} + + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) if valuedict is not None: kwargs.update(valuedict) for key, value in kwargs.items(): if key not in exclude: - fields.add(key) + fields[key] = None setattr(self, key, value) - def get_copy_kwargs(self, **kwargs): + def get_copy_kwargs(self, **kwargs: Any) -> Dict[str, Any]: for f in self.__class__.fields: if f not in kwargs: with contextlib.suppress(AttributeError): kwargs[f] = getattr(self, f) return kwargs - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> "RecordWithoutPickling": return self.__class__(**self.get_copy_kwargs(**kwargs)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(f"{fld}={getattr(self, fld)!r}" for fld in sorted(self.__class__.fields) if hasattr(self, fld))) - def register_fields(self, new_fields): + def register_fields(self, new_fields: Iterable[str]) -> None: try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} - fields.update(new_fields) + fields.update(dict.fromkeys(sorted(new_fields))) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This method is implemented to avoid pylint 'no-member' errors for # attribute access. raise AttributeError( @@ -455,39 +464,39 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): __slots__: ClassVar[list[str]] = [] - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { key: getattr(self, key) for key in self.__class__.fields if hasattr(self, key)} - def __setstate__(self, valuedict): + def __setstate__(self, valuedict: Mapping[str, Any]) -> None: try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} + + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) for key, value in valuedict.items(): - fields.add(key) + fields[key] = None setattr(self, key, value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self is other: return True return (self.__class__ == other.__class__ and self.__getstate__() == other.__getstate__()) - def __ne__(self, other): - return not self.__eq__(other) - class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: RecordWithoutPickling.__init__(self, *args, **kwargs) - self._cached_hash = None + self._cached_hash: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: # This attribute may vanish during pickling. if getattr(self, "_cached_hash", None) is None: self._cached_hash = hash(( @@ -495,7 +504,7 @@ def __hash__(self): *(getattr(self, field) for field in self.__class__.fields) )) - return self._cached_hash + return cast(int, self._cached_hash) class ImmutableRecord(ImmutableRecordWithoutPickling, Record): diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index a7e0b4f1..79bce853 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -887,6 +887,139 @@ def test_permutations(): assert len(perms) == 12 +# These classes must be defined globally to be picklable +class SimpleRecord(Record): + pass + + +class SetBasedRecord(Record): + fields = {"c", "b", "a"} # type: ignore[assignment] + + def __init__(self, c, b, a): + super().__init__(c=c, b=b, a=a) + + +def test_record(): + # {{{ New, dict-based Record + + r1 = SimpleRecord(a=1, b=2) + assert r1.a == 1 + assert r1.b == 2 + + r2 = r1.copy() + assert r2.a == 1 + assert r1 == r2 + + r3 = r1.copy(b=3) + assert r3.b == 3 + assert r1 != r3 + + assert str(r1) == str(r2) == "SimpleRecord(a=1, b=2)" + assert str(r3) == "SimpleRecord(a=1, b=3)" + + # Unregistered fields are (silently) ignored for printing + r1.f = 6 + assert str(r1) == "SimpleRecord(a=1, b=2)" + + # Registered fields are printed + r1.register_fields({"d", "e"}) + assert str(r1) == "SimpleRecord(a=1, b=2)" + + r1.d = 4 + r1.e = 5 + assert str(r1) == "SimpleRecord(a=1, b=2, d=4, e=5)" + + with pytest.raises(AttributeError): + r1.ff + + # Test pickling + + import pickle + r1_pickled = pickle.loads(pickle.dumps(r1)) + assert r1 == r1_pickled + + class SimpleRecord2(Record): + pass + + r_new = SimpleRecord2(b=2, a=1) + assert r_new.a == 1 + assert r_new.b == 2 + + assert str(r_new) == "SimpleRecord2(b=2, a=1)" + + assert r_new != r1 + + # }}} + + # {{{ Legacy set-based record (used in Loopy) + + r = SetBasedRecord(3, 2, 1) + + # Fields are converted to a dict during __init__ + assert isinstance(r.fields, dict) + assert r.a == 1 + assert r.b == 2 + assert r.c == 3 + + # Fields are sorted alphabetically in set-based records + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + # Unregistered fields are (silently) ignored for printing + r.f = 6 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + # Registered fields are printed + r.register_fields({"d", "e"}) + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + r.d = 4 + r.e = 5 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" + + with pytest.raises(AttributeError): + r.ff + + # Test pickling + r_pickled = pickle.loads(pickle.dumps(r)) + assert r == r_pickled + + # }}} + + # {{{ __slots__, __dict__, __weakref__ handling + + class RecordWithEmptySlots(Record): + __slots__ = [] + + assert hasattr(RecordWithEmptySlots(), "__slots__") + assert not hasattr(RecordWithEmptySlots(), "__dict__") + assert not hasattr(RecordWithEmptySlots(), "__weakref__") + + class RecordWithUnsetSlots(Record): + pass + + assert hasattr(RecordWithUnsetSlots(), "__slots__") + assert hasattr(RecordWithUnsetSlots(), "__dict__") + assert hasattr(RecordWithUnsetSlots(), "__weakref__") + + from pytools import ImmutableRecord + + class ImmutableRecordWithEmptySlots(ImmutableRecord): + __slots__ = [] + + assert hasattr(ImmutableRecordWithEmptySlots(), "__slots__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__dict__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__weakref__") + + class ImmutableRecordWithUnsetSlots(ImmutableRecord): + pass + + assert hasattr(ImmutableRecordWithUnsetSlots(), "__slots__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__dict__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__weakref__") + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])