Java教程

GFPGAN源码分析—第六篇

本文主要是介绍GFPGAN源码分析—第六篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

2021SC@SDUSC

源码:archs\gfpganv1_clean_arch.py

本篇主要分析gfpganv1_clean_arch.py下的

class GFPGANv1Clean(nn.Module)类_init_()方法

目录

class GFPGANv1Clean(nn.Module)

init()

(1)channels的设置

(2)调用torch.nn.Conv2d()创建了一层卷积神经网络

(3)下采样(downsample)

(4)上采样(upsample)

(5)全连接层

(6)创建self.stylegan_decoder

(7)如果decoder_load_path不为空则读取

(8)for SFT(SFT layer)


class GFPGANv1Clean(nn.Module)

        继承自nn.Module类,使得我们可以使用很多现成的类,比如本类中使用的Conv2d以及RelU激活函数等等。

init()

参数:

self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False

在class GFPGANer()-init()中被调用时:

self.gfpgan = GFPGANv1Clean(
    out_size=512,
    num_style_feat=512,
    channel_multiplier=channel_multiplier,
    decoder_load_path=None,
    fix_decoder=False,
    num_mlp=8,
    input_is_latent=True,
    different_w=True,
    narrow=1,
    sft_half=True)

(1)channels的设置

实际调用的时候narrow=1,

channels保存了经过convolution层后的输出的通道数

unet_narrow = narrow * 0.5

channels = {
    '4': int(512 * unet_narrow),
    '8': int(512 * unet_narrow),
    '16': int(512 * unet_narrow),
    '32': int(512 * unet_narrow),
    '64': int(256 * channel_multiplier * unet_narrow),
    '128': int(128 * channel_multiplier * unet_narrow),
    '256': int(64 * channel_multiplier * unet_narrow),
    '512': int(32 * channel_multiplier * unet_narrow),
    '1024': int(16 * channel_multiplier * unet_narrow)
}

(2)调用torch.nn.Conv2d()搭建卷积神经网络

#out_size=512,so log_size=9
self.log_size = int(math.log(out_size, 2))
#first_out_size = 512
first_out_size = 2 ** (int(math.log(out_size, 2)))
#channels['512']=32*2*0.5=32
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)

在这里介绍一下nn.Conv2d()的几个参数

in_channels: int,#输入的通道数目【必选】
out_channels: int,# 输出的通道数目【必选】
kernel_size: _size_2_t,#卷积核的大小,类型为int(方形边长) 或者元组(长和宽)【必选】
stride: _size_2_t = 1,#步长
padding: Union[str, _size_2_t] = 0,#边界增益,可以控制输出结果的尺寸
dilation: _size_2_t = 1,#控制卷积核之间的间距
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',  # TODO: refine this type
device=None,
dtype=None

那么可以得知

self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)


#实际上是传入通道为3(RGB)的输入,使用边长为1的卷积核,最后获得通道为32的输出
#由于卷积核边长为1,我们输入与输入的图片大小仍然保持一致,但增加了通道数

(3)下采样(downsample)

可以看到实际上是调用ResBlock做了下采样

# 输入图片的通道数(实际为32)
in_channels = channels[f'{first_out_size}']
 #创建ModuleList容器
self.conv_body_down = nn.ModuleList()
# i从self.log_size(9)->3      :7次循环
for i in range(self.log_size, 2, -1):
    out_channels = channels[f'{2 ** (i - 1)}']
    #调用ResBlock残差网络做下采样,并将该module添加到设置的ModuleList
    self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
    #这一层的输出管道数作为下一层输入的管道数
    in_channels = out_channels

介绍一下nn.ModuleList()

nn.ModuleList,它是一个储存不同module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。
#注意nn.ModuleList则没有实现内部forward函数,所以需要手动实现

最后一层卷积层的搭建:

#最终输出通道数为channels['4']=256,使用边长为3的卷积核,步长为1,padding为1,保证维度不变
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)

(4)上采样(upsample)

#输入通道数为channels['4']=256,即下采样的输出的通道数
        in_channels = channels['4']
        #创建ModuleList容器
        self.conv_body_up = nn.ModuleList()
        # i从3->self.log_size(9)     :7次循环
        for i in range(3, self.log_size + 1):
            # 定义输出的通道数
            out_channels = channels[f'{2 ** i}']
            # 调用带有上采样ResBlock残差网络,并将该module添加到设置的ModuleList
            self.conv_body_up.append(ResBlock(in_channels, out_channels, 
                                              mode='up'))
            #这一层的输出管道数作为下一层输入的管道数
            in_channels = out_channels

(5)全连接层

根据传入的参数different_w,选择每个输出样本的大小,并搭建相应的全连接层。

if different_w:
    #16*512=8192
    linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
    print(linear_out_channel)
else:
    #512
    linear_out_channel = num_style_feat
#全连接层size of each input sample:4096,size of each output sample:8192
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)

(6)创建self.stylegan_decoder

self.stylegan_decoder = StyleGAN2GeneratorCSFT(
    out_size=out_size,
    num_style_feat=num_style_feat,
    num_mlp=num_mlp,
    channel_multiplier=channel_multiplier,
    narrow=narrow,
    sft_half=sft_half)

(7)如果decoder_load_path不为空则读取

if decoder_load_path:
    self.stylegan_decoder.load_state_dict(
        torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
if fix_decoder:
    for name, param in self.stylegan_decoder.named_parameters():
        param.requires_grad = False

(8)for SFT(SFT layer)

#ModuleList
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
  # i从3->self.log_size(9)     :7次循环
for i in range(3, self.log_size + 1):
    # 定义输出的通道数
    out_channels = channels[f'{2 ** i}']
     #输出通道数是否减半
    if sft_half:
        sft_out_channels = out_channels
    else:
        sft_out_channels = out_channels * 2
         #使用nn.Sequential搭建网络,并添加到ModuleList
    self.condition_scale.append(
        nn.Sequential(
             #卷积核边长为3,步长为1,输出与输出保持相同维度
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
    self.condition_shift.append(
        nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, 
                                                                         True),
            nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))

nn.Sequential是一个有序的容器,其中传入的是构造器类(各种用来处理input的类),最终input会被Sequential中的构造器依次执行。

这篇关于GFPGAN源码分析—第六篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!