from PIL import Image image=Image.open("./xxx.png") #读取图片 img_data = np.array(image) #将图片转换为np对象 (此时img_data的大小为 [H,W,3],其中W为图片的宽,H为图片的高,3为RGB通道数)
为什么要增加成四维呢?
因为pytorch中的数据为tensor(张量),而张量的描述格式为(batch_size,色彩通道数量,高度,宽度)
,而一张图片一般是3维结构(高度,宽度,色彩通道数量)
,明显差一个维度,因此需要在第一个位置增加一个维度。
此外,还注意到tensor的第二个参数为通道数,而RGB的第三个才是通道数,因此需要在此处转换一下。
转换步骤:将三个通道的数据拆开,再拼起来
img_R = img_data[:,:,0] img_G = img_data[:,:,1] img_B = img_data[:,:,2] img = np.array([img_R,img_G,img_B]) # 此时img的大小为[3,H,W]
使用unsqueeze()来增加维度:x = torch.from_numpy(img).float().unsqueeze(0)
,其中的参数0是指“在第0个维度增加一维”
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class NET(nn.Module): # 搭建网络结构 def __init__(self): super(NET, self).__init__() self.conv11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,padding=1) self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1) def forward(self,inputs): x11 = self.conv11(inputs) # 卷积 x11 = F.relu(x11) # relu激活 x12 = self.conv12(x11) x12 = F.relu(x12) flatten = torch.flatten(x12) # 平坦化 output = F.log_softmax(flatten) # softmax处理(使用log_softmax能够防止单纯使用softmax时的边界溢出问题) return output
net = NET() # 实例化网络 output = net(img) # 此处的img为之前经过“转换步骤”转换过的数据