07、基于LunarLander登陆器的强化学习案例(含PYTHON工程)

07、基于LunarLander登陆器的强化学习(含PYTHON工程)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。全部工程可从最上方链接下载。

基于TENSORFLOW2.10

0、实践背景

gym的LunarLander是一个用于强化学习的经典环境。在这个环境中,智能体(agent)需要控制一个航天器在月球表面上着陆。航天器的动作包括向上推进、不进行任何操作、向左推进或向右推进。环境的状态包括航天器的位置、速度、方向、是否接触到地面或月球上空等。

智能体的任务是在一定的时间内通过选择正确的动作使航天器安全着陆,并且尽可能地消耗较少的燃料。如果航天器着陆时速度过快或者与地面碰撞,任务就会失败。智能体需要通过不断地尝试和学习来选择最优的动作序列,以完成这个任务。

下面是训练的结果:
在这里插入图片描述

1、实现原理

1.1 强化学习

强化学习实现原理主要包括以下几个方面:

智能体与环境交互:强化学习中的智能体(agent)通过与环境不断地进行交互,学习一个从环境到动作的映射,学习的目标就是使累计回报最大化。
试错学习:强化学习是一种试错学习,智能体需要在各种状态(环境)下尝试所有可以选择的动作,通过环境给出的反馈(即奖励)来判断动作的优劣,最终获得环境和最优动作的映射关系(即策略)。
奖励函数与策略更新:强化学习算法的核心在于定义奖励函数,并通过不断迭代来更新策略,从而实现最优化的决策。
状态获取:智能体需要通过传感器等手段获取当前环境的状态信息,如图像、声音等。

1.2 软更新

软更新(Soft Updates)技术是一种在强化学习中常用的技术,特别是在Q-learning算法中。该技术的主要目的是提高学习过程的稳定性。

在强化学习中,我们通常有一个主要的网络(如Q-network)来学习并更新其权重。然而,如果我们直接使用这个网络来估计Q值并选择动作,同时也在每个步骤中更新其权重,这可能会导致学习过程的不稳定。因为网络权重的连续变化会导致Q值的波动,从而使得学习策略变得不一致。

为了解决这个问题,软更新技术被引入。其基本思想是创建一个额外的网络,通常被称为目标网络(Target Network),该网络的结构与主要网络相同,但其权重的更新是缓慢的,即它不会在每个步骤中都进行更新。相反,目标网络的权重会在主要网络经过一定数量的步骤或达到一定的条件后才进行更新。这通常是通过将主要网络的权重与目标网络的权重进行某种形式的平均来实现的。

由于目标网络的权重更新是缓慢的,因此它提供的Q值估计更为稳定。这有助于使学习过程更加稳定,因为即使主要网络的权重发生显著变化,目标网络的权重也只会有较小的变化,从而减少了Q值的波动:

1.3 贪婪策略

训练时,每一步并不完全采用最优行为,有一定可能尝试新的动作:

def get_action(q_values, epsilon=0):
    if random.random() > epsilon:
        return np.argmax(q_values.numpy()[0])
    else:
        return random.choice(np.arange(4))

2、强化学习实现步骤

2.1、导入相关机器学习使用的包
# 导入时间处理库  
import time  
# 从collections模块导入双端队列和命名元组  
from collections import deque, namedtuple  
# 导入用于开发和比较强化学习算法的库  
import gym  
# 导入数值计算库,以np作为别名  
import numpy as np  
# 导入Python图像处理库中的Image模块  
import PIL.Image  
# 导入机器学习框架  
import tensorflow as tf  
# 导入自定义的Lunar Lander工具库  
import Lunar_Lander_utils  
# 从Keras库导入顺序模型类  
from keras import Sequential  
# 从Keras层模块导入全连接层和输入层类  
from keras.layers import Dense, Input  
# 从Keras损失模块导入均方误差损失函数  
from keras.losses import MSE  
# 从Keras优化器模块导入Adam优化器  
from keras.optimizers import Adam
2.2、LunarLander登陆器环境加载

在gym库中的使用指导可以参考:LunarLander

我们关注的是可以从这个交互接口中得到什么和控制什么,对于此处的登陆器,我们关注可以得到它的哪些状态和对其进行那些操作
在这里插入图片描述
依据官方手册,存在四种可用的离散动作:不执行任何操作、启动左方向引擎、启动主引擎、启动右方向引擎。能够得到的状态是一个8维向量,包括着陆器在x和y方向上的坐标、x和y方向上的线速度、角度、角速度,以及两个布尔值,表示每个着陆腿是否与地面接触。

# 使用gym库创建一个名为'LunarLander-v2'的环境,并设置渲染模式为'rgb_array'  
# 'rgb_array'模式返回一个numpy数组,表示环境的RGB图像  
env = gym.make('LunarLander-v2', render_mode='rgb_array')  
  
# 重置环境到初始状态,并返回初始状态  
env.reset()  
  
# 使用PIL库(Python Imaging Library)从环境的渲染数组创建一个图像  
PIL.Image.fromarray(env.render())  
  
# 获取观测空间(状态)的尺寸,这是一个8维向量  
state_size = env.observation_space.shape  
  
# 获取动作空间的数量,这表示有多少种可能的离散动作可以选择  
num_actions = env.action_space.n  
  
# 打印状态空间和动作空间的信息  
print('State Shape:', state_size)  
print('Number of actions:', num_actions)  
2.3、创建神经网络结构-使用软更新
# 创建一个名为Q-Network的神经网络  
q_network = Sequential([
    Input(shape=state_size),  # 输入层,形状由state_size定义  
    Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  
    Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  
    Dense(units=num_actions, activation='linear'),  # 输出层,单元数由num_actions定义,使用线性激活函数  
])

# 这里是软更新的网络(Target Q-Network)  
target_q_network = Sequential([
    Input(shape=state_size),  # 输入层,形状由state_size定义  
    Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  
    Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  
    Dense(units=num_actions, activation='linear'),  # 输出层,单元数由num_actions定义,使用线性激活函数  
])

2.4、强化学习的误差计算与梯度下降

首先是误差计算的函数,这边的Q-learning算法类似于一种迭代算法,
在这里插入图片描述
这就好像我们在高中学习的数组题目中,已经知道了an和an+1的关系式,去求解详细的an的表达式。此处误差计算的代码如下(值得注意的是,下一步的回报Q(s’,a’)是使用Target Q-Network计算的,而当前步的是使用Q-Network网络计算的):

def compute_loss(experiences, gamma, q_network, target_q_network):  
    """  
    计算损失函数。  
  
    参数:  
      experiences: 一个包含["state", "action", "reward", "next_state", "done"]的namedtuples的元组  
      gamma: (浮点数) 折扣因子。  
      q_network: (tf.keras.Sequential) 用于预测q_values的Keras模型  
      target_q_network: (tf.keras.Sequential) 用于预测目标的Keras模型  
  
    返回:  
      loss: (TensorFlow Tensor(shape=(0,), dtype=int32)) y目标与Q(s,a)值之间的均方误差。  
    """  
  
    # 解压经验元组的小批量数据  
    states, actions, rewards, next_states, done_vals = experiences  
  
    # 计算最大的Q^(s,a),reduce_max用于求最大值  
    max_qsa = tf.reduce_max(target_q_network(next_states), axis=-1)  
  
    # 如果回合结束,设置y = R,否则设置y = R + γ max Q^(s,a)。  
    y_targets = rewards + (gamma * max_qsa * (1 - done_vals))  
  
    # 获取q_values  
    q_values = q_network(states)  
    q_values = tf.gather_nd(q_values, tf.stack([tf.range(q_values.shape[0]),  
                                                tf.cast(actions, tf.int32)], axis=1))  
  
    # 计算损失  
    loss = MSE(y_targets, q_values)  
  
    return loss

学习算法的定义如下所示,使用了软更新技术:


def agent_learn(experiences, gamma):
    """  
    更新Q网络的权重。  

    参数:  
      experiences: 一个包含["state", "action", "reward", "next_state", "done"]的namedtuples的元组  
      gamma: (浮点数) 折扣因子。  

    """
    # 使用tf.GradientTape()来计算损失相对于权重的梯度  
    with tf.GradientTape() as tape:
        # 调用compute_loss函数计算损失  
        loss = compute_loss(experiences, gamma, q_network, target_q_network)

        # 使用GradientTape计算损失相对于q_network的可训练变量的梯度  
    gradients = tape.gradient(loss, q_network.trainable_variables)

    # 使用优化器应用梯度,从而更新q_network的权重  
    optimizer.apply_gradients(zip(gradients, q_network.trainable_variables))

    # 使用软更新技术将q_network的权重更新至target_q_network  
    Lunar_Lander_utils.update_target_network(q_network, target_q_network)

Lunar_Lander_utils.update_target_network(q_network, target_q_network)是软更新的关键所在:

def update_target_network(q_network, target_q_network):
    for target_weights, q_net_weights in zip(target_q_network.weights, q_network.weights):
        target_weights.assign(TAU * q_net_weights + (1.0 - TAU) * target_weights)
2.5、强化学习的训练过程

在这里插入图片描述

# 重置环境至初始状态并获得初始状态  
state,_ = env.reset()  
total_points = 0  
  
# 这里进行一次模拟,最多运行max_num_timesteps个时间步  
for t in range(max_num_timesteps):  
    # 从当前状态S使用ε-贪婪策略选择一个动作A  
    # 从元组中提取NumPy数组  
    # (注:这部分代码被注释掉了,所以下面的state_array并不会实际运行)  
    # if state[0].shape == ():  
    #     state_array = state  
    # else:  
    #     state_array = state[0]  
    # 将state_array转换为NumPy数组  
    state_qn = np.expand_dims(state, axis=0)  
    # 得到每个动作的回报数值,是一个1x4的数组,分别表示4个action的回报  
    q_values = q_network(state_qn)  
    # 此处实行贪婪策略,从当前最优action和随机action中选择  
    action = Lunar_Lander_utils.get_action(q_values, epsilon)  
  
    # 执行上述动作后得到的新状态、奖励、是否完成等信息  
    next_state, reward, done, _, _ = env.step(action)  
  
    # 将经验元组(S,A,R,S')存储在记忆缓冲区中  
    # 使用memory存储历史数据  
    memory_buffer.append(experience(state, action, reward, next_state, done))  
  
    # 只在特定的时间步进行更新  
    update = Lunar_Lander_utils.check_update_conditions(t, NUM_STEPS_FOR_UPDATE, memory_buffer)  
  
    if update:  
        # 从D中随机抽取小批量的经验元组(S,A,R,S')  
        # 只随机取MINIBATCH_SIZE个数据进行一次训练  
        experiences = Lunar_Lander_utils.get_experiences(memory_buffer)  
  
        # 设置y目标,执行梯度下降步骤,并更新网络权重  
        agent_learn(experiences, GAMMA)  
  
    state = next_state.copy()  
    total_points += reward  
  
    if done:  
        break  
  
# 将本次总得分添加到历史得分中  
total_point_history.append(total_points)  
# 计算最近num_p_av次得分的平均值  
av_latest_points = np.mean(total_point_history[-num_p_av:])  
  
# 更新ε值  
epsilon = Lunar_Lander_utils.get_new_eps(epsilon)

3、LunarLander文件解释

Lunar_Lander.py:运行此文件进行训练
lunar_lander_model.h5:Lunar_Lander.py训练得到的模型文件
Lunar_Lander_test.py:此文件调用h5模型并运行模拟器,将数据打包成视频格式,视频位于Lunar_Lander_videos文件夹
Lunar_Lander_utils.py:函数库

注意:运行Lunar_Lander_test.py出现长时间(大于20s)无返回0的情况,需要重新运行。这是因为LunarLander一直悬浮在空中了(相当于直升机了)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/213138.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【论文 | 联邦学习】 | Towards Personalized Federated Learning 走向个性化的联邦学习

Towards Personalized Federated Learning 标题:Towards Personalized Federated Learning 收录于:IEEE Transactions on Neural Networks and Learning Systems (Mar 28, 2022) 作者单位:NTU,Alibaba Group,SDU&…

【设计模式-4.1】行为型——观察者模式

说明:本文介绍设计模式中行为型设计模式中的,观察者模式; 商家与顾客 观察者模式属于行为型设计模式,关注对象的行为。以商家与顾客为例,商家有商品,顾客来购买商品,如果商家商品卖完了&#…

go语言学习-并发编程(并发并行、线程协程、通道channel)

1、 概念 1.1 并发和并行 并发:具有处理多个任务的能力 (是一个处理器在处理任务),cpu处理不同的任务会有时间错位,比如有A B 两个任务,某一时间段内在处理A任务,这时A任务需要停止运行一段时间,那么会切换到处理B任…

DockerFile常用保留字指令及知识点合集

目录 DockerFile加深理解&#xff1a; DockerFile常用保留字指令 保留字&#xff1a; RUN&#xff1a;容器构建时需要运行的命令 COPY&#xff1a;类似ADD&#xff0c;拷贝文件和目录到镜像中。 将从构建上下文目录中 <源路径> 的文件/目录复制到新的一层的镜像内的 …

【动态规划】LeetCode-面试题 17.16. 按摩师

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…

vs 安装 qt qt扩展

1 安装qt 社区版 免费 Download Qt OSS: Get Qt Online Installer 2 vs安装 qt vs tools 3 vs添加 qt添加 bin/cmake.exe 路径 3.1 扩展 -> qt versions 3.2

【STM32】STM32学习笔记-新建工程(04)

00. 目录 文章目录 00. 目录01. 创建STM32工程02. STM32工程编译和下载03. LED测试04. 型号分类及缩写05. 工程结构06. 附录 01. 创建STM32工程 【STM32】STM32F103C8T6 创建工程模版详解(固件库) 02. STM32工程编译和下载 2.1 选择下载器位ST-Link Debugger 2.2 勾选上电…

04. 函数

目录 1、前言 2、Python中的函数 2.1、内置函数 2.2、自定义函数 2.3、函数调用 3、函数的参数 3.1、形参和实参 3.2、位置参数&#xff08;Positional Arguments&#xff09; 3.3、默认参数&#xff08;Default Arguments&#xff09;&#xff1a; 3.4、关键字参数&a…

如何为C#WinFrom编译的.exe添加个性化图标

1、在VS中点击菜单栏上的“项目”,找到最下面的属性&#xff0c;单击进去 2、加载自定义的.ico文件&#xff0c;如果没有此格式的文件可以使用此网站去转换&#xff1a;图标制作大师 - 轻松制作网站favicon图标 3、重新编译文件即可

【【水 MicroBlaze 最后的介绍和使用】】

水 MicroBlaze 最后的介绍和使用 我对MicroBlaze 已经有了一个普遍的理解 了 现在我将看的两个 一个是 AXI4接口的 DDR读写实验 还有一个是 AXI DMA 环路实验 虽然是 水文 但是 也许能从中 得到一些收获 第一个是 AXI DDR 读写实验 Xilinx 从 Spartan-6 和 Virtex-6 系列开始…

SSM框架(六):SpringBoot技术及整合SSM

文章目录 一、概述1.1 简介1.2 起步依赖1.3 入门案例1.4 快速启动 二、基础配置2.1 三种配置文件方式2.2 yaml文件格式2.3 yaml读取数据方式&#xff08;3种&#xff09; 三、多环境开发3.1 yml文件-多环境开发3.2 properties文件-多环境开发3.3 多环境命令行启动参数设置3.4 多…

【数值计算方法(黄明游)】函数插值与曲线拟合(一):Lagrange插值【理论到程序】

​ 文章目录 一、近似表达方式1. 插值&#xff08;Interpolation&#xff09;2. 拟合&#xff08;Fitting&#xff09;3. 投影&#xff08;Projection&#xff09; 二、Lagrange插值1. 天书1. 人话拉格朗日插值方法a. 线性插值&#xff08;n1&#xff09;基本思想线性插值与线…

解决uview中uni-popup弹出层不能设置高度问题

开发场景&#xff1a;点击条件筛选按钮&#xff0c;在弹出的popup框中让用户选择条件进行筛选 但是在iphone12/13pro展示是正常&#xff0c;但是切换至其他手机型号就填充满了整个屏幕&#xff0c;需要给这个弹窗设置一个固定的高度 iphone12/13pro与其他型号手机对比 一开始…

智能优化算法应用:基于海洋捕食者算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于海洋捕食者算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于海洋捕食者算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.海洋捕食者算法4.实验参数设定5.算法结果…

工业机器视觉megauging(向光有光)使用说明书(四,轻量级的visionpro)

第三个相机的添加&#xff0c;突然发现需要补充一下&#xff1a; 第一步&#xff0c;假定你对c#编程懂一点&#xff0c;我们添加了一个页面“相机三”在tabcontrol1&#xff1a; 第二步&#xff0c;添加dll到工具箱&#xff1a; 第三步&#xff0c;点击‘浏览’&#xff0c;找…

Web前端JS如何获取 Video/Audio 视音频声道(左右声道|多声道)、视音频轨道、音频流数据

写在前面&#xff1a; 根据Web项目开发需求&#xff0c;需要在H5页面中&#xff0c;通过点击视频列表页中的任意视频进入视频详情页&#xff0c;然后根据视频的链接地址&#xff0c;主要是 .mp4 文件格式&#xff0c;在进行播放时实时的显示该视频的音频轨道情况&#xff0c;并…

Fiddler抓包工具之Fiddler+willow插件应用

安装Fiddler的安装包地址&#xff1a;fillderwillow 解压后安装fiddler4和willow1.4.*版本。 安装成功后&#xff0c;启动fiddler后会出现willow插件按钮&#xff1a; 说明安装成功。 重定向 willow重定向 进入willow界面后&#xff0c;通过右键->Add Project ->Add Ru…

canvas基础:fillStyle 和strokeStyle示例

canvas实例应用100 专栏提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。 canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重要的帮助。 文章目录 上色…

Spring Task 超详解版

目录 一、定时任务的理解 二、入门案例 三、Cron表达式 四、Cron实战案例 五、多线程案例 一、定时任务的理解 定时任务即系统在特定时间执行一段代码&#xff0c;它的场景应用非常广泛&#xff1a; 购买游戏的月卡会员后&#xff0c;系统每天给会员发放游戏资源。管理系…

基于姿态估计的3D动画生成

在本文中&#xff0c;我们将尝试通过跟踪 2D 视频中的动作来渲染人物的 3D 动画。 在 3D 图形中制作人物动画需要大量的运动跟踪器来跟踪人物的动作&#xff0c;并且还需要时间手动制作每个肢体的动画。 我们的目标是提供一种节省时间的方法来完成同样的任务。 我们对这个问题…
最新文章