Skip to content

Commit ac9c6d1

Browse files
committed
NonUniformBatch: implement __getitem__ to allow index array (batch[idx]).
1 parent 67c9b62 commit ac9c6d1

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@ altdss/_version.py
131131

132132
apidocs
133133
electricdss-tst
134+
htmlcov/
135+
.coverage
136+
private_tests/

altdss/Batch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,5 +648,24 @@ def __iter__(self) -> Iterator[DSSObj]:
648648
for idx in range(cnt[0]):
649649
yield self._obj_cls(self._api_util, ptr[idx])
650650

651+
def __getitem__(self, idx: int) -> DSSObj:
652+
'''
653+
Return an object of the batch by index (0-based).
654+
'''
655+
_pointer, _count = self._get_ptr_cnt()
656+
if idx > _count or idx < 0:
657+
raise IndexError('Invalid object index inside the batch')
658+
659+
ptr = _pointer[idx]
660+
if ptr == self._ffi.NULL:
661+
return None
662+
663+
if self._obj_cls is None:
664+
cls_idx = self._lib.Obj_GetClassIdx(ptr)
665+
pycls = DSSObj._idx_to_cls[cls_idx]
666+
return pycls(self._api_util, ptr)
667+
668+
return self._obj_cls(self._api_util, ptr)
669+
651670

652671
__all__ = ('DSSBatch', 'NonUniformBatch', )

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ relevant. See [DSS C-API's repository](https://github.com/dss-extensions/dss_cap
66
## 0.2.2
77

88
- CircuitElementBatch: fix `MaxCurrent`. This will require the backend to be updated to v0.14.3.
9+
- NonUniformBatch: allow `batch[idx]` to get a single element by index.
910

1011
## 0.2.1
1112

0 commit comments

Comments
 (0)