Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions docs/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,22 @@ on the data of the snapshot.

#### Snapshots Correlation Hook

There are two correlation hooks available:

- [`register_correlation_hook`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook]
- [`register_correlation_hook_with_master_record`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook_with_master_record]

Both do the same thing, but as the naming suggests, the latter also provides a master record.
The [`register_correlation_hook`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook]
method expects a callable with the following signature:
`Callable[[str, dict], Union[None, list[DataPointTask]]]`, where the first argument is the entity type, and the second is a dict
containing the current values of the entity and its linked entities.
The method can optionally return a list of DataPointTask objects to be inserted into the system.
The [`register_correlation_hook_with_master_record`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook_with_master_record] method expects a callable with the following signature:
`Callable[[str, dict, dict], Union[None, list[DataPointTask]]]` - the first two arguments are identical (entity type and dict with current values), but there is also a third argument: a dictionary of values stored in the master record of the entity.
The method (applicable to both variants) can optionally return a list of `DataPointTask` objects to be inserted into the system.

As correlation hooks can depend on each other, the hook inputs and outputs must be specified
using the depends_on and may_change arguments. Both arguments are lists of lists of strings,
using the `depends_on` and `may_change` arguments. Both arguments are lists of lists of strings,
where each list of strings is a path from the specified entity type to individual attributes (even on linked entities).
For example, if the entity type is `test_entity_type`, and the hook depends on the attribute `test_attr_type1`,
the path is simply `[["test_attr_type1"]]`. If the hook depends on the attribute `test_attr_type1`
Expand All @@ -351,9 +359,21 @@ of an entity linked using `test_attr_link`, the path will be `[["test_attr_link
def correlation_hook(entity_type: str, values: dict):
...

def correlation_hook_with_master_record(entity_type: str, values: dict, master_record: dict):
...

# Without master record
registrar.register_correlation_hook(
correlation_hook, "test_entity_type", [["test_attr_type1"]], [["test_attr_type2"]]
)

# Or with master record
registrar.register_correlation_hook_with_master_record(
correlation_hook_with_master_record,
"test_entity_type",
[["test_attr_type1"]],
[["test_attr_type2"]]
)
```

The order of running callbacks is determined automatically, based on the dependencies.
Expand Down
43 changes: 42 additions & 1 deletion dp3/common/callback_registrar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from functools import partial
from functools import partial, wraps
from logging import Logger
from typing import Callable, Union

Expand Down Expand Up @@ -365,6 +365,47 @@ def register_correlation_hook(
may_change: each item should specify an attribute that `hook` may change.
specification format is identical to `depends_on`.

Raises:
ValueError: On failure of specification validation.
"""

# Ignore master record for this variant of the hook
@wraps(hook)
def wrapped_hook(e: str, s: dict, _m: dict):
return hook(e, s)

self._snap_shooter.register_correlation_hook(
wrapped_hook, entity_type, depends_on, may_change
)

def register_correlation_hook_with_master_record(
self,
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
entity_type: str,
depends_on: list[list[str]],
may_change: list[list[str]],
):
"""
Registers passed hook to be called during snapshot creation.

Identical to `register_correlation_hook`, but the hook also receives the master record.

Binds hook to specified entity_type (though same hook can be bound multiple times).

`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.

Args:
hook: `hook` callable should have the signature
`hook(entity_type: str, current_values: dict, master_record: dict)`.
where `current_values` includes linked entities.
Can optionally return a list of DataPointTask objects to perform.
entity_type: specifies entity type
depends_on: each item should specify an attribute that is depended on
in the form of a path from the specified entity_type to individual attributes
(even on linked entities).
may_change: each item should specify an attribute that `hook` may change.
specification format is identical to `depends_on`.

Raises:
ValueError: On failure of specification validation.
"""
Expand Down
18 changes: 13 additions & 5 deletions dp3/snapshots/snapshooter.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,24 @@ def register_timeseries_hook(

def register_correlation_hook(
self,
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
entity_type: str,
depends_on: list[list[str]],
may_change: list[list[str]],
):
"""
Registers passed hook to be called during snapshot creation.

Common implementation for hooks with and without master record.

Binds hook to specified entity_type (though same hook can be bound multiple times).

`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.

Args:
hook: `hook` callable should expect entity type as str
and its current values, including linked entities, as dict
and its current values, including linked entities, as dict;
and its master record as dict.
Can optionally return a list of DataPointTask objects to perform.
entity_type: specifies entity type
depends_on: each item should specify an attribute that is depended on
Expand Down Expand Up @@ -457,9 +460,12 @@ def make_linkless_snapshot(self, entity_type: str, master_record: dict, time: da
self.run_timeseries_processing(entity_type, master_record)
values = self.get_values_at_time(entity_type, master_record, time)
self.add_mirrored_links(entity_type, values)
entity_values = {(entity_type, master_record["_id"]): values}
entity_id = master_record["_id"]
entity_values = {(entity_type, entity_id): values}

tasks = self._correlation_hooks.run(entity_values)
tasks = self._correlation_hooks.run(
entity_values, {(entity_type, entity_id): master_record}
)
for task in tasks:
self.task_queue_writer.put_task(task)

Expand Down Expand Up @@ -499,6 +505,7 @@ def make_snapshot(self, task: Snapshot):
The resulting snapshots are saved into DB.
"""
entity_values = {}
entity_master_records = {}
for entity_type, entity_id in task.entities:
record = self.db.get_master_record(entity_type, entity_id) or {"_id": entity_id}
if not self.config.keep_empty and len(record) == 1:
Expand All @@ -508,9 +515,10 @@ def make_snapshot(self, task: Snapshot):
values = self.get_values_at_time(entity_type, record, task.time)
self.add_mirrored_links(entity_type, values)
entity_values[entity_type, entity_id] = values
entity_master_records[entity_type, entity_id] = record

self.link_loaded_entities(entity_values)
created_tasks = self._correlation_hooks.run(entity_values)
created_tasks = self._correlation_hooks.run(entity_values, entity_master_records)
for created_task in created_tasks:
self.task_queue_writer.put_task(created_task)

Expand Down
15 changes: 10 additions & 5 deletions dp3/snapshots/snapshot_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT

def register(
self,
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
entity_type: str,
depends_on: list[list[str]],
may_change: list[list[str]],
Expand All @@ -97,8 +97,9 @@ def register(
If entity_type and attribute specifications are validated
and ValueError is raised on failure.
Args:
hook: `hook` callable should expect entity type as str
and its current values, including linked entities, as dict.
hook: `hook` callable should expect entity type as str;
its current values, including linked entities, as dict;
and its master record as dict.
Can optionally return a list of DataPointTask objects to perform.
entity_type: specifies entity type
depends_on: each item should specify an attribute that is depended on
Expand Down Expand Up @@ -191,7 +192,7 @@ def _resolve_entities_in_path(self, base_entity: str, path: list[str]) -> list[t
position = entity_attributes[position.relation_to]
return resolved_path

def run(self, entities: dict) -> list[DataPointTask]:
def run(self, entities: dict, entity_master_records: dict) -> list[DataPointTask]:
"""Runs registered hooks."""
entity_types = {etype for etype, _ in entities}
hook_subset = [
Expand All @@ -200,18 +201,22 @@ def run(self, entities: dict) -> list[DataPointTask]:
topological_order = self._dependency_graph.topological_order
hook_subset.sort(key=lambda x: topological_order.index(x[0]))
entities_by_etype = defaultdict(dict)
entity_master_records_by_etype = defaultdict(dict)
for (etype, eid), values in entities.items():
entities_by_etype[etype][eid] = values
for (etype, eid), mr in entity_master_records.items():
entity_master_records_by_etype[etype][eid] = mr

created_tasks = []

with task_context(self.model_spec):
for hook_id, hook, etype in hook_subset:
short_id = hook_id if len(hook_id) < 160 else self._short_hook_ids[hook_id]
for eid, entity_values in entities_by_etype[etype].items():
entity_master_record = entity_master_records_by_etype[etype].get(eid, {})
self.log.debug("Running hook %s on entity %s", short_id, eid)
try:
tasks = hook(etype, entity_values)
tasks = hook(etype, entity_values, entity_master_record)
if tasks is not None and tasks:
created_tasks.extend(tasks)
except Exception as e:
Expand Down
28 changes: 28 additions & 0 deletions tests/modules/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ def modify_value(_: str, record: dict, attr: str, value):
record[attr] = value


def use_master_record(
_: str, record: dict, master_record: dict, target_attr: str, source_attr: str
):
"""Hook that uses master record to copy a value from master to snapshot.

Only applies when source attribute in master record has value starting with "master_"
to avoid interfering with other test cases.
"""
if source_attr in master_record:
# Get the value from master record
master_value = master_record[source_attr].get("v", None)
if master_value is not None and str(master_value).startswith("master_"):
# Append a suffix to demonstrate master record was used
record[target_attr] = f"{master_value}_from_master"


dummy_hook_abc = update_wrapper(partial(modify_value, attr="data2", value="abc"), modify_value)
dummy_hook_def = update_wrapper(partial(modify_value, attr="data1", value="def"), modify_value)

Expand Down Expand Up @@ -67,3 +83,15 @@ def __init__(
depends_on=[],
may_change=[["data1"]],
)

# Testing register_correlation_hook_with_master_record
# This hook should copy data1 from master record to data2 with a suffix
registrar.register_correlation_hook_with_master_record(
update_wrapper(
partial(use_master_record, target_attr="data4", source_attr="data3"),
use_master_record,
),
"A",
depends_on=[],
may_change=[["data4"]],
)
17 changes: 17 additions & 0 deletions tests/test_api/test_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def make_dp(type, id, attr, v, time=False):
make_dp("C", "c1", "ds", {"eid": "d1"}, time=True),
make_dp("C", "c1", "data1", "inita"),
make_dp("C", "c1", "data2", "inita"),
# For test_master_record_hook (A-423)
make_dp("A", 423, "data3", "master_test"),
make_dp("A", 423, "data4", "placeholder"),
]
res = cls.push_datapoints(entity_datapoints)
if res.status_code != 200:
Expand Down Expand Up @@ -79,3 +82,17 @@ def test_hook_dependency_value_forwarding(self):
for snapshot in data.snapshots:
self.assertEqual(snapshot["data1"], "modifd")
self.assertEqual(snapshot["data2"], "modifc")

def test_master_record_hook(self):
"""
Test that hooks registered via register_correlation_hook_with_master_record
correctly receive the master record parameter.
"""
# Entity A-423 has data1="master_test" in its master record
# The master record hook should copy data1 from master and append "_from_master"
data = self.get_entity_data("entity/A/423", EntityEidData)
self.assertGreater(len(data.snapshots), 0)
for snapshot in data.snapshots:
self.assertEqual(snapshot["data3"], "master_test")
# The hook should have set data4 to data3 from master record + "_from_master"
self.assertEqual(snapshot["data4"], "master_test_from_master")
4 changes: 2 additions & 2 deletions tests/test_common/test_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer


def modify_value(_: str, record: dict, attr: str, value):
def modify_value(_: str, record: dict, _master_record: dict, attr: str, value):
record[attr] = value


Expand All @@ -37,7 +37,7 @@ def test_basic_function(self):
hook=dummy_hook_abc, entity_type="A", depends_on=[["data1"]], may_change=[["data2"]]
)
values = {}
self.container.run({("A", "a1"): values})
self.container.run({("A", "a1"): values}, {})
self.assertEqual(values["data2"], "abc")

def test_circular_dependency_error(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_config/db_entities/A.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ attribs:
type: plain
data_type: string

data3:
name: data3
description: entity data
type: plain
data_type: string

data4:
name: data4
description: entity data
type: plain
data_type: string

as:
name: As
description: Link to other A entities
Expand Down