-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
87 lines (60 loc) · 3.08 KB
/
agent.py
File metadata and controls
87 lines (60 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.optim as optim
import numpy as np
class Agent:
def __init__(self, model, optimizer, args):
self.model = model
self.optimizer = optimizer
self.precision = args.precision
def apply_precision(self, model, precision):
"""Apply the specified precision to the model and default tensor type."""
if self.precision == "float16":
self.model = self.model.half()
elif self.precision == "float32":
self.model = self.model.float()
else:
raise ValueError(f"Unsupported precision: {self.precision}")
def mutate(self, noise_std):
# Add Gaussian noise to the model weights (for ES or GA mutation)
for param in self.model.parameters():
noise = torch.normal(0, noise_std, size=param.size()).to(param.device)
param.data += noise
def mutate_ES(self, args, role, step, weights_logging_agent_0, weights_logging_agent_1, weights_logging_adversary):
"""
Mutates the weights of the model using a normal distribution.
Args:
args: Arguments containing the mutation power (scale of the mutation).
Returns:
dict: A dictionary of perturbations applied to the weights.
"""
if role == "agent_0":
mutation_power = args.mutation_power_agent_0
if role == "agent_1":
mutation_power = args.mutation_power_agent_1
if role == "adversary_0":
mutation_power = args.mutation_power_adversary
weights = self.model.get_perturbable_weights()
noise = np.random.normal(loc=0.0, scale=mutation_power, size=len(weights))
self.model.set_perturbable_weights(weights + noise, args)
"""for key, value in weights.items():
# Generate noise with the same shape as the weight tensor
noise = np.random.normal(loc=0.0, scale=args.mutation_power, size=value.shape).astype(args.precision)
perturbations[key] = torch.tensor(noise, dtype=value.dtype, device=value.device)
# Add the noise to the weights
weights[key] += perturbations[key]
# Set the mutated weights back to the model
#print(f"weights = {weights.keys()}")
self.model.set_perturbable_weights(weights)"""
self.log_weight_statistics(step=step, weights_logging_agent_0=weights_logging_agent_0, weights_logging_agent_1=weights_logging_agent_1, weights_logging_adversary=weights_logging_adversary, role=role)
return noise
def set_weights(self, weights):
# Set specific weights for the model (useful for ES)
self.model.load_state_dict(weights)
def get_weights(self):
# Return model weights (useful for ES)
return self.model.state_dict()
def clone(self, args):
"""
Delegates the clone method to the specific subclass instance.
"""
raise NotImplementedError("The clone method should be implemented by the specific agent type.")