pytorch实战-6手写数字加法机-迁移学习

1 概述

迁移学习概念:将已经训练好的识别某些信息的网络拿去经过训练识别另外不同类别的信息

优越性:提高了训练模型利用率,解决了数据缺失的问题(对于新的预测场景,不需要大量的数据,只需要少量数据即可实现训练,可用于数据点很少的场景)

如何实现:将训练好的一个网络拿来和另一个网络连起来去训练即可实现迁移

训练方式:按是否改变源网络参数可分两类,分别是可改变和不可改变

2 案例 南非贫困预测

2.1 背景

南非存在贫困,1990-2021贫困人口从56%下降到43%,但下降的贫困人口数量和国际人道主义援助资源并不对应,而且大量资金援助一定程度加剧了贫富差距。可以看下具体哪些地区需要援助

2.2 方法

一个方法:夜光光亮遥感数据和人类gdp相关性经实验可达0.8-0.9,但夜光遥感和贫富没太大相关性:夜间光照月亮表示该地区越富有,但越安并不表示该地区越贫穷,也可能无人居住。

另一个方法:光亮遥感数据无法准确预测地区贫穷程度,但卫星遥感数据大体可以做到,判定依据有街道混乱程度等。如果要用深度网络训练,还需要对卫星遥感数据的图片标注贫困程度。非洲能获取到的贫困数据很少,但深度网络需要的数据量很大

最终方法:用迁移学习,将前两种方法合起来,见下图

3 案例2

3.1 背景

任务:区分图像里动物是蚂蚁还是蜜蜂,像素均为224x224

难点:只有244个图像,样本太少不足训练大型卷积网络,准确率只有50%左右

3.2 解决方案

解决方案:resnet与模型迁移,即用已训练好的物体分类的网络加全连接用来区分蚂蚁与蜜蜂

resnet:残差网络,对物体分类有较高精度

3.3 代码实现

3.3.1 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import osdata_path = 'pytorch/jizhi/figure_plus/data'
image_size = 224class TranNet():def __init__(self):super(TranNet, self).__init__()self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([transforms.RandomSizedCrop(image_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([transforms.Scale(256),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)self.num_classes = len(self.train_dataset.classes)def exec(self):...def main():TranNet().exec()if __name__ == '__main__':main()

3.3.2 模型迁移

    def exec(self):self.model_prepare()def model_prepare(self):net = models.resnet18(pretrained=True)# float net valuesnum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)# fixed net values'''for param in net.parameters():param.requires_grad = Falsenum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)'''

3.3.3 gpu加速

特点:gpu速度快,但内存低,所以尽量减少在gpu中存储的数据,只用来计算就好

    def model_prepare(self):# jusge whether GPUuse_cuda = torch.cuda.is_available()dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensornet = models.resnet18(pretrained=True)net = net.cuda() if use_cuda else net

3.3.4 训练

    def model_prepare(self):net = models.resnet18(pretrained=True)# jusge whether GPUuse_cuda = torch.cuda.is_available()if use_cuda:dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensornet = net.cuda() if use_cuda else net# float net valuesnum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)# fixed net values'''for param in net.parameters():param.requires_grad = Falsenum_features = net.fc.in_featuresnet.fc = nn.Linear(num_features, 2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)'''record = []num_epochs = 3net.train(True) # open dropoutfor epoch in range(num_epochs):train_rights = []train_losses = []for batch_index, (data, target) in enumerate(self.train_loader):data, target = data.clone().detach().requires_grad_(True), target.clone().detach()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = rightness(output, target)train_rights.append(right)train_losses.append(loss.data.numpy())if batch_index % 400 == 0:verify_rights = []for index, (data_v, target_v) in enumerate(self.verify_loader):data_v, target_v = data_v.clone().detach(), target_v.clone().detach()output_v = net(data_v)right = rightness(output_v, target_v)verify_rights.append(right)verify_accu = sum([row[0] for row in verify_rights]) / sum([row[1] for row in verify_rights])record.append((verify_accu))print(f'verify data accu:{verify_accu}')# plotpyplot.figure(figsize=(8, 6))pyplot.plot(record)pyplot.xlabel('step')pyplot.ylabel('verify loss')pyplot.show()

4 手写数字加法机

4.1 网络结构

可以先用cnn识别出两个待求和数字,不要输出,只保留池化层加后面一层全连接层,可以获取图像一维特征,然后将两个图像识别获取的一维特征合并,然后用全连接作为剩下的网络

4.2 代码实现

4.2.1 数据加载

class FigurePlus():def __init__(self):super(FigurePlus, self).__init__()self.image_size = 28self.num_classes = 10self.num_epochs = 3self.batch_size = 64self.train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)self.test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())sampler_a = torch.utils.data.sampler.SubsetRandomSampler(np.random.permutation(range(len(self.train_dataset))))sampler_b = torch.utils.data.sampler.SubsetRandomSampler(np.random.permutation(range(len(self.train_dataset))))self.train_loader_a = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, sampler=sampler_a)self.train_loader_b = torch.utils.data.DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, sampler=sampler_b)self.verify_size = 5000verify_index_a = range(self.verify_size)verify_index_b = np.random.permutation(range(self.verify_size))test_index_a = range(self.verify_size, len(self.test_dataset))test_index_b = np.random.permutation(test_index_a)verify_sampler_a = torch.utils.data.sampler.SubsetRandomSampler(verify_index_a)verify_sampler_b = torch.utils.data.sampler.SubsetRandomSampler(verify_index_b)test_sampler_a = torch.utils.data.sampler.SubsetRandomSampler(test_index_a)test_sampler_b = torch.utils.data.sampler.SubsetRandomSampler(test_index_b)self.verify_loader_a = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=verify_sampler_a)self.verify_loader_b = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=verify_sampler_b)self.test_loader_a = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=test_sampler_a)self.test_loader_b = torch.utils.data.DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, sampler=test_sampler_b)def gpu_ok(self):use_cuda = torch.cuda.is_available()dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensoritype = torch.cuda.LongTensor if use_cuda else torch.LongTensordef exec(self):passdef main():# TranNet().exec()FigurePlus().exec()if __name__ == '__main__':main()

4.2.2 手写数字加法机实现(网络实现)

    def forward(self, x, y, training=True):x, y = F.relu(self.net1_conv1(x)), F.relu(self.net2_conv1(y))x, y = self.net_pool(x), self.net_pool(y)x, y = F.relu(self.net1_conv2(x)), F.relu(self.net2_conv2(y))x, y = self.net_pool(x), self.net_pool(y)x = x.view(-1, (self.image_size // 4) ** 2 * self.depth[1])y = y.view(-1, (self.image_size // 4) ** 2 * self.depth[1])z = torch.cat((x, y), 1)z = self.fc1(z)z = F.relu(z)z = F.dropout(z, training=self.training)z = F.relu(self.fc2(z))z = F.relu(self.fc3(z))return F.relu(self.fc4(z))

4.2.3 模型迁移

思路:将上一篇弄好的数字识别模型保存到文件,然后在本章节加载进来,将各参数权重赋值到本节创的新网络

torch.save(cnn, model_save_path)

将网络加载进来时,需要源模型的定义,在本章重新定义下,拷贝后稍作修改

class FigureIdentify(nn.Module):def __init__(self):super(FigureIdentify, self).__init__()self.depth = (4, 8)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.pool(x)x = self.conv2(x)x = F.relu(x)x = self.pool(x)x = x.view(-1, (image_size // 4)**2 * self.depth[1])x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)x = F.log_softmax(x, dim=1)return x

注意

1 从文件加载模型后只加载网络权重,没加载网络的方法,需重新定义

2 加载后模型会赋值给一个对象,这个对象需要和保存网络时的网络架构保持一致(尽量保持一致,不然报错!!

模型文件加载进来后,用预训练模式(从模型文件加载网络权重作为初始权重,参数会随新网络的训练跟随大网络参数调节)

注意:加法器两个数字识别网络不可直接将文件加载的网络赋值,因为会共享地址,实际是一组参数,一个网络训练后另一个网络的参数也会变,可以复制出来再操作

    def copy_origin_weight(self, net):self.net1_conv1.weight.data = copy.deepcopy(net.conv1.weight.data)self.net1_conv1.bias.data = copy.deepcopy(net.conv1.bias.data)self.net1_conv2.weight.data = copy.deepcopy(net.conv2.weight.data)self.net1_conv2.bias.data = copy.deepcopy(net.conv2.bias.data)self.net2_conv1.weight.data = copy.deepcopy(net.conv1.weight.data)self.net2_conv1.bias.data = copy.deepcopy(net.conv1.bias.data)self.net2_conv2.weight.data = copy.deepcopy(net.conv2.weight.data)self.net2_conv2.bias.data = copy.deepcopy(net.conv2.bias.data)def main():# TranNet().exec()net = FigurePlusNet()origin_net = torch.load(model_save_path)net.copy_origin_weight(origin_net)criterion = nn.MSELoss()optmizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

如要固定值迁移(即新模型训练过程不改变加载进来模型的权重),设requires_grad = False即可

    def copy_origin_weight_nograd(self, net):self.copy_origin_weight(net)self.net1_conv1.weight.requires_grad = Falseself.net1_conv1.bias.requires_grad = Falseself.net1_conv2.weight.requires_grad = Falseself.net1_conv2.bias.requires_grad = Falseself.net2_conv1.weight.requires_grad = Falseself.net2_conv1.bias.requires_grad = Falseself.net2_conv2.weight.requires_grad = Falseself.net2_conv2.bias.requires_grad = False

4.3 训练与测试

    # train records = []for epoch in range(net.num_epochs):losses = []for index, data in enumerate(zip(net.train_loader_a, net.train_loader_b)):(x1, y1), (x2, y2) = dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()optimizer.zero_grad()net.train()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))loss.backward()optimizer.step()loss = loss.cpu() if net.gpu_ok() else losslosses.append(loss.data.numpy())if index % 300 == 0:verify_losses = []rights = []net.eval()for verify_data in zip(net.verify_loader_a, net.verify_loader_b):(x1, y1), (x2, y2) = verify_dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))loss = loss.cpu() if net.gpu_ok() else lossverify_losses.append(loss.data.numpy())right = rightness(outputs.data, labels)rights.append(right)right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])print(f'no.{epoch}, {index}/{len(net.train_loader_a)}, train loss:{np.mean(losses)}, verify loss:{np.mean(verify_losses)}, accu: {right_ratio}')# records.append([np.mean(losses), np.mean(verify_losses), right_ratio])records.append([right_ratio])# plot train datapyplot.figure(figsize=(8, 6))pyplot.plot(records)pyplot.xlabel('step')  pyplot.ylabel('loss & accuracy')# testrights = []net.eval()for test_data in zip(net.test_loader_a, net.test_loader_b):(x1, y1), (x2, y2) = test_dataif net.gpu_ok():x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()outputs = net(x1.clone().detach(), x2.clone().detach())outputs = outputs.squeeze()labels = y1 + y2loss = criterion(outputs, labels.type(torch.float))right = rightness(outputs, labels)rights.append(right)right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])print(f'test accuracy: {right_ratio}')pyplot.show()

4.4 结果

一轮打3个点,总共6轮(多了太慢,没用gpu),总共打大概24个点。发现随着训练轮数增加,准确率逐步提高

4.5 大规模测试

4.5.1 大模型

大模型指迁移学习全连接层用4层网络。以数据量为自变量,分别看5%,50%,100%数据量情况下,迁移学习与不迁移学习准确率随轮数变化趋势

结论:1 数据量100%时,迁移学习与无迁移学习准确率趋势很接近;数据量较小时,迁移学习准确率上升速度会远快于无迁移学习。数据量到50%左右时差异就变得不是很明显了但还是有 2数据量大时,固定值训练模式比预训练模式精度更好

4.5.2 小模型

小模型指迁移学习全连接层用2层而不是4层,仍然以数据量和是否迁移学习为自变量分析

结论:1 数据量小时,迁移学习比无迁移学习准确率上升速度快,数据量大时,迁移学习与无迁移学习这种差异变小 

5 总结

适用场景(不仅限于这些):数据量小

两种迁移方式:固定值和预训练,固定值方式参数变动范围小,训练可能更快

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/644602.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2012-2022年全国各省数字经济相关指标数据合集(18个指标)

2012-2022年全国各省数字经济相关指标数据合集(18个指标) 1、时间:2012-2022年 2、指标:地区、year、互联网接入端口数、互联网宽带接入用户数、互联网域名数、移动电话普及率、长途光缆线路长度(万公里)…

java开发——《并发编程》

目录 一.jmm 二.并发了什么 1.只有一个核(单核)并发还有没有意义 2.单核,还有什么可见性问题 3.并发和并行 三.volitaile 1.变量的可见性问题 2.原因是什么 3.本次修改的变量直接刷到主内存 4.声明其他内存对于这个地址的缓存无效 …

highcharts.css文件的样式覆盖了options的series里面的color问题解决

文章目录 一、问题背景二、解决问题 一、问题背景 原本的charts我们的每个数据是有对应的color显示的,如下图: 后面我们系统做了黑白模式,引入了highcharts的css文件,结果highcharts的css文件中class的颜色样式覆盖了我们数据中的…

【云原生】Docker的端口映射、数据卷、数据卷容器、容器互联

目录 一、端口映射(相当于添加iptables的DANT) 二、数据卷创建(宿主机目录或文件挂载到容器中) 三、数据卷容器(多个容器通过同一个数据卷容器为基点,实现所有容器数据共享) 四、容器互联&am…

Java 设计者模式以及与Spring关系(六) 装饰和模版方法模式

简介: 本文是个系列一次会出两个设计者模式作用,如果有关联就三个,除此外还会讲解在spring中作用。 23设计者模式以及重点模式 我们都知道设计者模式有3类23种设计模式,标红是特别重要的设计者模式建议都会,而且熟读于心&#…

Deployment介绍

1、Deployment介绍 Deployment一般用于部署公司的无状态服务。 格式: apiVersion: apps/v1 kind: Deployment metadata: name: nginx-deployment labels: app: nginx spec: replicas: 3 selector: matchLabels: app: nginx template: metada…

菜鸟导入导出assetbundle

因为菜鸟不会用unity c#什么的,所以最后参考贴吧的方法用的是UABE(Unity Assets Bundle Extractor)和UABEA(Unity Assets Bundle Extractor Avalonia) 可以去github上下载 对于txt、xml什么的可以直接改,但是byte文件里还是会有一些类似乱码的东西&…

Qt5项目拆解第一集解决:中文乱码| 全局字体|注册表|QSS/CSS

# 一、乱码解决代码片段 QTextCodec是Qt中用于处理文本编码和字符集转换的类。它提供了一系列静态函数来实现不同编码的文本转换,包括编码转换、字符集检测和转换、以及数据流中的文本编码处理。QTextCodec类使得Qt可以在不同的编码和字符集之间进行无缝转换,从而方便地处理…

Switch用法以及新特性-最全总结版

本篇文章参考了大佬文章,感谢大佬无私分享: http://t.csdnimg.cn/MjZnX http://t.csdnimg.cn/QFg0x 目录 一、Switch用法:JDK7及以前 1.1、举例一: 1.2、举例二: 二、Switch穿透: 2.1、举例&#xf…

【Linux】常见指令(二)

前言 常见指令第二部分。 文章目录 一、指令&#xff08;下&#xff09;重定向>&#xff1a;输出重定向>>&#xff1a;追加输出<&#xff1a;输入重定向 10. more—显示文本文件内容11.less—逐屏浏览文本文件内容12. head13. tail管道 |14. date—时间指令在这里插…

2024年可能会用到的几个地图可视化模板

前言 在数字化的过程中&#xff0c;数据可视化变得越来越重要。用户喜欢通过酷炫的视觉效果和直观的数据展示来理解数据。可视化地图组件是数据可视化的重要组成部分。这些地图组件提供多样化的效果&#xff0c;能够更好地展示数据的关系和地理分布&#xff0c;直观地将数据与…

裁员潮中的自我成长,小故事,大鼓励

程序员裁员潮&#xff1a;技术变革下的职业危机 科技浪潮滚滚而来&#xff0c;我们了解科技&#xff0c;敬畏科技&#xff0c;拥抱科技。我们怕的不是裁员&#xff0c;而是自己无所适从的样子。 2023年&#xff0c;科技公司裁员的新闻屡见不鲜。据统计&#xff0c;今年以来&…

uniapp设置隐藏原生导航栏(3)

1、单个页面隐藏 在pages.json里配置 (第一种方式) {"path": "pages/home/index","style": {"navigationBarTitleText": "首页","navigationStyle": "custom" // 使用自定义导航栏&#xff0c;系统会关…

SpringBoot3+JDK21集成MyBatisPlus3.5.5

哈喽&#xff0c;大家好&#xff0c;我是呼噜噜&#xff0c;在上一篇文章SpringBoot3Jdk17来了 | 春见知识分享基础上&#xff0c;笔者把jdk17直接换成了jdk21一步到位&#xff0c;来踩踩坑 添加依赖 修改pom.xml文件&#xff1a; <dependency><groupId>com.baom…

日历的实现(java语言,包括钟表盘的实现、日历内部的日程提醒)

整理文件发现了大一的时候的作业&#xff0c;先感慨一波时间过得真的快&#xff01; 手中的这个是一个独立的java文件&#xff0c;可以直接就可以运行&#xff0c;应该是没有什么问题的。不想这个代码就此落灰了&#xff0c;希望可以给友友们带来一点点帮助&#xff01; 运行…

避免邮件进入垃圾箱的实用技巧:提高邮件接收率的策略

邮件进垃圾邮箱一部分原因是IP地址出现了问题&#xff0c;一部分是邮件内容。那我们应该怎么避免邮件进入垃圾邮箱呢&#xff1f; 1、邮件内容 1&#xff09;邮件标题 邮件标题是影响邮件打开率非常重要的因素&#xff0c;所以大家可能会在标题上放置一些吸引人的符号或者词…

聚道云连接器助力钉钉与金蝶云星辰无缝对接,实现多维度数据同步

客户介绍 某企业服务有限公司专注于为企业提供全方位、高质量的企业服务&#xff0c;致力于于企业管理咨询、企业形象策划、市场营销策划、财务管理咨询等方面。该公司拥有一支经验丰富、专业化的团队&#xff0c;他们深入了解企业需求&#xff0c;为客户提供个性化的解决方案…

列表的创建与删除

Python 中列表可以动态地添加、修改和删除元素&#xff0c;是 Python 编程中不可或缺的一部分。本文将介绍如何使用 Python 创建和删除列表&#xff0c;以及常用的方法和技巧。 创建列表 在 Python 中&#xff0c;我们可以使用一对方括号 [ ] 来创建一个空列表&#xff0c;也可…

开源免费无广告Gopeed,现代化的高速下载器,支持(HTTP、BitTorrent、Magnet)等多种协议下载,开源免费、无广告、高度可定制、不限速。

目录 特点 支持的平台 一键部署 体验 特点 全平台支持、开源免费&#xff0c;不限速、无广告 遵循 GPL-3.0 开源协议 支持&#xff08;HTTP、BitTorrent、Magnet&#xff09;协议下载 高速下载&#xff0c;底层使用golang协程并发下载 每日自动更新 tracker 列表 去中心…

IPv4 over IPv6简介

在IPv4 Internet向IPv6 Internet过渡的后期&#xff0c;IPv6网络已被大量部署&#xff0c;此时可能出现IPv4孤岛。利用隧道技术可在IPv6网络上创建隧道&#xff0c;从而实现IPv4孤岛的互连。这类似于在IP网络上利用隧道技术部署VPN。在IPv6网络上用于连接IPv4孤岛的隧道&#x…