强化学习玩flappy_bird

强化学习玩flappy_bird(代码解析)

游戏地址:https://flappybird.io/

该游戏的规则是:

  • 点击屏幕则小鸟立即获得向上速度。

  • 不点击屏幕则小鸟受重力加速度影响逐渐掉落。

  • 小鸟碰到地面会死亡,碰到水管会死亡。(碰到天花板不会死亡)

  • 小鸟通过水管会得分。

    img

    具体的网络结构如图所示,网络架构是拿到游戏状态(每个样本维度是 80 * 80 * 4),然后卷积(输出维度 20 * 20 * 32)、池化(输出 10 * 10 * 32)、卷积(输出 5 * 5 * 64)、卷积(输出 5 * 5 * 64)、reshape(1600)、全连接层(512)、输出层(2)

一、flappy_bird_utils.py

"""
游戏素材加载
"""
import pygame
import sys
import os

assets_dir = os.path.dirname(__file__)

def load():
    # 小鸟挥动翅膀的3个造型
    PLAYER_PATH = (
        assets_dir + '/assets/sprites/redbird-upflap.png',
        assets_dir + '/assets/sprites/redbird-midflap.png',
        assets_dir + '/assets/sprites/redbird-downflap.png'
    )

    # 游戏背景图,纯黑色是为了训练降低干扰
    BACKGROUND_PATH = assets_dir + '/assets/sprites/background-black.png'

    # 水管图片
    PIPE_PATH = assets_dir + '/assets/sprites/pipe-green.png'

    IMAGES, SOUNDS, HITMASKS = {}, {}, {}
    #初始化了三个空字典:IMAGES用于存储加载的图片资源,SOUNDS用于存储加载的声音资源,HITMASKS用于存储碰撞掩码(用于检测游戏中的碰撞)

    # 加载数字0~9的图片,类型是Surface图像
    #使用convert_alpha()方法将图片转换为带有透明度的格式
    IMAGES['numbers'] = (
        pygame.image.load(assets_dir + '/assets/sprites/0.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/1.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/2.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/3.png').convert_alpha(),    # convert/conver_alpha是为了将图片转成绘制用的像素格式,提高绘制效率
        pygame.image.load(assets_dir + '/assets/sprites/4.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/5.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/6.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/7.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/8.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/9.png').convert_alpha()
    )

    # 地面图片
    IMAGES['base'] = pygame.image.load(assets_dir + '/assets/sprites/base.png').convert_alpha()


    #根据操作系统类型,设置声音文件的扩展名。Windows系统使用.wav,其他系统使用.ogg
    if 'win' in sys.platform:
        soundExt = '.wav'
    else:
        soundExt = '.ogg'

    # 各种Sound对象
    #加载各种游戏音效,并将它们存储在SOUNDS字典中
    SOUNDS['die']    = pygame.mixer.Sound(assets_dir + '/assets/audio/die' + soundExt)
    SOUNDS['hit']    = pygame.mixer.Sound(assets_dir + '/assets/audio/hit' + soundExt)
    SOUNDS['point']  = pygame.mixer.Sound(assets_dir + '/assets/audio/point' + soundExt)
    SOUNDS['swoosh'] = pygame.mixer.Sound(assets_dir + '/assets/audio/swoosh' + soundExt)
    SOUNDS['wing']   = pygame.mixer.Sound(assets_dir + '/assets/audio/wing' + soundExt)

    # 加载背景图片
    IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()

    # 加载小鸟的3个姿态
    IMAGES['player'] = (
        pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
    )

    # 加载水管图片,并使用rotate()方法将其旋转180度以创建上方的水管图片,然后将这两个图片存储在IMAGES字典中
    IMAGES['pipe'] = (
        pygame.transform.rotate(
            pygame.image.load(PIPE_PATH).convert_alpha(), 180),
        pygame.image.load(PIPE_PATH).convert_alpha(),
    )

    # 计算水管图片的bool掩码
    #为水管图片生成碰撞掩码,并将它们存储在HITMASKS字典中。
    HITMASKS['pipe'] = (
        getHitmask(IMAGES['pipe'][0]),
        getHitmask(IMAGES['pipe'][1]),
    )

    # 生成小鸟图片的bool掩码
    HITMASKS['player'] = (
        getHitmask(IMAGES['player'][0]),
        getHitmask(IMAGES['player'][1]),
        getHitmask(IMAGES['player'][2]),
    )

    return IMAGES, SOUNDS, HITMASKS

# 生成图片的bool掩码矩阵,true表示对应像素位置不是透明的部分
def getHitmask(image):
    """returns a hitmask using an image's alpha."""
    mask = []
    for x in range(image.get_width()):#遍历所有的像素点
        mask.append([])
        for y in range(image.get_height()):
            mask[x].append(bool(image.get_at((x,y))[3]))    # 像素点是RGBA,例如:(83, 56, 70, 255),最后是透明度(0是透明,255是不透明)
    #对于图像中的每一个像素点,使用 image.get_at((x,y)) 获取该点的颜色值。颜色值通常以 RGBA(红色、绿色、蓝色、透明度)格式存储,其中 A 代表 Alpha 通道,即透明度。image.get_at((x,y))[3] 就是获取该像素点的 Alpha 值。
    return mask

这里面需要解释的碰撞掩码是什么?

碰撞掩码(Collision Mask)是一种在计算机图形学和游戏开发中用于检测物体间碰撞的技术。它通常由一个布尔矩阵表示,其中每个像素点的值表示该点是否是物体的一部分。在处理碰撞检测时,通过比较两个物体的碰撞掩码可以判断它们是否重叠,从而确定是否发生了碰撞。

以下是碰撞掩码的一些关键点:

  1. 透明度判断:在许多游戏中,碰撞掩码是通过检查图像的透明度(Alpha通道)来生成的。如果图像的某个像素点是不透明的(例如,Alpha值为255),那么在碰撞掩码中对应的位置会被标记为True或实体部分;如果是透明的(Alpha值为0),则被标记为False或非实体部分。

  2. 简化碰撞检测:使用碰撞掩码可以避免直接对图像的每个像素点进行碰撞检测,这样可以显著提高碰撞检测的效率,尤其是在处理复杂图形或大规模场景时。

  3. 灵活性:碰撞掩码可以根据需要设计成不同的形状和大小,从而实现精确的碰撞检测。例如,一个角色的碰撞掩码可以是其轮廓的形状,而不仅仅是一个矩形或正方形。

  4. 性能优化:在游戏开发中,碰撞检测通常是一个计算密集型的过程。通过使用碰撞掩码,可以减少不必要的像素点比较,从而提高游戏性能。

  5. 应用场景:碰撞掩码不仅用于检测角色与障碍物之间的碰撞,还可以用于检测子弹与目标的碰撞、角色间的交互等。

getHitmask函数就是用来生成碰撞掩码的。它通过遍历图像的每个像素点,并检查其透明度来创建一个布尔矩阵。这个矩阵随后可以用于游戏中的碰撞检测逻辑,以判断小鸟是否与水管或其他物体发生了碰撞。

二、wrapped_flappy_bird.py

import numpy as np
import sys
import random
import pygame
from . import flappy_bird_utils
import pygame.surfarray as surfarray#用于将pygame的Surface对象转换为NumPy数组
from pygame.locals import *
from itertools import cycle#用于创建一个可循环的对象

# 屏幕宽*高
FPS = 30
SCREENWIDTH  = 288
SCREENHEIGHT = 512

# 初始化游戏,创建一个时钟对象来控制帧率,设置游戏窗口的尺寸和标题
pygame.init()
FPSCLOCK = pygame.time.Clock()  # FPS限速器
SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))   # 宽*高
pygame.display.set_caption('Flappy Bird')   # 标题

# 加载素材
IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()

PIPEGAPSIZE = 100 # 上下水管之间的距离是固定的100像素
BASEY = SCREENHEIGHT * 0.79 # 地面图片的y坐标

'''
地面图片在游戏窗口中的垂直位置。SCREENHEIGHT是游戏窗口的高度,
乘以0.79后得到一个值,这个值就是地面图片距离窗口顶部的像素距离。
因此,BASEY变量代表了地面图片在Y轴(垂直轴)上的位置,
它被设置在屏幕高度的79%的位置,这样地面图片会显示在屏幕的下半部分。
'''

# 小鸟图片的宽*高
PLAYER_WIDTH = IMAGES['player'][0].get_width()
PLAYER_HEIGHT = IMAGES['player'][0].get_height()
# 水管图片的宽*高
PIPE_WIDTH = IMAGES['pipe'][0].get_width()
PIPE_HEIGHT = IMAGES['pipe'][0].get_height()

# 背景图片的宽
BACKGROUND_WIDTH = IMAGES['background'].get_width()

# 创建一个循环对象,小鸟图片动画播放顺序
PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])

'''
0 表示第一张图片(翅膀上挥)
1 表示第二张图片(翅膀中挥)
2 表示第三张图片(翅膀下挥)
序列最后再次包含1,以实现翅膀的自然循环
'''

# Flappy bird游戏类
class GameState:
    def __init__(self):
        self.score = 0#初始化玩家的得分为 0
        self.playerIndex = 0#初始化玩家小鸟的当前动画索引为 0,这将决定小鸟显示哪一张动画图片
        self.loopIter = 0#初始化一个循环计数器,可能用于跟踪动画或游戏循环的次数

        # 玩家初始坐标
        self.playerx = int(SCREENWIDTH * 0.2)#设置玩家小鸟的初始 x 坐标,位于屏幕宽度的 20% 位置
        self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)#计算并设置玩家小鸟的初始 y 坐标,使得小鸟位于屏幕垂直居中的位置

        # 地面图片需要跑马灯效果,它比屏幕宽一点,每帧向左移动,当要耗尽时重新回到右边,如此往复
        self.basex = 0         # 地面图片的x坐标
        self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH  # 地面图片比屏幕宽度长多少像素,就是它可以移动的距离

        newPipe1 = getRandomPipe()  # 生成一对上下管子
        newPipe2 = getRandomPipe()  # 再生成一对上下管子

        # 上面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.upperPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
        ]
        # 下面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.lowerPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
        ]

        # 水管的水平移动速度,每次x-4实现向左移动
        self.pipeVelX = -4

        # 小鸟Y方向速度
        self.playerVelY    =  0
        # 小鸟Y方向重力加速度,每帧作用域playerVelY,令其Y速度向下加大
        self.playerAccY    =   1
        # 点击后,小鸟Y方向速度重置为-9,也就是开始向上移动
        self.playerFlapAcc =  -9

        # 小鸟Y方向速度限制
        self.playerMaxVelY =  10   # Y向下最大速度10

    # 执行一次操作,返回操作后的画面、本次操作的奖励(活着+0.1,死了-1,飞过水管+1)、游戏是否结束
    def frame_step(self, input_actions):
        # 给pygame对积累的事件做一下默认处理
        pygame.event.pump()

        # 活着就奖励0.1分
        reward = 0.01
        # 是否死了
        terminal = False

        # 必须传有效的action,[1,0]表示不点击,[0,1]表示点击,全传0是不对的
        if sum(input_actions) != 1:#检查 input_actions 确保只有一个动作被执行
            raise ValueError('Multiple input actions!')

        # 每3帧换一次小鸟造型图片,loopIter统计经过了多少帧
        if (self.loopIter + 1) % 3 == 0:
            self.playerIndex = next(PLAYER_INDEX_GEN)
        self.loopIter += 1

        # 让地面向左移动,游戏开始的时候地面x=0,逐步减小x
        if self.basex + self.pipeVelX <= -self.baseShift:
            self.basex = 0
        else: # 图片即将滚动耗尽,重置x坐标
            self.basex += self.pipeVelX

        # 点击了屏幕
        if input_actions[1] == 1:
            self.playerVelY = self.playerFlapAcc # 将小鸟y方向速度重置为-9,也就是向上移动
            #SOUNDS['wing'].play()   # 播放扇翅膀的声音
        elif self.playerVelY < self.playerMaxVelY:  # 没点击屏幕并且没达到最大掉落速度,继续施加重力加速度
            self.playerVelY += self.playerAccY

        # 将速度施加到小鸟的y坐标上
        self.playery += self.playerVelY
        if self.playery < 0:    # 撞到上边缘不算死
            self.playery = 0 # 限制它别飞出去
        elif self.playery + PLAYER_HEIGHT >= BASEY: # 小鸟碰到地面
            self.playery = BASEY - PLAYER_HEIGHT # 限制它别穿地

        # 让上下水管都向左移动一次
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            uPipe['x'] += self.pipeVelX
            lPipe['x'] += self.pipeVelX

        # 判断小鸟是否穿过了一排水管,因为上下水管x一样,只需要用上排水管判断
        playerMidPos = self.playerx + PLAYER_WIDTH / 2  # 小鸟中心的x坐标(这个是固定值,小鸟实际不会动,是水管在动)
        for pipe in self.upperPipes:    # 检查与上排水管的关系
            pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 # 水管中心的x坐标
            if pipeMidPos <= playerMidPos < pipeMidPos + abs(self.pipeVelX): # 小鸟x坐标刚刚飞过了水管x中心(4是水管的移动速度)
                self.score += 1 # 游戏得分+1
                #SOUNDS['point'].play()
                reward = 100  # 产生强化学习的动作奖励10分

        # 最左侧水管马上离开屏幕,生成新水管
        if 0 < self.upperPipes[0]['x'] < 5:
            newPipe = getRandomPipe()
            self.upperPipes.append(newPipe[0])
            self.lowerPipes.append(newPipe[1])

        # 最左侧水管彻底离开屏幕,删除它的上下2根水管
        if self.upperPipes[0]['x'] < -PIPE_WIDTH:
            self.upperPipes.pop(0)
            self.lowerPipes.pop(0)

        # 检查小鸟是否碰到水管
        isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 'index': self.playerIndex}, self.upperPipes, self.lowerPipes)
        if isCrash:  # 死掉了
            #SOUNDS['hit'].play()
            #SOUNDS['die'].play()
            reward = -10 # 负向激励分
            terminal = True # 本次操作导致游戏结束了

        ##### 进入重绘 #######

        # 贴背景图
        SCREEN.blit(IMAGES['background'], (0,0))
        # 画水管
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
            SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
        # 画地面
        SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
        # 画得分(训练时候别打开,造成干扰了)
        #showScore(self.score)
        # 画小鸟
        SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery))
        # 重绘
        pygame.display.update()
        # 留存游戏画面(截图是列优先存储的,需要转行行优先存储)
        # https://stackoverflow.com/questions/34673424/how-to-get-numpy-array-of-rgb-colors-from-pygame-surface
        image_data = pygame.surfarray.array3d(pygame.display.get_surface()).swapaxes(0,1)
        # 死亡则重置游戏状态
        if terminal:
            self.__init__()
        # 控制FPS
        FPSCLOCK.tick(FPS)
        return image_data, reward, terminal

# 生成一对水管,放到屏幕外面
def getRandomPipe():
    gapY = random.randint(70, 140)#生成一个介于 70 到 140 之间的随机整数,并将其赋值给变量 gapY。这个随机数决定了水管之间缝隙的上边缘的 y 坐标

    # 注:每一对水管的缝隙高度都是一样的PIPEGAPSIZE,gayY决定的是缝隙的上边缘y坐标
    pipeX = SCREENWIDTH + 10    # 水管出现在屏幕右侧之外

    return [
        {'x': pipeX, 'y': gapY - PIPE_HEIGHT},  # 计算上面水管图片的y坐标,就是缝隙上边缘y减去水管本身高度
        {'x': pipeX, 'y': gapY + PIPEGAPSIZE},  # 计算下面水管图片的y坐标,就是缝隙上边缘y加上缝隙本身高度
    ]

# 检查小鸟是否碰到水管或者地面(天花板不算)
def checkCrash(player, upperPipes, lowerPipes):
    pi = player['index']    # 小鸟的第几张图片

    # 图片的宽*高
    player['w'] = IMAGES['player'][pi].get_width()
    player['h'] = IMAGES['player'][pi].get_height()

    # 小鸟碰到了地面
    if player['y'] + player['h'] >= BASEY - 1:
        return True
    else: # 小鸟与水管进行碰撞检测
        # 小鸟图片的矩形区域
        playerRect = pygame.Rect(player['x'], player['y'], player['w'], player['h'])

        # 每一对水管
        for uPipe, lPipe in zip(upperPipes, lowerPipes):
            # 上面水管的矩形
            uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
            # 下面水管的矩形
            lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)

            # 小鸟图片的非透明像素掩码
            pHitMask = HITMASKS['player'][pi]
            # 上水管的非透明像素掩码
            uHitmask = HITMASKS['pipe'][0]
            # 下水管的非透明像素掩码
            lHitmask = HITMASKS['pipe'][1]

            # 检测小鸟与上面水管的碰撞
            uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
            # 检测小鸟与下面水管的碰撞
            lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)

            if uCollide or lCollide:
                return True
    return False


# 2个矩形区域的碰撞检测
def pixelCollision(rect1, rect2, hitmask1, hitmask2):
    '''
    rect1 和 rect2 是参与碰撞检测的两个矩形区域,通常是游戏中对象的位置和大小
    hitmask1 和 hitmask2 是与这两个矩形关联的碰撞掩码,它们是布尔数组,表示相应对象的哪些部分是实体(非透明)
    '''
    # 计算两个矩形的交集,即它们重叠的区域。如果没有重叠(即两个矩形没有碰撞),则 clip 方法返回一个宽度或高度为 0 的矩形
    rect = rect1.clip(rect2)

    # 相交面积为0
    if rect.width == 0 or rect.height == 0:
        return False

    # 相交矩形x,y相对于2个矩形左上角的距离
    x1, y1 = rect.x - rect1.x, rect.y - rect1.y
    #计算交集区域相对于 rect1 的相对位置
    x2, y2 = rect.x - rect2.x, rect.y - rect2.y#同理

    # 检查相交矩形内的每个点,是否在2个矩形内同时是非透明点,那么就碰撞了
    for x in range(rect.width):
        for y in range(rect.height):
            if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
                return True
    return False

# 展示得分,传入一个整数得分
def showScore(score):
    # 转成单个数字的列表
    scoreDigits = [int(x) for x in list(str(score))]
    #将得分 score 转换成字符串,然后将其每个字符(即每个单独的数字)转换成整数,并存储在列表 scoreDigits 中。这样,得分就被分解成了单个数字的列表

    # 计算展示所有数字要占多少像素宽度
    totalWidth = 0
    for digit in scoreDigits:
        totalWidth += IMAGES['numbers'][digit].get_width()
        '''
        遍历 scoreDigits 列表中的每个数字
        将每个数字图像的宽度累加到 totalWidth
        '''

    # 计算绘制起始x坐标
    Xoffset = (SCREENWIDTH - totalWidth) / 2

    # 逐个数字绘制
    for digit in scoreDigits:
        SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, 20))    # y坐标贴近屏幕上边缘
        Xoffset += IMAGES['numbers'][digit].get_width() # 移动绘制x坐标

三、q_game.py

"""
强化学习q learning flappy bird
"""
from game.wrapped_flappy_bird import GameState
import time
import numpy as np 
import skimage.color
import skimage.transform
import skimage.exposure
import tensorflow as tf 
import random 
import argparse

# 命令行参数
parser = argparse.ArgumentParser()#创建一个 ArgumentParser 对象,用于定义命令行参数
parser.add_argument("--model-only", help="加载已有模型,不随机探索,仍旧训练", action='store_true')
args = parser.parse_args()#解析命令行输入的参数,并将它们存储在 args 变量中

# 测试用代码
def _test_save_img(img):
    # 把每一帧图片存储到文件里,调试用
    from PIL import Image
    im = Image.fromarray((img*255).astype(np.uint8), mode='L') # 图片已经被处理为0~1之间的亮度值,所以*255取整数变灰度展示
    im.save('./img.jpg')

# 构建卷积神经网络
def build_model():
    # 卷积神经网络:https://blog.csdn.net/FontThrone/article/details/76652753
    model = tf.keras.models.Sequential([#创建一个 Sequential 模型,它是 tf.keras 中用于线性堆叠网络层的模型类
        tf.keras.layers.Input(shape=(80,80,4)),
        tf.keras.layers.Conv2D(filters=32, kernel_size=(8, 8), padding='same',strides=4, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(4, 4), padding='same',strides=2, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same',strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Flatten(),#将三维的卷积层输出展平为一维,以便传入到全连接层
        tf.keras.layers.Dense(256, activation='relu'),#定义一个具有 256 个单元的全连接层,并使用 ReLU 激活函数
        tf.keras.layers.Dense(2), # 对应2个action未来总回报预期
    ])
    model.compile(loss='mse', optimizer='adam')#编译模型,指定均方误差(MSE)作为损失函数,使用 Adam 优化器

    # 尝试加载之前保存的模型参数
    try:
        model.load_weights('./weights.h5')
        print('加载模型成功...................')
    except:
        pass
    return model

# 创建游戏
game = GameState()
# 卷积模型
model = build_model()

# 执行1帧游戏
def run_one_frame(action):
    global game 
    # image_data:执行动作后的图像(288*512*3的RGB三维数组)
    # reward:本次动作的奖励
    # terminal:游戏是否失败
    img, reward, terminal = game.frame_step(action)
    # RGB转灰度图
    img = skimage.color.rgb2gray(img)
    # 压缩到80*80的图片(根据RGB算出来的亮度,其数值很小)
    img = skimage.transform.resize(img, (80,80))
    # 把亮度标准化到0~1之间,用作模型输入
    img = skimage.exposure.rescale_intensity(img, out_range=(0,1))
    return img,reward,terminal

# 强化学习初始化状态
def reset_stat():
    # 执行第一帧,不点击
    img_t,_,_ =  run_one_frame([1,0])
    '''
    使用 numpy.stack 函数将首帧图像 img_t 重复四次,沿着第三个维度堆叠,
    形成初始状态 stat_t。这是因为卷积神经网络需要连续几帧的图像作为输入
    '''
    stat_t = np.stack([img_t] * 4, axis=2)
    return stat_t 

# 初始状态
stat_t = reset_stat()
# 训练样本
transitions = []#用于存储训练过程中的状态转换样本

# 时刻
t = 0

# 随机探索的概率控制,定义了随机探索概率的初始值、最终值和每次更新的步长。
INIT_EPSILON = 0.1
FINAL_EPSILON = 0.005
EPSLION_DELTA = 1e-6
# 最大留存样本个数
TRANS_CAP =  20000
# 至少有多少样本才训练
TRANS_SIZE_FIT = 10000
# 训练集大小
BATCH_SIZE = 32
# 未来激励折扣
GAMMA = 0.99

# 随机探索概率
if args.model_only: # 不随机探索(极低概率)
    epsilon = FINAL_EPSILON
else:
    epsilon = INIT_EPSILON

# 打印一些进度信息
rand_flap =0    # 随机点击次数
rand_noflap = 0 # 随机不点击次数
model_flap=0    # 模型点击次数
model_noflap=0  # 模型不点击次数
model_train_times = 0   # 模型训练次数

# 游戏启动
while True:    
    # 动作
    action_t = [0,0]

    action_type = '随机'#设置动作类型默认为 '随机',这将在选择动作时用于判断动作是随机选择的还是基于模型经验选择的。

    # 随着学习,降低随机探索的概率,让模型趋于稳定
    if (t <= TRANS_SIZE_FIT and not args.model_only) or random.random() <= epsilon:
        '''判断是否应该进行随机探索。如果在观察期内(t <= TRANS_SIZE_FIT)或者随机数小于或等于 epsilon,则执行随机探索'''
        n = random.random()
        if n <= 0.95:
            action_index = 0
            rand_noflap+=1
        else:
            action_index = 1
            rand_flap+=1
        #print('[随机探索] t时刻进行随机动作探索...')
    else: # 模型预测2个操作的未来累计回报
        action_type = '经验'
        Q_t = model.predict(np.expand_dims(stat_t, axis=0))[0]
        #使用当前的模型和状态 stat_t 来预测两个动作的未来总回报
        action_index = np.argmax(Q_t)   # 回报最大的action下标
        if action_index==0:
            model_noflap+=1
        else:
            model_flap+=1
        #print('[已有经验] 预测t时刻2个动作的未来总回报 -- 不点击:{} 点击:{}'.format(Q_t[0], Q_t[1]))

    action_t[action_index] = 1
    #print('时刻t将执行的动作为{}'.format(action_t))

    # 执行当前动作,返回操作后的图片、本次激励、游戏是否结束
    img_t1, reward, terminal = run_one_frame(action_t)
    _test_save_img(img_t1)
    img_t1 = img_t1.reshape((80,80,1)) # 增加通道维度,因为我们要最近4帧作为4通道图片,用作卷积模型输入
    stat_t1 = np.append(stat_t[:,:,1:], img_t1, axis=2) # 80*80*4,淘汰当前的第0通道,添加最新t1时刻到第3通道

    # 收集训练样本(保留有限的)
    transitions.append({
        'stat_t': stat_t,   # t时刻状态
        'stat_t1': stat_t1, # t1时刻状态
        'reward': reward,   # 本次动作的激励得分
        'terminal': terminal,   # 执行动作后游戏是否结束(ps: 结束意味着没有未来激励了)
        'action_index': action_index,   # 执行了什么动作(0:不点击,1:点击)
    })
    if len(transitions) > TRANS_CAP:
        transitions.pop(0)
    
    # 游戏结束则重置stat_t
    if terminal:
        stat_t = reset_stat()
        #print('死了!!!!!!! 状态t重置为初始帧...')
    else:   # 否则切为新的状态
        stat_t = stat_t1
        #print('没死~~~ 状态t切换为状态t1...')

    # 过了观察期,开始训练
    if t >= TRANS_SIZE_FIT and t % 10 == 0:
        minibatch = random.sample(transitions, BATCH_SIZE)
        # 模型训练的输入:t时刻的状态(最近4帧图片)
        inputs_t = np.concatenate([tran['stat_t'].reshape((1,80,80,4)) for tran in minibatch])
        #print('inputs_t shape', inputs_t.shape)
        ######################################################
        # 模型训练的输出:t时刻的未来总激励(Q_t = reward+gamma*Q_t1)
        # 1,让模型预测t时刻2种action的未来总激励
        Q_t = model.predict(inputs_t, batch_size=len(minibatch))
        # 2,让模型预测t1时刻2种action的未来总激励
        input_t1 = np.concatenate([tran['stat_t1'].reshape((1,80,80,4)) for tran in minibatch])
        Q_t1 = model.predict(input_t1, batch_size=len(minibatch))
        # 3,保留t1时刻2个action中最大的未来总激励
        Q_t1_max = [max(q) for q in Q_t1]
        # 4,t时刻进行action_index动作得到真实激励
        reward_t = [tran['reward'] for tran in minibatch]
        # 5,t时刻进行了什么action
        action_index_t = [tran['action_index'] for tran in minibatch]
        # 6,t1时刻是否死掉了
        terminal = [tran['terminal'] for tran in minibatch]
        # 7,修正训练的目标Q_t=reward+gamma*Q_t1
        # (t时刻action_index的未来总激励=action_index真实激励+t1时刻预测的最大未来总激励)
        for i in range(len(minibatch)):
            if terminal[i]:
                Q_t[i][action_index_t[i]] = reward_t[i] # 因为t1时刻已经死了,所以没有t1之后的累计激励
            else:
                Q_t[i][action_index_t[i]] = reward_t[i] + GAMMA*Q_t1_max[i] # Q_t=reward+Q_t1
        # print('Q_t shape', Q_t.shape)
        # 训练一波
        #print(inputs_t)
        #print(Q_t)
        model.fit(inputs_t, Q_t, batch_size=len(minibatch))
        model_train_times += 1
        # 训练1次则降低些许的随机探索概率
        if epsilon > FINAL_EPSILON:
            epsilon -= EPSLION_DELTA
        
        # 每5000次batch保存一次模型权重(不适用saved_model,后续加载只会加载权重,模型结构还是程序构造,因为这样可以保持keras model的api)
        if model_train_times % 5000 == 0:
            model.save_weights('./weights.h5')

        ######################################################
    if t % 100 == 0:
        print('总帧数:{} 剩余探索概率:{}% 累计训练次数:{} 累计随机点:{} 累计随机不点:{} 累计模型点:{} 累计模型不点:{} 训练集:{} '.format(
            t, round(epsilon * 100, 4), model_train_times, rand_flap, rand_noflap, model_flap, model_noflap,
            len(transitions)))
    t = t + 1
    #time.sleep(1)

四、text_game.py

"""
演示pygame制作的flappy bird如何逐帧调用执行
"""
from game.wrapped_flappy_bird import GameState
from random import random
import time

# 创建游戏
game = GameState()

# 游戏启动
while True:
    r = random()
    if r <= 0.92:  # 92%的概率不点击屏幕
        game.frame_step([1,0]) # 动作:[1,0] 表示不点击
    else: # 8%的概率点击屏幕
        game.frame_step([0,1]) # 动作:[0,1] 表示点击

五、训练结果

请添加图片描述

代码源自:强化学习Deep Q-Network自动玩flappy bird | 鱼儿的博客 (yuerblog.cc)

仅想具体看一下工作原理和代码,仅供学习使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/596586.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【iOS】NSOperation、NSOperationQueue

文章目录 前言一、NSOperation、NSOperationQueue 简介二、NSOperation、NSOperationQueue 操作和操作队列三、NSOperation四、NSOperationQueue五、NSOperationQueue 控制串行执行、并发执行六、 NSOperation 操作依赖七、NSOperation 优先级八、NSOperation、NSOperationQueu…

安卓应用开发(一):工具与环境

开发工具 Android Studio&#xff0c;用于开发 Android 应用的官方集成开发环境 (IDE)。包括以下功能&#xff1a; 基于Gradle的构建系统 gradle是一个项目构建工具&#xff0c;将源工程打包构建为apk 安卓模拟器统一环境代码编辑模拟器实时更新Github集成Lint功能&#xff0…

2023年乡镇街道边界数据、行政村边界、省市县区划边界、建筑轮廓边界数据、流域边界数据、降雨量分布、气温分布、道路网分布

数据范围&#xff1a;全国行政区划-行政乡镇街道边界 数据类型&#xff1a;面状数据&#xff0c;全国各省市县【乡镇-边界】乡村界、乡村范围 数据属性&#xff1a;标准12位行政区划编码、乡镇名称、所属地区 分辨率&#xff1a;1:2万--1:5万 数据格式&#xff1a;SHP数据&…

一、Vagrant搭建相关环境

目录 一、创建Vagrant相关环境1.下载安装VirtualBox2.在BlOS中设置CPU虚拟化3.使用Vagrant新建linux虚拟机3.1下载Vagrant3.2Vagrant官方镜像仓库3.3使用Vagrant初始化一个centos7的虚拟机 4.设置固定ip地址 二、安装docker1.按照docker 三、docker安装一些中间件1.mysql安装2.…

C++例题:大数运算---字符串相加(使用数字字符串来模拟竖式计算)

1.代码速览 class Solution2 { public:string addStrings(string num1, string num2){//end1和end1是下标int end1 num1.size() - 1;int end2 num2.size() - 1;string str;//下标(指针)从后向前走,走到头才可以结束,所以是end>0int next 0;while (end1 > 0 || end2 &…

C#连接S7-200 smart通讯测试

honeytree 一、编程环境 VS2022软件&#xff0c;选择windows窗体应用&#xff08;.NET FrameWork&#xff09;&#xff1a;​博途TIA/WINCC社区VX群 ​博途TIA/WINCC社区VX群 添加NuGet程序包&#xff1b;S7netplus 二、引用http://S7.net 三、建立PLC链接 S7-200smart和…

嵌入式5-6QT

1> 思维导图 2> 自由发挥应用场景&#xff0c;实现登录界面。 要求&#xff1a;尽量每行代码都有注释。 #include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QWidget(parent) {//设置标题this->setWindowTitle("MYQQ");//设置图标this…

03_Redis

文章目录 Redis介绍安装及使用redis的核心配置数据结构常用命令stringlistsethashzset(sortedset) 内存淘汰策略Redis的Java客户端JedisRedisson Redis 介绍 Redis是一个NoSQL数据库。 NoSQL: not only SQL。表示非关系型数据库&#xff08;不支持SQL标准语法&#xff09;。 …

一个新细节,Go 1.17 将允许切片转换为数组指针!

在 Go 语言中&#xff0c;一个切片&#xff08;slice&#xff09;包含了对其支持数组的引用&#xff0c;无论这个数组是作为一个独立的变量存在于某个地方&#xff0c;还是仅仅是一个为支持分片而分配的匿名数组。 其切片基本结构都如下&#xff1a; // runtime/slice.go typ…

Sentinel流量防卫兵

1、分布式服务遇到的问题 服务可用性问题 服务可用性场景 服务雪崩效应 因服务提供者的不可用导致服务调用者的不可用,并将不可用逐渐放大的过程&#xff0c;就叫服务雪崩效应导致服务不可用的原因&#xff1a; 在服务提供者不可用的时候&#xff0c;会出现大量重试的情况&…

鸿蒙内核源码分析(原子操作篇) | 谁在为原子操作保驾护航

基本概念 在支持多任务的操作系统中&#xff0c;修改一块内存区域的数据需要“读取-修改-写入”三个步骤。然而同一内存区域的数据可能同时被多个任务访问&#xff0c;如果在修改数据的过程中被其他任务打断&#xff0c;就会造成该操作的执行结果无法预知。 使用开关中断的方…

CTF(Web)中关于执行读取文件命令的相关知识与绕过技巧

在我遇到的题目中&#xff0c;想要读取文件必然是要执行cat /flag这个命令&#xff0c;但是题目当然不会这么轻松。让你直接cat出来&#xff0c;必然会有各种各样的滤过条件&#xff0c;你要做的就是尝试各种方法在cat /flag的基础上进行各种操作构建出最终的payload。 下面我…

38-1 防火墙了解

一、防火墙的概念: 防火墙(Firewall),也称防护墙,是由Check Point创立者Gil Shwed于1993年发明并引入国际互联网(US5606668 [A]1993-12-15)。它是一种位于内部网络与外部网络之间的网络安全系统,是一项信息安全的防护系统,依照特定的规则,允许或是限制传输的数据通过。…

《QT实用小工具·五十八》模仿VSCode的可任意拖拽的Tab标签组

1、概述 源码放在文章末尾 该项目实现了模仿VSCode的可任意拖拽的Tab标签组&#xff0c;包含如下功能&#xff1a; 拖拽标签页至新窗口 拖拽标签页合并控件 无限嵌套的横纵分割布局&#xff08;类似Qt Creator的编辑框&#xff09; 获取当前使用的标签组、标签页 自动向上合并…

【CTF Reverse】XCTF GFSJ0490 simple-unpack Writeup(UPX壳+脱壳+反汇编)

simple-unpack 菜鸡拿到了一个被加壳的二进制文件 解法 拉进 exeinfope。 检测到是 UPX 打包的 ELF 文件。 NOT Win EXE - .o - ELF [ 64bit obj. Exe file - CPU : AMD x86-64 - OS/ABI: Linux/GNU ]Detected UPX! packer - http://upx.github.io -> try unpack with &…

Linux第三节--常见的指令介绍集合(持续更新中)

点赞关注不迷路&#xff01;&#xff0c;本节涉及初识Linux第三节&#xff0c;主要为常见的几条指令介绍。 如果文章对你有帮助的话 欢迎 评论&#x1f4ac; 点赞&#x1f44d;&#x1f3fb; 收藏 ✨ 加关注&#x1f440; 期待与你共同进步! Linux下基本指令 1. man指令 Linu…

在uniapp里面使用 mp-html 并且开启 latex 功能

在uniapp里面使用 mp-html 并且开启 latex 功能 默认情况下 mp-html 是不会开启 latex 功能的, 如果需要开启 latex 功能是需要到代码操作拉取代码自行打包的。 这里说一下 mp-html 里面的 latex 功能是由 https://github.com/rojer95/katex-mini 提供的技术实现&#xff0c;…

科技园3d数据可视化展示

我们的园区安防3D可视化报警平台通过高精度三维建模技术&#xff0c;将管理对象场景以立体、直观的方式呈现&#xff0c;实现管理对象的全面可视化。同时&#xff0c;平台与各类业务系统深度对接&#xff0c;实现数据、告警、远程操作的整合展示和联动响应&#xff0c;构建中心…

Java基础教程 - 4 流程控制

更好的阅读体验&#xff1a;点这里 &#xff08; www.doubibiji.com &#xff09; 更好的阅读体验&#xff1a;点这里 &#xff08; www.doubibiji.com &#xff09; 更好的阅读体验&#xff1a;点这里 &#xff08; www.doubibiji.com &#xff09; 4 流程控制 4.1 分支结构…

CSS Web服务器、2D、动画和3D转换

Web服务器 我们自己写的网站只能自己访问浏览&#xff0c;但是如果想让其他人也浏览&#xff0c;可以将它放到服务器上。 什么是Web服务器 服务器(我们也会称之为主机)是提供计算服务的设备&#xff0c;它也是一台计算机。在网络环境下&#xff0c;根据服务器提供的服务类型不…