reinforce 跑 CartPole-v1

gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。

然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。

代码

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了

class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, X):
        X = F.relu(self.fc1(X))
        return F.softmax(self.fc2(X),dim=1)

class REINFORCE:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
        self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)
        self.gamma = gamma # 折扣因子
        self.device = device
    
    def take_action(self, state): # 根据动作概率分布随机采样
        state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)
        probs = self.policy_net(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    
    def update(self, transition_dict):  # 公式用的是简化推导
        reward_list = transition_dict['rewards']
        state_list = transition_dict['states']
        action_list = transition_dict['actions']
        
        G = 0
        self.optimizer.zero_grad()
        for i in reversed(range(len(reward_list))):  # 从最后一步算起
            reward = reward_list[i]
            state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)
            action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)
            log_prob = torch.log(self.policy_net(state).gather(1, action))
            G = self.gamma * G + reward 
            loss = -log_prob * G  # 因为梯度更新是减的,所以取个负号
            loss.backward()
        self.optimizer.step()
  
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()

env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :  # 主要是这部分和原始的有点不同
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                transition_dict['states'].append(state)
                transition_dict['actions'].append(action)
                transition_dict['next_states'].append(next_state)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(done)
                state = next_state
                episode_return += reward
            return_list.append(episode_return)
            agent.update(transition_dict)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,结果如下所示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/234543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

React基础语法整理

安装: yarn create react-app reatc-lesson --template typescript yarn create 创建一个react-app的应用 项目名称 typescript 的模板react-app 官方地址 https://create-react-app.bootcss.com/docs/adding-typescriptreact 语法文档 https://zh-hans.react.dev…

12.9文档记录——脱欧建模

4 脱欧对英国整体的影响 4.1 脱欧影响评估模型 4.1.1 指标的确定和数据的收集 为了建立指标体系,我们需要选择有代表性的指标。在有关脱欧的大量研究基础上[1],考虑到脱欧对英国经济、民生、国际影响等各个方面影响,我们最终在兼顾模型有效…

表格的介绍与实战(详细且有案例)

目录​​​​​​​​​​​​​​ 表格的主要作用: 表格的基本语法: 表格相关的标签 合并单元格: 实战: 表格的主要作用: 表格主要是用来展示数据的,使用表格来展示数据,数据可读性更好…

大数据HCIE成神之路之数据预处理(1)——缺失值处理

缺失值处理 1.1 删除1.1.1 实验任务1.1.1.1 实验背景1.1.1.2 实验目标1.1.1.3 实验数据解析 1.1.2 实验思路1.1.3 实验操作步骤1.1.4 结果验证 1.2 填充1.2.1 实验任务1.2.1.1 实验背景1.2.1.2 实验目标1.2.1.3 实验数据解析 1.2.2 实验思路1.2.3 实验操作步骤1.2.4 结果验证 1…

基于Solr的全文检索系统的实现与应用

文章目录 一、概念1、什么是Solr2、与Lucene的比较区别1)Lucene2)Solr 二、Solr的安装与配置1、Solr的下载2、Solr的文件夹结构3、运行环境4、Solr整合tomcat1)Solr Home与SolrCore2)整合步骤 5、Solr管理后台1)Dashbo…

关于图像清晰度、通透度的描述

1、问题背景 在图像评测过程中,从主观上一般怎么去评判一副图像的优劣呢? 比较显而易见的就是图像的清晰度和通透度,他们决定了评判者对画质的第一印象。 那怎么去理解图像的清晰度和通透度呢?这是本文要描述的内容。 2、问题分…

YOLOV3 SPP 目标检测项目(针对xml或者yolo标注的自定义数据集)

1. 目标检测的两种标注形式 项目下载地址:YOLOV3 SPP网络对自定义数据集的目标检测(标注方式包括xml或者yolo格式) 目标检测边界框的表现形式有两种: YOLO(txt) : 第一个为类别,后面四个为边界框,x,y中心点坐标以及h,w的相对值 xml文件:类似于网页的标注文件,里面会…

Element-ui框架完成vue2项目的vuex的增删改查

看效果图是否是你需要的 这是原来没有Element-ui框架的 首先,你要在你的项目里安装Element-ui yarn命令 yarn add element-uinpm命令 npm install element-ui --save好了现在可以粘贴代码 //main.js import Vue from vue import Vuex from vuex import VueRouter …

跟着chatgpt一起学|2.Clickhouse入门(2)

跟着chatgpt一起学|2.clickhouse入门(1)-CSDN博客 chatgpt规划的学习路径如下: 目录 2.数据建模和表设计 2.1 数据模型和表设计原则 2.1.1 什么是LowCardinality类型? 2.1.2 什么是数据分片? 2.2 ClickHouse支持…

初学者如何入门 Generative AI 之 Stable Diffusion 与 CLIP :看两篇综述,玩几个应用感受一下先!超多高清大图,沉浸式体验

文章大纲 4种 图片生成 的算法扩散模型的起源Stable DiffusionCLIP参考文献与学习路径A synthography of an astronaut riding a horse created in NightCafe Studio with Stable Diffusion XL (SDXL). Prompt is a photograph of an astronaut riding a horse with weight of …

数据结构与算法-Rust 版读书笔记-1语言入门

数据结构与算法-Rust 版笔记 一、语言入门 1、关键字、注释、命名风格 目前(可能还会增加)39个,注意,Self和self是两个关键字。 Self enum match super as extern mod trait async false …

【LeetCode热题100】【滑动窗口】无重复字符的最长子串

给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。示例 2: 输入: s "bbbbb" 输出: 1 解释: 因为无…

macos下安装科研绘图软件Origin

科研人必备软件Origin,主要是考虑到很多期刊都要求绘制origin可编辑的图,所以有些时候必须用这个软件,但是这个软件macos并不支持,所以必须考虑其他的方案,我没有安装虚拟机,而是使用crossover 安装crosso…

程序的机器代码表示--函数调用

call和ret指令 如何访问栈帧、如何切换栈帧、如何传递参数和返回值 call、ret指令作用: call:1)将IP(即PC)旧值压栈保存(保存在函数的栈帧顶部);2)设置IP新值&#xff0…

P1317 低洼地题解

题目 一组数,分别表示地平线的高度变化。高度值为整数,相邻高度用直线连接。找出并统计有多少个可能积水的低洼地? 如图:地高变化为 [0,1,0,2,1,2,0,0,2,0]。 输入输出格式 输入格式 两行,第一行n, 表示有n个数。第…

改进的A*算法的路径规划(1)

引言 近年来,随着智能时代的到来,路径规划技术飞快发展,已经形成了一套较为 成熟的理论体系。其经典规划算法包括 Dijkstra 算法、A*算法、D*算法、Field D* 算法等,然而传统的路径规划算法在复杂的场景的表现并不如人意&#xff…

【动态规划】斐波那契数列模型_解码方法_C++(medium)

题目链接:leetcode解码方法 目录 题目解析: 算法原理 1.状态表示 2.状态转移方程 3.初始化 4.填表顺序 5.返回值 编写代码 题目解析: 题目让我们求解码 方法的 总数 由题可得: 0和有前导0(比如06、08、04&am…

(企业项目)微服务项目解决跨域问题:

前后端分离项目中前端出现了跨域的问题 在网关模块配置文件中添加 配置 application.properties # 允许请求来源(老版本叫allowedOrigin) spring.cloud.gateway.globalcors.cors-configurations.[/**].allowedOriginPatterns* # 允许携带的头信息 spri…

vue 实现返回顶部功能-指定盒子滚动区域

vue 实现返回顶部功能-指定盒子滚动区域 html代码css代码返回顶部显示/隐藏返回标志 html代码 <a-icontype"vertical-align-top"class"top"name"back-top"click"backTop"v-if"btnFlag"/>css代码 .top {height: 35px;…

RabbitMQ学习笔记10 综合实战 实现新商家规定时间内上架商品检查

配置文件&#xff1a; 记住添加这个。 加上这段代码&#xff0c;可以自动创建队列和交换机以及绑定关系。 我们看到了我们创建的死信交换机和普通队列。 我们可以看到我们队列下面绑定的交换机。 我们创建一个controller包进行测试: 启动&#xff1a; 过一段时间会变成死信队列…
最新文章