-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmanagedsystem.py
More file actions
122 lines (97 loc) · 4.44 KB
/
managedsystem.py
File metadata and controls
122 lines (97 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from masced_bandits.bandits import init_bandit
from masced_bandits.bandit_options import initialize_arguments
from environmentgrammar import environment_grammar, EnvironmentTransformer
from lark import Lark
import numpy as np
import matplotlib.pyplot as plt
from statistics import mean
import pprint
from math import floor
class MockSAS:
def __init__(self, policy, parse_tree):
#could be multiple systems with one environment
self.managing_systems = []
transformer_dict = EnvironmentTransformer().transform(parse_tree)
#pprint.pprint(transformer_dict)
managed = [self.ManagedSystem(transformer_dict["reward_generator"], transformer_dict["feature_generator"])] #extend this with multiple managed systems if thats what you want to model.
EnvironmentTransformer.environment_grabber = managed[0].get_observations
m_sys = self.ManagingSystem(policy, managed, list(transformer_dict["all_arms"]))
self.managing_system = m_sys
def operation(self, res = {}):
managed_busy = [True]
while all(managed_busy):
managed_busy = []
for managed in self.managing_system.managed:
try:
managed.environment.notify_observers() #Observe environment
except AttributeError:
#no environment
pass
acks = managed.notify_observers() #Notify the managing system that an adaptation is necessary
busy = all(acks)
managed_busy.append(busy)
res[self.managing_system.name] = self.managing_system.avg_rw_record
return res
# if(input("> ") == 'stop'): break
class ManagingSystem:
def __init__(self, policy_tuple, managed, arms):
self.name = str(policy_tuple) + "_msys"
initialize_arguments(arms, 0, bounds=(0,1))
self.policy = init_bandit(**policy_tuple)
self.current_action = arms[0]
self.managed = managed
self.round = 1
self.average_reward = 0
self.avg_rw_record = []
for managed_system in managed:
managed_system.register_observer(self)
def notify(self, reward_distributions):
#metrics = environment.metrics; reward_function(metrics)
if(reward_distributions):
reward = next(reward_distributions[self.current_action])
#add check for none reward due to inactive arm.
self.current_action = self.policy.get_next_arm(reward)
self.average_reward = self.average_reward + ((1/self.round) * (reward - self.average_reward))
self.avg_rw_record.append(self.average_reward)
self.round = self.round + 1
return True
else:
#end of trace
#print("trace ended")
return False
"""
This is the environment with which the SAS interacts.
It should in essence be a mapping from actions to rewards.
"""
class ManagedSystem:
def __init__(self, generator, env_generator):
self._observers = []
self.observations = {}
self.generator = generator
if(env_generator): #Essentially, if there are env variables to observe
self.environment = self.Environment(env_generator)
self.environment.register_observer(self)
def get_observations(self):
return self.observations
def notify(self, new_observations):
self.observations = new_observations
def register_observer(self, observer):
self._observers.append(observer)
def notify_observers(self):
round_dists = next(self.generator)
acks = []
for obs in self._observers:
acks.append(obs.notify(round_dists))
return acks
class Environment:
def __init__(self, generator):
self._observers = []
self.generator = generator
def register_observer(self, observer):
self._observers.append(observer)
def notify_observers(self):
round_dists = next(self.generator)
acks = []
for obs in self._observers:
acks.append(obs.notify(round_dists))
return acks