文章目录
- 网络构建
网络构建
在打卡第一天就简单演示了网络构建,一个神经网络模型表示为一个Cell,由不同的子Cell构成。使用这样的嵌套结构可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。
继承nn.Cell类来定义神经网络,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。
construct意为神经网络(计算图)构建,后面单开一章来介绍。
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
nn.Flatten是将28x28的2D张量转换为784大小的连续数组。
nn.SequentialCell是一个有序的Cell容器,输入Tensor将按照定义的顺序通过所有Cell。
代码块里nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。
构造一个输入数据再调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。
model = Network()
X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
输出结果:Tensor(shape=[1, 10], dtype=Float32, value=
[[-5.08734025e-04, 3.39190010e-04, 4.62840870e-03 … -1.20305456e-03, -5.05689112e-03, 3.99264274e-03]])
再通过一个nn.Softmax层实例来获得预测概率
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
输出结果:Predicted class: [4]