1、导入库
import torchimport torch.nn as nn
2、搭建卷积神经网络
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3) self.linear = nn.Linear(5766,10) self.relu = nn.ReLU(inplace=True) self.maxpooling = nn.MaxPool2d((2,2)) def forward(self, x): x = self.conv(x) #print (x.shape) x = self.maxpooling(x) #print (x.shape) x = self.relu(x) x = x.view(1,-1) #print (x.shape) x = self.linear(x) return x
对于新手来说,可以先熟悉pytorch的格式。网络定义一般由两部分组成,
def __init__(self):用来定义网络节点参数; def forward(self, x): 将节点连接成图。 卷积计算规则,对我们输入形状(1,1,64,64),四个维度分别是(batch,channel,height,width)。batch:一次训练的批次,channel图像通道(比如RGB,channel = 3)。height,width分别指图像的高和宽。 new_height= (height - kernel_size + 2×padding)/(stride[0])+1;padding默认为0,意思的在周围补一圈零; stride默认为1,因此。 new_height = new_width = (64 - 3)/(1) + 1 = 62。 由于输出通道数为6,所以通过卷积层后维度(1,6,62,62) 经过pooling后,(1,6,31,31) x.view(1,-1):把x伸缩为(1,?)的维度,即(1,1×6×31×31)=(1,5766) nn.Linear(5766,10),把(1,5766)映射为(1,10)的维度。这样整个网络其实输入(1,1,64,64),输出(1,10) 3、添加训练数据
if __name__ =='__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = Net().to(device) optimizer = torch.optim.Adam(net.parameters()) criterion = nn.MSELoss() net.train() epoch = 100 input = torch.randn(1,1,64,64).cuda() output = torch.ones(1,10).cuda() batch = 32
optimizer:是优化器,即所谓的反向传播算法。 criterion = nn.MSELoss()定义损失函数。 input = torch.randn(1,1,64,64).cuda() output = torch.ones(1,10).cuda()。定义训练样本,注意如果在gpu中训练,在pytorch中需要.cuda()把数据从cpu中导入到gpu中 网络的功能是给定随机噪声向量,输出是逼近1的单位向量。 4、训练:
for step in range(epoch): prediction = net(input) loss = criterion(prediction, output) optimizer.zero_grad() #消除优化器梯度 loss.backward() optimizer.step() if step % 10 == 0: print("EPOCHS: {},Loss:{:4f}".format(step, loss))
loss.backward() 指自动求导 optimizer.step() 指根据自动求导反向传播优化参数。 5、我们可以输出样本看看结果:
print (prediction.cpu().detach().numpy()) #返回一个新的 从当前图中分离的 Variable。print (output.cpu().numpy())
注意输出结果时必须对张量.cpu()把张量从gpu转到cpu中。 对于计算图中的张量(比如x,prediction),必须加.detach()从计算图中导出才能转化成numpy。 输出结果: [[1.0127475 0.98897606 1.002695 0.9881151 1.0137383 1.0051517 1.0140573 1.0051212 1.0088345 0.9978328 ]] [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] (若有新的见解请加以批评指正)