Skip to content

Commit 0d015a4

Browse files
authored
Refactor PerturbationEngine methods for consistency
1 parent a775ac5 commit 0d015a4

1 file changed

Lines changed: 92 additions & 91 deletions

File tree

  • domain-shift-fusion-benchmark/src/fusionbench/perturbations

domain-shift-fusion-benchmark/src/fusionbench/perturbations/operators.py

Lines changed: 92 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -10,95 +10,96 @@
1010

1111

1212
class PerturbationEngine:
13-
def __init__(self, seed: int = 42):
14-
self._rng = random.Random(seed)
15-
16-
def apply(self, samples: List[Sample], operations: Iterable[Dict]) -> List[Sample]:
17-
transformed = copy.deepcopy(samples)
18-
for operation in operations:
19-
transformed = self._apply_operation(transformed, operation)
20-
return transformed
21-
22-
def _apply_operation(self, samples: List[Sample], op: Dict) -> List[Sample]:
23-
op_type = str(op.get("type", "")).strip().lower()
24-
target = op.get("target", "all")
25-
26-
if op_type == "gaussian_noise":
27-
std = float(op.get("std", 0.05))
28-
return self._gaussian_noise(samples, target, std)
29-
30-
if op_type == "dropout":
31-
probability = float(op.get("probability", 0.1))
32-
return self._dropout(samples, target, probability)
33-
34-
if op_type == "bias_drift":
35-
offset = float(op.get("offset", 0.1))
36-
return self._bias_drift(samples, target, offset)
37-
38-
if op_type == "uncertainty_inflation":
39-
factor = float(op.get("factor", 1.5))
40-
return self._uncertainty_inflation(samples, target, factor)
41-
42-
if op_type == "temporal_jitter":
43-
window = int(op.get("window", 3))
44-
return self._temporal_jitter(samples, target, window)
45-
46-
raise ValueError(f"Unknown perturbation type: {op_type}")
47-
48-
def _target_sensors(self, sample: Sample, target: str) -> List[str]:
49-
if target == "all":
50-
return list(sample.sensors.keys())
51-
if target not in sample.sensors:
52-
return []
53-
return [target]
54-
55-
def _gaussian_noise(self, samples: List[Sample], target: str, std: float) -> List[Sample]:
56-
for sample in samples:
57-
for sensor_name in self._target_sensors(sample, target):
58-
reading = sample.sensors[sensor_name]
59-
reading.score = clamp(reading.score + self._rng.gauss(0.0, std))
60-
reading.uncertainty = clamp(reading.uncertainty + abs(self._rng.gauss(0.0, std * 0.5)), 0.01, 1.0)
61-
return samples
62-
63-
def _dropout(self, samples: List[Sample], target: str, probability: float) -> List[Sample]:
64-
for sample in samples:
65-
for sensor_name in self._target_sensors(sample, target):
66-
if self._rng.random() < probability:
67-
reading = sample.sensors[sensor_name]
68-
reading.score = 0.5
69-
reading.uncertainty = min(1.0, reading.uncertainty + 0.45)
70-
return samples
71-
72-
def _bias_drift(self, samples: List[Sample], target: str, offset: float) -> List[Sample]:
73-
for sample in samples:
74-
for sensor_name in self._target_sensors(sample, target):
75-
reading = sample.sensors[sensor_name]
76-
reading.score = clamp(reading.score + offset)
77-
reading.uncertainty = clamp(reading.uncertainty + abs(offset) * 0.25, 0.01, 1.0)
78-
return samples
79-
80-
def _uncertainty_inflation(self, samples: List[Sample], target: str, factor: float) -> List[Sample]:
81-
for sample in samples:
82-
for sensor_name in self._target_sensors(sample, target):
83-
reading = sample.sensors[sensor_name]
84-
reading.uncertainty = clamp(reading.uncertainty * factor, 0.01, 1.0)
85-
return samples
86-
87-
def _temporal_jitter(self, samples: List[Sample], target: str, window: int) -> List[Sample]:
88-
if window < 2:
89-
return samples
90-
91-
history: Dict[str, deque] = {}
92-
for sample in samples:
93-
for sensor_name in self._target_sensors(sample, target):
94-
history.setdefault(sensor_name, deque(maxlen=window))
95-
queue = history[sensor_name]
96-
queue.append(sample.sensors[sensor_name].score)
97-
98-
if len(queue) > 1 and self._rng.random() < 0.4:
99-
delayed_value = queue[0]
100-
reading = sample.sensors[sensor_name]
101-
reading.score = clamp((reading.score + delayed_value) * 0.5)
102-
reading.uncertainty = clamp(reading.uncertainty + 0.1, 0.01, 1.0)
103-
return samples
13+
    def __init__(self, seed: int = 42):
14+
        self._rng = random.Random(seed)
15+
16+
    def apply(self, samples: List[Sample], operations: Iterable[Dict]) -> List[Sample]:
17+
        transformed = copy.deepcopy(samples)
18+
        for operation in operations:
19+
            transformed = self._apply_operation(transformed, operation)
20+
        return transformed
21+
22+
    def _apply_operation(self, samples: List[Sample], op: Dict) -> List[Sample]:
23+
        op_type = str(op.get("type", "")).strip().lower()
24+
        target = op.get("target", "all")
25+
26+
        if op_type == "gaussian_noise":
27+
            std = float(op.get("std", 0.05))
28+
            return self._gaussian_noise(samples, target, std)
29+
30+
        if op_type == "dropout":
31+
            probability = float(op.get("probability", 0.1))
32+
            return self._dropout(samples, target, probability)
33+
34+
        if op_type == "bias_drift":
35+
            offset = float(op.get("offset", 0.1))
36+
            return self._bias_drift(samples, target, offset)
37+
38+
        if op_type == "uncertainty_inflation":
39+
            factor = float(op.get("factor", 1.5))
40+
            return self._uncertainty_inflation(samples, target, factor)
41+
42+
        if op_type == "temporal_jitter":
43+
            window = int(op.get("window", 3))
44+
            return self._temporal_jitter(samples, target, window)
45+
46+
        raise ValueError(f"Unknown perturbation type: {op_type}")
47+
48+
    def _target_sensors(self, sample: Sample, target: str) -> List[str]:
49+
        if target == "all":
50+
            return list(sample.sensors.keys())
51+
        if target not in sample.sensors:
52+
            return []
53+
        return [target]
54+
55+
    def _gaussian_noise(self, samples: List[Sample], target: str, std: float) -> List[Sample]:
56+
        for sample in samples:
57+
            for sensor_name in self._target_sensors(sample, target):
58+
                reading = sample.sensors[sensor_name]
59+
                reading.score = clamp(reading.score + self._rng.gauss(0.0, std))
60+
                reading.uncertainty = clamp(reading.uncertainty + abs(self._rng.gauss(0.0, std * 0.5)), 0.01, 1.0)
61+
        return samples
62+
63+
    def _dropout(self, samples: List[Sample], target: str, probability: float) -> List[Sample]:
64+
        for sample in samples:
65+
            for sensor_name in self._target_sensors(sample, target):
66+
                if self._rng.random() < probability:
67+
                    reading = sample.sensors[sensor_name]
68+
                    reading.score = 0.5
69+
                    reading.uncertainty = min(1.0, reading.uncertainty + 0.45)
70+
        return samples
71+
72+
    def _bias_drift(self, samples: List[Sample], target: str, offset: float) -> List[Sample]:
73+
        for sample in samples:
74+
            for sensor_name in self._target_sensors(sample, target):
75+
                reading = sample.sensors[sensor_name]
76+
                reading.score = clamp(reading.score + offset)
77+
                reading.uncertainty = clamp(reading.uncertainty + abs(offset) * 0.25, 0.01, 1.0)
78+
        return samples
79+
80+
    def _uncertainty_inflation(self, samples: List[Sample], target: str, factor: float) -> List[Sample]:
81+
        for sample in samples:
82+
            for sensor_name in self._target_sensors(sample, target):
83+
                reading = sample.sensors[sensor_name]
84+
                reading.uncertainty = clamp(reading.uncertainty * factor, 0.01, 1.0)
85+
        return samples
86+
87+
    def _temporal_jitter(self, samples: List[Sample], target: str, window: int) -> List[Sample]:
88+
        if window < 2:
89+
            return samples
90+
91+
        history: Dict[str, deque] = {}
92+
        for sample in samples:
93+
            for sensor_name in self._target_sensors(sample, target):
94+
                history.setdefault(sensor_name, deque(maxlen=window))
95+
                queue = history[sensor_name]
96+
                queue.append(sample.sensors[sensor_name].score)
97+
98+
                if len(queue) > 1 and self._rng.random() < 0.4:
99+
                    delayed_value = queue[0]
100+
                    reading = sample.sensors[sensor_name]
101+
                    reading.score = clamp((reading.score + delayed_value) * 0.5)
102+
                    reading.uncertainty = clamp(reading.uncertainty + 0.1, 0.01, 1.0)
103+
        return samples
104+
104105

0 commit comments

Comments
 (0)