Java教程

GFPGAN源码分析—第八篇

本文主要是介绍GFPGAN源码分析—第八篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

2021SC@SDUSC

源码:

models\init.py

models\gfpgan_model.py

本篇主要分析init.py与models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类init(self, opt) 方法

目录

init.py

gfpgan_model.py

class GFPGANModel(BaseModel)

init(self, opt)

init_training_settings(self)


init.py

自动扫描和导入注册表的模型模块

#在models文件夹下扫描所有以 '_model.py' 结尾的文件
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# 导入所有模型模块
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]

那么实际上就是导入models文件夹下gfpgan_model.py文件,接下来我们来看一下

gfpgan_model.py

本文件中只包含GFPGANModel(BaseModel)一个类

创建了一个MODEL_REGISTRY对象,并在类定义的时候用装饰器装饰它,以装饰器的形式调用MODEL_REGISTRY类的register函数

@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
    """GFPGAN model for <Towards real-world blind faces restoratin with generative facial prior>"""

class GFPGANModel(BaseModel)

基于生成性人脸先验信息的真实盲脸修复 的 GFPGAN 模型

init(self, opt)

简单看一下代码

super(GFPGANModel, self).__init__(opt)
self.idx = 0

# 网络定义
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)

# 读取预训练的模型
load_path = self.opt['path'].get('pretrain_network_g', None)
#如果路径不为空
if load_path is not None:
    param_key = self.opt['path'].get('param_key_g', 'params')
    self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)

self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))

if self.is_train:
    self.init_training_settings()

在读取预训练的模型时,实际上就是从train_gfpgan_v1.yml配置文件中读取到相应的参数的数值与路径。

init_training_settings(self)

初始化训练设置

1.读取opt['train']

train_opt = self.opt['train']

2.定义net_d

#构建网络
self.net_d = build_network(self.opt['network_d'])
#将模型放到gpu(cuda)上
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# 读取与训练好的模型
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
    self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))

3.定义net_g

# net_g_ema 仅用于在一个GPU上测试并保存
# 不需要使用DistributedDataParallel进行包装
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# 读取预训练模型
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
    self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
    self.model_ema(0)  # copy net_g weight

self.net_g.train()
self.net_d.train()
self.net_g_ema.eval()

根据配置文件:net_g读取预训练模型为arcface_resnet18.pth

4.面部组件网络

if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
            self.use_facial_disc = True
        else:
            self.use_facial_disc = False

        if self.use_facial_disc:
            # left eye
            self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
            self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
            self.print_network(self.net_d_left_eye)
            load_path = self.opt['path'].get('pretrain_network_d_left_eye')
            if load_path is not None:
                self.load_network(self.net_d_left_eye, load_path, True, 'params')
            # right eye
            self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
            self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
            self.print_network(self.net_d_right_eye)
            load_path = self.opt['path'].get('pretrain_network_d_right_eye')
            if load_path is not None:
                self.load_network(self.net_d_right_eye, load_path, True, 'params')
            # mouth
            self.net_d_mouth = build_network(self.opt['network_d_mouth'])
            self.net_d_mouth = self.model_to_device(self.net_d_mouth)
            self.print_network(self.net_d_mouth)
            load_path = self.opt['path'].get('pretrain_network_d_mouth')
            if load_path is not None:
                self.load_network(self.net_d_mouth, load_path, True, 'params')

            self.net_d_left_eye.train()
            self.net_d_right_eye.train()
            self.net_d_mouth.train()

            # ----------- 定义面部组件的 gan loss ----------- #
            self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)

5.定义损失

if train_opt.get('pixel_opt'):
    self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
    self.cri_pix = None

if train_opt.get('perceptual_opt'):
    self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
    self.cri_perceptual = None
    # pyramid loss, component style loss, identity loss 都使用L1损失
    self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)

    # gan loss (wgan)
    self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)

6.identity loss的定义

if 'network_identity' in self.opt:
    self.use_identity = True
else:
    self.use_identity = False

if self.use_identity:
    # 定义 identity network
    self.network_identity = build_network(self.opt['network_identity'])
    self.network_identity = self.model_to_device(self.network_identity)
    self.print_network(self.network_identity)
    load_path = self.opt['path'].get('pretrain_network_identity')
    if load_path is not None:
        self.load_network(self.network_identity, load_path, True, None)
    self.network_identity.eval()
    for param in self.network_identity.parameters():
        param.requires_grad = False

# 正则化权重
self.r1_reg_weight = train_opt['r1_reg_weight']  # for discriminator
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
self.net_d_reg_every = train_opt['net_d_reg_every']

# 设置优化器和调度程序
self.setup_optimizers()
self.setup_schedulers()

这篇关于GFPGAN源码分析—第八篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!