3.10. 多重感知机的简洁实现
import torch
from torch import nn
from torch. nn import init
import numpy as np
import sys
sys. path. append( ".." )
import d2lzh_pytorch as d2l
3.10.1. 定义模型
num_inputs, num_outputs, num_hiddens = 784 , 10 , 256
net = nn. Sequential( d2l. FlattenLayer( ) , nn. Linear( num_inputs, num_hiddens) , nn. ReLU( ) , nn. Linear( num_hiddens, num_outputs) ,
)
for param in net. parameters( ) : init. normal_( param, mean= 0 , std= 0.01 )
3.10.2 读取数据并训练模型
batch_size = 256
train_iter, test_iter = d2l. load_data_fashion_mnist( batch_size)
loss = torch. nn. CrossEntropyLoss( ) optimizer = torch. optim. SGD( net. parameters( ) , lr = 0.5 ) num_epochs = 5
d2l. train_ch3( net, train_iter, test_iter, loss, num_epochs, batch_size, None , None , optimizer)
X, y = iter ( test_iter) . next ( ) true_labels = d2l. get_fashion_mnist_labels( y. numpy( ) )
pred_labels = d2l. get_fashion_mnist_labels( net( X) . argmax( dim= 1 ) . numpy( ) )
titles = [ true + '\n' + pred for true, pred in zip ( true_labels, pred_labels) ] d2l. show_fashion_mnist( X[ 0 : 9 ] , titles[ 0 : 9 ] )