Skip to content

Commit e37b0c9

Browse files
Merge pull request #39 from ATTPC/gwm17_dev
Fix small bug in simulator.py, add type hint for Numba dicts
2 parents ef6f098 + 1786dfd commit e37b0c9

7 files changed

Lines changed: 54 additions & 9 deletions

File tree

docs/api/detector/pairing.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# pairing Module
2+
3+
::: attpc_engine.detector.pairing

docs/api/detector/typed_dict.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# typed_dict Module
2+
3+
::: attpc_engine.detector.typed_dict

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ nav:
1919
- About detector: api/detector/index.md
2020
- beam_pads: api/detector/beam_pads.md
2121
- constants: api/detector/constants.md
22+
- pairing: api/detector/pairing.md
2223
- parameters: api/detector/parameters.md
2324
- response: api/detector/response.md
2425
- simulator: api/detector/simulator.md
2526
- solver: api/detector/solver.md
2627
- transporter: api/detector/transporter.md
28+
- typed_dict: api/detector/typed_dict.md
2729
- writer: api/detector/writer.md
2830
- kinematics:
2931
- About kinematics: api/kinematics/index.md

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "attpc_engine"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = "AT-TPC Monte-Carlo simulation engine"
55
authors = [
66
{name = "Gordon McCann", email = "gordonmccann215@gmail.com"},

src/attpc_engine/detector/simulator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .constants import NUM_TB
1212
from .. import nuclear_map
1313
from .pairing import unpair
14+
from .typed_dict import NumbaTypedDict
1415

1516
import numpy as np
1617
import h5py as h5
@@ -123,7 +124,7 @@ def digitize(self, config: Config, rng: Generator) -> np.ndarray:
123124

124125

125126
@njit
126-
def dict_to_points(points: Dict[int, int]) -> np.ndarray:
127+
def dict_to_points(points: NumbaTypedDict[int, int]) -> np.ndarray:
127128
"""
128129
Converts dictionary of N pad,tb keys with corresponding number of electrons
129130
to Nx3 array where each row is [pad, tb, e], now combined over pad/tb combos.
@@ -294,7 +295,7 @@ def generate_electrons(
294295
return electrons
295296

296297
def generate_point_cloud(
297-
self, config: Config, rng: Generator, points: Dict[int, int]
298+
self, config: Config, rng: Generator, points: NumbaTypedDict
298299
):
299300
"""Create the point cloud
300301
@@ -331,7 +332,7 @@ def generate_point_cloud(
331332
if config.pad_grid_edges is None or config.pad_grid is None:
332333
raise ValueError("Pad grid is not loaded at SimParticle.generate_hits!")
333334

334-
points = transport_track(
335+
transport_track(
335336
config.pad_grid,
336337
config.pad_grid_edges,
337338
config.det_params.diffusion,

src/attpc_engine/detector/transporter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from .pairing import pair
22
from .beam_pads import BEAM_PADS_ARRAY
3+
from .typed_dict import NumbaTypedDict
34

45
import numpy as np
56
from numba import njit
6-
from numba.typed import Dict
77

88
STEPS = 10
99

@@ -124,7 +124,7 @@ def point_transport(
124124
time: float,
125125
center: tuple[float, float],
126126
electrons: int,
127-
points: Dict[int, int],
127+
points: NumbaTypedDict[int, int],
128128
):
129129
"""
130130
Transports all electrons created at a point in a simulated nucleus' track
@@ -170,7 +170,7 @@ def transverse_transport(
170170
center: tuple[float, float],
171171
electrons: int,
172172
sigma_t: float,
173-
points: Dict[int, int],
173+
points: NumbaTypedDict[int, int],
174174
):
175175
"""
176176
Transports all electrons created at a point in a simulated nucleus'
@@ -246,7 +246,7 @@ def find_pads_hit(
246246
center: tuple[float, float],
247247
electrons: int,
248248
sigma_t: float,
249-
points: Dict[int, int],
249+
points: NumbaTypedDict[int, int],
250250
):
251251
"""
252252
Finds the pads hit by transporting the electrons created at a point in
@@ -298,7 +298,7 @@ def transport_track(
298298
dv: float,
299299
track: np.ndarray,
300300
electrons: np.ndarray,
301-
points: Dict[int, int],
301+
points: NumbaTypedDict[int, int],
302302
):
303303
"""
304304
High-level function that transports each point in a nucleus' trajectory
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import TypeVar, Generic, Iterable
2+
3+
K = TypeVar("K")
4+
V = TypeVar("V")
5+
6+
7+
class NumbaTypedDict(Generic[K, V]):
8+
"""This is simply a type hint interface for Numba typed dictionaries
9+
10+
Use this as a type hint wherever you would be using numba.typed.Dict
11+
and magically all of your linters will be happy.
12+
13+
Example:
14+
```py
15+
my_dict = numba.typed.Dict.empty(
16+
key_type=numba.core.types.int64,
17+
value_type=numba.core.types.int64
18+
)
19+
```
20+
is type-hinted as
21+
```py
22+
my_dict: NumbaTypedDict[int, int]
23+
```
24+
25+
Do not attempt to instantiate this object! It won't work!
26+
"""
27+
28+
def __getitem__(self, x: K) -> V: ...
29+
30+
def __setitem__(self, x: K, v: V): ...
31+
32+
def __len__(self) -> int: ...
33+
34+
def items(self) -> Iterable[tuple[K, V]]: ...
35+
36+
def get(self, x: K, default: V | None = None) -> V: ...

0 commit comments

Comments
 (0)