搭建全连接网络进行分类(糖尿病为例)

拿来练手,大神请绕道。

1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。

2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。

3.全连接的回归也写了,有空再上传吧。

4.一般都是先写data或者model

import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)

首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)

然后就可以写train或者test了,其实test和train一样

from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)

最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:

1.跳跃很大,波动:增大batch_size,减小lr。

2.降低过拟合:

        a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。

        b.batchsize增大,lr减小是有效的。

        c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。

        d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/92179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT的截图识别功能测评:开启图像中的文字与信息的新纪元

文章目录 根据截图,识别菜品根据截图,识别数学公式根据截图生成前端UI代码可视化图像复现案例一案例二 更多可以使用的方向 制作人:川川 辛苦测评,如果对你有帮助支持一下书籍:https://item.jd.com/14049708.html 根据…

自动化测试-友好的第三方库

目录 mock furl coverage deepdiff pandas jsonpath 自动化测试脚本开发中,总是会遇到各种数据处理,例如MOCK、URL处理、JSON数据处理、结果断言等,也会遇到所采用的测试框架不能满足当前需求,这些问题都需要我们自己动手解…

Flink CDC MySQL同步MySQL错误记录

1、启动 Flink SQL [appuserwhtpjfscpt01 flink-1.17.1]$ bin/sql-client.sh2、新建源表 问题1:Encountered “(” 处理方法:去掉int(11),改为int Flink SQL> CREATE TABLE t_user ( > uid int(11) NOT NULL AUTO_INCREMENT COMME…

若依不分离+Thymeleaf select选中多个回显

项目中遇到的场景&#xff0c;亲测实用 表单添加时&#xff0c;select选中多个&#xff0c;编辑表单时&#xff0c;select多选回显&#xff0c;如图 代码&#xff1a; // 新增代码 <label class"col-sm-3 control-label">通道&#xff1a;</label><…

计算机图形学、贝塞尔曲线及绘制方法、反走样问题的解决(附完整代码)

贝塞尔曲线 1. 本次作业实现的函数及简单描述&#xff08;详细代码见后&#xff09;2. 与本次作业有关的基础知识整理3. 代码描述&#xff08;详细&#xff09;4. 完整代码5. 参考文献 &#xff08;本篇为作者学习计算机图形学时根据作业所撰写的笔记&#xff0c; 如有同课程请…

LabVIEW风力涡轮机的雷电流测量系统中集成高速摄像机

LabVIEW风力涡轮机的雷电流测量系统中集成高速摄像机 随着全球风电装机容量的快速增长&#xff0c;雷电活动对风力发电机组造成的损害受到更多关注&#xff0c;特别是在雷电活动强烈的地区。在冬季闪电期间&#xff0c;风力涡轮机等高层结构会受到向上的雷击。众所周知&#x…

Acwing 837. 连通块中点的数量

Acwing 837. 连通块中点的数量 题目描述思路讲解代码展示 题目描述 思路讲解 大家看y总这段代码时要注意&#xff0c;在C操作时&#xff0c;y总先把a&#xff0c;b的根结点取出来了&#xff1a;a find(a), b find(b);&#xff0c;因此接下来是先将集合a接到集合b下再把a的连通…

Android修行手册 - Activity 在 Java 和 Kotlin 中怎么写构造参数

点击跳转>Unity3D特效百例点击跳转>案例项目实战源码点击跳转>游戏脚本-辅助自动化点击跳转>Android控件全解手册点击跳转>Scratch编程案例点击跳转>软考全系列 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&…

构建捡垃圾机器人的 ROS 2 项目

一、说明 本系列是关于学习如何使用 ROS2、Docker 和 Github 设计、设置和维护机器人项目。 先决条件 — ROS2 软件包的基本知识、实现发布者、订阅者、操作并连接它们。 我们之前在 ROS2 中了解了不同的部分。但是&#xff0c;在我们转向实际的基于硬件的项目之前&#xff0c;…

阿里云ECS服务器无法发送邮件问题解决方案

这篇文章分享一下自己把项目部署在阿里云ECS上之后&#xff0c;登录邮件提醒时的邮件发送失败问题&#xff0c;无法连接发送邮箱的服务器。 博主使用的springboot提供的发送邮件服务&#xff0c;如下所示&#xff0c;为了实现异步的效果&#xff0c;新开了一个线程来发送邮件。…

基于 SpringBoot 2.7.x 使用最新的 Elasticsearch Java API Client 之 ElasticsearchClient

1. 从 RestHighLevelClient 到 ElasticsearchClient 从 Java Rest Client 7.15.0 版本开始&#xff0c;Elasticsearch 官方决定将 RestHighLevelClient 标记为废弃的&#xff0c;并推荐使用新的 Java API Client&#xff0c;即 ElasticsearchClient. 为什么要将 RestHighLevelC…

Windows的批处理——获取系统时间、生成当天日期日志

Windows批处理基础https://coffeemilk.blog.csdn.net/article/details/132118351 一、Windows批处理的日期时间 在我们进行软件开发的过程中&#xff0c;有时候会使用到一些批处理命令&#xff0c;其中就涉及到获取系统日期、时间来进行一些逻辑的判断处理&#xff1b;那么我们…

Tomcat启动后的日志输出为乱码

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

[Linux] 4.常用初级指令

pwd&#xff1a;显示当前文件路径 ls:列出当前文件夹下有哪些文件 mkdir空格文件名&#xff1a;创建一个新的文件夹 cd空格文件夹名&#xff1a;进入文件夹 cd..&#xff1a;退到上一层文件夹 ls -a&#xff1a;把所有文件夹列出来 .代表当前文件夹 ..代表上层文件夹 用…

探索ClickHouse——连接Kafka和Clickhouse

安装Kafka 新增用户 sudo adduser kafka sudo adduser kafka sudo su -l kafka安装JDK sudo apt-get install openjdk-8-jre下载解压kafka 可以从https://downloads.apache.org/kafka/下找到希望安装的版本。需要注意的是&#xff0c;不要下载路径包含src的包&#xff0c;否…

最新ChatGPT网站系统源码+支持GPT4.0+支持AI绘画Midjourney绘画+支持国内全AI模型

一、SparkAI创作系统 SparkAi系统是基于很火的GPT提问进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT系统&#xff1f;小编这里写一个详细图文教程吧&a…

CCF CSP认证 历年题目自练Day18

CCF CSP认证 历年题目自练Day18 题目一 试题编号&#xff1a; 201809-1 试题名称&#xff1a; 卖菜 时间限制&#xff1a; 1.0s 内存限制&#xff1a; 256.0MB 问题描述&#xff1a; 问题描述   在一条街上有n个卖菜的商店&#xff0c;按1至n的顺序排成一排&#xff0c;这…

Apollo自动驾驶系统概述(文末参与活动赠送百度周边)

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 「推荐专栏」&#xff1a; ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄&#xff0c;vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…

大喜国庆,聊聊我正式进入职场的这三个月...

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

基础数据结构之——【顺序表】(上)

从今天开始更新数据结构的相关内容。&#xff08;我更新博文的顺序一般是按照我当前的学习进度来安排&#xff0c;学到什么就更新什么&#xff08;简单来说就是我的学习笔记&#xff09;&#xff0c;所以不会对一个专栏一下子更新到底&#xff0c;哈哈哈哈哈哈哈&#xff01;&a…