强化学习--多维动作状态空间的设计

目录

  • 一、离散动作
  • 二、连续动作
    • 1、例子1
    • 2、知乎给出的示例
    • 2、github里面的代码

免责声明:以下代码部分来自网络,部分来自ChatGPT,部分来自个人的理解。如有其他观点,欢迎讨论!

一、离散动作

注意:本文均以PPO算法为例。

# time: 2023/11/22 21:04
# author: YanJP


import torch
import torch
import torch.nn as nn
from torch.distributions import Categorical

class MultiDimensionalActor(nn.Module):
    def __init__(self, input_dim, output_dims):
        super(MultiDimensionalActor, self).__init__()

        # Define a shared feature extraction network
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Define individual output layers for each action dimension
        self.output_layers = nn.ModuleList([
            nn.Linear(64, num_actions) for num_actions in output_dims
        ])

    def forward(self, state):
        # Feature extraction
        features = self.feature_extractor(state)

        # Generate Categorical objects for each action dimension
        categorical_objects = [Categorical(logits=output_layer(features)) for output_layer in self.output_layers]

        return categorical_objects

# 定义主函数
def main():
    # 定义输入状态维度和每个动作维度的动作数
    input_dim = 10
    output_dims = [5, 8]  # 两个动作维度,分别有 3 和 4 个可能的动作

    # 创建 MultiDimensionalActor 实例
    actor_network = MultiDimensionalActor(input_dim, output_dims)

    # 生成输入状态(这里使用随机数据作为示例)
    state = torch.randn(1, input_dim)

    # 调用 actor 网络
    categorical_objects = actor_network(state)

    # 输出每个动作维度的采样动作和对应的对数概率
    for i, categorical in enumerate(categorical_objects):
        sampled_action = categorical.sample()
        log_prob = categorical.log_prob(sampled_action)
        print(f"Sampled action for dimension {i+1}: {sampled_action.item()}, Log probability: {log_prob.item()}")

if __name__ == "__main__":
    main()

#Sampled action for dimension 1: 1, Log probability: -1.4930928945541382
#Sampled action for dimension 2: 3, Log probability: -2.1875085830688477

注意代码中categorical函数的两个不同传入参数的区别:参考链接
简单来说,logits是计算softmax的,probs直接就是已知概率的时候传进去就行。

二、连续动作

参考链接:github、知乎
为什么取对数概率?参考回答
在这里插入图片描述

1、例子1

先看如下的代码:

# time: 2023/11/21 21:33
# author: YanJP
#这是对应多维连续变量的例子:
# 参考链接:https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/utils.py
# https://www.zhihu.com/question/417161289
import torch.nn as nn
import torch
class Policy(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_outputs):
        super(Policy, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            nn.Linear(n_hidden_2, num_outputs)
        )

class Normal(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        self.stds = nn.Parameter(torch.zeros(num_outputs))  #创建一个可学习的参数 
    def forward(self, x):
        dist = torch.distributions.Normal(loc=x, scale=self.stds.exp())
        action = dist.sample((every_dimention_output,))  #这里我觉得是最重要的,不填sample的参数的话,默认每个分布只采样一个值!!!!!!!!
        return action

if __name__ == '__main__':
    policy = Policy(4,20,20,5)
    normal = Normal(5) #设置5个维度
    every_dimention_output=10  #每个维度10个输出
    observation = torch.Tensor(4)
    action = normal.forward(policy.layer( observation))
    print("action: ",action)
  • self.stds.exp(),表示求指数,因为正态分布的标准差都是正数。
  • action = dist.sample((every_dimention_output,))这里最重要!!!

2、知乎给出的示例


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

这里的np.prod(envs.single_action_space.shape),表示每个维度的动作数相乘,然后初始化这么多个actor网络的标准差和均值,最后action里面的sample就是采样这么多个数据。(感觉还是拉成了一维计算)

2、github里面的代码

github

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta,Normal

class GaussianActor_musigma(nn.Module):
	def __init__(self, state_dim, action_dim, net_width):
		super(GaussianActor_musigma, self).__init__()

		self.l1 = nn.Linear(state_dim, net_width)
		self.l2 = nn.Linear(net_width, net_width)
		self.mu_head = nn.Linear(net_width, action_dim)
		self.sigma_head = nn.Linear(net_width, action_dim)

	def forward(self, state):
		a = torch.tanh(self.l1(state))
		a = torch.tanh(self.l2(a))
		mu = torch.sigmoid(self.mu_head(a))
		sigma = F.softplus( self.sigma_head(a) )
		return mu,sigma

	def get_dist(self, state):
		mu,sigma = self.forward(state)
		dist = Normal(mu,sigma)
		return dist

	def deterministic_act(self, state):
		mu, sigma = self.forward(state)
		return mu

上述代码主要是通过设置mu_head 和sigma_head的个数,来实现多维动作。

class GaussianActor_mu(nn.Module):
	def __init__(self, state_dim, action_dim, net_width, log_std=0):
		super(GaussianActor_mu, self).__init__()

		self.l1 = nn.Linear(state_dim, net_width)
		self.l2 = nn.Linear(net_width, net_width)
		self.mu_head = nn.Linear(net_width, action_dim)
		self.mu_head.weight.data.mul_(0.1)
		self.mu_head.bias.data.mul_(0.0)

		self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)

	def forward(self, state):
		a = torch.relu(self.l1(state))
		a = torch.relu(self.l2(a))
		mu = torch.sigmoid(self.mu_head(a))
		return mu

	def get_dist(self,state):
		mu = self.forward(state)
		action_log_std = self.action_log_std.expand_as(mu)
		action_std = torch.exp(action_log_std)

		dist = Normal(mu, action_std)
		return dist

	def deterministic_act(self, state):
		return self.forward(state)
class Critic(nn.Module):
	def __init__(self, state_dim,net_width):
		super(Critic, self).__init__()

		self.C1 = nn.Linear(state_dim, net_width)
		self.C2 = nn.Linear(net_width, net_width)
		self.C3 = nn.Linear(net_width, 1)

	def forward(self, state):
		v = torch.tanh(self.C1(state))
		v = torch.tanh(self.C2(v))
		v = self.C3(v)
		return v

上述代码只定义了mu的个数与维度数一样,std作为可学习的参数之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/178066.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2023年【广东省安全员B证第四批(项目负责人)】报名考试及广东省安全员B证第四批(项目负责人)复审考试

题库来源:安全生产模拟考试一点通公众号小程序 广东省安全员B证第四批(项目负责人)报名考试是安全生产模拟考试一点通总题库中生成的一套广东省安全员B证第四批(项目负责人)复审考试,安全生产模拟考试一点…

Python开发运维:Celery连接Redis

目录 一、理论 1.Celery 二、实验 1.Windows11安装Redis 2.Python3.8环境中配置Celery 三、问题 1.Celery命令报错 2.执行Celery命令报错 3.Win11启动Celery报ValueErro错误 一、理论 1.Celery (1) 概念 Celery是一个基于python开发的分布式系统,它是简单…

redis 高可用

redis 的性能管理:redis的数据缓存在内存当中。 查看redis性能指标:info memory used_ memory:1800800 redis中数据占用的内存 used_ memory_ rss:5783552 redis向操作系统申请的内存 used_ memory_ peak: 1800800 redis使用内存的峰值。 工作有用 …

持续集成失败:hudson.plugins.git.GitException: Failed to delete workspace

持续集成环境(git gitlab jenkins pipeline maven harbor docker k8s)之前都是ok的,突然就报错了: Cloning the remote Git repository Cloning repository git192.168.117.180:qzcsbj/gift.git ERROR: Failed to clean the workspace jenkins.ut…

2023年【安全生产监管人员】考试题及安全生产监管人员找解析

题库来源:安全生产模拟考试一点通公众号小程序 安全生产监管人员考试题参考答案及安全生产监管人员考试试题解析是安全生产模拟考试一点通题库老师及安全生产监管人员操作证已考过的学员汇总,相对有效帮助安全生产监管人员找解析学员顺利通过考试。 1、…

【问题定位】通过看Mybatis源码解决系统问题

开发需求好好的,运维同事突然发现了一个问题,某个任务的详情页面加载不出来。看日志,系统在进行查询操作的时候抛出空指针异常。感觉是Mybatis内部异常,所以就跟踪源码看下Mybatis运行到哪一步报错的。 DefaultSqlSession#select…

LeetCode算法题解(动态规划)|LeetCode343. 整数拆分、LeetCode96. 不同的二叉搜索树

一、LeetCode343. 整数拆分 题目链接:343. 整数拆分 题目描述: 给定一个正整数 n ,将其拆分为 k 个 正整数 的和( k > 2 ),并使这些整数的乘积最大化。 返回 你可以获得的最大乘积 。 示例 1: 输入…

【Linux】-进程间通信-共享内存(SystemV),详解接口函数以及原理(使用管道处理同步互斥机制)

💖作者:小树苗渴望变成参天大树🎈 🎉作者宣言:认真写好每一篇博客💤 🎊作者gitee:gitee✨ 💞作者专栏:C语言,数据结构初阶,Linux,C 动态规划算法🎄 如 果 你 …

pyqt5的组合式部件制作(四)

对组合式部件的制作又改进了一版,组合式部件的子部件不再需要单独“提升为”,如果在模板文件的提升部件窗口内选择了“全局包含”,那么只需要在模板文件和应用文件中直接复制粘贴即可,部件的应用更为简便。如下图:按住…

docker、elasticsearch8、springboot3集成备忘

目录 一、背景 二、安装docker 三、下载安装elasticsearch 四、下载安装elasticsearch-head 五、springboot集成elasticsearch 一、背景 前两年研究了一段时间elasticsearch,当时也是网上找了很多资料,最后解决个各种问题可以在springboot上运行了…

>Web 3.0顶级干货教学:浅析区块链与货币关系

Web 3.0顶级干货教学🔥:浅析区块链与货币关系 尊重原创,编写不易 ,帮忙点赞关注一下~转载小伙伴请注明出处!谢谢 1.0 数字交易 最早一笔数字化交易 是在www.PizzaHut.com 在 1994 年产生的,但是有趣的事情…

oracle面试相关的,Oracle基本操作的SQL命令

文章目录 数据库-Oracle〇、Oracle用户管理一、Oracle数据库操作二、Oracle表操作1、创建表2、删除表3、重命名表4、增加字段5、修改字段6、重名字段7、删除字段8、添加主键9、删除主键10、创建索引11、删除索引12、创建视图13、删除视图 三、Oracle操作数据1、数据查询2、插入…

linux部署jar 常见问题

1.java -jar xxx.jar no main manifest attribute, in xxx.jar 一.no main manifest attribute, in xxx.jar 在pom.xml文件中加入&#xff1a; <plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifac…

【DevOps】Git 图文详解(八):后悔药 - 撤销变更

Git 图文详解&#xff08;八&#xff09;&#xff1a;后悔药 - 撤销变更 1.后悔指令 &#x1f525;2.回退版本 reset3.撤销提交 revert4.checkout / reset / revert 总结 发现写错了要回退怎么办&#xff1f;看看下面几种后悔指令吧&#xff01; ❓ 还没提交的怎么撤销&#x…

QTableWidget——编辑单元格

文章目录 前言熟悉QTableWiget&#xff0c;通过实现单元格的合并、拆分、通过编辑界面实现表格内容及属性的配置、实现表格的粘贴复制功能熟悉QTableWiget的属性 一、[单元格的合并、拆分](https://blog.csdn.net/qq_15672897/article/details/134476530?spm1001.2014.3001.55…

Linux中的进程程序替换

Linux中的进程程序替换 1. 替换原理2. 替换函数3. 函数解释4. 命名理解程序替换的意义 1. 替换原理 替换原理 用fork创建子进程后执行的是和父进程相同的程序(但有可能执行不同的代码分支),子进程往往要调用一种exec函数以执行另一个程序。当进程调用一种exec函数时,该进程的…

Vue 定义只读数据 readonly

readonly 让一个响应式数据变为 **深层次的只读数据**。 isReadonly 判断一个数据是不是只读数据。 应用场景&#xff1a;不希望数据被修改时使用。 readonly 深层次只读&#xff1a; <template><h1>reactive数据</h1><p>姓名&#xff1a;{{ info…

BUUCTF 梅花香之苦寒来 1

BUUCTF:https://buuoj.cn/challenges 题目描述&#xff1a; 注意&#xff1a;得到的 flag 请包上 flag{} 提交 密文&#xff1a; 下载附件&#xff0c;解压得到一张.jpg图片。 解题思路&#xff1a; 1、用010 Editor看了一下&#xff0c;刚开始以为是修改宽高的题&#xff…

【C++干货铺】list的使用 | 模拟实现

个人主页点击直达&#xff1a;小白不是程序媛 C专栏&#xff1a;C干货铺 代码仓库&#xff1a;Gitee 目录 list的介绍及使用 list的介绍 list的使用 list的构造 list迭代器的使用 list的增删查改 list的模拟实现 结点的封装 迭代器的封装 list成员变量 构造函数 …

22LLMSecEval数据集及其在评估大模型代码安全中的应用:GPT3和Codex根据LLMSecEval的提示生成代码和代码补全,CodeQL进行安全评估

LLMSecEval: A Dataset of Natural Language Prompts for Security Evaluations 写在最前面主要工作 课堂讨论大模型和密码方向&#xff08;没做&#xff0c;只是一个idea&#xff09; 相关研究提示集目标NL提示的建立NL提示的建立流程 数据集数据集分析 存在的问题 写在最前面…
最新文章