DNQ算法原理(Deep Q Network)

1.强化学习概念

学习系统没有像很多其它形式的机器学习方法一样被告知应该做出什么行为

必须在尝试了之后才能发现哪些行为会导致奖励的最大化

当前的行为可能不仅仅会影响即时奖励,还会影响下一步的奖励以及后续的所有奖励

uTools_1689855542629

每一个动作(action)都能影响代理将来的状态(state)

通过一个标量的奖励(reward)信号来衡量成功

目标:选择一系列行动来最大化未来的奖励

具体的过程就是先观察,再行动,再观察....

uTools_1689855720456

状态(state)

Experience is a sequence of observations, actions, rewards.

The state is a summary of experience.

uTools_1689855933110

2.马尔科夫决策

马尔科夫决策要求:

  1. 能够检测到理想的状态

  2. 可以多次尝试

  3. 系统的下个状态只与当前状态信息有关,而与更早之前的状态无关,在决策过程中还和当前采取的动作有关

马尔科夫决策过程由5个元素构成:

S:表示状态集(states)

A:表示一组动作(actions)

P:表示状态转移概率P。表示在当前s∈S状态下,经过a∈A作用后,会转移到的其他状态的概率分布情况,在状态s下执行动作a,转移到s'的概率可以表示为p(s|s,a)

R:奖励函数(reward function)表示agent采取某个动作后的即时奖励

y:折扣系数意味着当下的reward比未来反馈的reward更重要

状态价值函数:v(s)=E[Ut|St=s]

t时刻的状态s能获得的未来回报的期望

价值函数用来衡量某一状态或状态 - 动作对的优劣价,累计奖励的期望

最优价值函数:所有策略下的最优累计奖励期望v (s)=max v.(s)

策略:己知状态下可能产生动作的概率分布

3.Bellman方程

Bellman方程:当前状态的价值和下一步的价值及当前的奖励(Reward)有关

价值函数分解为当前的奖励和下一步的价值两部分

这个过程通常采用迭代法实现,每一次迭代都会更新一次状态的值函数,直到收敛为止

值迭代求解

值迭代是一种求解Bellman方程的方法,其基本思想是通过不断迭代更新状态的值函数,直到收敛到最优解。具体步骤如下:

  1. 初始化值函数V(s)为0,或者任意一个非负值。

  2. 对于每个状态s,按照以下公式更新值函数:

V(s) = max{R(s, a) + γ * V(next_state)},其中a为状态s的一个动作,next_state为动作a对应的下一个状态,γ为折扣因子。 3. 重复步骤2,直到值函数收敛到最优解。

值迭代的时间复杂度为O(NT^2),其中N为状态数,T为迭代次数。值迭代的优点是计算量较小,缺点是只能找到局部最优解,而无法保证全局最优解。

前提是安装一下gym

pip install一下就可以了

import numpy as np
import sys
from gym.envs.toy_test import discrete
​
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
​
class CridworldEnv(discrete.DiscreteEnv):
    metadata = {'render.modss':['humin','ansi']}
    
    def __init__(self, shape=[4,4]):
        if not isinstance(shape, (list, tuple)) or not len[shape] == 2:
            raise ValueError('shape argument must be a list/tuple of length 2')
​
        self.shape = shape
​
# 定义状态空间、动作空间、转移概率和即时奖励
state_space = [0, 1, 2, 3, 4]
action_space = [0, 1, 2, 3]
transition_probabilities = {
    (0, 0): [0.5, 0.5],
    (0, 1): [0.5, 0.5],
    (1, 0): [0.1, 0.8, 0.1],
    (1, 1): [0.8, 0.1, 0.1],
    (2, 0): [0.5, 0.5],
    (2, 1): [0.5, 0.5],
    (3, 0): [0.8, 0.1, 0.1],
    (3, 1): [0.1, 0.8, 0.1],
    (4, 0): [0.5, 0.5],
    (4, 1): [0.5, 0.5]
}
reward_matrix = {
    (0, 0): [-1, -1],
    (0, 1): [10, -1],
    (0, 2): [-1, 10],
    (1, 0): [-1, -1],
    (1, 1): [-1, -1],
    (1, 2): [-1, -1],
    (2, 0): [-1, -1],
    (2, 1): [10, -1],
    (2, 2): [-1, 10],
    (3, 0): [-1, -1],
    (3, 1): [-1, -1],
    (3, 2): [-1, -1],
    (4, 0): [-1, -1],
    (4, 1): [-1, -1]
}
​
# 定义值函数初始值和折扣因子
V = {s: 0 for s in state_space}
gamma = 0.9
​
# 值迭代求解
T = 1000  # 迭代次数
for t in range(T):
    for s in state_space:
        Q = {a: 0 for a in action_space}
        for a in action_space:
            for next_s in state_space:
                Q[a] += transition_probabilities[(s, a)][next_s] * (reward_matrix[(s, a)][next_s] + gamma * V[next_s])
        V[s] = max(Q.values())
​
# 输出最优值函数和最优策略
print("Optimal value function:")
for s in state_space:
    print("V(%d) = %f" % (s, V[s]))
​
print("Optimal policy:")
for s in state_space:
    max_action = argmax(Q.items(), key=lambda x: x[1])[0]
    print("Policy for state %d: take action %d" % (s, max_action))

手写案例:

import numpy
from gridworld import GridworldEnv
​
env = GridworldEnv()
​
def value_iteration(env, theta=0.0001,discount_factor = 1.0):
    def one_setp_lookahead(state, v):
        A = np.roros(env.nA)
        #更新值
        for a in range(env.nA):
            for prob,next_state,reward,done in env.P[state][a]:
                A[a] += ropb*(reward + discount_factor*v[next_state])
        return A
    w = np.reros(env.nS)
    
    #进行一个迭代更新
    while True:
        delta = 0
        
        for s in range(env.nS):
            # Do a one step lookahead to find the best action
            A = one_step_lookahead(s,v)
            # Calculate delta across all states seen so far
            best_action_value = np.max(A)
            # Update the value function
            delta = max(delta,np.abs(best_action_value-v[s]))
            v[s] = best_action_value
        # Check if we can stop
        if delta < theta:
            break
    policy = np.zeros((env.nS,env.nA))
    for s in range(env.nS):
        A = one_step_lookahead(s,v)
        best_action_value = np.max(A)
        policy[s,best_action_value] = 1.0
    return policy,v
​
policy, v = value_iteration(env)
​
print("Policy Probability Distribution")
print(policy)
print("")
​
print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

4.Q-learning

uTools_1689918204365

针对图例的形式,我们想要走到5号Goal State,我们要给靠近5的几条路径上加上一些分数奖励,这样才能吸引智能体靠近,并获取达到最后的目的。

Q-learning是强化学习的主要算法之一,是一种无模型的学习方法。它基于一个关键假设,即智能体和环境的交互可看作为一个Markov决策过程(MDP),根据智能体当前所处的状态和所选择的动作,决定一个固定的状态转移概率分布、下一个状态、并得到一个即时回报。Q-learning的目标是寻找一个策略可以最大化将来获得的报酬。

Q-learning的内在思想是通过一个价值表格或价值函数来选取价值最大的动作。Q(s,a)表示在某一具体初始状态s和动作a的情况下,对未来收益的期望值。Q-Learning算法维护一个Q-table,Q-table记录了不同状态下s(s∈S),采取不同动作a(a∈A)的所获得的Q值。在探索环境之前,初始化Q-table,当智能体与环境交互的过程中,算法利用贝尔曼方程来迭代更新Q(s,a),每一轮结束后就生成了一个新的Q-table。智能体不断与环境进行交互,不断更新这个表格,使其最终能收敛。最终,智能体就能通过表格判断在某个状态下采取什么动作,才能获得最大的Q值。

Q-learning迭代计算

Step1 给定学习参数γ和reward矩阵R

Step2 令Q=0

Step3 For each episode

步骤3中也可以细分:首先,可以随机选择一个初始状态s。然后当没有达到目标状态,则执行一下几步,在当前状态s的所有可能行为中选取一个行为a,再利用选定的行为a,得到下一个状态s1,按照前面规定的计算方式来计算Q(s, a),再把s1赋值给我们的s,进行下一步迭代计算。

这可能需要上千上万次才能收敛到一个状态。

5.Deep Q Network

uTools_1689920362769

Q-table是Q学习算法中的一个关键概念,它是一个表格,记录了每个状态和动作对应的最大Q值。

Q-table中的每一行代表一个状态,每一列代表一个动作,表格中的每个元素Q(s,a)表示在状态s下采取动作a所能获得的最大收益的期望值。在Q-learning算法中,智能体通过不断探索环境,与环境交互,更新Q-table,从而逐渐学习到在特定状态下采取何种动作能够获得最大的收益.

  1. Convert image to grayscale

  2. Resize image to 80 * 80

  3. Stack last 4 frames to produce an 80 * 80 * 4 input array for network

Exploration VS Exploitation : we both need.

δ - greedy exploration : have chances to explore.

6.DQN的环境搭建

我们主要是以小鸟为例子进行操作的。

uTools_1689930427128

import tensorflow as tf
import cv2
import sys
sys.path('game')
import random
import numpy as np
from collections import deque
​
GAME = 'bird'
# 或上或下
ACTIONs = 2
GAMMA = 0.99
OBSERVE = 1000
ECPLORE = 3000000
FINAL_EPSILOW = 0.0001
INITIAL = 0.1
REPLAY_MOMORY = 50000
RATCH = 32
FRAME_PER_ACTION = 1
​
def createNetwork():
    # 三层卷积的形式
    # 注意,池化层是没有参数的
    W_conv1 = weights_variable([8, 8, 4, 32])
    b_conv1 = bias_variable([32])
    
    W_conv2 = weights_variable([4, 4, 32, 64])
    b_conv2 = bias_variable([64])
    
    W_conv3 = weights_variable([3, 3, 64, 64])
    b_conv3 = bias_variable([32])
    
    W_fc1 = weights_variable([1600,512])
    b_fc1 = weights_variable([512])
    
    W_fc1 = weights_variable([512,ACTIONS])
    b_fc1 = weights_variable([ACTIONS])
    
    s = tf.placeholder('float', [None,80,80,4])
    
    h_conv1 = tf.nn.relu(conv2d(s,W_conv1,4)+b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    
    h_conv2 = tf.nn.relu(conv2d(h-pool1,W_conv2,2)+b_conv2)
    # h_pool2 = max_pool_2x2(h_conv2)
    h_conv3 = tf.nn.relu(conv2d(h-pool1,W_conv3,1)+b_conv3)
    
    # reshape是将连接操作,将立体图转化为向量化数据
    h_conv3_flat = tf.reshape(h_conv3, [-1,1600])
    
    h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat,W_fc1)+b_fc1)
    
    readout = tf.matmul(h_fc1,W_fc2) + b_fc2
    return s,readout,h_fc1
​
def weights_variable(shape):
    initial = tf.truncated_normal(shape,stddev=0.01)
    return tf.Variable(initial)
def bias_variable(shape):
    initial = tf.constant(0.01,shape = shape)
    return tf.Variable(initial)
def conv2d(x,W,stride):
    return tf.nn.conv2d(x,W,strides=[1,stride,stride,1],padding='SAME')
def max_pool_2x2(x):
    return nn.max_pool(x,ksize = [1,2,2,1],strides=[1,stride,stride,1],padding='SAME')
​
def trainNetwork(s,readout,,h_fc1,sess):
    
    a = tf.placeholder('float', [None,ACTIONS])
    y = tf.placeholder('float', [None])
    
    readout_action = tf.reduce_mean(tf.multiply(readout,a),reduce_indices = 1)
    cost = tf.reduce_mean(tf.square(y = readout_action))
    train_step = tf.train.AdamOptimizer(1e-6).minimaize(cost)
    
    game_state = game.GameState()
    
    D = deque()
    do_nothing = np.zeros(ACTIONS)
    do_nothing[0] = 1
    
    x_t,r_0,terminal = game_state.frame_step(do_nothing)
    # 将图变为80*80的二维图像,在转化为1,255的
    x_t = cv2.cvtColor(cv2.resize(x_t,(80,80),cv2.COLOR_BGR2CRAY))
    ret,x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
    
    s_t = np.stack((x_t,x_t,x_t,x_t),axis = 2)
    
    saver = tf.train.Saver()
    see.run(tf.initialize_all_variables())
    checkpoint = tf.train.get_checkpoint_state('saved network')
    
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)
        print('Successfully loaded')
    else:
        print('load failed')
        
    epsilon = INITIAL_EPSILOW
    t = 0
    while 'flappy bird' != 'angry bird':
        readout_t = readout.eval(feed_dict = {s:[s_t]})[0]
        a_t = np.zeros([ACTIONS])
        action_index = 0
        
        if t % 1 == 0:
            if random.random() <= epsilon:
                print('Rondom Action')
                action_index = random.randint(ACTIONS)
                a_t[action_index] = 1
            else:
                # 决定小鸟向上飞还是向下
                action_index = np.argmax(readout_t)
                a_t[action_index] = 1
        x_t1_colored,r_t,r_t,terminal = game_state.frame_step(a_t)
        x_t = cv2.cvtColor(cv2.resize(x_t1,colored,(80,80),cv2.COLOR_BGR2CRAY))
        ret,x_t = cv2.threshold(x_t1,1,255,cv2.THRESH_BINARY)
        x_t1 = np.reshape(x_t1, (80,80,1))
        s_t1 = np.append(x_t1, s_t[:,:,3],axis = 2)
        
        # 强化学习
        D.append(s_t,a_t,r_t,s_t1,terminal)
        # s_t当前状态
        # a_t当前动作
        # r_t奖励和回馈
        # s_t1新的状态
        # terminal判断是否结束
        if len(D) > REPLAY_MOMORY:
            D.popleft()
        
        if t > OBSERVE:
            minibatch = random.sample(D,BATCH)
            
            s_j_batch = [d[0] for d in minibatch]
            
            a_batch = [d[1] for d in minibatch]
            
            r_batch = [d[2] for d in minibatch]
            
            s_j1_batch = [d[3] for d in minibatch]
            
            y_batch = []
            
            # 神经网络的输出值
            readout_j1_batch = readout.eval(feed_dict = [s:s_j1_batch])
            for i in range(0, len(minibatch)):
                terminal = minibatch[i][4]
                
                if terminal:
                    y_batch.append(r_batch[i])
                else:
                    y_batch.append(r_batch[i] + GAMMA*np.max(readout_j1_batch[i]))
                    
            train_step.run(feed_dict = {
                y:y_batch,
                a:a_batch,
                s:s_j_batch,
            })
            
            # update information
            s_t = s_t1
            t += 1
            if t % 10000 == 0:
                saver.save(sess, './',global_step = t)
                
            state = ''
            if t <= OBSERVE:
                state = 'OBSERVE'
            else:
                state = 'train'
                
            print 
    
def playGame():
    sess = tf.InterativeSession()
    s,readout,h_fel = createNetwork()
    # 训练
    trainNetwork()
    
​
def main():
    playGame()
    
if __name__ == '__main__':
    main()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/86536.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

手机无人直播软件,有哪些优势?

近年来&#xff0c;随着手机直播的流行和直播带货的市场越来越大&#xff0c;手机无人直播软件成为许多商家开播带货的首选。在这个领域里&#xff0c;声音人无人直播系统以其独特的优势&#xff0c;成为市场上备受瞩目的产品。接下来&#xff0c;我们将探讨手机无人直播软件给…

OpenCV中QR二维码的生成与识别(CIS摄像头解析)

1、QR概述 QR(Quick Response)属于二维条码的一种&#xff0c;意思是快速响应的意思。QR码不仅信息容量大、可靠性高、成本低&#xff0c;还可表示汉字及图像等多种文字信息、其保密防伪性强而且使用非常方便。更重要的是QR码这项技术是开源的&#xff0c;在移动支付、电影票、…

Elasticsearch Split和shrink API

背景&#xff1a; 尝试解决如下问题&#xff1a;单分片存在过多文档&#xff0c;超过lucene限制 分析 1.一般为日志数据或者OLAP数据&#xff0c;直接删除索引重建 2.尝试保留索引&#xff0c;生成新索引 - 数据写入新索引&#xff0c;查询时候包含 old_index,new_index 3.…

内容分发网络CDN与应用程序交付网络ADN之间的异同

当您想要提高网站性能时&#xff0c;需要考虑许多不同的配置和设施&#xff0c;CDN和ADN是我们常遇见的几种选项之一。“CDN”指“内容分发网络”&#xff0c;而“ADN”指“应用程序交付网络”&#xff0c;但他们两者很容易被混淆&#xff0c;虽然它们的功能和作用都有较大差异…

使用多个神经网络进行细菌分类(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

vellum (Discovering Houdini VellumⅡ柔体系统)学习笔记

视频地址&#xff1a; https://www.bilibili.com/video/BV1ve411u7nE?p3&spm_id_frompageDriver&vd_source044ee2998086c02fedb124921a28c963&#xff08;搬运&#xff09; 个人笔记如有错误欢迎指正&#xff1b;希望可以节省你的学习时间 ~享受艺术 干杯&#x1f37b…

[Mac软件]AutoCAD 2024 for Mac(cad2024) v2024.3.61.182中文版支持M1/M2/intel

下载地址&#xff1a;前往黑果魏叔官网 AutoCAD是一款计算机辅助设计&#xff08;CAD&#xff09;软件&#xff0c;目前已经成为全球最受欢迎的CAD软件之一。它可以在二维和三维空间中创建精确的技术绘图&#xff0c;并且可以应用于各种行业&#xff0c;如建筑、土木工程、机械…

【操作系统】24王道考研笔记——第三章 内存管理

第三章 内存管理 一、内存管理概念 1.基本概念 2.覆盖与交换 覆盖技术&#xff1a; 交换技术&#xff1a; 总结&#xff1a; 3.连续分配管理方式 单一连续分配 固定分区分配 动态分区分配 动态分区分配算法&#xff1a; 总结&#xff1a; 4.基本分页存储管理 定义&#xf…

【Unity3D赛车游戏】【二】如何制作一个真实模拟的汽车

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;Uni…

VoxWeekly|The Sandbox 生态周报|20230821

欢迎来到由 The Sandbox 发布的《VoxWeekly》。我们会在每周发布&#xff0c;对上一周 The Sandbox 生态系统所发生的事情进行总结。 如果你喜欢我们内容&#xff0c;欢迎与朋友和家人分享。请订阅我们的 Medium 、关注我们的 Twitter&#xff0c;并加入 Discord 社区&#xf…

01、Cannot resolve MVC View ‘xxxxx前端页面‘

Cannot resolve MVC View ‘xxxxx前端页面’ 没有找到对应的mvc的前端页面。 代码&#xff1a;前端这里引入了 thymeleaf 模板 解决&#xff1a; 需要添加 thymeleaf 的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>s…

基于nginx禁用访问ip

一、背景 网络安全防护时&#xff0c;禁用部分访问ip,基于nginx可快速简单实现禁用。 二、操作 1、创建 conf.d文件夹 在nginx conf 目录下创建conf.d文件夹 Nginx 扩展配置文件一般在conf.d mkdir conf.d 2、新建blocksip.conf文件 在conf.d目录新建禁用ip的扩展配置文…

DevExpress WPF HeatMap组件,一个高度可自定义的热图控件!

像所有DevExpress UI组件一样&#xff0c;HeatMap组件针对速度进行了优化&#xff0c;包括数十个自定义设置和高级API&#xff0c;因此用户可以快速将美观的数据可视化集成到下一个WPF应用程序中。 P.S&#xff1a;DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满…

vscode里配置C#环境并运行.cs文件

vscode是一款跨平台、轻量级、开源的IDE, 支持C、C、Java、C#、R、Python、Go、Nodejs等多种语言的开发和调试。下面介绍在vscode里配置C#环境。这里以配置.Net SDK v5.0&#xff0c;语言版本为C#9.0&#xff0c;对应的开发平台为VS2019&#xff0c;作为案例说明。 1、下载vsc…

文件四剑客

目录 前言 一、正则表达式 二、grep 三、find 四、sed 五、awk 前言 文件四剑客是指在计算机领域中常用的四个命令行工具&#xff0c;包括awk、find、grep和sed。它们在处理文本文件和搜索文件时非常强大和实用。 1. awk是一种强大的文本处理工具&#xff0c;它允许用户根据指…

数据结构——栈和队列

栈和队列的建立 前言一、栈1.栈的概念2.栈的实现3.代码示例&#xff08;1&#xff09;Stack.h&#xff08;2&#xff09;Stack.c&#xff08;3&#xff09;Test.c&#xff08;4&#xff09;运行结果&#xff08;5&#xff09;完整代码演示 二、队列1.队列的概念2.队列的实现3.代…

ps吸管工具用不了怎么办?

我们的办公神器ps软件&#xff0c;大家一定是耳熟能详的吧。Adobe photoshop是电影、视频和多媒体领域的专业人士&#xff0c;使用3D和动画的图形和Web设计人员&#xff0c;以及工程和科学领域的专业人士的理想选择。Photoshop支持宽屏显示器的新式版面、集20多个窗口于一身的d…

软件测试技术分享丨遇到bug怎么分析?

为什么定位问题如此重要&#xff1f; 可以明确一个问题是不是真的“bug” 很多时候&#xff0c;我们找到了问题的原因&#xff0c;结果发现这根本不是bug。原因明确&#xff0c;误报就会降低 多个系统交互&#xff0c;可以明确指出是哪个系统的缺陷&#xff0c;防止“踢皮球…

IDEA中导出Javadoc遇到的GBK编码错误的解决思路和应用

IDEA中导出Javadoc遇到的GBK编码错误的解决思路和应用 ​ 当我们在导出自己写的项目的api文档的时候呢&#xff0c;有的时候会出现以下问题&#xff1a;也就是GBK编码错误不可导出 错误描述&#xff1a;编码GBK的不可映射字符无法导出&#xff0c;可以看出这是我们自己写的中文…

容器和云原生(三):kubernetes搭建与使用

目录 单机K8S docker containerd image依赖 kubeadm初始化 验证 crictl工具 K8S核心组件 上文安装单机docker是很简单docker&#xff0c;但是生产环境需要多个主机&#xff0c;主机上启动多个docker容器&#xff0c;相同容器会绑定形成1个服务service&#xff0c;微服务…
最新文章