李沐动手学习深度学习——4.2练习

1. 在所有其他参数保持不变的情况下,更改超参数num_hiddens的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。

通过改变隐藏层的数量,导致就是函数拟合复杂度下降,隐藏层过多可能导致过拟合,而过少导致欠拟合。
我们将层数改为128可得:
在这里插入图片描述

2. 尝试添加更多的隐藏层,并查看它对结果有何影响。

过拟合,导致测试机精确度下降。

3. 改变学习速率会如何影响结果?保持模型架构和其他超参数(包括轮数)不变,学习率设置为多少会带来最好的结果?

过高的学习率导致,梯度跨度过大,使得降低不到对应的驻点。
过低的学习率导致训练缓慢,需要增加epoch。
在训练轮数不变的情况下,我们可以通过for 设置不同的学习率找出最合适的学习率。一般来说设置为0.01或者0.1足以

4. 通过对所有超参数(学习率、轮数、隐藏层数、每层的隐藏单元数)进行联合优化,可以得到的最佳结果是什么?

跑了一次学习率lr=0.01的情况:
在这里插入图片描述

需要大量的训练,但是目前我训练结果是学习率lr=0.1、轮数是num_epochs=10,隐藏层数为1,隐藏层数单元num_hiddens=128。

5. 描述为什么涉及多个超参数更具挑战性。

因为组合的情况更多,当层数越多时,训练时间也更多,这玩意就是炼丹了,看你自己的GPU还有时间、运气。

6. 如果想要构建多个超参数的搜索方法,请想出一个聪明的策略。

套用for 循环暴力破解,时间上肯定慢的要死,我们可以先固定其他变量,挑选一个变量寻找最优解,以此类推对所有的超参数这样使用,但是这种做法肯定不是最优的,只是能够较好的找出比较好的超参数。

由于学校穷逼所以没有闲置GPU服务器,所有的模型只能在colab上进行运行,其中遇到了d2l的版本对应问题,所以对于d2l.train_ch3跑不起来,只能使用自写进行替代如下:

import torch.nn
from d2l import torch as d2l
from IPython import displayclass Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n       # 创建一个长度为 n 的列表,初始化所有元素为0.0。def add(self, *args):           # 累加self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):                # 重置累加器的状态,将所有元素重置为0.0self.data = [0.0] * len(self.data)def __getitem__(self, idx):     # 获取所有数据return self.data[idx]def accuracy(y_hat, y):"""计算正确的数量:param y_hat::param y::return:"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)            # 在每行中找到最大值的索引,以确定每个样本的预测类别cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):"""计算指定数据集的精度:param net::param data_iter::return:"""if isinstance(net, torch.nn.Module):net.eval()                  # 通常会关闭一些在训练时启用的行为metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Animator:"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量的绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):"""向图表中添加多个数据点:param x::param y::return:"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def train_epoch_ch3(net, train_iter, loss, updater):"""训练模型一轮:param net:是要训练的神经网络模型:param train_iter:是训练数据的数据迭代器,用于遍历训练数据集:param loss:是用于计算损失的损失函数:param updater:是用于更新模型参数的优化器:return:"""if isinstance(net, torch.nn.Module):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。net.train()# 训练损失总和, 训练准确总和, 样本数metric = Accumulator(3)for X, y in train_iter:  # 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。# 使用pytorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()  # 方法用于计算损失的平均值updater.step()else:# 使用定制(自定义)的优化器和损失函数l.sum().backward()updater(X.shape())metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):"""训练模型():param net::param train_iter::param test_iter::param loss::param num_epochs::param updater::return:"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):trans_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, trans_metrics + (test_acc,))train_loss, train_acc = trans_metricsprint(trans_metrics)def predict_ch3(net, test_iter, n=6):"""进行预测:param net::param test_iter::param n::return:"""global X, yfor X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + "\n" + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/714605.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】表的内连和外连(重点)

表的连接分为内连和外连。 一、内连接 内连接实际上就是利用 where 子句对两种表形成的笛卡儿积进行筛选,前面学习的查询都是内连接,也是在开发过程中使用的最多的连接查询。 select 字段 from 表1 inner join 表2 on 连接条件 and 其他条件; 注意&…

Linux使用基础命令

1.常用系统工作命令 (1).用echo命令查看SHELL变量的值 qiangziqiangzi-virtual-machine:~$ echo $SHELL /bin/bash(2).查看本机主机名 qiangziqiangzi-virtual-machine:~$ echo $HOSTNAME qiangzi-virtual-machine (3).date命令用于显示/设置系统的时间或日期 qiangziqian…

Linux多线程服务端编程:使用muduo C++网络库 学习笔记 附录B 从《C++ Primer(第4版)》入手学习C++

这是作者为《C Primer(第4版)(评注版)》写的序言,文中“本书”指的是这本书评注版。 B.1 为什么要学习C 2009年本书作者Stanley Lippman先生应邀来华参加上海祝成科技举办的C技术大会,他表示人们现在还用…

扩展学习|大数据分析的现状和分类

文献来源:[1] Mohamed A , Najafabadi M K , Wah Y B ,et al.The state of the art and taxonomy of big data analytics: view from new big data framework[J].Artificial Intelligence Review: An International Science and Engineering Journal, 2020(2):53. 下…

蓝桥杯(3.2)

1209. 带分数 import java.io.*;public class Main {static BufferedReader br new BufferedReader(new InputStreamReader(System.in));static PrintWriter pw new PrintWriter(new OutputStreamWriter(System.out));static final int N 10;static int n, cnt;static int[…

LabVIEW流量控制系统

LabVIEW流量控制系统 为响应水下航行体操纵舵翼环量控制技术的试验研究需求,通过LabVIEW开发了一套小量程流量控制系统。该系统能够满足特定流量控制范围及精度要求,展现了其在实验研究中的经济性、可靠性和实用性,具有良好的推广价值。 项…

tritonserver学习之八:redis_caches实践

tritonserver学习之一:triton使用流程 tritonserver学习之二:tritonserver编译 tritonserver学习之三:tritonserver运行流程 tritonserver学习之四:命令行解析 tritonserver学习之五:backend实现机制 tritonserv…

【C++初阶】内存管理

目录 一.C语言中的动态内存管理方式 二.C中的内存管理方式 1.new/delete操作内置类型 2.new和delete操作自定义类型 3.浅识抛异常 (内存申请失败) 4.new和delete操作自定义类型 三.new和delete的实现原理 1.内置类型 2.自定义类型 一.C语…

C++学习笔记:二叉搜索树

二叉搜索树 什么是二叉搜索树?搜索二叉树的操作查找插入删除 二叉搜索树的应用二叉搜索树的代码实现K模型:KV模型 二叉搜索树的性能怎么样? 什么是二叉搜索树? 二叉搜索树又称二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: 若它的左子树…

Linux安装Nginx详细步骤

1、创建两台虚拟机,分别为主机和从机,区别两台虚拟机的IP地址 2、将Nginx素材内容上传到/usr/local目录(pcre,zlib,openssl,nginx) 附件 3、安装pcre库   3.1 cd到/usr/local目录 3.2 tar -zxvf pcre-8.36.tar.gz 解压 3.3 cd…

MATLAB图像噪声添加与滤波

在 MATLAB 中添加图像噪声和进行滤波通常使用以下函数: 添加噪声:可以使用imnoise函数向图像添加各种类型的噪声,如高斯噪声、椒盐噪声等。 滤波:可以使用各种滤波器对图像进行滤波处理,例如中值滤波、高斯滤波等。 …

前端学习、HTML

html是由一些标签构成的,标签之间可以嵌套,每个标签都有开始标签和结束标签,也有部分标签只有开始标签,没有结束标签。html的标签也可以成为元素。(树形结构) html文件的最顶层标签就是html。 head用来放…

**蓝桥OJ 178全球变暖 DFS

蓝桥OJ 178全球变暖 思路: 将每一座岛屿用一个颜色scc代替, 用dx[]和dy[]判断他的上下左右是否需要标记颜色,如果已经标记过颜色或者是海洋就跳过.后面的淹没,实际上就是哪个块上下左右有陆地,那么就不会被淹没,我用一个tag标记,如果上下左右一旦有海洋,tag就变为false.如果tag…

用冒泡排序模拟C语言中的内置快排函数qsort!

目录 ​编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…

机器学习:模型评估和模型保存

一、模型评估 from sklearn.metrics import accuracy_score, confusion_matrix, classification_report# 使用测试集进行预测 y_pred model.predict(X_test)# 计算准确率 accuracy accuracy_score(y_test, y_pred) print(f"Accuracy: {accuracy*100:.2f}%")# 打印…

整数和浮点数在内存中的存储(大小端字节序,浮点数的存取)

目录 1.整数在内存中的存储 2.大小端字节序和字节序判断 2.1什么是大小端? 2.2为什么会有大小端 3.浮点数在内存中的存储 3.1浮点数的存储 3.1.1 浮点数存的过程 3.1.2 浮点数取的过程 3.2 解析 3.3 验证浮点数的存储方式 1.整数在内存中的存储 整数的二进…

亿道信息轻工业三防EM-T195,零售、制造、仓储一网打尽

厚度仅10.5mm,重量仅0.65千克的EM-T195,其紧凑而纤薄的设计为以往加固型平板带来了全新的轻薄概念。尽管设计时尚、轻薄,但经过军用认证的强固性仍然能够承受所有具有挑战性的环境条件。随身携带无负担的轻便性加上抗震功能使其成为餐厅、酒店…

数独游戏(dfs)

代码注释如下 #include <iostream> using namespace std; const int N 10; bool col[N][N], rol[N][N], cell[3][3][N]; char g[N][N]; bool dfs(int x, int y) { //用bool这样在找到一个方案就可以迅速退出if(y 9) x, y 0; //若y超出边界&#xff0c;则第二…

S1---FPGA硬件板级原理图实战导学

视频链接 FPGA板级实战导学01_哔哩哔哩_bilibili FPGA硬件板级原理图实战导学 【硬件电路设计的方法和技巧-哔哩哔哩】硬件电路设计的方法和技巧01_哔哩哔哩_bilibili&#xff08;40min&#xff09; 【高速板级硬件电路设计-哔哩哔哩】 高速板级硬件电路设计1_哔哩哔哩_bil…

【RT-Thread基础教程】邮箱的使用

文章目录 前言一、邮箱的特性二、邮箱操作函数2.1 创建邮箱创建动态邮箱创建静态邮箱 2.2 删除邮箱2.3 发邮件2.4 取邮件 三、示例代码总结 前言 RT-Thread是一个开源的实时嵌入式操作系统&#xff0c;广泛应用于各种嵌入式系统和物联网设备。在RT-Thread中&#xff0c;邮箱是…