【Python机器学习】深度学习——调参

        先用MLPClassifier应用到two_moons数据集上:

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)mlp=MLPClassifier(solver='lbfgs',random_state=0)
mlp.fit(X_train,y_train)
mglearn.plots.plot_2d_separator(mlp,X_train,fill=True,alpha=.3)
mglearn.discrete_scatter(X_train[:,0],X_train[:,1],y_train)
plt.xlabel('特征0')
plt.ylabel('特征1')
plt.show()

        可以看到,神经网络学到的决策边界完全是非线性的,但相对平滑, 默认情况下,MLP使用100个隐结点,可以减少数量,降低模型复杂度,对于小型数据集来说,仍然可以得到很好的结果。

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)mlp=MLPClassifier(solver='lbfgs',random_state=0,hidden_layer_sizes=[10],max_iter=10000)
mlp.fit(X_train,y_train)
mglearn.plots.plot_2d_separator(mlp,X_train,fill=True,alpha=.3)
mglearn.discrete_scatter(X_train[:,0],X_train[:,1],y_train)
plt.xlabel('特征0')
plt.ylabel('特征1')
plt.show()

可以看到,决策边界更加参差不齐。默认的非线性是relu,如果想要得到更平滑的决策边界,可以添加更多隐单元,或者使用tanh非线性。

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)mlp=MLPClassifier(solver='lbfgs',activation='tanh',random_state=0,hidden_layer_sizes=[10,10],max_iter=10000)
mlp.fit(X_train,y_train)
mglearn.plots.plot_2d_separator(mlp,X_train,fill=True,alpha=.3)
mglearn.discrete_scatter(X_train[:,0],X_train[:,1],y_train)
plt.xlabel('特征0')
plt.ylabel('特征1')
plt.show()

除此以外,还可以利用L2惩罚使权重趋向于0,从而控制神经网络的复杂度,alpha的默认值很小,下面对不同参数下,神经网络结果的可视化:

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)fig,axes=plt.subplots(2,4,figsize=(20,8))
for axx,n_hidden_nodes in zip(axes,[10,100]):for ax,alpha in zip(axx,[0.0001,0.01,0.1,1]):mlp = MLPClassifier(solver='lbfgs'#, activation='tanh', random_state=0, hidden_layer_sizes=[n_hidden_nodes, n_hidden_nodes],alpha=alpha,max_iter=10000)mlp.fit(X_train,y_train)mglearn.plots.plot_2d_separator(mlp,X_train,fill=True,alpha=.3,ax=ax)mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train,ax=ax)ax.set_title('隐单元个数=[{},{}]\nalpha={:.4f}'.format(n_hidden_nodes,n_hidden_nodes,alpha))
plt.show()

神经网络的一个重要性质是:在开始学习之前,权重是随机设置的,这种随机化会影响学到的模型,也就是即使使用完全相同的参数,用的随机种子不同,也可能得到非常不一样的模型:

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)fig,axes=plt.subplots(2,4,figsize=(20,8))
for i,ax in enumerate(axes.ravel()):mlp = MLPClassifier(solver='lbfgs'#,activation='tanh',random_state=i,hidden_layer_sizes=[100,100],max_iter=10000)mlp.fit(X_train,y_train)mglearn.plots.plot_2d_separator(mlp,X_train,fill=True,alpha=.3,ax=ax)mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train,ax=ax)ax.set_title('随机初始化参数={:.4f}'.format(i))
plt.show()

用另一个例子,使用默认参数查看模型的特征数据和精度:


from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge,LinearRegression,Lasso,LogisticRegression
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPClassifierplt.rcParams['font.sans-serif']=['SimHei']
cancer=load_breast_cancer()print('癌症数据集每个特征的最大值:{}'.format(cancer.data.max(axis=0)))
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0
)
mlp=MLPClassifier(random_state=0)
mlp.fit(X_train,y_train)print('训练集精度:{:.4f}'.format(mlp.score(X_train,y_train)))
print('测试集精度:{:.4f}'.format(mlp.score(X_test,y_test)))

MLP模型的精度很好,但是没有其他模型好,原因可能在于数据的缩放。神经网络也要求所有数据特征的变化范围相近,最理想的情况是均值为0,方差为1,人工处理:


#计算每个特征的平均值
mean_on_train=X_train.mean(axis=0)
#计算每个特征的标准差
std_on_train=X_train.std(axis=0)#减去平均值,然后乘标准差的倒数
#计算完成后mean=0,std=1
X_train_scaled=(X_train-mean_on_train)/std_on_train
X_test_scaled=(X_test-mean_on_train)/std_on_trainmlp_std=MLPClassifier(random_state=0)
mlp_std.fit(X_train_scaled,y_train)print('训练集精度:{:.4f}'.format(mlp_std.score(X_train_scaled,y_train)))
print('测试集精度:{:.4f}'.format(mlp_std.score(X_test_scaled,y_test)))

 

        可以看到缩放之后的结果要好很多,另外,增大迭代次数可以提高训练集性能,但不提高泛化性能。

        对特征重要性的可视化:


from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge,LinearRegression,Lasso,LogisticRegression
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPClassifierplt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
cancer=load_breast_cancer()print('癌症数据集每个特征的最大值:{}'.format(cancer.data.max(axis=0)))
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0
)
mlp=MLPClassifier(random_state=0)
mlp.fit(X_train,y_train)plt.figure(figsize=(20,5))
plt.imshow(mlp.coefs_[0],interpolation='none',cmap='viridis')
plt.yticks(range(30),cancer.feature_names)
plt.xlabel('隐单元权重')
plt.ylabel('输入特征')
plt.colorbar()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/616124.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

训练营第四十二天 | 01背包问题,你该了解这些! ● 01背包问题,你该了解这些! 滚动数组 ● 416. 分割等和子集

01背包问题 二维 代码随想录 dp二维数组 优化 01背包问题 一维 代码随想录 dp一维数组 416. 分割等和子集 把数组分成总和相等的两份,如果数组总和为奇数,不能分割,若有符合的数组子集,返回true 代码随想录 class Solution {p…

数据中心建设之——理解基于财务三大报表的BI指标体系搭建

目录 1.1 三张报表的作用 1.2 三张报表长的样子 1.2.1 资产负债表 1.2.2 利润表 1.2.3 现金流 1.3 BI指标构建 1.3.1 盈利能力指标构建 1.3.2 营运能力指标构建 1.3.3 偿债能力指标构建 转眼间,一年又悄然而逝,时光荏苒,岁月如梭 &a…

仓储|仓库管理水墨屏RFID电子标签2.4G基站CK-RTLS0501G功能说明与安装方式

随着全球智能制造进度的推进以及物流智能化管理水平的升级,行业亟需一种既能实现RFID批量读取、又能替代纸质标签在循环作业、供应链管理以及实现动态条码标签显示的产品。在此种行业需求背景下,我是适时推出了基于墨水屏显示技术的VT系列可视化超高频标…

JVM-JVM支持高并发底层原理精讲

一、透彻掌握高并发-从理解JVM开始 二、从线程的开闭看JVM的作用 1.run方法 启动start方法,会调用底层C方法,告诉操作系统当前线程处于可运行状态,而如果直接调用run方法,则就不是以线程的方式来运行了,只是当做一个普…

一套成熟的Spring Cloud智慧工地平台源码,自主版权,支持二次开发!

智慧工地源码,java语言开发的智慧工地源码 智慧工地利用移动互联、物联网、云计算、大数据等新一代信息技术,彻底改变传统施工现场各参建方的交互方式、工作方式和管理模式,为建设集团、施工企业、监理单位、设计单位、政府监管部门等提供一揽…

RabbitMQ(十)队列的声明方式

目录 1.编程式声明补充:RabbitTemplate 和 AmqpAdmin 的区别 2.声明式声明补充:new Queue() 和 QueueBuilder.durable(queueName).build() 的区别 背景: 在学习 RabbitMQ 的使用时, 经常会遇到不同的队列声明方式,有的…

酚醛胶面建筑模板 — 广西厂家直销,质保可靠

在现代建筑行业中,选择高质量的建筑板材对于确保施工质量和工程安全至关重要。广西厂家直销的酚醛胶面建筑板,以其卓越的质量和可靠的质保,成为了建筑行业的优选材料。 产品特性 卓越的耐候性:我们的酚醛胶面建筑板采用高品质酚醛…

图文看懂Android的Matrix原理

Matrix结构 在Android开发中,矩阵是一个非常强大且有趣的工具。位于图形库中,android.graphics.Matrix 是一个 33 的 float 矩阵,其主要作用是坐标变换。 它的结构大概是这样的: 其中每个位置的数值作用和其名称所代表的的含义是…

Vue-18、Vue人员列表排序

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>列表排序</title><script type"text/javascript" src"https://cdn.jsdelivr.net/npm/vue2/dist/vue.js"></script…

Linux中DCHP与时间同步

目录 一、DHCP &#xff08;一&#xff09;工作原理 1.获取 2.续约 &#xff08;二&#xff09;分配方式 &#xff08;三&#xff09;服务器配置 1.随机地址分配 2.固定地址分配 二、时间同步 &#xff08;一&#xff09;ntpdate &#xff08;二&#xff09;chrony …

window-nginx注册服务(nginx-1.24.0.zip)

window-nginx注册服务(nginx-1.24.0.zip) 1、下载当前windows版nginx的稳定版本。 https://nginx.org/en/download.html 2、解压到指定目录中&#xff0c;这里解压到D盘根目录&#xff0c;D:\nginx-1.24.0 3、管理员打开命令行&#xff0c;可先进行相关操作&#xff0c;看一下n…

uni-app修改头像和个人信息

效果图 代码&#xff08;总&#xff09; <script setup lang"ts"> import { reqMember, reqMemberProfile } from /services/member/member import type { MemberResult, Gender } from /services/member/type import { onLoad } from dcloudio/uni-app impor…

Google的Ndk-Sample学习笔记之一(hello-jniCallback)

前言: 近段时间因为项目的需求,需要使用JNI,所以下载了Google的Ndk-Sample学习下,准备记录 下来,留给后期自己查看 问题点一:JNI_OnLoad方法必须返回JNI的版本 JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {JNIEnv *env;memset(&g_ctx, 0, sizeof(g_…

亚马逊API:快速查询全球商品数据的技巧!

了解亚马逊API的限制和要求&#xff1a;在使用亚马逊API之前&#xff0c;您需要了解其限制和要求&#xff0c;例如请求频率限制、认证要求等。确保您遵循了API的使用条款&#xff0c;以避免不必要的麻烦。使用合适的亚马逊API服务&#xff1a;亚马逊提供了多个API服务&#xff…

Atlassian版本选择趋势是上云还是本地部署?全面分析两个版本的特性

近日&#xff0c;龙智联合Atlassian举办的DevSecOps研讨会年终专场”趋势展望与实战探讨&#xff1a;如何打好DevOps基础、赋能创新”在上海圆满落幕。龙智Atlassian技术与顾问咨询团队&#xff0c;以及清晖、JamaSoftware、CloudBees等生态伙伴的嘉宾发表了主题演讲&#xff0…

flutter封装dio请求库,让我们做前端的同学可以轻松上手使用,仿照axios的使用封装

dio是一个非常强大的网络请求库&#xff0c;可以支持发送各种网络请求&#xff0c;就像axios一样灵活强大&#xff0c;但是官网没有做一个demo示例&#xff0c;所以前端同学使用起来还是有点费劲&#xff0c;所以就想在这里封装一下&#xff0c;方便前端同学使用。 官网地址&a…

uniapp开发安卓应用微信开放平台创建应用如何获取签名

微信开放平台创建应用时需要应用的签名 比如我们开发了一个应用叫 “滴滴拉屎” 包名&#xff1a;uni.DIDILASHI #mermaid-svg-BUKbltDr30J93dUs {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-BUKbltDr30J93dUs .…

直播带货2024:洗牌、阵痛和暗流涌动

文 | 螳螂观察 作者 | 青月 一天前&#xff0c;大学生齐夏根本不会在直播间购买《额尔古纳河右岸》这种书籍。 她是喜欢看小说&#xff0c;但只钟爱悬疑无限流题材&#xff0c;至于《额尔古纳河右岸》这种讲述一个弱小民族顽强的抗争和优美的爱情的长篇小说&#xff0c;用齐…

vue上传文件加进度条,fake-progress一起使用

el-upload上传过程中加进度条&#xff0c;进度条el-progress配合fake-progress一起使用&#xff0c;效果如下&#xff1a; 安装 npm install fake-progress 在用到的文件里面引用 import Fakeprogress from "fake-progress"; 这个进度条主要是假的进度条&#xff…

轻量级图床Imagewheel本地部署并结合内网穿透实现远程访问

文章目录 1.前言2. Imagewheel网站搭建2.1. Imagewheel下载和安装2.2. Imagewheel网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar临时数据隧道3.2.Cpolar稳定隧道&#xff08;云端设置&#xff09;3.3.Cpolar稳定隧道&#xff08;本地设置&#xff09; 4.公网访问测…