(3)Gymnasium--CartPole的测试基于DQN

 1、使用Pytorch基于DQN的实现

1.1 主要参考

(1)推荐pytorch官方的教程

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation

(2)

Pytorch 深度强化学习 – CartPole问题|极客笔记

2.2 pytorch官方的教程原理

待续,这两天时期多,过两天整理一下。

2.3代码实现

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("CartPole-v1")

# set up matplotlib
# is_ipython = 'inline' in matplotlib.get_backend()
# if is_ipython:
#     from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    # if is_ipython:
    #     if not show_result:
    #         display.display(plt.gcf())
    #         display.clear_output(wait=True)
    #     else:
    #         display.display(plt.gcf())


def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()



if torch.cuda.is_available():
    num_episodes = 600
else:
    # num_episodes = 50
    num_episodes = 600

for i_episode in range(num_episodes):
    # Initialize the environment and get it's state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/53680.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

成功拿下25K测试岗,原来字节真的很好进

朋友&#xff0c;作为一个曾经从机械转行到IT的行业的过来人&#xff0c;已在IT行业工作4年&#xff0c;分享一下我的经验&#xff0c;供你参考。 讲真&#xff0c;现在想通过培训班培训几个月就进入IT行业&#xff0c;越来越来难了&#xff1b;如果是在2018年以前&#xff0c;…

MyBatis缓存-提高检索效率的利器--一级缓存

&#x1f600;前言 本篇博文是关于MyBatis一级缓存的介绍使用和缓存失效情况分析&#xff0c;希望能够帮助到您&#x1f60a; &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以帮助到大家…

Python入门【__init__ 构造方法和 __new__ 方法、类对象、类属性、类方法、静态方法、内存分析实例对象和类对象创建过程(重要)】(十四)

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱敲代码的小王&#xff0c;CSDN博客博主,Python小白 &#x1f4d5;系列专栏&#xff1a;python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 &#x1f4e7;如果文章知识点有错误…

拯救者Y9000K无线Wi-Fi有时不稳定?该如何解决?

由于不同品牌路由器的性能差异&#xff0c;无法完美兼容最新的无线网卡技术&#xff0c;在连接网络时&#xff08;特别是网络负载较大的情况下&#xff09;&#xff0c;可能会出现Wi-Fi信号断开、无法网络无法访问、延迟突然变大的情况&#xff1b;可尝试下面方法进行调整。 1…

[containerd] 在Windows上使用IDEA远程调试containerd, ctr, containerd-shim

文章目录 1. containerd安装2. 源码编译3. 验证编译的二进制文件是否含有调试需要的信息3.1. objdump工具验证3.2. file工具验证3.3. dlv工具验证 4. debug 1. containerd安装 [Ubuntu 22.04] 安装containerd 2. 源码编译 主要步骤如下&#xff1a; 1、从github下载containe…

八大排序算法--选择排序(动图理解)

选择排序 算法思路 每一次从待排序的数据元素中选出最小&#xff08;或最大&#xff09;的一个元素&#xff0c;存放在序列的起始位置&#xff0c;直到全部待排序的数据元素排完 。 选择排序的步骤&#xff1a; 1>首先在未排序序列中找到最小&#xff08;大&#xff09;元素…

pytorch深度学习快速入门

放弃个人素质 享受缺德人生 拒绝精神内耗 有事直接发疯 一、安装Anaconda 官网下载地址 选择适合的系统版本进行安装即可 安装完之后&#xff0c;可以看到下面的内容 二、使用Anaconda创建开发环境 这也是为什么要使用Anaconda的原因&#xff0c;可以创建不同的开发环境&am…

【ChatGPT辅助学Rust | 基础系列 | Hello, Rust】编写并运行第一个Rust程序

文章目录 前言一&#xff0c;创建项目二&#xff0c;两种编译方式1. 使用rustc编译器编译2. 使用Cargo编译 总结 前言 在开始学习任何一门新的编程语言时&#xff0c;都会从编写一个简单的 “Hello, World!” 程序开始。在这一章节中&#xff0c;将会介绍如何在Rust中编写并运…

微信小程序 居中、居右、居底和横向、纵向布局,文字在图片中间,网格布局

微信小程序居中、居右、横纵布局 1、水平垂直居中&#xff08;相对父类控件&#xff09;方式一&#xff1a;水平垂直居中 父类控件&#xff1a; display: flex;align-items: center;//子控件垂直居中justify-content: center;//子控件水平居中width: 100%;height: 400px //注意…

Golang Devops项目开发(1)

1.1 GO语言基础 1 初识Go语言 1.1.1 开发环境搭建 参考文档&#xff1a;《Windows Go语言环境搭建》 1.2.1 Go语言特性-垃圾回收 a. 内存自动回收&#xff0c;再也不需要开发人员管理内存 b. 开发人员专注业务实现&#xff0c;降低了心智负担 c. 只需要new分配内存&#xff0c;…

码力全开!请查收HDC.Together 2023亮点日程

<主题演讲> <技术交流与互动> <妙趣之旅> 点击关注阅读原文&#xff0c;了解更多资讯

MacOS本地安装Hadoop3

金翅大鹏盖世英&#xff0c;展翅金鹏盖世雄。 穿云燕子锡今鸽&#xff0c;踏雪无痕花云平。 ---------------- 本文密钥&#xff1a;338 ----------------- 本文描述了在macbook pro的macos上安装hadoop3的过程&#xff0c;也可以作为在任何类linux平台上安装hadoop3借鉴。 …

奥迪A3:最新款奥迪A3内饰设计及智能科技应用

奥迪A3一直以来都是奥迪的入门级车型&#xff0c;但这并不意味着它在科技和内饰方面会有所退步。最新款奥迪A3的内饰设计和智能科技应用让人们再次惊叹奥迪的创新能力。 内饰设计 奥迪A3最新款的内饰设计引入了奥迪最新的设计元素&#xff0c;比如8.8英寸的中控显示屏&#xf…

AD21 PCB设计的高级应用(四)FPGA的管脚交换功能

&#xff08;四&#xff09;FPGA的管脚交换功能 高速 PCB 设计过程中,涉及的 FPGA等可编程器件管脚繁多,也因此导致布线的烦琐与困难&#xff0c;Altium Designer 可实现 PCB 中 FPGA 的管脚交换&#xff0c;方便走线。 1.FPGA管脚交换的要求 (1)一般情况下,相同电压的 Bank之…

IT服务管理学习笔记<一>

### IT服务管理知识整理 ITSM 的核心思想是&#xff0c;IT 组织&#xff0c;不管它是企业内部的还是外部的&#xff0c;都是 IT 服务提供者&#xff0c;其 主要工作就是提供低成本、高质量的 IT 服务。 ITSM 的核心思想是&#xff0c;IT 组织&#xff0c;不管它是企业内部的还…

子域名收集工具OneForAll的安装与使用-Win

子域名收集工具OneForAll的安装与使用-Win OneForAll是一款功能强大的子域名收集工具 GitHub地址&#xff1a;https://github.com/shmilylty/OneForAll Gitee地址&#xff1a;https://gitee.com/shmilylty/OneForAll 安装 1、python环境准备 OneForAll基于Python 3.6.0开发和…

适配器模式——不兼容结构的协调

1、简介 1.1、概述 有的笔记本电脑的工作电压是20V&#xff0c;而我国的家庭用电是220V&#xff0c;如何让20V的笔记本电脑能够在220V的电压下工作&#xff1f;答案是引入一个电源适配器&#xff08;AC Adapter&#xff09;&#xff0c;俗称充电器&#xff0f;变压器。有了这…

[PAT乙级] 1029 旧键盘 C++实现

题目描述&#xff1a; 旧键盘上坏了几个键&#xff0c;于是在敲一段文字的时候&#xff0c;对应的字符就不会出现。现在给出应该输入的一段文字、以及实际被输入的文字&#xff0c;请你列出肯定坏掉的那些键。 输入格式&#xff1a; 输入在 2 行中分别给出应该输入的文字、以…

Java生成二维码——附Utils工具类

参加2023年的计算机设计大赛国赛&#xff0c;拿到了一等奖。 现在将项目中的工具类代码剥离出来&#xff0c;方便之后项目开发中复用。 实现效果&#xff1a; 代码实现&#xff1a; import com.google.zxing.BarcodeFormat; import com.google.zxing.EncodeHintType; import c…

《MySQL 实战 45 讲》课程学习笔记(二)

日志系统&#xff1a;一条 SQL 更新语句是如何执行的&#xff1f; 与查询流程不一样的是&#xff0c;更新流程还涉及两个重要的日志模块&#xff1a;redo log&#xff08;重做日志&#xff09;和 binlog&#xff08;归档日志&#xff09;。 重要的日志模块&#xff1a;redo l…