PyTorch随机擦除:提升模型抗遮挡能力

PyTorch随机擦除:提升模型抗遮挡能力

PyTorch中内置的随机擦除(Random Erasing)数据增强通过torchvision.transforms.RandomErasing实现,以下是原理和用法的详细说明:

核心原理

正则化作用:

随机擦除在训练图像上随机遮盖一个矩形区域,模拟遮挡场景,强迫模型学习非主导特征,减轻过拟合。

类似于Dropout(针对神经元),但作用于输入空间(图像像素)。

实现细节:

区域选择 :随机生成一个矩形区域:

面积比例:scale=(min_area, max_area)(默认(0.02, 0.33))

宽高比:ratio=(min_ratio, max_ratio)(默认(0.3, 3.3))

填充内容 :

value:填充值,可以是:

单数字(如0)→ 所有通道用该值填充。

元组(R, G, B) → 每通道独立填充。

字符串'random' → 使用均匀分布的随机值(0255整数或0.01.0浮点)。

PyTorch内置实现

1. 导入与初始化

python

复制代码

from torchvision import transforms

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),

transforms.RandomErasing(

p=0.5, # 应用概率(默认0.5)

scale=(0.02, 0.2), # 遮盖面积比例范围

ratio=(0.3, 3.3), # 宽高比范围

value='random', # 填充值(或指定数字/元组)

inplace=False # 是否原地修改

)

])

2. 关键参数

参数

作用

p

执行概率(默认0.5)

scale

矩形区域面积占比范围(默认(0.02, 0.33))

ratio

矩形宽高比范围(默认(0.3, 3.3))

value

填充值:int/float、元组(R, G, B)或'random'(默认0)

inplace

是否原地操作(默认False)

示例代码

python

复制代码

import torch

from torchvision.transforms import RandomErasing

import matplotlib.pyplot as plt

# 初始化随机擦除(50%概率执行)

eraser = RandomErasing(p=0.5, value="random")

# 模拟输入图像(3通道,224x224)

image = torch.randn(3, 224, 224) # 归一化后的数据

# 应用随机擦除

augmented = eraser(image)

# 可视化

plt.subplot(121)

plt.title("Original")

plt.imshow(image.permute(1, 2, 0).clamp(-1, 1).numpy() * 0.5 + 0.5)

plt.subplot(122)

plt.title("Random Erasing")

plt.imshow(augmented.permute(1, 2, 0).clamp(-1, 1).numpy() * 0.5 + 0.5)

plt.show()

输出效果:

左图:原始图像。

右图:随机出现一个矩形遮盖区域(用噪声填充)。

使用注意事项

放置位置:

必须在ToTensor()和Normalize()之后 ,因为操作对象是张量(shape=[C, H, W])。

如果使用value='random',需确保填充值与图像归一化范围兼容。

填充值选择:

归一化后的图像 :推荐用value=0(相当于均值)或与数据集统计量匹配的值。

未归一化图像 :用value='random'生成噪声更合理。

常见设置:

论文推荐:p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0。

对小物体数据集(如CIFAR):调小scale(如(0.02, 0.1))。

底层算法逻辑

区域生成:

随机选择一个满足scale和ratio的矩形框(尝试10次,失败则跳过)。

计算矩形区域:

area=img_area×random(scalemin,scalemax)\text{area} = \text{img\area} \times \text{random}(\text{scale}\text{min}, \text{scale}\text{max})area=img_area×random(scalemin,scalemax)

aspect_ratio=random(ratiomin,ratiomax)\text{aspect\ratio} = \text{random}(\text{ratio}\text{min}, \text{ratio}\text{max})aspect_ratio=random(ratiomin,ratiomax)

h=area×aspect_ratio,w=area/aspect_ratioh = \sqrt{\text{area} \times \text{aspect\_ratio}}, \quad w = \sqrt{\text{area} / \text{aspect\_ratio}}h=area×aspect_ratio ,w=area/aspect_ratio

覆盖操作:

python

复制代码

image[:, top:top+h, left:left+w] = value # 矩形区域赋值

效果对比(实验数据)

数据集

基线准确率

+随机擦除

提升

CIFAR-10

94.1%

95.6%

+1.5%

ImageNet

75.3%

77.1%

+1.8%

结论:对小/密集物体数据集效果显著(如CIFAR、PASCAL VOC)。

通过这种方式,随机擦除以极小计算成本提升模型鲁棒性,是图像分类任务的实用增强工具。