Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络
![]()
它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为[2,2,2,2],而不是Swin Transformer的[2,2,6,2]。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。
我们来看一下Patch Expanding的代码实现
from einops import rearrange
class PatchExpand(nn.Module):
"""
块状扩充,尺寸翻倍,通道数减半
"""
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
"""
Args:
input_resolution: 解码过程的feature map的宽高
dim: frature map通道数
dim_scale: 通道数扩充的倍数
norm_layer: 通道方向归一化
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
# 通过全连接层来扩大通道数
self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
# 先把通道数翻倍
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# 将各个通道分开,再将所有通道拼成一个feature map
# 增大了feature map的尺寸
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
# 通道翻倍后再除以4,实际相当于通道数减半
x = x.view(B, -1, C // 4)
x = self.norm(x)
return x
在编码器这边基本上跟Swin Transformer是一样的,我们重点来看解码器这边。它是使用BasicLayer_up类来对SwinTransformerBlock和Patch Expanding来进行搭配的。
class BasicLayer_up(nn.Module):
""" A basic Swin Transformer layer for one stage.
一个BasicLayer_up包含偶数个SwinTransformerBlock和一个upsamele层(即Patch Expanding层)
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
"""
Args:
dim: feature map通道数
input_resolution: feature map的宽高
depth: 各个Stage中,Swin Transformer Block的数量
num_heads: 多头注意力各个Stage中的头数
window_size: 窗口自注意力机制的窗口中的patch数
mlp_ratio: 层感知机模块中第一个全连接层输出的通道倍数
qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
qk_scale: 窗口自注意力公式常数
drop: dropout rate,默认为0
attn_drop: 用于自注意力机制中的dropout rate,默认为0
drop_path: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0
norm_layer: 通道方向归一化
upsample: 使用Patch Expanding来升采样
use_checkpoint: 是否使用Pytorch中间数据保存机制
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build SwinTransformerBlock
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
# 用于区分是使用W-MSA还是SW-MSA,0为W-MSA,1为SW-MSA
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
# 当stage=4的时候为None
if upsample is not None:
self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
else:
self.upsample = None
def forward(self, x):
# 通过每一个SwinTransformerBlock
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
# 进行块状扩充(PatchExpanding)上采样
if self.upsample is not None:
x = self.upsample(x)
return x
SwinTransformerBlock跟SwinTransformer中的代码也是一样的,这里就不重复了。
然后还有一个从编码器到解码器之间的跳连。这里需要看一下Swin-Unet的主类代码
class SwinTransformerSys(nn.Module):
""" Swin-UNet网络模型
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, final_upsample="expand_first", **kwargs):
"""
Args:
img_size: 原始图像尺寸
patch_size: 一个patch中的像素点数
in_chans: 进入网络的图片通道数
num_classes: 分类数量
embed_dim: feature map通道数
depths: 编码器各个Stage中,Swin Transformer Block的数量
depths_decoder: 解码器各个Stage中,Swin Transformer Block的数量
num_heads: 多头注意力各个Stage中的头数
window_size: 窗口自注意力机制的窗口中的patch数
mlp_ratio: 多层感知机模块中第一个全连接层输出的通道倍数
qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
qk_scale: 自注意力公式中的常量
drop_rate: dropout rate,默认为0
attn_drop_rate: 用于自注意力机制中的dropout rate,默认为0
drop_path_rate: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1
norm_layer: 通道方向归一化
ape: 是否进行绝对位置嵌入,默认False
patch_norm: 如果是True的话,在patch embedding之后加上归一化
use_checkpoint: 是否使用Pytorch中间数据保存机制
final_upsample: 解码器stage4后的Patch Expanding
**kwargs:
"""
super().__init__()
print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
depths_decoder, drop_path_rate, num_classes))
self.num_classes = num_classes
# stage的数量
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
# 编码器stage4输出特征的通道数(Swin-Tiny:768)
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
# 解码器stage4输出特征的通道数(192)
self.num_features_up = int(embed_dim * 2)
self.mlp_ratio = mlp_ratio
self.final_upsample = final_upsample
# 把图像分割成不重叠的patch
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
# 获取feature map的高宽
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# 绝对位置嵌入
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# 不同的stage,舍弃整个直连分支的概率不同,从小到大,最小为0,最大为0.1
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# 创建编码器layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers): # layer相当于stage
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
# 只有前3个stage有patchmerging,最后一个没有
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
# 创建解码器layers
self.layers_up = nn.ModuleList()
self.concat_back_dim = nn.ModuleList()
for i_layer in range(self.num_layers): # layer相当于stage
# 每一个stage结束后,通道数减半的全连接层
concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)),
int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
if i_layer == 0: # 第一个stage只进行上采样
layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
else:
layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
depth=depths[(self.num_layers-1-i_layer)],
num_heads=num_heads[(self.num_layers-1-i_layer)],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])],
norm_layer=norm_layer,
# 只有前3个stage有PatchExpand,最后一个没有
upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers_up.append(layer_up)
self.concat_back_dim.append(concat_linear)
self.norm = norm_layer(self.num_features)
self.norm_up = norm_layer(self.embed_dim)
# 解码器最后一个stage进行FinalPatchExpand处理
if self.final_upsample == "expand_first":
print("---final upsample expand_first---")
self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim)
self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
self.apply(self._init_weights)
这里有一个FinalPatchExpand_X4的方法,我们来看一下它的实现
class FinalPatchExpand_X4(nn.Module):
"""
stage4之后的PatchExpand
尺寸翻倍,通道数不变
"""
def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
"""
Args:
input_resolution: feature map的宽高
dim: feature map通道数
dim_scale: 通道数扩充的倍数
norm_layer: 通道方向归一化
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.dim_scale = dim_scale
# 通过全连接层来扩大通道数
self.expand = nn.Linear(dim, 16 * dim, bias=False)
self.output_dim = dim
self.norm = norm_layer(self.output_dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
# 先把通道数翻倍
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# 将各个通道分开,再将所有通道拼成一个feature map
# 增大了feature map的尺寸
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
# 把扩大的通道数转成原来的通道数
x = x.view(B, -1, self.output_dim)
x = self.norm(x)
return x
回到SwinTransformerSys代码中
def _init_weights(self, m):
"""
对全连接层或者通道归一化进行权重以及偏置的初始化
"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
#Encoder and Bottleneck
def forward_features(self, x):
"""
编码器过程
"""
# 图像分割
x = self.patch_embed(x)
# 绝对位置嵌入
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
# 跳连点
x_downsample = []
# 通过各个编码过程的stage
for layer in self.layers:
x_downsample.append(x)
x = layer(x)
x = self.norm(x) # B L C
return x, x_downsample
#Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
"""
解码器过程,包含了跳连拼接
"""
# 通过各个解码过程的stage
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
# 拼接编码器的跳连部分再进入Swin Transformer Block
x = torch.cat([x, x_downsample[3-inx]], -1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
x = self.norm_up(x) # B L C
return x
def up_x4(self, x):
"""
完成解码器的最后一个stage后进入
"""
H, W = self.patches_resolution
B, L, C = x.shape
assert L == H * W, "input features has wrong size"
if self.final_upsample == "expand_first":
x = self.up(x)
x = x.view(B, 4 * H, 4 * W, -1)
x = x.permute(0, 3, 1, 2) #B,C,H,W
x = self.output(x)
return x
def forward(self, x):
"""
前向运算
"""
x, x_downsample = self.forward_features(x)
x = self.forward_up_features(x, x_downsample)
x = self.up_x4(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
接下来就是模型训练了,这里我舍弃了原框架的训练代码,使用了之前U-Net类似的代码
import torch
from torch import optim
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import copy
from torch.utils.tensorboard import SummaryWriter
from swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
from utils import DiceLoss
RUN_NAME = 'swinunetv1'
N_CLASSES = 3
INPUT_SIZE = 128
EPOCHS = 21
LEARNING_RATE = 0.01
START_FRAME = 16
DROP_DATE = 0.5
DATA_PATH = '/media/jingzhi/新加卷/'
IMAGE_PATH = 'data_dataset_voc/JPEGImagespng/'
MASK_PATH = 'data_dataset_voc/SegmentationClassPNG-new/'
TEST_IMAGE_PATH = 'test_dataset_voc/JPEGImagespng/'
TEST_MASK_PATH = 'test_dataset_voc/SegmentationClassPNG-new/'
IMAGE_TYPE = '.png'
MASK_TYPE = '.png'
LOG_PATH = './runs'
SAVE_PATH = './'
REAL_HEIGHT = 3000
REAL_WIDTH = 4096
IMG_HEIGHT = 224
IMG_WIDTH = 224
RANDOM_SEED = 42
VALID_RATIO = 0.2
BATCH_SIZE = 32
NUM_WORKERS = 1
CLASSES = {1: 'line'}
class LineDataset(Dataset):
def __init__(self, root_dir=DATA_PATH, transform=None):
self.root_dir = root_dir
listname = []
for imgfile in os.listdir(DATA_PATH + IMAGE_PATH):
list = imgfile.split('.')
l = len(list)
if '.' + list[l - 1] == IMAGE_TYPE:
if l > 2:
filename = list[0] + '.' + list[1]
else:
filename = list[0]
listname.append(filename)
self.ids = listname
if transform is None:
self.transform1 = transforms.Compose(
[transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.ToTensor()])
self.transform2 = transforms.Compose(
[transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()])
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
id = self.ids[index]
image = Image.open(self.root_dir + IMAGE_PATH + id + IMAGE_TYPE)
mask = Image.open(self.root_dir + MASK_PATH + id + MASK_TYPE)
image = self.transform1(image)
mask = self.transform2(mask)
return image, mask
def get_trainloader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS):
train_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
return train_loader
def get_dataloader(dataset, batch_size=BATCH_SIZE, random_seed=RANDOM_SEED,
valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS):
error_msg = "[!] valid_ratio should be in the range [0, 1]."
assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg
n = len(dataset)
n_valid = int(valid_ratio * n)
n_train = n - n_valid
torch.manual_seed(random_seed)
train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid))
#
train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)
return train_loader, valid_loader
def show_dataset(dataset, n_sample=4):
plt.figure(figsize=(30, 15))
for i in range(n_sample):
image, mask = dataset[i]
image = transforms.ToPILImage()(image)
mask = transforms.ToPILImage()(mask)
print(i, image.size, mask.size)
plt.tight_layout()
ax = plt.subplot(n_sample, 1, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(image, cmap="Greys")
plt.imshow(mask, alpha=0.3, cmap="OrRd")
if i == n_sample - 1:
plt.show()
break
class Test_LineDataset(Dataset):
def __init__(self, root_dir=DATA_PATH, transform=None):
self.root_dir = root_dir
listname = []
for imgfile in os.listdir(DATA_PATH + TEST_MASK_PATH):
if '.' + imgfile.split('.')[1] == MASK_TYPE:
filename = imgfile.split('.')[0]
listname.append(filename)
self.ids = listname
if transform is None:
self.transform = transforms.Compose([transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()])
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
id = self.ids[index]
image = Image.open(self.root_dir + TEST_IMAGE_PATH + id + IMAGE_TYPE)
mask = Image.open(self.root_dir + TEST_MASK_PATH + id + MASK_TYPE)
image = self.transform(image)
mask = self.transform(mask)
return image, mask
def get_validloader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS):
valid_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
return valid_loader
def show_test_dataset(dataset, n_sample=2):
plt.figure(figsize=(30, 15))
for i in range(n_sample):
image = dataset[i][0]
image = transforms.ToPILImage()(image)
print(i, image.size)
plt.tight_layout()
ax = plt.subplot(1, n_sample, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(image, cmap="Greys")
if i == n_sample - 1:
plt.show()
break
def labels():
l = {}
for i, label in enumerate(CLASSES):
l[i] = label
return l
def tensor2np(tensor):
tensor = tensor.squeeze().cpu()
return tensor.detach().numpy()
def normtensor(tensor):
tensor = torch.where(tensor < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())
return tensor
def count_params(model):
pytorch_total_params = sum(p.numel() for p in model.parameters())
return pytorch_total_params
def cal_iou(outputs, labels, SMOOTH=1e-6):
with torch.no_grad():
outputs = outputs.squeeze(1).bool()
labels = labels.squeeze(1).bool()
intersection = (outputs & labels).float().sum((1, 2))
union = (outputs | labels).float().sum((1, 2))
iou = (intersection + SMOOTH) / (union + SMOOTH)
return iou
def get_iou_score(outputs, labels):
A = labels.squeeze(1).bool()
pred = torch.where(outputs < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())
B = pred.squeeze(1).bool()
intersection = (A & B).float().sum((1, 2))
union = (A | B).float().sum((1, 2))
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.cpu().detach().numpy()
def train(model, device, trainloader, optimizer, loss_function, dice_function, epoch):
model.train()
# model.is_train = True
running_loss = 0
mask_list, iou = [], []
for i, (input, mask) in enumerate(trainloader):
input, mask = input.to(device), mask.to(device)
predict = model(input)
loss_ce = loss_function(predict, mask)
loss_dice = dice_function(predict, mask, softmax=True)
loss = 0.4 * loss_ce + 0.6 * loss_dice
iou.append(get_iou_score(predict, mask).mean())
running_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if ((i + 1) % 10) == 0:
pred = normtensor(predict[0])
img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])
print(f'Epoch: {epoch} | Item: {i} | Train loss: {loss:.5f}')
mean_iou = np.mean(iou)
total_loss = running_loss / len(trainloader)
writer.add_scalar('training loss', total_loss, epoch)
return total_loss, mean_iou
def test(model, device, testloader, loss_function, dice_function, best_iou, epoch):
model.eval()
# model.is_train = False
running_loss = 0
mask_list, iou = [], []
with torch.no_grad():
for i, (input, mask) in enumerate(testloader):
input, mask = input.to(device), mask.to(device)
predict = model(input)
loss_ce = loss_function(predict, mask)
loss_dice = dice_function(predict, mask, softmax=True)
loss = 0.4 * loss_ce + 0.6 * loss_dice
running_loss += loss.item()
iou.append(get_iou_score(predict, mask).mean())
if ((i + 1) % 1) == 0:
pred = normtensor(predict[0])
img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])
print(f'Epoch: {epoch} | Item: {i} | test loss: {loss:.5f}')
test_loss = running_loss / len(testloader)
mean_iou = np.mean(iou)
writer.add_scalar('val loss', test_loss, epoch)
if mean_iou > best_iou:
try:
torch.save(model.state_dict(), SAVE_PATH + RUN_NAME + '.pth')
except:
print('Can export weights')
return test_loss, mean_iou
def model_pipeline(prev_model=None):
best_model = None
model, criterion1, criterion2, optimizer = make_model(prev_model)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
best_iou = -1
for epoch in range(EPOCHS):
t0 = time.time()
train_loss, train_iou = train(model, device, trainloader, optimizer, criterion1, criterion2, epoch)
t1 = time.time()
print(f'Epoch: {epoch} | Train loss: {train_loss:.5f} | Train IoU: {train_iou:.3f} | Time: '
f'{(t1 - t0):.1f}s')
t0 = time.time()
test_loss, test_iou = test(model, device, validloader, criterion1, criterion2, best_iou, epoch)
t1 = time.time()
print(f'Epoch: {epoch} | Valid loss: {test_loss:.5f} | Valid IoU: {test_iou:.3f} | Time: '
f'{(t1 - t0):.1f}s')
scheduler.step()
if best_iou < test_iou:
best_iou = test_iou
best_model = copy.deepcopy(model)
return best_model
def make_model(prev_model=None):
if prev_model == None:
model = SwinTransformerSys().to(device)
else:
model = prev_model
print("Number of parameter:", count_params(model))
criterion1 = nn.BCEWithLogitsLoss()
criterion2 = DiceLoss(2)
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0001)
return model, criterion1, criterion2, optimizer
def predict(model, test_loader, device):
model.eval()
predicted_masks = []
back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])
with torch.no_grad():
for i, (input, _) in enumerate(test_loader):
input = input.to(device)
predict = model(input)
predict = back_transform(predict)
predict = (predict > 0).type(torch.float)
predicted_masks.append(predict)
predicted_masks = torch.cat(predicted_masks)
return predicted_masks
def show_sample_test_result(test_dataset, predicted_mask, n_samples=60):
plt.rcParams['figure.figsize'] = (30, 15)
back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])
for i in range(n_samples):
sample = predicted_mask[i]
sample = torch.squeeze(sample, dim=0)
sample = transforms.ToPILImage()(sample)
X = test_dataset[i][0]
X = back_transform(X)
X = transforms.ToPILImage()(X)
if (i + 1) % 4 != 0:
index = (i + 1) % 4
else:
index = 4
ax = plt.subplot(2, 2, index)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(X, cmap="Greys")
plt.imshow(sample, alpha=0.7, cmap="winter")
# if i == n_samples - 1:
if i % 3 == 0 and i != 0:
plt.show()
# break
if __name__ == '__main__':
writer = SummaryWriter(LOG_PATH)
dataset = LineDataset(DATA_PATH)
valid_dataset = Test_LineDataset(DATA_PATH)
trainloader = get_trainloader(dataset=dataset)
validloader = get_validloader(dataset=valid_dataset)
# show_dataset(dataset)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = UNet_ResNet()
# model = model.to(device)
# model.load_state_dict(torch.load(SAVE_PATH + RUN_NAME + '.pth'))
# print(device)
model = model_pipeline()
writer.close()
# predict_mask = predict(model, validloader, device)
# show_sample_test_result(valid_dataset, predict_mask)
DiceLoss的代码如下
import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
print(input_tensor.size())
print(temp_prob.size())
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
# target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
return dice, hd95
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0
else:
return 0, 0
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice = image[ind, :, :]
x, y = slice.shape[0], slice.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
outputs = net(input)
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
pred = out
prediction[ind] = pred
else:
input = torch.from_numpy(image).unsqueeze(
0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
img_itk = sitk.GetImageFromArray(image.astype(np.float32))
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
img_itk.SetSpacing((1, 1, z_spacing))
prd_itk.SetSpacing((1, 1, z_spacing))
lab_itk.SetSpacing((1, 1, z_spacing))
sitk.WriteImage(prd_itk, test_save_path + '/' + case + "_pred.nii.gz")
sitk.WriteImage(img_itk, test_save_path + '/' + case + "_img.nii.gz")
sitk.WriteImage(lab_itk, test_save_path + '/' + case + "_gt.nii.gz")
return metric_list