Skip to content

15_mac_a2c.py中Categorical的冗余操作 #8

@wzz618

Description

@wzz618
class MAC(nn.Module):
    def policy(self, observation, agent):
        # 参考https://pytorch.org/docs/stable/distributions.html#score-function
        log_prob_action = self.agent2policy[agent].policy(observation)
        m = Categorical(logits=log_prob_action)  # 应该用prob传参
        action = m.sample()
        log_prob_a = m.log_prob(action) 
        return action.item(), log_prob_a

上文定义的策略函数返回的是归一化概率和归一化对数概率,所以创建Categorical对象时候应该传入的参数名是prob,而不是logits

m = Categorical(prob=log_prob_action)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions