机器学习复习(4)——CNN算法

目录

数据增强方法

CNN图像分类数据集构建

导入数据集

定义trainer

超参数设置

数据增强

构建CNN网络

开始训练

模型测试

数据增强方法

# 一般情况下,我们不会在验证集和测试集上做数据扩增
# 我们只需要将图片裁剪成同样的大小并装换成Tensor就行
test_tfm = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),
])# 当然,我们也可以再测试集中对数据进行扩增(对同样本的不同装换)
#  - 用训练数据的装化方法(train_tfm)去对测试集数据进行转化,产出扩增样本
#  - 对同个照片的不同样本分别进行预测
#  - 最后可以用soft vote / hard vote 等集成方法输出最后的预测
train_tfm = transforms.Compose([# 图片裁剪 (height = width = 128)transforms.Resize((128, 128)),transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),# ToTensor() 放在所有处理的最后transforms.ToTensor(),
])

CNN图像分类数据集构建

class FoodDataset(Dataset):# 构造函数def __init__(self, path, tfm=test_tfm, files=None):# 调用父类的构造函数super(FoodDataset).__init__()# 存储图像文件夹路径self.path = path# 从路径中获取所有以.jpg结尾的文件,并按字典顺序排序self.files = sorted([os.path.join(path, x) for x in os.listdir(path) if x.endswith(".jpg")])# 如果提供了文件列表,则使用该列表代替自动搜索得到的列表if files is not None:self.files = files# 打印路径中的一个样本文件路径print(f"One {path} sample", self.files[0])# 存储用于图像变换的函数self.transform = tfm# 返回数据集中的样本数def __len__(self):return len(self.files)# 根据索引获取单个样本def __getitem__(self, idx):# 获取文件名fname = self.files[idx]# 打开图像文件im = Image.open(fname)# 应用变换im = self.transform(im)# 尝试从文件名中提取标签,如果失败则设置为-1(表示测试集中没有标签)try:label = int(fname.split("/")[-1].split("_")[0])except:label = -1  # 测试集没有label# 返回图像和标签return im, label

导入数据集

注意这里的“私有方法”

_dataset_dir = config['dataset_dir']#“_”是为了避免和python中的dataset重名train_set = FoodDataset(os.path.join(_dataset_dir,"training"), tfm=train_tfm)
train_loader = DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True)valid_set = FoodDataset(os.path.join(_dataset_dir,"validation"), tfm=test_tfm)
valid_loader = DataLoader(valid_set, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True)# 测试级保证输出顺序一致
test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=test_tfm)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

定义trainer

def trainer(train_loader, valid_loader, model, config, device, rest_net_flag=False):# 定义交叉熵损失函数,用于评估分类任务的模型性能criterion = nn.CrossEntropyLoss()# 初始化优化器,这里使用Adam优化器optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])# 根据rest_net_flag标志选择模型保存路径save_path = config['save_path'] if rest_net_flag else config['resnet_save_path']# 初始化TensorBoard的SummaryWriter,用于记录训练过程writer = SummaryWriter()# 如果'models'目录不存在,则创建该目录if not os.path.isdir('./models'):os.mkdir('./models')# 初始化训练参数:训练轮数、最佳损失、步骤计数器和早停计数器n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0# 进行多个训练周期for epoch in range(n_epochs):# 设置模型为训练模式model.train()# 初始化损失记录器和准确率记录器loss_record = []train_accs = []# 使用tqdm显示训练进度条train_pbar = tqdm(train_loader, position=0, leave=True)# 遍历训练数据for x, y in train_pbar:# 重置优化器梯度optimizer.zero_grad()# 将数据和标签移动到指定设备(如GPU)x, y = x.to(device), y.to(device)# 进行一次前向传播pred = model(x)# 计算损失loss = criterion(pred, y)# 反向传播loss.backward()# 如果启用梯度裁剪,则应用梯度裁剪if config['clip_flag']:grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)# 进行一步优化(梯度下降)optimizer.step()# 记录当前步骤step += 1# 计算准确率并记录损失和准确率acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()l_ = loss.detach().item()loss_record.append(l_)train_accs.append(acc.detach().item())train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')train_pbar.set_postfix({'loss': f'{l_:.5f}', 'acc': f'{acc:.5f}'})# 计算并记录平均训练损失和准确率mean_train_acc = sum(train_accs) / len(train_accs)mean_train_loss = sum(loss_record) / len(loss_record)writer.add_scalar('Loss/train', mean_train_loss, step)writer.add_scalar('ACC/train', mean_train_acc, step)# 设置模型为评估模式model.eval()# 初始化验证集损失记录器和准确率记录器loss_record = []test_accs = []# 遍历验证数据for x, y in valid_loader:x, y = x.to(device), y.to(device)with torch.no_grad():pred = model(x)loss = criterion(pred, y)acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()loss_record.append(loss.item())test_accs.append(acc.detach().item())# 计算并打印平均验证损失和准确率mean_valid_acc = sum(test_accs) / len(test_accs)mean_valid_loss = sum(loss_record) / len(loss_record)print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, acc: {mean_train_acc:.4f} Valid loss: {mean_valid_loss:.4f}, acc: {mean

超参数设置

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 6666,'dataset_dir': "../input/data",'n_epochs': 10,      'batch_size': 64, 'learning_rate': 0.0003,           'weight_decay':1e-5,'early_stop': 300,'clip_flag': True, 'save_path': './models/model.ckpt','resnet_save_path': './models/resnet_model.ckpt'
}
print(device)
all_seed(config['seed'])

数据增强

test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra1 = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0, pin_memory=True)test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra2 = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0, pin_memory=True)test_set = FoodDataset(os.path.join(_dataset_dir,"test"), tfm=train_tfm)
test_loader_extra3 = DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

构建CNN网络

class Classifier(nn.Module):def __init__(self):super(Classifier, self).__init__()# input 維度 [3, 128, 128]self.cnn = nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2, 2, 0),      # [64, 64, 64]nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2, 2, 0),      # [128, 32, 32]nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2, 2, 0),      # [256, 16, 16]nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2, 0),       # [512, 8, 8]nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2, 0),       # [512, 4, 4])self.fc = nn.Sequential(nn.Linear(512*4*4, 1024),nn.ReLU(),nn.Linear(1024, 512),nn.ReLU(),nn.Linear(512, 11))def forward(self, x):out = self.cnn(x)out = out.view(out.size()[0], -1)return self.fc(out)

举一个具体的例子来解释: out = out.view(out.size()[0], -1)

假设我们有一个4维的张量 out,其维度是 [10, 3, 32, 32]。这个张量可以被理解为一个小批量(batch)的图像数据,其中:

  • 10 是批处理大小(batch size),表示有10个图像。
  • 3 是通道数(channels),例如在RGB图像中有3个颜色通道。
  • 3232 是图像的高度和宽度。

现在,我们想将这个4维张量转换为2维张量,以便它可以被用作全连接层(dense layer)的输入。这就是 out.view(out.size()[0], -1) 用途所在。

执行这个操作后,张量的形状将会是:

  • 第一个维度仍然是10,这保持了批处理大小不变。
  • 第二个维度是由-1指定的,这让PyTorch自动计算这个维度的大小。在我们的例子中,其余的维度(3, 32, 32)将被展平,所以第二个维度的大小是 3 * 32 * 32 = 3072。

因此,执行 out = out.view(out.size()[0], -1) 后,out 的形状将会从 [10, 3, 32, 32] 变为 [10, 3072]。这个新的二维张量可以被看作是一个包含10个样本的数据批次,每个样本都被展平为3072个特征的一维数组。这种形状的张量适合作为全连接层的输入。

1. Conv2d(卷积层)

卷积层的输出尺寸可以用以下公式计算:

其中:

  • 输入尺寸是输入特征图的高度或宽度。
  • 卷积核尺寸是卷积核的高度或宽度。
  • 填充(Padding)是在输入特征图周围添加的零的层数。
  • 步长(Stride)是卷积核移动的步幅。

2. MaxPool2d(最大池化层)

最大池化层的输出尺寸可以用类似的公式计算:

对于最大池化,通常不使用填充

 假设我们有一个大小为[32, 32](高度32,宽度32)的输入特征图,并且我们想应用以下两个层:

  1. Conv2d层,卷积核大小为[3, 3],步长为1,填充为1
  2. MaxPool2d层,池化核大小为[2, 2],步长为2

对于Conv2d层,输出尺寸计算如下:

对于MaxPool2d层,输出尺寸计算如下:

所以,经过这两层处理后,最终输出的特征图尺寸将会是[16, 16]

开始训练

model = Classifier().to(device)
trainer(train_loader, valid_loader, model, config, device)

或者可以通过调用pytorch官方的一些标准model进行训

from torchvision.models import resnet50
resNet = resnet50(pretrained=False)
# 残差网络
resNet = resNet.to(device)
trainer(train_loader, valid_loader, resNet, config, device)

模型测试

model_best = Classifier().to(device)
model_best.load_state_dict(torch.load(config['save_path']))
model_best.eval()
prediction = []
with torch.no_grad():for data,_ in test_loader:test_pred = model_best(data.to(device))test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)prediction += test_label.squeeze().tolist()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657489.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【国产MCU】-CH32V307-GPIO控制:输入与输出

GPIO控制:输入与输出 文章目录 GPIO控制:输入与输出1、GPIO简单介绍2、驱动API介绍3、GPIO配置代码实现3.1 GPIO配置为输出3.2 GPIO配置为输入CH32V307的GPIO口可以配置成多种输入或输出模式,内置可关闭的上拉或下拉电阻,可以配置成推挽或开漏功能。GPIO口还可以复用成其他…

一文掌握SpringBoot注解之@Component 知识文集(8)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用。生成对抗网络(GAN)是一种强大的生成模型,在手写数字生成方面具有广泛的应用前景。通过生成…

类和对象 第六部分 继承 第一部分:继承的语法

一.继承的概念 继承是面向对象的三大特性之一 有些类与类之间存在特殊的关系,例如下图: 我们可以发现,下级别的成员除了拥有上一级的共性,还有自己的特性,这个时候,我们可以讨论利用继承的技术,…

【前端素材】bootstrap3 实现地产置业公司source网页设计

一、需求分析 地产置业公司的网页通常是该公司的官方网站,旨在向访问者提供相关信息和服务。这些网页通常具有以下功能: 公司介绍:网页通常包含有关公司背景、历史、核心价值观和使命等方面的信息。此部分帮助访问者了解公司的身份和目标。 …

LabVIEW船舶自动识别系统

在现代航海领域,安全高效的船舶自动识别系统对于保障航行安全和提高船舶管理效率非常重要。介绍了利用LabVIEW软件开发的一个船舶自动识别系统,该系统通过先进的数据采集和信号处理技术,显著提升了传统自动识别系统的性能。 这个船舶自动识别…

代理IP在游戏中的作用有哪些?

游戏代理IP的作用是什么?IP代理软件相当于连接客户端和虚拟服务器的软件“中转站”,在我们向远程服务器提出需求后,代理服务器首先获得用户的请求,然后将服务请求转移到远程服务器,然后将远程服务器反馈的结果转移到客…

【lesson1】高并发内存池项目介绍

文章目录 这个项目做的是什么?这个项目的要求的知识储备和难度?什么是内存池池化技术内存池内存池主要解决的问题malloc 这个项目做的是什么? 当前项目是实现一个高并发的内存池,他的原型是google的一个开源项目tcmalloc&#xf…

江西省考报名照不能成功上传的原因

江西省考报名照需要根据以下要求生成: 1、近期6个月,免冠证件照 2、照片背景白背景 3、照片文件jpg格式,20-100kb 4、照片像素大小,295x413像素 5、照片必须使用审核工具审核后才能上传

【Java】内存溢出和内存泄露的区别

目录 概念 内存溢出分类 内存泄漏分类 发生场景以及解决方法 内存溢出 内存泄漏 解决方法 这道题是面试常考的,一定要区分好区别,我之前就是直接认为内存溢出就是内存泄漏了 概念 内存溢出:是指程序在申请内存时,没有足够…

React18-模拟列表数据实现基础表格功能

文章目录 分页功能分页组件有两种接口参数分页类型用户列表参数类型 模拟列表数据分页触发方式实现目录 分页功能 分页组件有两种 table组件自带分页 <TableborderedrowKey"userId"rowSelection{{ type: checkbox }}pagination{{position: [bottomRight],pageSi…

海外云手机对于亚马逊卖家的作用

近年来&#xff0c;海外云手机作为一种新型模式迅速崭露头角&#xff0c;成为专业的出海SaaS平台软件。海外云手机在云端运行和存储数据&#xff0c;通过网页端操作&#xff0c;将手机芯片放置在机房&#xff0c;通过网络连接到服务器&#xff0c;为用户提供便捷的上网功能。因…

Spring Boot 中文件上传

Spring Boot 中文件上传 一、MultipartFile二、单文件上传案例三、多文件上传案例四、Servlet 规范五、Servlet 规范实现文件上传 上传文件大家用的最多的就是 Apache Commons FileUpload&#xff0c;这个库使用非常广泛。Spring Boot3 版本中已经不能使用了。代替它的是 Sprin…

苍穹外卖项目可以写的简历和如何优化简历

文章目录 重点写中规写添加自己个性的项目面试会问道的问题 我是一名双非大二计算机本科生&#xff0c;希望我的分享对你有帮助&#xff0c;点赞关注不迷路。 简历编写一直是很多人求职人的心病&#xff0c;我自己上学期有一门课程是去校内企业面试&#xff0c;当时我就感受出…

MySQL库表操作 作业

题目&#xff1a; 1. sql语句分为几类?2. 表的约束有哪些,分别是什么,设置的语法分别是什么?3. 做出班级表,学生表的E-R图,数据库模型图,以及核心的sql语句. 1. MySQL致力于支持全套ANSI/ISO SQL标准。在MySQL数据库中&#xff0c;SQL语句主要可以划分为以下几类: > DD…

【三维重建】三角化

三角化要解决的问题是&#xff1a; 已知两个相机的内参K、K、相机之间的旋转平移矩阵R、t以及匹配点p、p&#xff0c;如何求得P点的三维坐标&#xff1f; 线性解法

彻底解决 MAC Android Studio gradle async 时出现 “connect timed out“ 问题

最近在编译一个比较老的项目&#xff0c;git clone 之后使用 async 之后出现一下现象&#xff1a; 首先确定是我网络本身是没有问题的&#xff0c;尝试几次重新 async 之后还是出现问题&#xff0c;网上找了一些方法解决了本问题&#xff0c;以此来记录一下问题是如何解决的。 …

【Tomcat与网络5】再论Tomcat的工作过程与两种经典的设计模式

前面两篇&#xff0c;我们重点分析了Tomcat的容器和连接器的基本设计&#xff0c;今天我们来看一下两个机构如何在service的调度下进行协同工作的。 目录 1.模板模式与Tomcat的重用性设计 2.观察者模式与Tomcat可扩展性设计 1.模板模式与Tomcat的重用性设计 首先&#xff0…

油分离器的介绍

压缩机的排气中带有冷冻机油&#xff0c;这些冷冻机油如果随制冷剂蒸汽进入冷凝器、蒸发器后将 在传热表面形成油膜&#xff0c;从而影响换热效果。因此通常在压缩机与冷凝器之间装设油分离器&#xff0c;用 来分离制冷剂蒸汽中挟带的冷冻机油。在氟利昂制冷系统中&#xff0c;…

读AI3.0笔记10_读后总结与感想兼导读

1. 基本信息 AI 3.0 (美)梅拉妮米歇尔 著 四川科学技术出版社,2021年2月出版 1.1. 读薄率 书籍总字数355千字&#xff0c;笔记总字数33830字。 读薄率33830355000≈9.53% 1.2. 读厚方向 千脑智能 脑机穿越 未来呼啸而来 虚拟人 新机器人 如何创造可信的AI 新机器智…