class MAC(nn.Module):
def policy(self, observation, agent):
# 参考https://pytorch.org/docs/stable/distributions.html#score-function
log_prob_action = self.agent2policy[agent].policy(observation)
m = Categorical(logits=log_prob_action) # 应该用prob传参
action = m.sample()
log_prob_a = m.log_prob(action)
return action.item(), log_prob_a
m = Categorical(prob=log_prob_action)
上文定义的策略函数返回的是归一化概率和归一化对数概率,所以创建Categorical对象时候应该传入的参数名是prob,而不是logits