ChatGPT的强化学习部分介绍——PPO算法实战LunarLander-v2

PPO算法

近线策略优化算法(Proximal Policy Optimization Algorithms) 即属于AC框架下的算法,在采样策略梯度算法训练方法的同时,重复利用历史采样的数据进行网络参数更新,提升了策略梯度方法的学习效率。 PPO重要的突破就是在于对新旧新旧策略器参数进行了约束,希望新的策略网络和旧策略网络的越接近越好。 近线策略优化的意思就是:新策略网络要利用到旧策略网络采样的数据进行学习,不希望这两个策略相差特别大,否则就会学偏。PPO依然是openai在2017年提出来的,论文地址

PPO-clip的损失函数,其中损失包含三个部分:

  • clip部分:加权采样后的优势值越好(可理解层价值网络的评分),同时采样通过clip防止新旧策略网络相差过大
  • VF部分: 价值网络预测的价值和环境真是的回报值越接近越好
  • S 部分:策略网络输出策略的熵值,越大越好,这个explore的思想,希望策略网络输出的动作分布概率不要太集中,提高了每个动作都有机会在环境中发生的可能。

实战部分

导入必要的包

%matplotlib inline
import matplotlib.pyplot as plt

from IPython import display

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm.notebook import tqdm
准备环境

准备好openai开发的LunarLander-v2的游戏环境。

seed = 543
def fix(env, seed):
    env.action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
import gym
import random
env = gym.make('LunarLander-v2' ,render_mode='rgb_array')
fix(env, seed) # fix the environment Do not revise this !!!

下面是采用代码去输出 环境的 观测值,一个8维向量,动作是一个标量,4选一。

  • 该环境共有 8 个观测值,分别是: 水平坐标 x; 垂直坐标 y; 水平速度; 垂直速度; 角度; 角速度; 腿1触地; 腿2触地;
  • 可以采取四种离散的行动,分别是: 0 代表不采取任何行动 1.代表主引擎向左喷射 2 .代表主引擎向下喷射 3 .代表主引擎向右喷射
  • 环境中的 reward 大致是这样计算: 小艇坠毁得到 -100 分; 小艇在黄旗帜之间成功着地则得 100~140 分; 喷射主引擎(向下喷火)每次 -0.3 分; 小艇最终完全静止则再得 100 分

在这里插入图片描述

随机动作玩5把

env.reset()

img = plt.imshow(env.render())

done = False
rewords = []
for i in range(5):
    env.reset()[0]
    img = plt.imshow(env.render())
    total_reward = 0
    done = False
    while not done:
        action = env.action_space.sample()
        observation, reward, done, _ , _= env.step(action)
        total_reward += reward
        img.set_data(env.render())
        display.display(plt.gcf())
        display.clear_output(wait=True)
    rewords.append(total_reward)

在这里插入图片描述

只有一把分数是正的,只有一次平安落地,小艇最终完全静止
在这里插入图片描述

搭建PPO agent

class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

class ActorCriticDiscrete(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCriticDiscrete, self).__init__()

        # actor
        self.action_layer = nn.Sequential(
                nn.Linear(state_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
                )

        # critic
        self.value_layer = nn.Sequential(
               nn.Linear(state_dim, 128),
               nn.ReLU(),
               nn.Linear(128, 64),
               nn.ReLU(),
               nn.Linear(64, 1)
               )

    def act(self, state, memory):
        state = torch.from_numpy(state).float()
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)
        action = dist.sample()

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(dist.log_prob(action))

        return action.item()

    def evaluate(self, state, action):
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()

        state_value = self.value_layer(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy

class PPOAgent:
    def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.timestep = 0
        self.memory = Memory()

        self.policy = ActorCriticDiscrete(state_dim, action_dim, n_latent_var)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
        self.policy_old = ActorCriticDiscrete(state_dim, action_dim, n_latent_var)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def update(self):   
        # Monte Carlo estimate of state rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards), reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards:
        rewards = torch.tensor(rewards, dtype=torch.float32)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).detach()
        old_actions = torch.stack(self.memory.actions).detach()
        old_logprobs = torch.stack(self.memory.logprobs).detach()

        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values : 新策略 重用 旧样本进行训练 
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # Finding the ratio (pi_theta / pi_theta__old): 
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:计算优势值
            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages ###  重要性采样的思想,确保新的策略函数和旧策略函数的分布差异不大
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages ### 采样clip的方式过滤掉一些新旧策略相差较大的样本
            loss = -torch.min(surr1, surr2)  + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

    def step(self, reward, done):
        self.timestep += 1 
        # Saving reward and is_terminal:
        self.memory.rewards.append(reward)
        self.memory.is_terminals.append(done)

        # update if its time
        if self.timestep % update_timestep == 0:
            self.update()
            self.memory.clear_memory()
            self.timstamp = 0

    def act(self, state):
        return self.policy_old.act(state, self.memory)

训练PPO agent


state_dim = 8 ### 游戏的状态是个8维向量
action_dim = 4 ### 游戏的输出有4个取值
n_latent_var = 256           # 神经元个数
update_timestep = 1200      # 每多少补跟新策略
lr = 0.002                  # learning rate
betas = (0.9, 0.999)
gamma = 0.99                # discount factor
K_epochs = 4                # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO  论文中表明0.2效果不错
random_seed = 1 

agent = PPOAgent(state_dim ,action_dim,n_latent_var,lr,betas,gamma,K_epochs,eps_clip)
# agent.network.train()  # Switch network into training mode 
EPISODE_PER_BATCH = 5  # update the  agent every 5 episode
NUM_BATCH = 200     # totally update the agent for 400 time


avg_total_rewards, avg_final_rewards = [], []

# prg_bar = tqdm(range(NUM_BATCH))
for i in range(NUM_BATCH):

    log_probs, rewards = [], []
    total_rewards, final_rewards = [], []
    values    = []
    masks     = []
    entropy = 0
    # collect trajectory
    for episode in range(EPISODE_PER_BATCH):
        ### 重开一把游戏
        state = env.reset()[0]
        total_reward, total_step = 0, 0
        seq_rewards = []
        for i in range(1000):  ###游戏未结束

            action = agent.act(state) ### 按照策略网络输出的概率随机采样一个动作
            next_state, reward, done, _, _ = env.step(action) ### 与环境state进行交互,输出reward 和 环境next_state
            state = next_state
            total_reward += reward
            total_step += 1     
            rewards.append(reward) ### 记录每一个动作的reward
            agent.step(reward, done)   
            if done:  ###游戏结束
                final_rewards.append(reward)
                total_rewards.append(total_reward)
                break

    print(f"rewards looks like ", np.shape(rewards))  
    if len(final_rewards)> 0 and len(total_rewards) > 0:
        avg_total_reward = sum(total_rewards) / len(total_rewards)
        avg_final_reward = sum(final_rewards) / len(final_rewards)
        avg_total_rewards.append(avg_total_reward)
        avg_final_rewards.append(avg_final_reward)

PPO agent在玩5把游戏


fix(env, seed)
agent.policy.eval() # set the network into evaluation mode
test_total_reward = []
for i in range(5):
    actions = []
    state = env.reset()[0]
    img = plt.imshow(env.render())
    total_reward = 0
    done = False
    while not done :
        action= agent.act(state)
        actions.append(action)
        state, reward, done, _, _ = env.step(action)
        total_reward += reward
        img.set_data(env.render())
        display.display(plt.gcf())
        display.clear_output(wait=True)
    test_total_reward.append(total_reward)

从下图可以看到,玩得比随机的时候要好很多,有三把都平安着地,小艇最终完全静止。
在这里插入图片描述

其中有3把得分超过200,证明ppo在300轮已经学到了如何玩这个游戏,比之前的随机agent要强了不少。
在这里插入图片描述
github 源码

参考

ChatGPT的强化学习部分介绍
基于人类反馈的强化学习(RLHF) 理论
Anaconda安装GYM &关于Box2D安装的相关问题
conda-forge / packages / box2d-py
jupyter Notebook 内核似乎挂掉了,它很快将自动重启

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/19107.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

尚硅谷-宋红康-JVM上中下篇完整笔记-JVM中篇

一.Class文件结构 1.概述 1.1 字节码文件的跨平台性 所有的JVM全部遵守Java虚拟机规范:Java SE Specifications,也就是说所有的JV环境都是一样的,这样一来字节码文件可以在各种JVM上运行。 1.2 Java的前端编译器 想要让一个Java程序正确地运行在JVM中&am…

177_模型_Power BI 进销存6大日期维度期初与期末

177_模型_Power BI 进销存6大日期维度期初与期末 一、背景 在经销存报表设计中,经常会遇到的便是期初与期末。当然我们这里说期初与期末指的是期初库存与期末库存。 这里的期一般常见的会有:年月日。本案例将演示 6 大日期维度,分别是&…

勒索病毒“顽疾”,没有“特效药”吗?

基础设施瘫痪、企业和高校重要文件被加密、毕业论文瞬间秒没……这就是六年前的今天,WannaCry勒索攻击爆发时的真实场景。攻击导致150多个国家数百万台计算机受影响,也让勒索病毒首次被全世界广泛关注。 六年后,勒索攻击仍是全球最严重的网络…

Kali E:Unable to locate package错误解决

默认的新装的kali 可能都会遇到这个安装报错E: Unable to locate package httrack问题,今天我记录下彻底解决过程和效果。 Command httrack not found, but can be installed with: apt install httrack Do you want to install it? (N/y)y apt install httrack Re…

什么是域名流量劫持?

作为传统的互联网攻击方式,域名流量劫持已经十分常见,这种网络攻击将会在不经授权的情况下控制或重定向一个域名的DNS记录。域名劫持的影响难以估量,因为它可以导致在访问一个网站时,用户被引导到另一个不相关的网站,对…

深入理解java虚拟机精华总结:运行时栈帧结构、方法调用、字节码解释执行引擎

深入理解java虚拟机精华总结:运行时栈帧结构、方法调用、字节码解释执行引擎 运行时栈帧结构局部变量表操作数栈动态连接方法返回地址 方法调用解析分派静态分派动态分派 基于栈的字节码解释执行引擎 运行时栈帧结构 Java虚拟机以方法作为最基本的执行单元&#xf…

Vue3 自定义指令让元素自适应高度,el-table在可视区域内滚起来

我始终坚持,前端开发不能满足于实现功能,而是需要提供优秀的交互与用户体验。即使没有产品没有UI的小项目,也可以自己控制出品质量,做到小而美。所以前端们不仅仅需要了解框架如何用,还要学习一些设计、交互、体验的知…

诗圣杜甫不同时期的代表作

杜甫一生大致分为四个时期。 中青年时期 青年时,作为官三代的杜甫并不缺钱,四处游历,与李白、高适唱和、仙游,成为佳话。这个时期杜甫的作品热血豪迈,气势蓬勃。代表作首推《望岳》: 岱宗夫如何&#xf…

2023/5/8总结

JAVA基础知识(2) 1.方法 1、方法定义 格式:public static void 方法名(){ //方法体 } 2、方法调用 格式:方法名(); 3、方法的通用格式 public static 返回值类型方法名&…

C++面向对象(黑马程序员)

内存分区模型 #include<iostream> using namespace std;//栈区数据注意事项&#xff1a;不要返回局部变量的地址 //栈区的数据由编译器管理开辟和释放int* func(int b) //形参数据也会放在栈区 {b 100;int a 10; //局部变量存放在栈区&#xff0c;栈区的数据在函数执…

存bean和取bean

准备工作存bean获取bean三种方式 准备工作 bean:一个对象在多个地方使用。 spring和spring boot&#xff1a;spring和spring boot项目&#xff1b;spring相当于老版本 spring boot本质还是spring项目&#xff1b;为了方便spring项目的搭建&#xff1b;操作起来更加简单 spring…

vue+express+mysql做一个简单前后端交互,从数据库中读取数据渲染到页面

1.下载上次的包 npm I &#xff0c;同时下载新的包 axios 2.打开数据库服务器&#xff0c;同时使用新建数据库一样&#xff0c;数据包名 3.新建一个项目 4.全局注册axios 5.新建一个server文件夹&#xff08;里面在建一个index.js的主文件&#xff09;用来放我们后端写的东西 …

数据结构与算法基础(王卓)(36):交换排序之快排【第三阶段:深挖解决问题】精华!精华!精华!!!重要的事情说三遍

目录 Review&#xff1a; 具体问题&#xff1a; 操作核心&#xff1a; 注&#xff1a; 操作分解&#xff1a; 操作实现&#xff1a; 问题&#xff08;1&#xff09;&#xff1a;进行不一样次数的 if / else 判断 问题&#xff08;2&#xff09;&#xff1a;通过判断条件…

TypeScript 最近各版本主要特性总结

&#xff08;在人生的道路上&#xff0c;当你的期望一个个落空的时候&#xff0c;你也要坚定&#xff0c;要沉着。——朗费罗&#xff09; TypeScript 官网 在线运行TypeScript代码 第三方中文博客 特性 typescript是javascript的超集&#xff0c;向javascript继承额外的编辑…

2023鲁大师评测沟通会:鲁大师尊享版登场、“鲁小车”正式上线

作为硬件评测界的“老兵”&#xff0c;鲁大师不仅有着十几年的硬件评测经验&#xff0c;并且一直都在不断地尝试、不断地推陈出新。在5月9日举行的“2023年鲁大师评测沟通会”上&#xff0c;鲁大师向大众展示了在过去一年间取得的成果。 PC业务迭代升级&#xff0c;鲁大师客户端…

干货满满!破解FP安全收款难题

怎样安全收款是做擦边产品卖家比较忧虑的问题&#xff0c;2023年已经即将来到了年中&#xff0c;跨境卖家们在这一方面做得怎么样了呢&#xff1f; 这期分享破解FP独立站收款难题的方法。 一、商家破解FP收款难题方法 1.第三方信用通道 优点&#xff1a;信用卡在国外使用率比…

好家伙,又一份牛逼笔记面世了...

最近网传的一些裁员的消息&#xff0c;搞的人心惶惶。已经拿到大厂offer的码友来问我&#xff1a;大厂还能去&#xff0c;去了会不会被裁。 还在学习的网友来问我&#xff1a;现在还要冲互联网么&#xff1f; 我是认为大家不用恐慌吧&#xff0c;该看啥看啥&#xff0c;该学啥…

ASEMI代理ADI亚德诺ADUM3211TRZ-RL7原厂芯片

编辑-Z ADUM3211TRZ-RL7参数描述&#xff1a; 型号&#xff1a;ADUM3211TRZ-RL7 数据速率&#xff1a;10 Mbps 传播延迟&#xff1a;50 ns 脉冲宽度失真&#xff1a;3 ns 脉冲宽度&#xff1a;100 ns 输出上升/下降时间&#xff1a;2.5 ns 供电电流&#xff1a;2.6 mA …

Maven与spring学习

目录 该如何学习Maven&#xff0c;是先该学习spring还是先学习Maven 能讲一下该如何学习Maven吗&#xff1f; 火狐浏览器有能让网页翻译成为中文的插件吗 秋田和柴犬是同一个狗吗 该如何学习Maven&#xff0c;是先该学习spring还是先学习Maven 学习Maven可以与学习Spring同…

一键安装k8s脚本

服务器配置 节点(华为云服务器)配置master 2vCPUs | 4GiB | s6.large.2 CentOS 7.8 64bit node1 2vCPUs | 8GiB | s6.large.4 CentOS 7.8 64bit node2 2vCPUs | 8GiB | s6.large.4 CentOS 7.8 64bit 1.master节点安装脚本&#xff1a;install_k8s_master.sh。 sh文件上传到…