-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontamination_stats.py
More file actions
103 lines (83 loc) · 3.84 KB
/
contamination_stats.py
File metadata and controls
103 lines (83 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from bitarray import bitarray
from typing import Optional, Dict, Any, Hashable
from dataclasses import dataclass
from light_scenario import LightScenario, LightScenarioKey
from common.general import asdict_without_nones
PART_INPUT: str = "input"
PART_REF: str = "reference"
@dataclass(frozen=True)
class ContaminationStatsKey:
"""Unique key representing a `ContaminationStats` instance."""
metadata: Dict[str, Hashable]
def __hash__(self):
return hash(tuple((k, self.metadata[k]) for k in sorted(self.metadata.keys())))
class ContaminationStats:
"""
A memory-efficient class for contamination stats. The core data structures are bit arrays where
every bit records whether an instance is contaminated or not.
"""
def __init__(
self, light_scenario_key: LightScenarioKey, num_instances: int, stats_tags: Optional[Dict[str, Any]] = None
):
self.stats_key = ContaminationStatsKey(metadata={"light_scenario_key": light_scenario_key})
self.num_instances = num_instances
if stats_tags is not None:
self.stats_key.metadata.update(stats_tags)
self._input_bits = bitarray(num_instances)
self._reference_bits = bitarray(num_instances)
self._input_bits.setall(0)
self._reference_bits.setall(0)
@classmethod
def from_scenario(cls, scenario: LightScenario, stats_tags: Optional[Dict[str, Any]] = None):
return cls(
light_scenario_key=scenario.light_scenario_key,
num_instances=len(scenario.light_instances),
stats_tags=stats_tags,
)
def write_one_to_bit(self, instance_id: int, part: str):
if part == PART_INPUT:
self._input_bits[instance_id] = 1
elif part == PART_REF:
self._reference_bits[instance_id] = 1
else:
raise ValueError(f"There is no valid part of instance named {part}")
def get_bit(self, instance_id: int, part: str) -> int:
if part == PART_INPUT:
return self._input_bits[instance_id]
elif part == PART_REF:
return self._reference_bits[instance_id]
else:
raise ValueError(f"There is no valid part of instance named {part}")
def merge(self, stats):
"""Merge two stats instance of the same scenario"""
if self.stats_key != stats.stats_key:
raise ValueError("Only stats with the same `stats_key` can be merged.")
if self.num_instances != stats.num_instances:
raise ValueError("The sizes of the two scenarios need to equal.")
self._input_bits |= stats._input_bits
self._reference_bits |= stats._reference_bits
@property
def num_instances_with_contaminated_input(self):
return self._input_bits.count()
@property
def num_instances_with_contaminated_reference(self):
return self._reference_bits.count()
@property
def contaminated_input_fraction(self):
return self._input_bits.count() / self.num_instances
@property
def contaminated_reference_fraction(self):
return self._reference_bits.count() / self.num_instances
def generate_summary(self, summary_tags: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Output a summary of the stats"""
if summary_tags is None:
summary_tags = {}
summary = {
"setting": {"stats_key": asdict_without_nones(self.stats_key), **summary_tags},
"num_instances": self.num_instances,
"num_instances_with_contaminated_input": self.num_instances_with_contaminated_input,
"num_instances_with_contaminated_reference": self.num_instances_with_contaminated_reference,
"contaminated_input_fraction": self.contaminated_input_fraction,
"contaminated_reference_fraction": self.contaminated_reference_fraction,
}
return summary