博主想对神经网络模型的参数写入 bin
文件,方便在后续创建IP的过程中读取数据进行验证,记录 python
读取 pytorch
的模块参数并进行bin文件写入和读取操作。本文以3x3卷积为例。
本文涉及的模块
struct: 该模块可以执行 Python 值和以 Python bytes
对象表示的 C 结构之间的转换。
pytorch
:神经网络框架
简单示例:
import struct SAVE_DIR = "./conv3x3_pool_relu_outputs" import struct val = -1 a = struct.pack('i', val)# 将 -1 进行二进制打包(4字节), 默认情况下,打包给定 C 结构的结果会包含填充字节以使得所涉及的 C 类型保持正确的对齐 print(a) # b'\xff\xff\xff\xff',确实是-1的二进制补码表示 file = os.path.join(SAVE_DIR, "wt.bin") with open(file, "ab+") as fw: # 二进制追加形式 fw.write(a) with open(file, "rb") as fr: # 二进制读形式 b = struct.unpack('i', fr.read(4)) print(b[0]) # (-1,),返回元组 print(b[0] == val) # true
完整保存参数代码:
# coding:utf-8 """ for generate conv3x3_pool_relu and data for test. """ import os import torch import torch.nn as nn # hyper param Hin = 6 Win = 12 CHin = 16 CHout = 16 step = 0.1 G_SIZE = 8 SAVE_DIR = "./conv3x3_pool_relu_outputs" seed = 2021 torch.random.manual_seed(seed) def format_num(x): """ >0 -> 1, <0 -> -1. switch func. """ return (torch.randn_like(x) > 0).to(torch.float32) * 2 - 1 def save_conv3x3_weight(weight, save_dir="./outputs", filename="conv3x3", size=8): """ 写入文件, """ shape = weight.shape print("save {} weights(bin format) ".format(filename), shape, end=" ---------wait---------- ") assert len(shape) == 4 and shape[0] % size == 0 and shape[1] % size == 0, "input error" if not ".dat" in filename: filename = filename + "_weight.bin" if type(weight) in [torch.nn.Parameter, torch.Tensor]: weight = weight.cpu().detach().numpy() filepath = os.path.join(save_dir, filename) with open(filepath, "wb+") as fw: for i in range(0, shape[0], size): for j in range(0, shape[1], size): for co in range(i, i + size): for ci in range(j, j + size): for h in range(3): for w in range(3): fw.write(struct.pack('i', int(weight[co][ci][h][w]))) # 写入前进行二进制转换 print("save conv3x3_weight done. save weights to {}".format(filepath)) return filepath class Conv3x3PoolRelu(nn.Module): def __init__(self, in_channels=16, out_channels=32, save=False, out_dir="./outputs", save_size=8): super().__init__() assert in_channels % G_SIZE == 0 and out_channels % G_SIZE == 0, "input error!!" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.act = nn.ReLU() self.pool = nn.MaxPool2d(2, 2, 0) self.init_weights() self.save_dir = out_dir def forward(self, x): for name, module in self.named_children(): print(name, module) if type(module) in [nn.Conv2d]: save_conv3x3_weight(module.weight, self.save_dir) x = module(x) return x def init_weights(self): for idx, m in self.named_modules(): # print(idx, " ", type(m)) # initialize if type(m) in [nn.Conv2d]: weight, bias = m.weight, m.bias m_weight, m_bias = format_num(weight), format_num(bias) # m_weight, m_bias = format_num(weight), torch.zeros_like(bias) m.weight, m.bias = nn.Parameter(m_weight, requires_grad=False), nn.Parameter(m_bias, requires_grad=False) if __name__ == '__main__': model = Conv3x3PoolRelu(8, 8, out_dir=SAVE_DIR) x = format_num(torch.randn(1, 8, 4, 4)) y = model(x) print(y.shape, y) # read to validate w = torch.empty(8, 8, 3, 3) con, cin, kh, kw = w.shape with open("./conv3x3_pool_relu_outputs/conv3x3_weight.bin", "rb") as fr: for co in range(con): for ci in range(cin): for i in range(kh): for j in range(kw): data = struct.unpack("i", fr.read(4)) # 使用unpack进行转换为数据类型,注意read后面跟的是4个字节 w[co][ci][i][j] = data[0] print(w)
详细步骤如下:
struct
库对相应的数据类型进行二进制转换(读使用unpack
,写使用pack
);Note:写入文件的格式和数据类型之间的关系如下:
格式 | C 类型 | Python 类型 | 标准大小 |
---|---|---|---|
x | 填充字节 | 无 | |
c | char | 长度为 1 的字节串 | 1 |
b | signed char | 整数 | 1 |
B | unsigned char | 整数 | 1 |
? | _Bool | bool | 1 |
h | short | 整数 | 2 |
H | unsigned short | 整数 | 2 |
i | int | 整数 | 4 |
I | unsigned int | 整数 | 4 |
l | long | 整数 | 4 |
L | unsigned long | 整数 | 4 |
q | long long | 整数 | 8 |
Q | unsigned long long | 整数 | 8 |
n | ssize_t | 整数 | |
N | size_t | 整数 | |
e | (6) | float | 2 |
f | float | float | 4 |
d | double | float | 8 |
s | char[] | 字节串 | |
p | char[] | 字节串 | |
P | void * | 整数 |
写入 bin
文件主要是将二进制数据写入,如果一开始就是二进制数据,那么就不需要进行 struct
的 pack
操作。另外,对于python的数据类型,写入文件的字节顺序、大小与对齐方式可以设置,详细见官方文档[2]。
1、python bin 文件处理 - 云 + 社区 - 腾讯云 (tencent.com)
2、struct — 将字节串解读为打包的二进制数据 — Python 3.8.12 文档