【强化学习】常用算法之一 “A3C”

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        强化学习是一种机器学习的方法,旨在通过与环境进行交互学习来最大化累积奖励。强化学习研究的核心问题是“智能体(agent)在不断与环境交互的过程中如何选择行为以最大化奖励”。其中,A3C算法(Asynchronous Advantage Actor-Critic)是一种基于策略梯度的强化学习方法,通过多个智能体的异步训练来实现快速而稳定的学习效果。

本文将详细讲解强化学习常用算法之一“A3C”


目录

一、A3C算法的简介

二、A3C算法的发展历程

三、A3C算法的公式和原理讲解

        1. A3C算法的公式

        2. A3C算法的原理

四、A3C算法的功能

五、A3C算法的示例代码

        分解代码

        完整代码

六、总结


一、A3C算法的简介

        A3C(Asynchronous Advantage Actor-Critic)算法是一种在强化学习领域中应用广泛的算法,它结合了策略梯度方法和价值函数的学习,用于近似解决马尔可夫决策过程(Markov Decision Process)问题。A3C算法在近年来备受关注,因为它在处理大规模连续动作空间和高维状态空间方面具有出色的性能。 

二、A3C算法的发展历程

        A3C算法是对DQN(Deep Q Network)算法在强化学习领域的一个重要延伸和改进。DQN算法在2013年被DeepMind团队首次提出,并在很多任务上取得了令人瞩目的效果。然而,DQN算法在处理连续动作空间、高维状态空间等复杂问题上面临着困难。为了解决这些问题,研究人员开始关注基于策略梯度的方法,并提出了A3C算法。

三、A3C算法的公式和原理讲解

        1. A3C算法的公式

        A3C算法的目标是最大化累积奖励,将这一目标表示为优化问题,可以用如下的公式表示:

L(θ) = -E[logπ(a|s;θ)A(s,a)]

        其中,L(θ)表示损失函数,θ表示模型参数,π(a|s;θ)表示在状态s下选择动作a的概率,A(s,a)表示在状态s选择动作a相对于平均回报的优势函数。A3C算法的优化目标是最小化损失函数L(θ)

        2. A3C算法的原理

        A3C算法采用Actor-Critic结构,由Actor和Critic两个网络组成。Actor网络的目标是学习策略函数,即在给定状态下选择动作的概率分布。Critic网络的目标是学习状态值函数或者状态-动作值函数,用于评估不同状态或状态-动作对的价值。

        A3C算法的训练过程可以分为以下几个步骤:

  • 初始化神经网络参数。
  • 创建多个并行的训练线程,每个线程独立运行一个智能体与环境交互,并使用Actor和Critic网络实现策略和价值的近似。
  • 每个线程根据当前的策略网络选择动作,并观测到新的状态和奖励,将这些信息存储在经验回放缓冲区中。
  • 当一个线程达到一定的时间步数或者轨迹结束时,该线程将经验回放缓冲区中的数据抽样出来,并通过计算优势函数进行梯度更新。
  • 每个线程进行一定次数的梯度更新后,将更新的参数传递给主线程进行整体参数更新。
  • 重复上述步骤直到达到预定的训练轮次或者达到终止条件为止。

        A3C算法采用了Asynchronous(异步)的训练方式,每个线程独立地与环境交互,并通过参数共享来实现梯度更新。这种异步训练的方式可以提高训练的效率和稳定性,并且能够学习到更好的策略和价值函数。

四、A3C算法的功能

        A3C算法具有以下功能和特点:

  • 支持连续动作空间和高维状态空间的强化学习;
  • 通过多个并行的智能体实现快速而稳定的训练;
  • 利用Actor和Critic两个网络分别学习策略和价值函数,具有更好的学习效果和收敛性;
  • 通过异步训练的方式提高了训练的效率和稳定性。

五、A3C算法的示例代码

        下面是一个简单的A3C算法的示例代码

        分解代码

        首先,导入需要的库和模块:

import gym
import torch
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F

        然后定义A3C算法中的Actor和Critic网络: 

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.actor = nn.Linear(128, 2)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

        定义ActorCritic类来实现Actor和Critic网络的结构。

        然后定义训练函数,该函数将使用A3C算法进行智能体的训练:

def train(global_model, rank):
    env = gym.make('CartPole-v1')
    model = ActorCritic()
    model.load_state_dict(global_model.state_dict())
    optimizer = optim.Adam(global_model.parameters(), lr=0.01)
    torch.manual_seed(123 + rank)
    max_episode_length = 1000

    for episode in range(2000):
        state = env.reset()
        done = False
        episode_length = 0
        while not done and episode_length < max_episode_length:
            episode_length += 1
            state = torch.from_numpy(state).float()
            action_probs, _ = model(state)
            dist = Categorical(action_probs)
            action = dist.sample()
            next_state, reward, done, _ = env.step(action.item())

            if done:
                reward = -1

            next_state = torch.from_numpy(next_state).float()
            next_state_value = model(next_state)[-1]
            model_value = model(state)[-1]

            delta = reward + (0.99 * next_state_value * (1 - int(done))) - model_value

            actor_loss = -dist.log_prob(action) * delta.detach()
            critic_loss = delta.pow(2)

            loss = actor_loss + critic_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            state = next_state.numpy()

        global_model.load_state_dict(model.state_dict())

        在训练函数中,每个进程都会拷贝全局模型,并在自己的进程中进行模型的训练。训练过程中,智能体使用Actor模型根据当前状态选择动作,然后与环境进行交互,得到下一状态和奖励。根据奖励和下一状态的估值更新网络参数,得到一个损失函数。然后使用反向传播算法更新网络参数。

        最后,运行训练函数:

if __name__ == '__main__':
    num_processes = 4
    global_model = ActorCritic()
    global_model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(global_model, rank,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

        完整代码

# -*- coding: utf-8 -*-
import gym
import torch
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda")

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.actor = nn.Linear(128, 2)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

def train(global_model, rank):
    env = gym.make('CartPole-v1')
    model = ActorCritic().to(device)
    model.load_state_dict(global_model.state_dict())
    optimizer = optim.Adam(global_model.parameters(), lr=0.01)
    torch.manual_seed(123 + rank)
    max_episode_length = 1000

    for episode in range(2000):
        state = env.reset()
        done = False
        episode_length = 0
        while not done and episode_length < max_episode_length:
            episode_length += 1
            state = torch.from_numpy(state).float().to(device)
            action_probs, _ = model(state)
            dist = Categorical(action_probs)
            action = dist.sample().to(device)
            next_state, reward, done, _ = env.step(action.item())

            if done:
                reward = -1

            next_state = torch.from_numpy(next_state).float().to(device)
            next_state_value = model(next_state)[-1].cpu()
            model_value = model(state)[-1].cpu()

            delta = reward + (0.99 * next_state_value * (1 - int(done))) - model_value

            actor_loss = -dist.log_prob(action) * delta.detach().to(device)
            critic_loss = delta.pow(2).to(device)

            loss = actor_loss + critic_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            state = next_state.cpu().numpy()

        global_model.load_state_dict(model.state_dict())

if __name__ == '__main__':
    num_processes = 4
    global_model = ActorCritic()
    global_model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(global_model, rank,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

        因为分解代码是cpu运行...的,速度很慢,所以我在完整代码里加了cuda去运行,但是我的电脑...唉,自己看吧。。。

        上述代码使用了OpenAI Gym中的’CartPole-v1’环境作为示例,通过A3C算法训练智能体在该环境中尽可能长时间地保持杆的平衡。 

六、总结

        A3C算法是一种基于策略梯度的强化学习算法,通过多个并行的智能体异步地与环境交互,并利用Actor和Critic网络实现策略和价值的近似,从而实现快速而稳定的强化学习训练。A3C算法具有良好的学习效果和收敛性,并且适用于处理连续动作空间和高维状态空间的问题。本文讲解了A3C算法的介绍、详细讲解其发展历程、算法公式、原理、功能和示例代码,希望能让读者对A3C算法有了更加深入和全面的理解。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/32898.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

VUE3-组件问题

VUE3-组件问题 文章目录 VUE3-组件问题一、S-Table1.问题描述2.问题展示3.问题解决 二、form表单无法显示1.问题描述2.问题展示3.问题解决 三、input 框为不可编辑状态四、Echarts组件未渲染五、图片正常引用&#xff0c;但是部署服务器部署不上去&#xff0c;看不到图片1.图片…

element封装 table表格 ,插槽的使用,修改el-table-column的值

举例 vue2这种不封装的 直接写的很罗嗦麻烦 下面圈起来的可以封装一个对象 进行循环 弊端: 循环后 无法进行获取更改某一列的值 比如data日期我需要转换年月日 不循环我直接在这个el-table-column的这一列进行写&#xff08;如下&#xff09; <el-table-column label&quo…

Vue3解决:[Vue warn]: Failed to resolve component: el-table(或el-button) 的三种解决方案

1、问题描述&#xff1a; 其一、报错为&#xff1a; [Vue warn]: Failed to resolve component: el-table If this is a native custom element, make sure to exclude it from component resolution via compilerOptions.isCustomElement. at <App> 或者&#xff1a; …

实验 4:排序与查找

东莞理工的学生可以借鉴&#xff0c;请勿抄袭 1.实验目的 通过实验达到&#xff1a; 理解典型排序的基本思想&#xff0c;掌握典型排序方法的思想和相应实现算法&#xff1b; 理解和掌握用二叉排序树(BST)实现动态查找的基本思想和相应的实现 算法。 理解和掌握哈希(HASH)存…

【Django】图形验证码显示及验证

图形验证码显示及验证 开发项目时&#xff0c;在登陆或者注册功能中为防止脚本恶意破坏系统&#xff0c;故而添加图形验证码。 文章目录 图形验证码显示及验证1 安装配置2 验证码显示及正确性验证3 效果显示 1 安装配置 安装第三方库 pip install django-simple-captcha配置s…

《计算机系统与网络安全》 第六章 密钥管理

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

mysql8.0新特性详解

一、my.ini或my.cnf的全局参数 一个连接最少占用内存是256K&#xff0c;最大是64M&#xff0c;如果一个连接的请求数据超过64MB&#xff08;比如排序&#xff09;&#xff0c;就会申请临时空间&#xff0c;放到硬盘上。 #最大连接数 max_connections3000 #最大用户连接数 max_…

RS485转Profinet通讯

RS485转Profinet通讯 概述系统组成流量积算仪网关 软件总结 概述 一个支持RS485的流量积算仪的数据要被Profinet的PLC读取。制作一个网关&#xff0c;实现RS485到Profinet的转换。 系统组成 流量积算仪 支持RS485通讯&#xff0c;通讯协议是modbus RTU。采用功能码3可以读取…

ChatGPT从入门到精通,深入认识Prompt

ChatGPT从入门到精通&#xff0c;一站式掌握办公自动化/爬虫/数据分析和可视化图表制作 全面AI时代就在转角 道路已经铺好了 “局外人”or“先行者” 就在此刻 等你决定 让ChatGPT帮你高效实现职场办公&#xff01;行动起来吧。欢迎关注专栏 1、ChatGPT从入门到精通&#xff0…

【SQL Server】数据库开发指南(八)高级数据处理技术 MS-SQL 事务、异常和游标的深入研究

本系列博文还在更新中&#xff0c;收录在专栏&#xff1a;#MS-SQL Server 专栏中。 本系列文章列表如下&#xff1a; 【SQL Server】 Linux 运维下对 SQL Server 进行安装、升级、回滚、卸载操作 【SQL Server】数据库开发指南&#xff08;一&#xff09;数据库设计的核心概念…

【Unity 实用插件篇】 | UI适配神器 Device Simulator 移动设备模拟器 的详细使用方法

前言 【Unity 实用插件篇】 UI适配神器 Device Simulator 移动设备模拟器 的详细使用方法一、安装Device Simulator包二、使用Device Simulator模拟各种设备三、自定义设备类型信息 总结 &#x1f3ac; 博客主页&#xff1a;https://xiaoy.blog.csdn.net &#x1f3a5; 本文由…

错误C2039:‘退出‘:不是‘`全局名称空间‘的成员

问题 VC\Tools\MSVC\14.27.29110\include 目录里的cstdint文件的内容 原因 一种典型的Microsoft产品错误. 解决 运行 点击修复

Windows 驱动开发环境搭建

Windows 驱动开发环境搭建及 windbg 调试工具安装使用 引言了解 Windows 驱动开发环境下载 Windows 驱动开发环境根据需要下载安装对应版本的 Visual Studio下载安装对应的 WDK 工具包 编写第一个驱动代码总结参考资料 引言 对于 Windows 驱动开发&#xff0c;在微软官方的文档…

windows 下安装 mysql-8.0.25 解压版

介绍 此文介绍 mysql-8.0.25-winx64 的 zip 解压版&#xff0c;在 windows 下的安装与配置过程。 官方下载 官网下载页&#xff1a; https://downloads.mysql.com/archives/community/ 进入官网&#xff0c;选择默认版本就行&#xff0c;不需要包含测试工具套件的版本 本地解…

【spring源码系列-03】xml配置文件启动spring时refresh的前置工作

Spring源码系列整体栏目 内容链接地址【一】spring源码整体概述https://blog.csdn.net/zhenghuishengq/article/details/130940885【二】通过refresh方法剖析IOC的整体流程https://blog.csdn.net/zhenghuishengq/article/details/131003428【三】xml配置文件启动spring时refres…

[RocketMQ] Consumer消费者启动主要流程源码 (六)

客户端常用的消费者类是DefaultMQPushConsumer, DefaultMQPushConsumer的构造器以及start方法的源码。 1.创建DefaultMQPushConsumer实例 最终都是调用下面四个参数的构造函数: /*** 创建DefaultMQPushConsumer实例** param namespace namespace地址* par…

两两交换链表中的节点(LeetCode 24)

题目 24. 两两交换链表中的节点 思路 最开始自己画&#xff0c;越画越复杂比较复杂&#xff0c;写不出来&#xff01;&#xff08;呜呜&#xff09;去看了解题思路&#xff0c;发现只需要三步。&#xff0c;按以下思路写了代码&#xff0c;循环停止那里的条件我还以有更好的写…

【Docker】Docker Desktop更换非C盘符(减轻占用率)

Win10中的Docker Desktop调整到其他盘符&#xff0c;由于新版本已经不让修改软连接了&#xff0c;只好另谋策略&#xff0c;最终还是改成功了。 出现问题 使用软连接修改 上面代码我们可以科幻的理解一下 幻想破灭 //TODO 用户点击执行安装 if(检查文件夹是否软连接){有则&a…

虚拟机中Ubuntu 22上传框被黑框包裹的解决方法

虚拟机中Ubuntu 22上传框被黑框包裹的解决方法 现象解决方法 现象 在vm17下的ubuntu22使用上传表单时出现了这种不和谐的现象&#xff0c;被领导批评一通。最后费劲心思&#xff0c;找到了这个问题的解决方法。 解决方法 解决方法特别容易&#xff0c;在虚拟机的设置中&…

模型实战(13)之YOLOv8实现手语字母检测与识别+权重分享

YOLOv8实现手语字母检测与识别+权重分享 本文借助yolov8 实现手语字母的检测与识别:先检测手的ROI,进而对手语表达的字母含义进行识别全文将从环境搭建、模型训练及预测来展开对整个算法流程进行讲解文中给出了开源数据集链接及从 Roboflow 上的下载教程实现效果如下: 1. 环…