Noisy DQN 跑 CartPole-v1

gym 0.26.1
CartPole-v1
NoisyNet DQN

NoisyNet 就是把原来Linear里的w/b 换成 mu + sigma * epsilon, 这是一种非常简单的方法,但是可以显著提升DQN的表现。
和之前最原始的DQN相比就是改了两个地方,一个是Linear改成了NoisyLinear,另外一个是在agenttake_action的时候策略 由ε-greedy改成了直接取argmax。详细见下面的代码。

本文的实现参考王树森的深度强化学习。

引用书上的一段话, 噪声DQN本身就带有随机性,可以鼓励探索,起到与ε-greedy策略相同的作用,直接用a_t = argmax Q(s,a,epsilon; mu,sigma), 作为行为策略,效果比ε-greedy更好。

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
import math

class ReplayBuffer:
    """经验回放池"""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
    
    def add(self, state, action, reward, next_state, done): # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
        transition = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transition)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self): # 目前buffer中数据的数量
        return len(self.buffer)

class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
        super().__init__(in_features, out_features, bias)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))
            self.register_buffer("epsilon_bias", torch.zeros(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x, is_training=True):
        self.epsilon_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.epsilon_bias.normal_()
            bias = bias + self.sigma_bias * self.epsilon_bias.data
        if is_training:
            return F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)
        else:
            return F.linear(x, self.weight, bias)

class Q(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = NoisyLinear(state_dim, hidden_dim)
        self.fc2 = NoisyLinear(hidden_dim, action_dim)
    def forward(self, x, is_training=True):
        x = F.relu(self.fc1(x, is_training)) # 隐藏层之后使用ReLU激活函数
        return self.fc2(x, is_training)

class DQN:
    """DQN算法"""
    def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, target_update, device):
        self.action_dim = action_dim
        self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
        self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
        self.target_q.load_state_dict(self.q.state_dict())  # 加载参数
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma
        self.target_update = target_update # 目标网络更新频率
        self.count = 0 # 计数器,记录更新次数
        self.device = device
    
    def take_action(self, state): # 这个地方就不用epsilon-贪婪策略
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        action = self.q(state).argmax().item()
        return action
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        q_values = self.q(states).gather(1, actions) # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
        q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
        loss = F.mse_loss(q_values, q_targets) # 均方误差
        self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
        loss.mean().backward() # 反向传播
        self.optimizer.step() # 更新梯度
        
        if self.count % self.target_update == 0:
            self.target_q.load_state_dict(self.q.state_dict())
        self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)

env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, target_update, device)
return_list = []

for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

这次是在pycharm上运行jupyter file,结果如下:




效果对比之前的DQN 详细参考这篇 表现是显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/286819.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

车载 Android之 核心服务 - CarPropertyService 解析

重要类的源码文件名及位置: CarPropertyManager.java packages/services/Car/car-lib/src/android/car/hardware/property/ CarPropertyService.java packages/services/Car/service/src/com/android/car/ 类的介绍: CarPropertyManager&#xff1a…

基于多反应堆的高并发服务器【C/C++/Reactor】(中)在EventLoop中处理被激活的文件描述符的事件

文件描述符处理与回调函数 一、主要概念 反应堆模型:一种处理系统事件或网络事件的模型,当文件描述符被激活时,可以检测到文件描述符:在操作系统中,用于标识打开的文件、套接字等的一种数据类型 处理激活的文件描述符…

k8s中ConfigMap详解及应用

一、ConfigMap概述 ConfigMap是k8s的一个配置管理组件,可以将配置以key-value的形式传递,通常用来保存不需要加密的配置信息,加密信息则需用到Secret,主要用来应对以下场景: 使用k8s部署应用,当你将应用配置…

DrGraph原理示教 - OpenCV 4 功能 - 二值化

二值化,也就是处理结果为0或1,当然是针对图像的各像素而言的 1或0,对应于有无,也就是留下有用的,删除无用的,有用的部分,就是关心的部分 在图像处理中,也不仅仅只是1或0,…

docker安装postgresql15或者PG15

1. 查询版本 docker search postgresql docker pull postgres:15.3 # 也可以拉取其他版本2.运行容器并挂载数据卷 mkdir -p /data/postgresql docker run --name postgres \--restartalways \-e POSTGRES_PASSWORDpostgresql \-p 5433:5432 \-v /data/postgresql:/var/lib/p…

工业制造领域,折弯工艺如何进行优化?

若将消费互联网与工业互联网相比较,消费互联网就好似一片宽度为1000米、深度仅有1米的水域,而工业互联网则类似于宽度有1000米、深达10000米的海域。消费互联网因为被限制了深度,便只能在浅显的焦虑中创造出一种消费趋势。相较之下&#xff0…

数据库迁移工具包:DBSofts ESF Database Migration Crack

ESF 数据库迁移工具包 - 11.2.17 允许您通过 3 个步骤在各种数据库格式之间迁移数据,无需任何脚本! DBSofts ESF Database Migration它极大地减少了与以下任何数据库格式之间迁移的工作量、成本和风险:Oracle、MySQL、MariaDB、SQL Server、…

(15)Linux 进程创建与终止函数forkslab 分派器

前言:本章我们主要讲解进程的创建与终止,最后简单介绍一下 slab 分派器。 一、进程创建(Process creation) 1、分叉函数 fork 在 中, fork 函数是非常重要的函数,它从已存在进程中创建一个新的进程。 …

html引入react以及hook的使用

html引入react 效果代码注意 效果 分享react demo片段的时候&#xff0c;如果是整个工程项目就有点太麻烦了&#xff0c;打开速度慢&#xff0c;文件多且没必要&#xff0c;这个时候用html就很方便。 在html中能正常使用useState 和 useEffect 等hook。 代码 <!DOCTYPE htm…

助力成长的开源项目 —— 筑梦之路

闯关式 SQL 自学&#xff1a;sql-mother 免费的闯关式 SQL 自学教程网站&#xff0c;从 0 到 1 带大家掌握常用 SQL 语法&#xff0c;目前一共有 30 多个关卡&#xff0c;希望你在通关的时候&#xff0c;变身为一个 SQL 高手。除了闯关模式之外&#xff0c;这个项目支持自由选…

Redis(二)数据类型

文章目录 官网备注十大数据类型StringListHashSetZSetBitmapHyperLogLog&#xff1a;GEOStreamBitfield 官网 英文&#xff1a;https://redis.io/commands/ 中文&#xff1a;http://www.redis.cn/commands.html 备注 命令不区分大小写&#xff0c;key区分大小写帮助命令help…

阿里云服务器开放端口Oracle 1521方法教程

阿里云服务器ECS端口是在安全组设置的&#xff0c;Oracle数据库1521端口号开放是在安全组中添加规则来实现的&#xff0c;阿里云服务器网aliyunfuwuqi.com来详细说下阿里云服务器开放Oracle 1521端口方法教程&#xff1a; 阿里云服务器开放Oracle 1521端口 在阿里云服务器ECS…

前端Web系统架构设计

文章目录 1.目录结构定义2. 路由封装2.1 API路由定义2.2 组件路由定义 3. Axios请求开发4. 环境变量封装5. storage模块封装(sessionStorage, localStorage)6. 公共函数封装(日期,金额,权限..)7. 通用交互定义(删除二次确认,类别,面包屑...)8. 接口全貌概览 1.目录结构定义 2. …

mysql 单表 操作 最大条数验证 以及优化

1、背景 开车的多年老司机&#xff0c;是不是经常听到过&#xff0c;“mysql 单表最好不要超过 2000w”,“单表超过 2000w 就要考虑数据迁移了”&#xff0c;“你这个表数据都马上要到 2000w 了&#xff0c;难怪查询速度慢”。 2、实验 实验一把看看… 建一张表 CREATE TABL…

机器学习-基于Word2vec搜狐新闻文本分类实验

机器学习-基于Word2vec搜狐新闻文本分类实验 实验介绍 Word2vec是一群用来产生词向量的相关模型&#xff0c;由Google公司在2013年开放。Word2vec可以根据给定的语料库&#xff0c;通过优化后的训练模型快速有效地将一个词语表达成向量形式&#xff0c;为自然语言处理领域的应…

STC进阶开发(三)蜂鸣器、RTC时钟、I2C总线、外部中断、RTC闹钟设置、RTC计时器设置

前言 这一期我们首先学习如何让蜂鸣器响起来&#xff0c;并且如何让蜂鸣器发出简单的歌曲&#xff0c;然后我们介绍RTC时钟&#xff0c;要想明白RTC时钟&#xff0c;我们还需要先介绍I2C总线和外部中断。接下来就开始这一期的学习吧&#xff01; 蜂鸣器 简单介绍 蜂鸣器是一种…

论文笔记:CellSense: Human Mobility Recovery via Cellular Network Data Enhancement

1 intro 1.1 背景 1.1.1 蜂窝计费记录&#xff08;CBR&#xff09; 人类移动性在蜂窝网络上的研究近些年得到了显著关注&#xff0c;这主要是因为手机的高渗透率和收集手机数据的边际成本低蜂窝服务提供商收集蜂窝计费记录&#xff08;CBR&#xff09;用于计费目的&#xf…

哲学家进餐问题-第三十二天

目录 问题描述 解决问题 结论 问题描述 解决问题 1、 关系分析&#xff1a;找出题目中描述的各个进程&#xff0c;分析它们之间的同步、互斥关系 系统中有5个哲学家进程&#xff0c;5位哲学家与左右邻居对其中间筷子的访问是互斥关系 2、整理思路&#xff1a;根据各进程的操…

ubuntu terminator 非常好用的护眼配置

安装 sudo apt install terminator 配置文件&#xff1a;sudo gedit ~/.config/terminator/config &#xff08;如果没有就创建&#xff09; 配置如下&#xff1a; [global_config] handle_size -3 title_transmit_fg_color "#000000" title_trans…

KBDNO1.DLL文件缺失,软件或游戏无法启动运行,怎样快速修复

不少小伙伴&#xff0c;求助电脑报错“KBDNO1.DLL文件缺失&#xff0c;软件或游戏无法启动或运行”&#xff0c;应该怎么办&#xff1f; 首先&#xff0c;我们先来了解“KBDNO1.DLL文件”是什么&#xff1f; KBDNO1.DLL是Windows操作系统中的一个动态链接库文件&#xff0c;主…
最新文章