Java教程

unet 网络

本文主要是介绍unet 网络,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

问题1:'Keyword argument not understood:', 'input

删去input=,output=

论文的几个创新点:数据增强,

论文用的一些方法:数据增强,数据的二值化

神经网络的训练:

遇到的一些问题:1、读取数据时一开始用libtiff读取tif格式图片一直导入不成功,后来更换成cv.imgread

2.tensorflow 调用adam函数时出现错误,通过查阅资料得知是Adam.keras版本需要匹配。

网络的版本号

主要结构:一个语义分割模型,encoder-decoder结构,u字形(论文中的输入大小是512*512,但这副图里给的是572*572,图片数据经过处理

特点:U型结构和skip-connection

/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果

UNet的encoder下采样4次,一共下采样16倍,对称地,其decoder也相应上采样4次,将encoder得到的高级语义特征图恢复到原图片的分辨率。

Skip connection:打破了网络的对称性,提升了网络的表征能力,关于对称性引发的特征退化问题  残差连接(skip connect)/(residual connections)_赵凯月的博客-CSDN博客

医疗影像有什么样的特点:图像语义较为简单、结构较为固定。我们做脑的,就用脑CT和脑MRI,做胸片的只用胸片CT,做眼底的只用眼底OCT,都是一个固定的器官的成像,而不是全身的。由于器官本身结构固定和语义信息没有特别丰富,所以高级语义信息和低级特征都显得很重要(UNet的skip connection(残差连接)和U型结构就派上了用场)。举两个例子直观感受下。

U-net 基于pytorch的代码:

import torch

import torch.nn as nn

import torch.nn.functional as F

class double_conv2d_bn(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):

初始化网络

        super(double_conv2d_bn, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels,

                               kernel_size=kernel_size,

                               stride=strides, padding=padding, bias=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,

                               kernel_size=kernel_size,

                               stride=strides, padding=padding, bias=True)

        self.bn1 = nn.BatchNorm2d(out_channels)

        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):

        out = F.relu(self.bn1(self.conv1(x)))

        out = F.relu(self.bn2(self.conv2(out)))

        return out

class deconv2d_bn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):

        super(deconv2d_bn, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,

                                        kernel_size=kernel_size,

                                        stride=strides, bias=True)

        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, x):

        out = F.relu(self.bn1(self.conv1(x)))

        return out

class Unet(nn.Module):

    def __init__(self):

        super(Unet, self).__init__()

        self.layer1_conv = double_conv2d_bn(1, 8)

        self.layer2_conv = double_conv2d_bn(8, 16)

        self.layer3_conv = double_conv2d_bn(16, 32)

        self.layer4_conv = double_conv2d_bn(32, 64)

        self.layer5_conv = double_conv2d_bn(64, 128)

        self.layer6_conv = double_conv2d_bn(128, 64)

        self.layer7_conv = double_conv2d_bn(64, 32)

        self.layer8_conv = double_conv2d_bn(32, 16)

        self.layer9_conv = double_conv2d_bn(16, 8)

        self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,

                                      stride=1, padding=1, bias=True)

        self.deconv1 = deconv2d_bn(128, 64)

        self.deconv2 = deconv2d_bn(64, 32)

        self.deconv3 = deconv2d_bn(32, 16)

        self.deconv4 = deconv2d_bn(16, 8)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        conv1 = self.layer1_conv(x)

        pool1 = F.max_pool2d(conv1, 2)

        conv2 = self.layer2_conv(pool1)

        pool2 = F.max_pool2d(conv2, 2)

        conv3 = self.layer3_conv(pool2)

        pool3 = F.max_pool2d(conv3, 2)

        conv4 = self.layer4_conv(pool3)

        pool4 = F.max_pool2d(conv4, 2)

        conv5 = self.layer5_conv(pool4)

        convt1 = self.deconv1(conv5)

        concat1 = torch.cat([convt1, conv4], dim=1)

        conv6 = self.layer6_conv(concat1)

        convt2 = self.deconv2(conv6)

        concat2 = torch.cat([convt2, conv3], dim=1)

        conv7 = self.layer7_conv(concat2)

        convt3 = self.deconv3(conv7)

        concat3 = torch.cat([convt3, conv2], dim=1)

        conv8 = self.layer8_conv(concat3)

        convt4 = self.deconv4(conv8)

        concat4 = torch.cat([convt4, conv1], dim=1)

        conv9 = self.layer9_conv(concat4)

        outp = self.layer10_conv(conv9)

        outp = self.sigmoid(outp)

        return outp

model = Unet()

inp = torch.rand(10, 1, 224, 224)

outp = model(inp)

print(outp.shape)

 

这篇关于unet 网络的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!