C/C++教程

21 PyTorch 可复现设置

本文主要是介绍21 PyTorch 可复现设置,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

PyTorch 可复现设置

参考链接:

https://www.jianshu.com/p/b95ec7351603

影响模型可复现性有以下几个因素:

1 随机种子

2 Dataloader

3 不确定性的算法

具体的看上面的链接,简单来说,加上下面这两段就ok了:

def set_seed(seed):
    # 随机种子
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.

    # 不确定性算法
    # torch.set_deterministic(True) # 会报错,所以注释掉
    torch.backends.cudnn.enabled = False 
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    os.environ['PYTHONHASHSEED'] = str(seed)
def worker_init(self, worked_id):
    # dataloader
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

train_loader = DataLoader(xxx, num_workers=0, worker_init_fn=self.worker_init)
这篇关于21 PyTorch 可复现设置的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!