CrossEntropyLoss(预测值,label)
需要的输入维度是:
[ batch_size, n ]
时,label的维度是1,size为[ batch_size ]
[ m, n ]
,label的维度是1,size为[ m ]
一个案例即可说明:
import torch import torch.nn as nn import numpy as np a = torch.tensor(np.random.random((30, 5))) b = torch.tensor(np.random.randint(0, 4, (30))).long() loss = nn.CrossEntropyLoss() print("a的维度:", a.size()) # torch.Size([30, 5]) print("b的维度:", b.size()) # torch.Size([30]) print(loss(a, b)) # tensor(1.6319, dtype=torch.float64)