搭建全连接网络进行分类(糖尿病为例)

拿来练手,大神请绕道。

1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。

2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。

3.全连接的回归也写了,有空再上传吧。

4.一般都是先写data或者model

import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)

首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)

然后就可以写train或者test了,其实test和train一样

from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)

最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:

1.跳跃很大,波动:增大batch_size,减小lr。

2.降低过拟合:

        a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。

        b.batchsize增大,lr减小是有效的。

        c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。

        d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/92179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT的截图识别功能测评:开启图像中的文字与信息的新纪元

文章目录 根据截图,识别菜品根据截图,识别数学公式根据截图生成前端UI代码可视化图像复现案例一案例二 更多可以使用的方向 制作人:川川 辛苦测评,如果对你有帮助支持一下书籍:https://item.jd.com/14049708.html 根据…

[C#]C#最简单方法获取GPU显存真实大小

你是否用下面代码获取GPU显存容量? using System.Management; private void getGpuMem() {ManagementClass c new ManagementClass("Win32_VideoController");foreach (ManagementObject o in c.GetInstances()){string gpuTotalMem String.For…

自动化测试-友好的第三方库

目录 mock furl coverage deepdiff pandas jsonpath 自动化测试脚本开发中,总是会遇到各种数据处理,例如MOCK、URL处理、JSON数据处理、结果断言等,也会遇到所采用的测试框架不能满足当前需求,这些问题都需要我们自己动手解…

Flink CDC MySQL同步MySQL错误记录

1、启动 Flink SQL [appuserwhtpjfscpt01 flink-1.17.1]$ bin/sql-client.sh2、新建源表 问题1:Encountered “(” 处理方法:去掉int(11),改为int Flink SQL> CREATE TABLE t_user ( > uid int(11) NOT NULL AUTO_INCREMENT COMME…

Pytorch中关于forward函数的理解与用法

目录 前言1. 问题所示2. 原理分析2.1 forward函数理解2.2 forward函数用法 前言 深入深度学习框架的代码,发现forward函数没有被显示调用 但代码确重写了forward函数,于是好奇是不是python的魔术方法作用 1. 问题所示 代码如下所示: cla…

若依不分离+Thymeleaf select选中多个回显

项目中遇到的场景&#xff0c;亲测实用 表单添加时&#xff0c;select选中多个&#xff0c;编辑表单时&#xff0c;select多选回显&#xff0c;如图 代码&#xff1a; // 新增代码 <label class"col-sm-3 control-label">通道&#xff1a;</label><…

计算机图形学、贝塞尔曲线及绘制方法、反走样问题的解决(附完整代码)

贝塞尔曲线 1. 本次作业实现的函数及简单描述&#xff08;详细代码见后&#xff09;2. 与本次作业有关的基础知识整理3. 代码描述&#xff08;详细&#xff09;4. 完整代码5. 参考文献 &#xff08;本篇为作者学习计算机图形学时根据作业所撰写的笔记&#xff0c; 如有同课程请…

LabVIEW风力涡轮机的雷电流测量系统中集成高速摄像机

LabVIEW风力涡轮机的雷电流测量系统中集成高速摄像机 随着全球风电装机容量的快速增长&#xff0c;雷电活动对风力发电机组造成的损害受到更多关注&#xff0c;特别是在雷电活动强烈的地区。在冬季闪电期间&#xff0c;风力涡轮机等高层结构会受到向上的雷击。众所周知&#x…

HelloWorld显示Go语言交叉编译的强大20230926

环境介绍 开发环境:windows 10 IDE:goland 实现的目标: 在windows10下编译go,分别在linux centos6和linux centos8上进行运行 具体流程 1.在windows10上建立项目 a. 打开GoLand&#xff0c;选择New Project。 b. 为项目取一个名称&#xff0c;例如HelloWorld&#xff0c…

Acwing 837. 连通块中点的数量

Acwing 837. 连通块中点的数量 题目描述思路讲解代码展示 题目描述 思路讲解 大家看y总这段代码时要注意&#xff0c;在C操作时&#xff0c;y总先把a&#xff0c;b的根结点取出来了&#xff1a;a find(a), b find(b);&#xff0c;因此接下来是先将集合a接到集合b下再把a的连通…

Android修行手册 - Activity 在 Java 和 Kotlin 中怎么写构造参数

点击跳转>Unity3D特效百例点击跳转>案例项目实战源码点击跳转>游戏脚本-辅助自动化点击跳转>Android控件全解手册点击跳转>Scratch编程案例点击跳转>软考全系列 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&…

构建捡垃圾机器人的 ROS 2 项目

一、说明 本系列是关于学习如何使用 ROS2、Docker 和 Github 设计、设置和维护机器人项目。 先决条件 — ROS2 软件包的基本知识、实现发布者、订阅者、操作并连接它们。 我们之前在 ROS2 中了解了不同的部分。但是&#xff0c;在我们转向实际的基于硬件的项目之前&#xff0c;…

阿里云ECS服务器无法发送邮件问题解决方案

这篇文章分享一下自己把项目部署在阿里云ECS上之后&#xff0c;登录邮件提醒时的邮件发送失败问题&#xff0c;无法连接发送邮箱的服务器。 博主使用的springboot提供的发送邮件服务&#xff0c;如下所示&#xff0c;为了实现异步的效果&#xff0c;新开了一个线程来发送邮件。…

基于 SpringBoot 2.7.x 使用最新的 Elasticsearch Java API Client 之 ElasticsearchClient

1. 从 RestHighLevelClient 到 ElasticsearchClient 从 Java Rest Client 7.15.0 版本开始&#xff0c;Elasticsearch 官方决定将 RestHighLevelClient 标记为废弃的&#xff0c;并推荐使用新的 Java API Client&#xff0c;即 ElasticsearchClient. 为什么要将 RestHighLevelC…

SuffixArray练习题

SuffixArray练习题 🍉题目 import java.util.Arrays;class SuffixArray {//LCP:Longest common prefix/*字符串后缀,指从字符串某个* 位置开始到字符串末尾的字串,原串和空串也是后缀* Create the LCP array from the suffix array* 从后缀数组创建LCP数组* @param s the…

Windows的批处理——获取系统时间、生成当天日期日志

Windows批处理基础https://coffeemilk.blog.csdn.net/article/details/132118351 一、Windows批处理的日期时间 在我们进行软件开发的过程中&#xff0c;有时候会使用到一些批处理命令&#xff0c;其中就涉及到获取系统日期、时间来进行一些逻辑的判断处理&#xff1b;那么我们…

Tomcat启动后的日志输出为乱码

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

[Linux] 4.常用初级指令

pwd&#xff1a;显示当前文件路径 ls:列出当前文件夹下有哪些文件 mkdir空格文件名&#xff1a;创建一个新的文件夹 cd空格文件夹名&#xff1a;进入文件夹 cd..&#xff1a;退到上一层文件夹 ls -a&#xff1a;把所有文件夹列出来 .代表当前文件夹 ..代表上层文件夹 用…

朴素贝叶斯分类(下):数据挖掘十大算法之一

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据开发、数据分析等。 🐴欢迎小伙伴们点赞👍🏻、收藏⭐️、…

探索ClickHouse——连接Kafka和Clickhouse

安装Kafka 新增用户 sudo adduser kafka sudo adduser kafka sudo su -l kafka安装JDK sudo apt-get install openjdk-8-jre下载解压kafka 可以从https://downloads.apache.org/kafka/下找到希望安装的版本。需要注意的是&#xff0c;不要下载路径包含src的包&#xff0c;否…