PyTorch – 逻辑回归

data

首先导入torch里面专门做图形处理的一个库,torchvision,根据官方安装指南,你在安装pytorch的时候torchvision也会安装。

我们需要使用的是torchvision.transforms和torchvision.datasets以及torch.utils.data.DataLoader

首先DataLoader是导入图片的操作,里面有一些参数,比如batch_size和shuffle等,默认load进去的图片类型是PIL.Image.open的类型,如果你不知道PIL,简单来说就是一种读取图片的库

torchvision.transforms里面的操作是对导入的图片做处理,比如可以随机取(50, 50)这样的窗框大小,或者随机翻转,或者去中间的(50, 50)的窗框大小部分等等,但是里面必须要用的是transforms.ToTensor(),这可以将PIL的图片类型转换成tensor,这样pytorch才可以对其做处理

torchvision.datasets里面有很多数据类型,里面有官网处理好的数据,比如我们要使用的MNIST数据集,可以通过torchvision.datasets.MNIST()来得到,还有一个常使用的是torchvision.datasets.ImageFolder(),这个可以让我们按文件夹来取图片,和keras里面的flow_from_directory()类似,具体的可以去看看官方文档的介绍。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

# 定义超参数

batch_size = 32

learning_rate = 1e-3

num_epoches = 100

# 下载训练集 MNIST 手写数字训练集

train_dataset = datasets.MNIST(root=\'./data\', train=True,

                               transform=transforms.ToTensor(),

                               download=True)

test_dataset = datasets.MNIST(root=\'./data\', train=False,

                              transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

以上就是我们对图片数据的读取操作

model

之前讲过模型定义的框架,废话不多说,直接上代码

1

2

3

4

5

6

7

8

9

10

class Logstic_Regression(nn.Module):

    def __init__(self, in_dim, n_class):

        super(Logstic_Regression, self).__init__()

        self.logstic = nn.Linear(in_dim, n_class)

    def forward(self, x):

        out = self.logstic(x)

        return out

model = Logstic_Regression(28*28, 10)  # 图片大小是28x28

我们需要向这个模型传入参数,第一个参数定义为数据的维度,第二维数是我们分类的数目。

接着我们可以在gpu上跑模型,怎么做呢?

首先可以判断一下你是否能在gpu上跑

1

torh.cuda.is_available()

如果返回True就说明有gpu支持

接着你只需要一个简单的命令就可以了

1

2

3

4

5

model = model.cuda()

或者

model.cuda()

然后需要定义loss和optimizer

1

2

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

这里我们使用的loss是交叉熵,是一种处理分类问题的loss,optimizer我们还是使用随机梯度下降

train

接着就可以开始训练了

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

for epoch in range(num_epoches):

    print(\'epoch {}\'.format(epoch 1))

    print(\'*\'*10)

    running_loss = 0.0

    running_acc = 0.0

    for i, data in enumerate(train_loader, 1):

        img, label = data

        img = img.view(img.size(0), -1)  # 将图片展开成 28x28

        if use_gpu:

            img = Variable(img).cuda()

            label = Variable(label).cuda()

        else:

            img = Variable(img)

            label = Variable(label)

        # 向前传播

        out = model(img)

        loss = criterion(out, label)

        running_loss  = loss.data[0] * label.size(0)

        _, pred = torch.max(out, 1)

        num_correct = (pred == label).sum()

        running_acc  = num_correct.data[0]

        # 向后传播

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

注意我们如果将模型放到了gpu上,相应的我们的Variable也要放到gpu上,也很简单

1

2

img = Variable(img).cuda()

label = Variable(label).cuda()

然后可以测试模型,过程与训练类似,只是注意要将模型改成测试模式

1

model.eval()

这是跑完100 epoch的结果

具体的结果多久打印一次,如何打印可以自己在for循环里面去设计。

相关代码:pytorch-beginner: pytorch-beginner

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/693032.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

营销系统黑名单优化:位图的应用解析 | 京东云技术团队

背景 营销系统中,客户投诉是业务发展的一大阻碍,一般会过滤掉黑名单高风险账号,并配合频控策略,来减少客诉,进而增加营销效率,减少营销成本,提升营销质量。 营销系统一般是通过大数据分析建模…

2024年了,如何从 0 搭建一个 Electron 应用

简介 Electron 是一个开源的跨平台桌面应用程序开发框架,它允许开发者使用 Web 技术(如 JavaScript、HTML 和 CSS)来构建桌面应用程序。Electron 嵌入了 Chromium(一个开源的 Web 浏览器引擎)和 Node.js(一…

游戏行业洞察:分布式开源爬虫项目在数据采集与分析中的应用案例介绍

前言 我在领导一个为游戏行业巨头提供数据采集服务的项目中,我们面临着实时数据需求和大规模数据处理的挑战。我们构建了一个基于开源分布式爬虫技术的自动化平台,实现了高效、准确的数据采集。通过自然语言处理技术,我们确保了数据的质量和…

【PostgreSQL实现psql连接时候提示用户的密码有效时间】

如下内容使用session_exec插件结合自定函数实现。类似于触发器的原理。 功能需要严格在测试环境测试后,才可在正式环境使用。没有相关要求,还是建议直接查询pg_roles/pg_authid/pg_user; 一、判断是否需要修改用户密码和有效期的检查SQL 首…

【Emgu CV教程】7.1、图像锐化之Laplacian(拉普拉斯)算子锐化

文章目录 一、介绍二、举例1.原始素材2.代码3.运行结果 一、介绍 前面几篇讲的是图像平滑,就是抑制或消除噪声,并使得图像亮度及颜色变化更平缓的操作。在图像处理领域,与平滑操作相对应的,叫图像锐化。 图像锐化就是增强图像的边…

python OpenCV:seamlessClone泊松融合

一、seamlessClone函数的用法 翻译 https://www.learnopencv.com/seamless-cloning-using-opencv-python-cpp/ def seamlessClone(src, dst, mask, p, flags, blendNone): # real signature unknown; restored from __doc__"""seamlessClone(src, dst, mask, …

【Hudi】Upsert原理

17张图带你彻底理解Hudi Upsert原理 1.开始提交:判断上次任务是否失败,如果失败会触发回滚操作。然后会根据当前时间生成一个事务开始的请求标识元数据。2.构造HoodieRecord Rdd对象:Hudi 会根据元数据信息构造HoodieRecord Rdd 对象&#xf…

2024年【起重机司机(限桥式起重机)】试题及解析及起重机司机(限桥式起重机)证考试

题库来源:安全生产模拟考试一点通公众号小程序 起重机司机(限桥式起重机)试题及解析考前必练!安全生产模拟考试一点通每个月更新起重机司机(限桥式起重机)证考试题目及答案!多做几遍,其实通过起重机司机(限桥式起重机)理论考试很…

linux ext3/ext4文件系统(part2 jbd2)

概述 jbd2(journal block device 2)是为块存储设计的 wal 机制,它为要写设备的buffer绑定了一个journal_head,这个journal_head与一个transaction绑定,随着事务状态的转移(运行,生成日志&#…

我为什么不喜欢关电脑?

程序员为什么不喜欢关电脑? 你是否注意到,程序员们似乎从不关电脑?别以为他们是电脑上瘾,实则是有他们自己的原因!让我们一起揭秘背后的原因,看看程序员们真正的“英雄”本色! 一、上大学时。 …

Backtrader 量化回测实践(1)—— 架构理解和MACD/KDJ混合指标

Backtrader 量化回测实践(1)—— 架构理解和MACD/KDJ混合指标 按Backtrader的架构组织,整理了一个代码,包括了Backtrader所有的功能点,原来总是使用SMA最简单的指标,现在稍微增加了复杂性,用MA…

k8s除了可以直接运行docker镜像之外,还可以运行什么? springboot项目打包成的压缩包可以直接运行在docker容器中吗?

Kubernetes(k8s)主要设计用于自动部署、扩展和管理容器化应用程序。虽然它与Docker容器最为密切相关,Kubernetes实际上是与容器运行时技术无关的,这意味着它不仅仅能够管理Docker容器。Kubernetes支持多种容器运行时,包…

[office] EXCEL表格不能使用键盘箭头切换单元格该怎么解决- #媒体#经验分享#知识分享

EXCEL表格不能使用键盘箭头切换单元格该怎么解决? EXCEL表格不能使用键盘箭头切换单元格该怎么解决? 1、入下图所示的键盘。 图中红色标记“1”的地方是Scroll Lock指示灯。Scroll Lock就是“滚动锁定”的意思。当该指示灯亮起来的时候,在excel表格中操…

Android 面试问题 2024 版(其一)

Android 面试问题 2024 版(其一) 一、Java 和 Kotlin二、安卓组件三、用户界面 (UI) 开发四、安卓应用架构五、网络和数据持久性 一、Java 和 Kotlin Java 中的抽象类和接口有什么区别? 答:抽象类是不能实例化的类,它…

使用IntelliJ IDEA查看接口的全部实现方法

在大型Java项目中,经常会使用接口和抽象类进行代码设计。为了更好地了解代码结构和功能,我们需要快速查看一个接口的所有实现类。IntelliJ IDEA提供了一些方便的方法来实现这一目标。 1. 点击查看接口的实现子类 在IDEA中,你可以轻松地查看…

大话设计模式——2.简单工厂模式(Simple Factory Pattern)

定义:又称静态工厂方法,可以根据参数的不同返回不同类的实例,专门定义一个类(工厂类)来负责创建其他类的实例可通过类名直接调用,被创建的实例通常具有共同的父类。 UML图: 例子: 计…

计算机视觉的应用24-ResNet网络与DenseNet网络的对比学习,我们该如何选择。

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用24-ResNet网络与DenseNet网络的对比学习,我们该如何选择。在计算机视觉领域,ResNet(残差网络)和DenseNet(密集网络)都是深度学…

华清远见作业第三十九天——Qt(第一天)

思维导图&#xff1a; 登录界面&#xff1a; 代码&#xff1a; #include "mainwindow.h" #include<QToolBar> #include<QPushButton> MainWindow::MainWindow(QWidget *parent): QMainWindow(parent) {this->resize(600,400);this->setFixedSize…

Mysql 8.0新特性详解

建议使用8.0.17及之后的版本&#xff0c;更新的内容比较多。 1、新增降序索引 MySQL在语法上很早就已经支持降序索引&#xff0c;但实际上创建的仍然是升序索引&#xff0c;如下MySQL 5.7 所示&#xff0c;c2字段降序&#xff0c;但是从show create table看c2仍然是升序。8.0…

ubuntu 22.04.3 live server安装JDK21与远程编程环境和maven

ubuntu 22.04.3 live server安装JDK21与远程编程环境 一、安装jdk21 解压jdk压缩包&#xff0c;命令&#xff1a; tar -zxvf jdk-21_linux-x64_bin.tar.gz打开环境变量&#xff0c;命令&#xff1a; sudo vim /etc/profile配置环境变量 export JAVA_HOME/root/jdk-21.0.2 …