K折交叉验证代码实现——详细注释版

在这里插入图片描述

正常方法

#---------------------------------Torch Modules --------------------------------------------------------
from __future__ import print_function
import numpy as np
import pandas as pd
import torch.nn as nn
import math
import torch.nn.functional as F
import torch
import torchvision
from torch.nn import init
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision import models
import torch.nn.functional as F
from torch.utils import data
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
###-----------------------------------variables-----------------------------------------------
# for Normalization
mean = [0.5]
std = [0.5]
# batch size
batch_size =128
epoch = 1        # epoch
lr = 0.01
##-----------------------------------Commands to download and perpare the MNIST dataset ------------------------------------
train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=True, download=True,transform=train_transform),batch_size=batch_size, shuffle=True) # train datasettest_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=False, transform=test_transform),batch_size=batch_size, shuffle=False) # test dataset loader的形状为128,1,28,28
#visualization
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """Plot a list of images."""figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# Tensor Imageax.imshow(img.numpy())else:# PIL Imageax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
mnist_train = torchvision.datasets.MNIST(root="../data", train=True,#load minist
transform=train_transform,
download=True)
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9)

在这里插入图片描述

model = nn.Sequential(nn.Flatten(), nn.Linear(784, 100), nn.ReLU(),nn.Linear(100, 10))#28*28,展平到784,最后被分类为10种,中间层的神经元数量自定义
def init_weights(m):#初始化权重if type(m) == nn.Linear:#如果只有一个线性层,设置std为0.01nn.init.normal_(m.weight, std=0.01)
model.apply(init_weights);## Loss function
criterion = torch.nn.CrossEntropyLoss() # pytorch's cross entropy loss function,多分类一般使用交叉熵# definin which paramters to train only the CNN model parameters
optimizer = torch.optim.SGD(model.parameters(),lr)#优化器,设置随机梯度下降
# defining the training function
# Train baseline classifier on clean data
def train(model, optimizer,criterion,epoch): model.train() # setting up for trainingfor batch_idx, (data, target) in enumerate(train_loader): # data contains the image and target contains the label = 0/1/2/3/4/5/6/7/8/9data = data.view(-1, 28*28).requires_grad_()#此处图像被展平了,bs是128,输入是28*28optimizer.zero_grad() # setting gradient to zerooutput = model(data) # forwardloss = criterion(output, target) # loss computationloss.backward() # back propagation here pytorch will take care of itoptimizer.step() # updating the weight valuesif batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
# to evaluate the model
## validation of test accuracy
def test(model, criterion, val_loader, epoch,train= False):    model.eval()test_loss = 0correct = 0  with torch.no_grad():for batch_idx, (data, target) in enumerate(val_loader):data = data.view(-1, 28*28).requires_grad_()output = model(data)test_loss += criterion(output, target).item() # sum up batch losspred = output.argmax(1, keepdim=True) # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item() # if pred == target then correct +=1test_loss /= len(val_loader.dataset) # average test lossif train == False:print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(test_loss, correct, val_loader.sampler.__len__(),100. * correct / val_loader.sampler.__len__() ))if train == True:print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(test_loss, correct, val_loader.sampler.__len__(),100. * correct / val_loader.sampler.__len__() ))return 100. * correct / val_loader.sampler.__len__() 
test_acc = torch.zeros([epoch])
train_acc = torch.zeros([epoch])
## training the logistic model
for i in range(epoch):train(model, optimizer,criterion,i)train_acc[i] = test(model, criterion, train_loader, i,train=True) #Testing the the current CNNtest_acc[i] = test(model, criterion, test_loader, i)if not os.path.exists('./saved_model'):os.makedirs('./saved_model')torch.save(model.state_dict(),'./saved_model/model_normal.bin')

Train Epoch: 0 [0/60000 (0%)] Loss: 0.314632
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.283743
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.229258
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.219923
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.187836
Train set: Average loss: 0.0020, Accuracy: 55673/60000 (92.7883%)
Test set: Average loss: 0.0020, Accuracy: 9284/10000 (92.8400%)

K折交叉验证

#!pip install sklearn -i https://pypi.mirrors.ustc.edu.cn/simple
from sklearn.model_selection import KFold
train_init = datasets.MNIST('./mnist', train=True,transform=train_transform)test_init =  datasets.MNIST('./mnist', train=False, transform=test_transform)# the dataset for k fold cross validation   
dataFold = torch.utils.data.ConcatDataset([train_init, test_init])#将验证集和测试集合并def train_flod_Mnist(k_split_value):different_k_mse = []kf = KFold(n_splits=k_split_value,shuffle=True, random_state=2024)  # init KFold 10折交叉for train_index , test_index in kf.split(dataFold):  # split  # get train, val train_fold = torch.utils.data.dataset.Subset(dataFold, train_index)test_fold = torch.utils.data.dataset.Subset(dataFold, test_index) # package type of DataLoadertrain_loader = torch.utils.data.DataLoader(dataset=train_fold, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_fold, batch_size=batch_size, shuffle=True)# train modeltest_acc = torch.zeros([epoch])#记录acctrain_acc = torch.zeros([epoch])## training the logistic modelfor i in range(epoch):train(model, optimizer,criterion,i)train_acc[i] = test(model, criterion, train_loader, i,train=True) #Testing the the current CNNtest_acc[i] = test(model, criterion, test_loader, i)#torch.save(model,'perceptron.pt')# one epoch, all accdifferent_k_mse.append(np.array(test_acc))return different_k_mse
testAcc_compare_map = {}
for k_split_value in range(10, 10+1): print('now k_split_value is:', k_split_value)testAcc_compare_map[k_split_value] = train_flod_Mnist(k_split_value)

now k_split_value is: 10
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.187658
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.191175
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.166395
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.216365
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.156119
Train set: Average loss: 0.0015, Accuracy: 59648/63000 (94.6794%)
Test set: Average loss: 0.0018, Accuracy: 6550/7000 (93.5714%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.220123
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.135370
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.132626
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.258545
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.181543
Train set: Average loss: 0.0015, Accuracy: 59698/63000 (94.7587%)
Test set: Average loss: 0.0015, Accuracy: 6622/7000 (94.6000%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.263651
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.204342
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.141438
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.137843
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.217006
Train set: Average loss: 0.0014, Accuracy: 59751/63000 (94.8429%)
Test set: Average loss: 0.0014, Accuracy: 6639/7000 (94.8429%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.166842
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.096612
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.230569
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.109163
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.101741
Train set: Average loss: 0.0014, Accuracy: 59757/63000 (94.8524%)
Test set: Average loss: 0.0013, Accuracy: 6669/7000 (95.2714%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.113768
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.202454
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.119112
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.116779
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.376868
Train set: Average loss: 0.0014, Accuracy: 59879/63000 (95.0460%)
Test set: Average loss: 0.0013, Accuracy: 6682/7000 (95.4571%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.100557
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.189366
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.174508
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.104910
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.146227
Train set: Average loss: 0.0013, Accuracy: 59914/63000 (95.1016%)
Test set: Average loss: 0.0013, Accuracy: 6652/7000 (95.0286%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.103640
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.179051
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.138919
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.214437
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.127463
Train set: Average loss: 0.0013, Accuracy: 59986/63000 (95.2159%)
Test set: Average loss: 0.0013, Accuracy: 6674/7000 (95.3429%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.154551
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.157627
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.163700
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.148417
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.130215
Train set: Average loss: 0.0013, Accuracy: 60056/63000 (95.3270%)
Test set: Average loss: 0.0013, Accuracy: 6685/7000 (95.5000%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.146108
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.205999
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.115849
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.222786
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.178309
Train set: Average loss: 0.0012, Accuracy: 60162/63000 (95.4952%)
Test set: Average loss: 0.0013, Accuracy: 6683/7000 (95.4714%)
tensor([0.])
Train Epoch: 0 [0/60000 (0%)] Loss: 0.240678
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.234599
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.183265
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.148125
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.168119
Train set: Average loss: 0.0012, Accuracy: 60174/63000 (95.5143%)
Test set: Average loss: 0.0012, Accuracy: 6716/7000 (95.9429%)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882914.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于潜空间搜索的策略自适应组合优化(NeurIPS2023)(未完)

文章目录 Abstract1 Introduction2 Related work3 Methods3.1 预备知识3.2 COMPASS4 Experiments4.1 TSP、CVRP和JSSP的标准基准测试4.2 对泛化的鲁棒性:解决变异实例4.3 搜索策略分析5 ConclusionAbstract 组合优化是许多现实应用的基础,但设计高效算法以解决这些复杂的、通…

MongoDB Shell 基本命令(三)生成学生脚本信息和简单查询

一、生成学生信息脚本 利用该脚本可以生成任意个学生信息,包括学号、姓名、班级、年级、专业、课程名称、课程成绩等信息,此处生成2万名学生,学生所有信息都是给定范围后随机生成。 生成学生信息后,再来对学生信息进行简单查询。…

关于武汉芯景科技有限公司的限流开关芯片XJ6241开发指南(兼容LTC4411)

一、芯片引脚介绍 1.芯片引脚 二、系统结构图 三、功能描述 1.CTL引脚控制VIN和VOUT的通断 2.CTL引脚控制STAT引脚的状态 3.输出电压高于输入电压加上–VRTO的值,芯片处于关断状态

Artistic Oil Paint 艺术油画着色器插件

只需轻轻一点,即可将您的视频游戏转化为艺术品!(也许更多…)。 ✓ 整个商店中最可配置的选项。 ✓ 六种先进算法。 ✓ 细节增强算法。 ✓ 完整的源代码(脚本和着色器)。 ✓ 包含在“艺术包”中。 &#x1f…

【数组知识的扩展①】

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” ArrayList在Java数组中的使用技巧 这篇博客灵感来源于某一天Aileen(🤫)遇到了一道数组合并的题&…

python 文件防感染扫描

一、安装 首先,你需要安装 secplugs-python-client 库。你可以通过 pip 命令来安装: pip install secplugs-python-client确保你的 Python 环境已经正确设置,并且网络连接畅通,以便能够顺利安装。 二、基本用法 1. 初始化客户…

【记录】Windows|Windows 修改字体大全(Windows 桌面、VSCode、浏览器)

【记录】Windows|Windows 修改字体大全(Windows 桌面、VSCode、浏览器) 前言 最近从学长那里发现了一款非常美观的衡水体字体——Maple Mono SC NF。您可以通过以下链接下载该字体:https://github.com/subframe7536/maple-font/…

TiDB替换Starrocks:业务综合宽表迁移的性能评估与降本增效决策

作者: 我是人间不清醒 原文来源: https://tidb.net/blog/6638f594 1、 场景 业务综合宽表是报表生成、大屏幕展示和数据计算处理的核心数据结构。目前,这些宽表存储在Starrocks系统中,但该系统存在显著的性能瓶颈。例如&#…

Vue组件开发的属性

组件开发的属性: 1.ref属性: 如果在vue里,想要获取DOM对象,并且不想使用JS的原生语法,那么就可以使用ref属性 ref属性的用法: 1)在HTML元素的开始标记中,或者在Vue子组件中的开始…

JVM、字节码文件介绍

目录 初识JVM 什么是JVM JVM的三大核心功能 JVM的组成 字节码文件的组成 基础信息 Magic魔数 主副版本号 其它基础信息 常量池 字段 方法 属性 字节码常用工具 javap jclasslib插件 阿里Arthas 初识JVM 什么是JVM JVM的三大核心功能 1. 解释和运行虚拟机指…

我的世界之合成

合成(Crafting)是一种在Minecraft中获得多种方块、工具和其他资源的方法。合成时,玩家必须先把物品从物品栏移入合成方格中。22的简易合成方格可以直接在物品栏中找到,而33的合成方格需要使用工作台或合成器来打开。 目录 1合成系…

LabVIEW智能螺杆空压机测试系统

基于LabVIEW软件开发的螺杆空压机测试系统利用虚拟仪器技术进行空压机的性能测试和监控。系统能够实现对螺杆空压机关键性能参数如压力、温度、流量、转速及功率的实时采集与分析,有效提高测试效率与准确性,同时减少人工操作,提升安全性。 项…

Ubuntu22.04 制作系统ISO镜像

第一步:安装软件-Systemback 1.如果已经添加过ppa,可以删除重新添加或者跳过此步 sudo add-apt-repository --remove ppa:nemh/systemback 2.添加ppa 我是ubuntu20,但这个软件最后支持的是 ubuntu16.04版本,所以加一个16版本…

C++ | Leetcode C++题解之第480题滑动窗口中位数

题目&#xff1a; 题解&#xff1a; class DualHeap { private:// 大根堆&#xff0c;维护较小的一半元素priority_queue<int> small;// 小根堆&#xff0c;维护较大的一半元素priority_queue<int, vector<int>, greater<int>> large;// 哈希表&#…

自动化测试实施过程中需要考虑的因素!

自动化测试是软件开发过程中不可或缺的一部分&#xff0c;它能够提高测试效率、减少人力成本&#xff0c;并确保软件质量的一致性。然而&#xff0c;自动化测试的实施并非没有挑战。为了确保自动化测试的有效性和可持续性&#xff0c;开发者需要综合考虑多种因素&#xff0c;包…

【CTF-SHOW】Web入门 Web14 【editor泄露-详】【var/www/html目录-详】

editor泄露问题通常出现在涉及文件编辑器或脚本编辑器的题目中&#xff0c;尤其是在Web安全或Pwn&#xff08;系统漏洞挖掘&#xff09;类别中。editor泄露的本质是由于系统未能妥善处理临时文件、编辑历史或进程信息&#xff0c;导致攻击者可以通过某种途径获取正在编辑的敏感…

EasyOCR——超强超便捷的OCR开源算法介绍与文本检测模型CRAFT微调方法

背景 最近在实际操作阿拉伯文小语种OCR功能的时候&#xff0c;尝试了诸多开源算法&#xff0c;但效果均不尽如人意。 说实在的&#xff0c;针对阿拉伯文的OCR开源算法&#xff0c;若仅仅是效果没那么优秀&#xff0c;比如识别率能有个70%80%&#xff0c;我还能微调微调&#…

【React系列三】—React学习历程的分享

一、组件实例核心—Refs 通过定义 ref 属性可以给标签添加标识 字符串形式的Refs 这种形式已经不再推荐使用&#xff0c;官方不建议使用 https://zh-hans.legacy.reactjs.org/docs/refs-and-the-dom.html#legacy-api-string-refs 回调形式的Refs <script type"te…

PostgreSQL中触发器递归的处理 | 翻译

许多初学者在某个时候都会陷入触发器递归的陷阱。通常&#xff0c;解决方案是完全避免递归。但对于某些用例&#xff0c;您可能必须处理触发器递归。本文将告诉您有关该主题需要了解的内容。如果您曾经被错误消息“超出堆栈深度限制”所困扰&#xff0c;那么这里就是解决方案。…

Pytest参数详解 — 基于命令行模式!

1、--collect-only 查看在给定的配置下哪些测试用例会被执行 2、-k 使用表达式来指定希望运行的测试用例。如果测试名是唯一的或者多个测试名的前缀或者后缀相同&#xff0c;可以使用表达式来快速定位&#xff0c;例如&#xff1a; 命令行-k参数.png 3、-m 标记&#xff08;…