语义分割

感谢UP主霹雳吧啦Wz个人主页-哔哩哔哩视频制作并分享的优质学习资源

1. 什么是语义分割

按照类别将图片上的像素点进行分类,大汪、二汪、三汪都属于“狗”类别,大喵、二喵、三喵都属于“猫”类别

其他的分割任务

  • 实例分割:大汪、二汪、三汪属于三种不同的“狗”实例
  • 全景分割:语义分割+实例分割+背景划分

2. 数据集格式介绍

2.1 PASCAL VOC格式

主要参考:PASCAL VOC2012数据集介绍_pascal voc 2012-CSDN博客

数据集镜像下载地址:Index of /dataset-mirror/voc

在Pascal VOC数据集中主要包含20个目标类别,下图展示了所有类别的名称以及所属超类。

img

下载后将文件进行解压,解压后的文件目录结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
VOCdevkit
└── VOC2012
├── Annotations 所有的图像标注信息(XML文件)
├── ImageSets
│ ├── Action 人的行为动作图像信息
│ ├── Layout 人的各个部位图像信息
│ │
│ ├── Main 目标检测分类图像信息
│ │ ├── train.txt 训练集(5717)
│ │ ├── val.txt 验证集(5823)
│ │ └── trainval.txt 训练集+验证集(11540)
│ │
│ └── Segmentation 目标分割图像信息
│ ├── train.txt 训练集(1464)
│ ├── val.txt 验证集(1449)
│ └── trainval.txt 训练集+验证集(2913)

├── JPEGImages 所有图像文件
├── SegmentationClass 语义分割png图(基于类别)
└── SegmentationObject 实例分割png图(基于目标)

关于图像分割我们重点关注:

1
2
3
4
└── Segmentation          目标分割图像信息
├── train.txt 训练集(1464)
├── val.txt 验证集(1449)
└── trainval.txt 训练集+验证集(2913)
  • 首先在Segmentarion文件夹中,读取对应的txt文件。比如使用train.txt中的数据进行训练,那么读取该txt文件,解析每一行,每一行对应一个图像的索引。

  • 根据索引在JPEGImages 文件夹中找到对应的图片。还是以2007_000323为例,可以找到2007_000323.jpg文件。

2007_000323.jpg

  • 根据索引在SegmentationClass文件夹中找到相应的标注图像(.png)。还是以2007_000323为例,可以找到2007_000323.png文件。

20017_000323.png

注意,在语义分割中对应的标注图像(.png)用PIL的Image.open()函数读取时,默认是P模式,即一个单通道的图像。在背景处的像素值为0,目标边缘处用的像素值为255(训练时一般会忽略像素值为255的区域),目标区域内根据目标的类别索引信息进行填充,例如人对应的目标索引是15,所以目标区域的像素值用15填充。

至于为什么单通道图像会是彩色的,是因为使用了调色板,将值为15的像素值映射成粉色,255的像素值映射为白色。

2.2 MS COCO格式

MS COCO数据集介绍以及pycocotools简单使用_coco数据集最多一张图有多少个instance-CSDN博客

3. 语义分割评价指标

image-20251107211653856

  • PA:所有正确像素的个数 / 所有像素
  • mean Accuracy:每个类别的正确率求和,再取平均
  • 常用mean IoU:image-20251107212101261如图所示:3 / (1+2+3)

矩阵方式查看

image-20251107212259175

image-20251107212407658

mean Accuracy:求和后取平均

mean IoU:求和后取平均

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
from torchvision.datasets import VOCSegmentation
import matplotlib.pyplot as plt
from PIL import Image
import os
import os.path as osp

# --- 1. 配置和常量 ---
# 设置计算设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 类别数量
NUM_CLASSES = 21
# 批次大小
BATCH_SIZE = 4
# 学习率
LEARNING_RATE = 1e-4
# 轮次数
NUM_EPOCHS = 10
# 图像缩放尺寸
IMAGE_SIZE = 256
# 根目录,VOCdevkit 在其之下
DATASET_ROOT = './DataSet'

def multiclass_dice_score(preds, targets, num_classes=NUM_CLASSES, ignore_index=255):
"""计算多类别 Dice 系数 (Mean Dice)。"""
preds = torch.argmax(preds, dim=1)

total_dice = 0.0
valid_classes = 0

# 遍历每个前景类别
for i in range(1, num_classes):
pred_i = (preds == i).float()
target_i = (targets == i).float()

# 排除忽略索引 255
valid_mask = (targets != ignore_index).float()
pred_i = pred_i * valid_mask
target_i = target_i * valid_mask

# Dice 计算: 2 * Intersection / (Prediction + Target)
intersection = (pred_i * target_i).sum()
union = pred_i.sum() + target_i.sum()

if union > 0:
dice = (2. * intersection + 1e-6) / (union + 1e-6)
total_dice += dice.item()
valid_classes += 1

return total_dice / valid_classes if valid_classes > 0 else 0.0

def multiclass_iou_score(preds, targets, num_classes=NUM_CLASSES, ignore_index=255):
"""计算多类别 IoU (Mean IoU)。"""
preds = torch.argmax(preds, dim=1)

total_iou = 0.0
valid_classes = 0

# 遍历每个前景类别
for i in range(1, num_classes):
pred_i = (preds == i).float()
target_i = (targets == i).float()

# 排除忽略索引 255
valid_mask = (targets != ignore_index).float()
pred_i = pred_i * valid_mask
target_i = target_i * valid_mask

# IoU 计算: Intersection / Union
intersection = (pred_i * target_i).sum()
# Union = Prediction + Target - Intersection (TP + FP + FN)
union = pred_i.sum() + target_i.sum() - intersection

if union > 0:
iou = intersection / (union + 1e-6)
total_iou += iou.item()
valid_classes += 1

return total_iou / valid_classes if valid_classes > 0 else 0.0

#-----------------------------------------------------------------#
# 自定义 VOC 数据转换类,确保输入图像和真实标签图像使用不同的插值模式
# transforms.Compose的作用详见:Pytorch深度学习实践/8.3
# 对于平滑的图像,比如输入图像,Resize使用双线性插值(BILINEAR)
# 对于离散的图像,比如标签图像,Resize使用最近邻插值(NEAREST)
#------------------------------------------------------------------#
class VOCDataTransform:
def __init__(self, size):
self.size = size
# 图像转换:Resize使用双线性插值(BILINEAR),连续数据平滑处理,然后转为Tensor并标准化
self.image_transform = transforms.Compose([
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
# 将维度从(H,W,C)改变为(C,H,W)
transforms.ToTensor(),
# 使用 ImageNet 的常用均值和标准差进行标准化
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 标签转换:Resize使用最近邻插值(NEAREST),离散数据保护类别完整完整性,然后转为Tensor并标准化
self.target_transform = transforms.Compose([
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.NEAREST),
# 将维度从(H,W,C)改变为(C,H,W)
transforms.ToTensor(),
# 将 ToTensor() 后的浮点数 * 255 恢复到原始索引范围,避免只显示轮廓,不填充颜色
transforms.Lambda(lambda t: (t * 255).long())
])

def __call__(self, image, target):
image = self.image_transform(image)
target = self.target_transform(target)
# 调整标签形状为 [H, W],并将数据类型转换为 Long,这是 CrossEntropyLoss 所需要的
target = target.squeeze(0).long()
# VOC 的掩码标签中,255 表示边界/忽略区域。
return image, target


# 实例化转换
transform = VOCDataTransform(IMAGE_SIZE)

# 加载 VOC 2012 训练集
train_dataset = VOCSegmentation(
root = DATASET_ROOT, # VOC 数据集的存放位置
year = '2012',
image_set = 'train',
download = False,
transforms = transform
)

# 加载 VOC 2012 验证集
val_dataset = VOCSegmentation(
root = DATASET_ROOT,
year = '2012',
image_set = 'val',
download = False,
transforms = transform
)

train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0, # 使用多进程加载数据
pin_memory=True
)

val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0,
pin_memory=True
)

class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=21):
super(UNet, self).__init__()

def conv_block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)

# 左侧编码器
self.encoder1 = conv_block(in_channels, 64)
self.encoder2 = conv_block(64, 128)
self.encoder3 = conv_block(128, 256)
self.encoder4 = conv_block(256, 512)

self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 底部过渡层
self.bottleneck = conv_block(512, 1024)

# 右侧解码器,其中上采样采用转置卷积,尺寸和通道都减半
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = conv_block(1024, 512) # in_c = 1024是因为拼接了,后续同理

self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = conv_block(512, 256)

self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = conv_block(256, 128)

self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = conv_block(128, 64)

# 最终输出层,使用1x1卷积核
self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

def forward(self, x):
# 编码
e1 = self.encoder1(x)
pool1 = self.pool(e1)

e2 = self.encoder2(pool1)
pool2 = self.pool(e2)

e3 = self.encoder3(pool2)
pool3 = self.pool(e3)

e4 = self.encoder4(pool3)
pool4 = self.pool(e4)

# 过渡层
b = self.bottleneck(pool4)

# 解码(上采样 + 拼接 + 卷积)
d4 = self.upconv4(b)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)

d3 = self.upconv3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)

d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)

d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)

# 最终输出
out = self.final_conv(d1)

return out

def train_model(model, loader, optimizer, criterion):
model.train()
running_loss = 0.0
num_samples = 0

for images, targets in loader:
images = images.to(DEVICE)
targets = targets.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
# CrossEntropyLoss: 适用于多类别分割,并设置自动处理忽略边界索引 255
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# 注意:CrossEntropyLoss返回的是当前批次 (Batch) 的平均损失
# 所以乘上批次大小即可得到当前批次的总损失
running_loss += loss.item() * images.size(0)
num_samples += images.size(0)

# 返回样本加权后的平均损失
epoch_loss = running_loss / num_samples if num_samples > 0 else 0.0
return epoch_loss

@torch.no_grad()
def validate_model(model, loader):
model.eval()
total_dice = 0.0
total_iou = 0.0
num_samples = 0

for images, targets in loader:
images = images.to(DEVICE)
targets = targets.to(DEVICE)

outputs = model(images)

# 计算多类别平均 Dice Score
dice = multiclass_dice_score(outputs, targets, num_classes=NUM_CLASSES, ignore_index=255)
total_dice += dice * images.size(0)

iou = multiclass_iou_score(outputs, targets, num_classes=NUM_CLASSES, ignore_index=255)
total_iou += iou * images.size(0)

num_samples += images.size(0)

avg_dice = total_dice / num_samples if num_samples > 0 else 0.0
avg_iou = total_iou / num_samples if num_samples > 0 else 0.0

return avg_dice, avg_iou

# ----------------------------------------------------
# 将灰度图像转变为RBG图像
# 参数1 mask: 包含索引(或者说标签)的二维张量
# 参数2 palette: 调色板数组,数组下标对应一个颜色
# ----------------------------------------------------
def mask_to_rgb(mask, palette):
# 确保传入的是二维张量
if len(mask.shape) > 2:
# 如果不是二维的,一般都是带了通道的,且通道一般是1,因为是灰度图像
# squeeze(),不填参数,删除维度为 1 的所有维度
mask = mask.squeeze()

# 创建一个RBG图像,且与张量尺寸一致
rgb_img = Image.new("RGB", (mask.shape[1], mask.shape[0]))
# 后续修改像素值对应的颜色
pixels = rgb_img.load()
# 转换为 NumPy 数组以便高效迭代
mask_np = mask.astype(np.uint8)
# 双层循环遍历每一个像素
for y in range(mask.shape[0]): # 高度 (H) 或 行索引 (y)
for x in range(mask.shape[1]): # 宽度 (W) 或 列索引 (x)
color_idx = mask_np[y, x]
# 确保索引在调色板范围内,VOC 2012 类别索引范围是 0-20
color_idx = color_idx if 0 <= color_idx < len(palette) else 0
# PIL/Pillow 库在访问像素时,(第一个参数) 对应于宽度/列,(第二个参数) 对应于高度/行
pixels[x, y] = palette[color_idx]
return rgb_img

# 可视化预测
@torch.no_grad()
def visualize_predictions(model, dataset, num_samples=3):
model.eval()

# PASCAL VOC 2012 默认调色板 (21 个类别)
palette = [
(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128),
(128, 0, 128), (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0),
(64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128),
(192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0),
(0, 64, 128)
]

fig, axes = plt.subplots(num_samples, 3, figsize=(10, 3 * num_samples))

for i in range(num_samples):
# 随机选择一个样本
idx = np.random.randint(0, len(dataset))
img, target = dataset[idx]

# 预测掩码
# 给图片添加一个维度0(Batch):[B, C, H, W]
img_tensor = img.unsqueeze(0).to(DEVICE)
# out本来是[B, 21, H, W], squeeze(0)后变为[21, H, W]
output = model(img_tensor).squeeze(0)
# 沿着dim=0,也就是在21个通道中,取得最大值, pred_mask变为[H, W]
pred_mask = torch.argmax(output, dim=0).cpu().numpy()

# 反标准化图像用于显示
img_np = img.cpu().numpy().transpose(1, 2, 0)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_np = std * img_np + mean
img_np = np.clip(img_np, 0, 1)

# 分别将真实的标签和预测的掩码转换为彩色图像。
target_mask_rgb = mask_to_rgb(target.numpy(), palette)
pred_mask_rgb = mask_to_rgb(pred_mask, palette)

# 绘制结果
axes[i, 0].imshow(img_np)
axes[i, 0].set_title("Original Image", fontsize=10)
axes[i, 1].imshow(target_mask_rgb)
axes[i, 1].set_title("Ground Truth", fontsize=10)
axes[i, 2].imshow(pred_mask_rgb)
axes[i, 2].set_title("Prediction", fontsize=10)

for ax in axes[i]: ax.set_xticks([]); ax.set_yticks([])

plt.tight_layout()
plt.show()

if __name__ == "__main__":
# 诊断检查:确保 VOCdevkit/VOC2012 路径存在,避免尝试下载
EXPECTED_PATH = osp.join(DATASET_ROOT, 'VOCdevkit', 'VOC2012')
if osp.exists(EXPECTED_PATH):
print("数据集加载成功。")
else:
print("=" * 60)
print("错误:未找到数据集路径。")
print(f"请确认您的脚本运行目录下存在 '{EXPECTED_PATH}' 文件夹,否则 PyTorch 将尝试下载数据集。")
print("=" * 60)
# 不使用 sys.exit,允许代码继续运行,但用户应知晓问题
model = UNet(in_channels = 3, num_classes = NUM_CLASSES).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
# ----------------------------------------------------
# 应用损失权重来解决背景陷阱(Background Trap)
# ----------------------------------------------------
# 权重数组:21个类别。
weights = torch.ones(NUM_CLASSES)
# 类别 0 (背景) 保持默认权重 1.0。
# 类别 1-20 (前景) 权重提高到 5.0,强制模型更关注前景错误。
weights[1:] = 5.0
weights = weights.to(DEVICE)
# 在 CrossEntropyLoss 中使用加权数组
criterion = nn.CrossEntropyLoss(weight = weights, ignore_index = 255)

print("\n--- Start Training U-Net (VOC 2012) ---")
print(f"使用的设备: {DEVICE}")

best_dice = 0.0

for epoch in range(1, NUM_EPOCHS + 1):
train_loss = train_model(model, train_loader, optimizer, criterion)
val_dice, val_iou = validate_model(model, val_loader)
print(f"Epoch {epoch}/{NUM_EPOCHS} "
f"| Train Loss: {train_loss:.4f} "
f"| Val Avg Dice: {val_dice:.4f} "
f"| Val Avg IoU: {val_iou:.4f}")

if val_dice > best_dice:
print("*" * 60)
print("Best model updated:")
best_dice = val_dice
if not os.path.exists("./models"):
os.makedirs("./models")
torch.save(model.state_dict(), "./models/best_unet_voc2012.pth")
print("Model saved to: ./models/best_unet_voc2012.pth")
print("*" * 60)
print("\n--- Training Complete, Starting Visualization ---")

try:
model.load_state_dict(torch.load("./models/best_unet_voc2012.pth"))
visualize_predictions(model, val_dataset, num_samples=3)
except FileNotFoundError:
print("Warning: Saved model file not found. Visualization skipped.")

语义分割
https://blog.gutaicheng.top/2025/11/07/语义分割/
作者
GuTaicheng
发布于
2025年11月7日
许可协议