BatchNorm可以加速模型的收敛并且缓解梯度消失问题,是深度学习领域常用的一个技术
最近仔细学习了BatchNorm的原理,因此想自己动手实现一下它,加深理解
代码如下:
import torch
import torch.nn as nn
class MyBatchNorm(nn.Module):
# def __init__(self, batch_norm2, dim):
def __init__(self, dim):
super().__init__()
# 可训练参数 gamma和beta
self.gamma = nn.Parameter(data=torch.randn((dim)))
self.beta = nn.Parameter(data=torch.randn((dim)))
# 全局的均值和方差
self.mean_whole = torch.zeros((dim))
self.var_whole = torch.zeros((dim))
self.lba = 0.99
# 防止除零错误
self.eps = 1e-7
def forward(self, x):
# 检查形状
if x.dim() == 4:
x = x.reshape(x.size(0), x.size(1), -1)
assert x.dim() == 3
# 处于训练状态
if self.training:
# 首先计算每个通道的均值和方差
# (b, c, d) -> (1, c, 1)
mean_batch = torch.mean(x, dim=[0, 2], keepdim=True)
var_batch = torch.var(x, dim=[0, 2], keepdim=True, unbiased=False)
# 使用滑动平均办法计算全局均值和方差
n = x.numel() / x.size(1)
self.mean_whole = self.lba * self.mean_whole + (1 - self.lba) * mean_batch
self.var_whole = self.lba * self.var_whole + (1 - self.lba) * var_batch * n / (n-1)
# 然后归一化数据
x = (x - mean_batch) / torch.sqrt((var_batch + self.eps))
else:
# 归一化数据
x = (x - self.mean_whole[None, ..., None]) / torch.sqrt((self.var_whole[None, ..., None] + self.eps))
# 放缩平移
x = x * self.gamma[None, ..., None] + self.beta[None, ..., None]
return x
x = torch.randn((2, 3, 4))
batch_norm = MyBatchNorm(dim=3)
batch_norm = batch_norm.train()
b = batch_norm(x)
print(b.shape)
参考资料:
1. 原理
https://zhuanlan.zhihu.com/p/34879333
2. 代码
https://zhuanlan.zhihu.com/p/337732517