PyTorch深度学习实战——人群计数

PyTorch深度学习实战——人群计数

    • 0. 前言
    • 1. 人群计数
      • 1.1 基本概念
      • 1.2 CRSNet 架构
    • 2. 使用 CSRNet 实现人群计数
      • 2.1 模型分析
      • 2.2 数据集分析
      • 2.3 模型构建与训练
    • 相关链接

0. 前言

人群计数是指通过图像或视频分析技术,对给定场景中的人群数量进行估计和统计的过程。人群计数在城市交通监控、公共安全、活动管理等领域具有广泛的应用。例如,在城市交通管理中,可以通过人群计数来评估交通拥堵情况;在公共安全中,可以利用人群计数来监测人员密集区域,及时发现异常情况。本节中,将介绍人群计数的基本概念,并基于 CSRNet 构建人群计数模型。

1. 人群计数

1.1 基本概念

人群计数指的是在一个给定场景中,通过计算机视觉和图像分析技术对人群数量进行精确计数,即估计图像中的人数。目前已经出现了很多不同的方法进行人群计数,例如使用深度学习模型来进行人群检测和跟踪,使用人工智能算法来预测人群密度和人数,以及使用传感器网络来获取实时人流信息等:

  • 单目标人群计数:单目标人群计数是指在给定的图像或视频中,对所有的个体进行逐一检测,然后对其进行计数,这种方法需要先使用目标检测算法(如基于深度学习的目标检测算法)检测出每个人的位置,然后进行数量统计
  • 密度估计人群计数:密度估计人群计数是指通过估计人群的密度分布来推断人群数量,该方法通常基于图像中人群的特征,如头部、肩膀等,通过密度估计算法(如核密度估计、高斯过程回归等)对图像中的人群密度进行建模,从而得到人群数量的估计结果

在构建模型来执行人群计数之前,我们首先了解可用的数据集和模型架构。为了训练能够预测图像中人数的模型,使用的数据集图像中,应该包括所有出现在图像中的人物头部的中心位置。输入图像样本中每个人物头部中心位置构成的图像非常稀疏,图像中使用 N 个白色像素表示图像中的 N 个人。为了便于观察,将标注(稀疏)图像转换为表示图像该区域中的人数的热力图:

标注图像
放大图像左上角,便于观察最终输入-输出对:

示例图像
在上图中,当两个人距离较近时,像素强度很高;当一个人远离其他人时,此人对应的像素密度分布更均匀,且像素强度较低。基本上,按照像素值的总和等于图像中存在的人数生成热力图。

1.2 CRSNet 架构

我们已经了解了人群计数模型需要接收的输入图像、表示图像中人物头部中心位置的稀疏图,以及处理后得到真实输出热力图。接下来,我们将利用 CSRNet 预测图像中的人数。CRSNet 模型架构如下:

CRSNet

在模型架构中,先将图像通过标准 VGG-16 主干网络,然后再通过四个额外的卷积层。接下来,可以使用四种不同的卷积配置(在本节中,我们采取第一种配置),最后通过 1 x 1 x 1 卷积层。
模型使用均方误差 (Mean-Square Error, MSE) 函数,并最小化损失值以学习最佳权重,同时使用平均绝对误差 (Mean Absolute Error, MAE) 跟踪实际人群计数。
需要注意的是,该架构的使用空洞卷积 (dilated convolution) 替代了普通卷积。空洞卷积计算方法如下所示:

空洞卷积
在上图中,第一张图中黄色部分表示普通的卷积核,第二张和第三张中黄色部分图表示扩张卷积核(或空洞卷积核),它们在各个像素之间有间隙,这样可以在不增加卷积核参数数量的情况下,增加感受野,从而提高模型的性能。因为模型需要了解给定人附近的人数,以便估计与此人相对应的像素密度,使用扩张卷积核(有 9 个参数)非普通卷积核(有 49 个参数),能够以更少的参数捕获更多信息。

2. 使用 CSRNet 实现人群计数

2.1 模型分析

在本节中,我们使用 PyTorch 实现 CSRNet 模型以执行人群计数。 在实现人群计数模型之前,首先总结模型实现策略,以对模型有全面的了解:

  1. 导入相关库和数据集
  2. 由于本节所用数据集已经将人物头部的中心位置图像转换为基于高斯滤波器密度的分布,因此无需再次进行转换
  3. 使用神经网络映射输入图像和输出高斯密度图像
  4. 定义函数执行空洞卷积
  5. 定义网络模型,并在批数据上训练模型最小化 MSE 损失

2.2 数据集分析

在本节中,我们使用 Shanghaitech with People Density Map 数据集构建人群计数模型,该数据集是一个用于人群计数研究的数据集,主要用于评估和训练人群计数算法。数据集由两个子数据集组成:Part_APart_BPart_A 包含人群密集场景的图像,而 Part_B 包含人群稀疏场景的图像,每个子数据集都有大约 300 张高分辨率的图像。
对于每张图像,数据集提供了人工标注的人群数量以及相应的人群密度热力图,人群密度热力图是一种灰度图,它通过将每个像素点的值设为该位置上的人群密度来表示人群的分布情况。该数据集可用于训练和评估各种人群计数算法,包括基于密度的方法、检测和回归方法等。可以在 Kaggle 下载 Shanghaitech with People Density Map 数据集,并解压缩。

2.3 模型构建与训练

接下来,使用 PyTorch 实现以上模型策略。

(1) 导入相关库并下载数据集:

import h5py
from scipy import io
from glob import glob
import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset
import cv2
import random
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm as c

定义图像 (image_folder)、目标输出 (gt_folder) 和热力图文件夹 (heatmap_folder) 位置:

part_A = glob('archive/shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/*')
image_folder = 'archive/shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/images/'
heatmap_folder = 'archive/shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/ground-truth-h5/'
gt_folder = 'archive/shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/ground-truth/'

(2) 定义训练、验证数据集和数据加载器:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Crowds(Dataset):
    def __init__(self, stems):
        self.stems = stems

    def __len__(self):
        return len(self.stems)

    def __getitem__(self, ix):
        _stem = self.stems[ix]
        image_path = f'{image_folder}/{_stem}.jpg'
        heatmap_path = f'{heatmap_folder}/{_stem}.h5'
        gt_path = f'{gt_folder}/GT_{_stem}.mat'
        pts = io.loadmat(gt_path)
        pts = len(pts['image_info'][0,0][0,0][0])

        image = cv2.imread(image_path, 1)
        h, w, _ = image.shape
        with h5py.File(heatmap_path, 'r') as hf:
            gt = hf['density'][:]
        gt = cv2.resize(gt, (int(w/8), int(h/8)))*64
        # gt = resize(gt, 1/8)*64
        return image.copy(), gt.copy(), pts

    def collate_fn(self, batch):
        ims, gts, pts = list(zip(*batch))
        ims = torch.cat([torch.tensor(im)[None].float() for im in ims]).to(device).permute(0,3,1,2)
        # print(ims.shape)
        gts = torch.cat([torch.tensor(gt)[None].float() for gt in gts]).to(device)# .permute(0,3,2,1)
        return ims, gts, torch.tensor(pts).to(device)

    def choose(self):
        return self[random.randint(len(self))]

def stems(split):
    items_new = [item.split('/')[-1] for item in split]
    items = [item.split('.')[0] for item in items_new]
    return items

from sklearn.model_selection import train_test_split
# print(stems(glob(image_folder)))
trn_stems, val_stems = train_test_split(stems(glob(f'{image_folder}/*.jpg')), random_state=10)

trn_ds = Crowds(trn_stems)
val_ds = Crowds(val_stems)

trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)

调整人群热力图的大小,因为网络的输出尺寸为原始图像尺寸的 1/8,因此我们通过将目标输出图乘以 64,以使图像像素的总和将按比例缩放回原始人群计数。

(3) 定义网络架构.

定义执行空洞卷积的函数 make_layers()

import torch.nn as nn
import torch
from torchvision import models

def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

定义网络架构 CSRNet

class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained = True)
            self._initialize_weights()
            items = list(self.frontend.state_dict().items())
            _items = list(mod.state_dict().items())
            for i in range(len(self.frontend.state_dict().items())):
                items[i][1].data[:] = _items[i][1].data[:]
    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

(4) 定义网络训练和验证函数:

def train_batch(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    ims, gts, pts = data
    _gts = model(ims).squeeze(1)
    # print(gts.shape)
    # print(_gts.shape)
    loss = criterion(_gts, gts)
    loss.backward()
    optimizer.step()
    pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
    return loss.item(), pts_loss.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, gts, pts = data
    _gts = model(ims)
    loss = criterion(_gts, gts)
    pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
    return loss.item(), pts_loss.item()

(5) 训练模型:

model = CSRNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-6)
n_epochs = 40

train_loss_epochs = []
val_loss_epochs = []
for ex in range(n_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for bx, data in enumerate(trn_dl):
        loss, pts_loss = train_batch(model, data, optimizer, criterion)
        pos = (ex + (bx+1)/N)
        trn_loss.append(loss)
    train_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, pts_loss = validate_batch(model, data, criterion)
        pos = (ex + (bx+1)/N)
        val_loss.append(loss)
    val_loss_epochs.append(np.average(val_loss))

绘制模型训练期间训练和验证损失的变化情况(损失值是人群计数的 MAE),如下所示:

epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型检测
(6) 使用训练后的模型对新图像进行推理。

读取测试图像并对其进行预处理:

from matplotlib import cm as c
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
transform=transforms.Compose([
                      transforms.ToTensor(),transforms.Normalize(
                          mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
                  ])

test_folder = 'archive/shanghaitech_with_people_density_map/ShanghaiTech/part_A/test_data/'
imgs = glob(f'{test_folder}/images/*.jpg')

f = random.choice(imgs)
print(f)
img = transform(Image.open(f).convert('RGB')).to(device)

使用训练好的模型处理图像:

output = model(img[None])
print("Predicted Count : ",int(output.detach().cpu().sum().numpy()))
temp = np.asarray(output.detach().cpu().reshape(output.detach().cpu().shape[2],output.detach().cpu().shape[3]))
plt.imshow(temp, cmap = c.jet)
plt.show()

预测结果

从以上输出结果,可以看出模型能够得到合理准确的热力图,预测人数接近实际值。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/235418.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【MATLAB】REMD信号分解+FFT+HHT组合算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 TVFEMDFFTHHT组合算法是一种结合了总体变分模态分解(TVFEMD)、傅里叶变换(FFT)和希尔伯特-黄变换(HHT)的信号分解方…

多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测

多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测 目录 多维时序 | MATLAB实现RIME-CNN-LSTM-Multihead-Attention多头注意力机制多变量时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MATLAB实现RIME-CNN-…

微信小程序:用map()将对象数组中的某一项组合成新数组

使用分析 使用map()方法来遍历 info 数组中的每个元素,并整合每一个对象中的某一项进行新数组的重组 效果展示 这里是查询对象数组中的全部name值 原始数据 提取出name的数组 核心代码 var infos items.map(item > item.name); 完整代码(用微信小程…

基于hadoop下的spark安装

目录 简介 安装准备 spark安装 配置文件配置 简介 Spark主要⽤于⼤数据的并⾏计算,⽽Hadoop在企业主要⽤于⼤数据的存储(⽐如HDFS、Hive和HBase 等),以及资源调度(Yarn)。但是也有很多公司也在使⽤MR2进…

Web server failed to start. Port 8888 was already in use.

端口占用 强制终止占用端口的进程 获取占用端口的进程ID(PID):在终端或命令提示符中运行以下命令以查找占用端口的进程ID: ①在 Unix/Linux/Mac 上:lsof -i :8888 ②在 Windows 上:netstat -ano | findstr …

HTML面试题---专题二

文章目录 一、前言二、解释input标签中占位符属性的用途三、如何在 HTML 中设置复选框或单选按钮的默认选中状态?四、表单输入字段中必填属性的用途是什么?五、如何使用 HTML 创建表格?六、解释a标签中目标属性的用途七、如何创建一个点击后会…

Java飞翔的小鸟

一、项目分析 创建一个窗口和画板,把画板放到窗口上,在画板上绘画图片 (2)让小鸟在画面中动起来,可以上下飞 (3)让地面和管道动起来 (4)碰撞检测 (5&#xf…

Nginx 优化与防盗链

目录 配置Nginx隐藏版本号 Nginx隐藏版本号的方法 修改配置文件法 修改源码法 修改用户与组 设置缓存时间 日志切割 连接超时 更改进程数 配置网页压缩 配置防盗链 fpm参数优化 总结:nginx优化 配置Nginx隐藏版本号 可以使用 Fiddler 工具抓取数据包&…

【Citespace】从Citespace开始的引文可视化分析

CiteSpace 译“引文空间”,是一款着眼于分析科学分析中蕴含的潜在知识,是在科学计量学、数据可视化背景下逐渐发展起来的引文可视化分析软件。由于是通过可视化的手段来呈现科学知识的结构、规律和分布情况,因此也将通过此类方法分析得到的可…

巧用ChatGPT高效搞定Excel数据分析【文末送书-04】

文章目录 一.巧用ChatGPT高效搞定Excel数据分析1. ChatGPT简介2. 安装所需工具2.1 Python2.2 OpenAI GPT库 3. 与ChatGPT交互进行数据分析4. 利用ChatGPT进行筛选和排序5. ChatGPT的局限性和注意事项6. ChatGPT与数据可视化7. ChatGPT与进阶数据分析任务 二. 结论&文末福利…

Windows安装Maven

一、Maven 是什么? Maven 是一个项目管理和整合工具。Maven 为开发者提供了一套完整的构建生命周期框架。开发团队几乎不用花多少时间就能够自动完成工程的基础构建配置,因为 Maven 使用了一个标准的目录结构和一个默认的构建生命周期。 在有多个开发团…

软件开发安全指南

2.1.应用系统架构安全设计要求 2.2.应用系统软件功能安全设计要求 2.3.应用系统存储安全设计要求 2.4.应用系统通讯安全设计要求 2.5.应用系统数据库安全设计要求 2.6.应用系统数据安全设计要求 软件开发全资料获取:点我获取

用Java实现根据数据库中的数量,生成年月份+序号递增

在日常开发中,经常会遇到根据年月日和第几号文件生成对应的编号,今天给大家提供一个简单的工具类 public static final Long CODE1L;/*** param select 数据库中数据总数* return*/public static String SubjectNo(Long select){// 在总数的基础上1&…

c2-C语言--指针

1.用一级指针遍历一维数组 结论 buf[i]<>*(buf i) <> *(p i)<> p[i] #include <stdio.h>int main(){int buf[5] {10,20 ,30 ,40,50}; //buf[0] --- int // buf --&buf[0] ----int *int *p buf;//&buf[0] --- &*(buf0)printf(&quo…

统一存储、全闪阵列、分布式NAS,企业级存储概述

Infortrend普安科技即将迎来公司成立30周年华诞。Infortrend普安科技从无到有&#xff0c;由小做强&#xff0c;为全球用户提供高性能、高可靠、高扩展、环保节能的存储解决方案&#xff0c;在存储领域造就了一段品牌佳话。从1993年成立伊始&#xff0c;Infortrend一直致力于企…

云服务器哪家便宜?亚马逊云科技按需选实例够便宜

随着云计算的迅猛发展&#xff0c;越来越多的企业和个人开始关注云服务器的选择。在众多云服务提供商中&#xff0c;亚马逊云科技&#xff08;Amazon Web Services&#xff0c;AWS&#xff09;凭借其强大的基础设施和丰富的服务&#xff0c;备受业界青睐。本文聚焦一个备受关注…

【lesson3】数据库表的操作

文章目录 创建修改修改表名增加表类型修改表的某一类型的类型修改表某一类型的类型名 删除删除表的某一列删除表 查看查看表信息查看表内容 创建 建表指令&#xff1a; 查看是否建表成功&#xff1a; 查看表的具体信息&#xff1a; 修改 修改表名 法一&#xff1a;修改…

基础宠物商店管理系统(Java)大一程序设计

一.开发环境 Windows 11 -- JDK 21 -- IDEA 2021.3.3 二.需求 三.代码部分 //创建一个宠物类&#xff0c;被另外两类继承public class Pet {private String name;private int age;private String gender;private double cost0;//买进价格private double sellprice0;//卖出价…

hdlbits系列verilog解答(mt2015_q4a)-52

文章目录 一、问题描述二、verilog源码三、仿真结果 一、问题描述 本次我们实现一个简单的组合逻辑输出。 z (x^y) & x 模块声明&#xff1a; module top_module (input x, input y, output z); 二、verilog源码 module top_module (input x, input y, output z);assig…

CRM系统是怎样帮助团队处理业务的?

客户关系管理的核心思想是将企业的客户作为最重要的资源&#xff0c;提供优质的客户服务&#xff0c;满足客户的需求&#xff0c;保证实现客户的终生价值&#xff0c;这也是众多企业使用CRM系统的原因。那么&#xff0c;CRM如何帮助中小企业解决业务与团队之间的问题&#xff1…
最新文章