PyTorch常用工具(2)预训练模型

文章目录

  • 前言
  • 2 预训练模型

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

2 预训练模型

除了加载数据,并对数据进行预处理之外,torchvision还提供了深度学习中各种经典的网络结构以及预训练模型。这些模型封装在torchvision.models中,包括经典的分类模型:VGG、ResNet、DenseNet及MobileNet等,语义分割模型:FCN及DeepLabV3等,目标检测模型:Faster RCNN以及实例分割模型:Mask RCNN等。读者可以通过下述代码使用这些已经封装好的网络结构与模型,也可以在此基础上根据需求对网络结构进行修改:

from torchvision import models
# 仅使用网络结构,参数权重随机初始化
mobilenet_v2 = models.mobilenet_v2()
# 加载预训练权重
deeplab = models.segmentation.deeplabv3_resnet50(pretrained=True)

下面使用torchvision中预训练好的实例分割模型Mask RCNN进行一次简单的实例分割:

In: from torchvision import modelsfrom torchvision import transforms as Tfrom torch import nnfrom PIL import Imageimport numpy as npimport randomimport cv2# 加载预训练好的模型,不存在的话会自动下载# 预训练好的模型保存在 ~/.torch/models/下面detection = models.detection.maskrcnn_resnet50_fpn(pretrained=True)detection.eval()def predict(img_path, threshold):# 数据预处理,标准化至[-1, 1],规定均值和标准差img = Image.open(img_path)transform = T.Compose([T.ToTensor(),T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])img = transform(img)# 对图像进行预测pred = detection([img])# 对预测结果进行后处理:得到mask与bboxscore = list(pred[0]['scores'].detach().numpy())t = [score.index(x) for x in score if x > threshold][-1]mask = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()pred_boxes = [[(i[0], i[1]), (i[2], i[3])] \for i in list(pred[0]['boxes'].detach().numpy())]pred_masks = mask[:t+1]boxes = pred_boxes[:t+1]return pred_masks, boxes

Transforms中涵盖了大部分对Tensor和PIL Image的常用处理,这些已在上文提到,本节不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如transf = transforms.Normalize(mean=x, std=y);第二步:执行转换操作,例如output = transf(input)。另外还可以将多个处理操作用Compose拼接起来,构成一个处理转换流程。

In: # 随机颜色,以便可视化def color(image):colours = [[0, 255, 255], [0, 0, 255], [255, 0, 0]]R = np.zeros_like(image).astype(np.uint8)G = np.zeros_like(image).astype(np.uint8)B = np.zeros_like(image).astype(np.uint8)R[image==1], G[image==1], B[image==1] = colours[random.randrange(0,3)]color_mask = np.stack([R,G,B],axis=2)return color_mask
In: # 对mask与bounding box进行可视化def result(img_path, threshold=0.9, rect_th=1, text_size=1, text_th=2):masks, boxes = predict(img_path, threshold)img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)for i in range(len(masks)):color_mask = color(masks[i])img = cv2.addWeighted(img, 1, color_mask, 0.5, 0)cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(255,0,0), thickness=rect_th)return img
In: from matplotlib import pyplot as pltimg=result('data/demo.jpg')plt.figure(figsize=(10, 10))plt.axis('off')img_result = plt.imshow(img)

TensorBoard界面

上述代码完成了一个简单的实例分割任务。如上图所示,Mask RCNN能够分割出该图像中的部分实例,读者可考虑对预训练模型进行微调,以适应不同场景下的不同任务。注意:上述代码均在CPU上进行,速度较慢,读者可以考虑将数据与模型转移至GPU上,具体操作可以参考第4节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/589165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一起学量化之KDJ指标

KDJ指标,也称为随机指数,是一个常用的技术分析工具。它由三条线组成:K线、D线和J线,分别代表不同的市场动态。KDJ指标通过分析最高价、最低价和收盘价计算得出。 1. KDJ指标理解 J线是移动速度最快的线,可以提供更加敏锐的市场信号。K线是指标的核心,显示市场的即时动态。…

【linux 多线程并发】线程属性设置与查看,绑定CPU,线程分离与可连接,避够多线程下的内存泄漏

线程属性设置 ​专栏内容: 参天引擎内核架构 本专栏一起来聊聊参天引擎内核架构,以及如何实现多机的数据库节点的多读多写,与传统主备,MPP的区别,技术难点的分析,数据元数据同步,多主节点的情况…

LeetCode1275. Find Winner on a Tic Tac Toe Game

文章目录 一、题目二、题解 一、题目 Tic-tac-toe is played by two players A and B on a 3 x 3 grid. The rules of Tic-Tac-Toe are: Players take turns placing characters into empty squares ’ . The first player A always places ‘X’ characters, while the seco…

Keras实现Transformer

# 导入所需的库 import numpy as np from keras.models import Model from keras.layers import Input, Dense, Embedding, MultiHeadAttention from keras.optimizers import Adam# 定义模型参数 vocab_size 10000 # 词汇表大小 embedding_dim 256 # 嵌入维度 num_heads …

营销系统升级:运荔枝无代码集成电商API功能

无代码开发:运荔枝连接电商与CRM 随着电子商务的持续扩张,企业亟需无缝集成电商平台与客户关系管理(CRM)系统,以提高运营效率。运荔枝通过其无代码开发平台,为企业提供了简化的API连接服务。商家可以在不具…

Prometheus 监控进程

prometheus 进程的监控 1. process exporter功能 2. 监控目标对主机进程的监控,chronyd sshd 等服务进程已经已定义脚本运行程序的运行状态监控。 process-compose的安装 监控所有进程 mkdir /data/process_exporter -p cd /data/process_exporter创建配置文件 …

Linux期末复习笔记

期末复习笔记 引言目录操作用户和组用户组 文件及文件权限文件文件目录及分类Linux文件目录文件类型文件权限 磁盘管理磁盘命名规则使用命令行工具管理磁盘分区和文件系统linux中的数据备份策略软件包安装检查维护文件系统 进程管理进程分类ps查看与top查看的区别: …

为什么ChatGPT选择了SSE,而不是WebSocket?

我在探索ChatGPT的使用过程中,发现了一个有趣的现象:ChatGPT在实现流式返回的时候,选择了SSE(Server-Sent Events),而非WebSocket。 那么问题来了:为什么ChatGPT选择了SSE,而不是We…

力扣25题: K 个一组翻转链表

【题目链接】力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台,解题代码如下: class Solution {public ListNode reverseKGroup(ListNode head, int k) {ListNode curNode head;ListNode groupHead, groupTail head, lastGrou…

UART通信协议:串行通信的精华

UART通信协议:串行通信的精华 UART(Universal Asynchronous Receiver/Transmitter)通信协议是一种广泛应用于串行通信的标准,它在电子设备和嵌入式系统中扮演着至关重要的角色。本文将深入介绍UART通信协议的基本原理、工作方式、…

一个可以用于生产环境得PHP上传函数

上传表单 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><title>文件上传</title> </head> <body><h1>选择要上传的文件</h1><!-- 定义一个包含文件输入字段的表单 --…

[每周一更]-(第46期):Linux下配置Java所需环境及Java架构选型

Linux下配置Java所需环境及Java架构选型 一、配置基础环境 1.配置tomcat 环境变量 wget https://dlcdn.apache.org/tomcat/tomcat-10/v10.1.8/src/apache-tomcat-10.1.8-src.tar.gz tar -zxvf apache-tomcat-10.1.8-src.tar.gz 在/etc/profile 末尾追加export CATALINA_HOME…

异常控制流ECF

大家好&#xff0c;我叫徐锦桐&#xff0c;个人博客地址为www.xujintong.com&#xff0c;github地址为https://github.com/xjintong。平时记录一下学习计算机过程中获取的知识&#xff0c;还有日常折腾的经验&#xff0c;欢迎大家访问。 一、异常控制流&#xff08;ECF) 现代系…

[BUG]Datax写入数据到psql报不能序列化特殊字符

1.问题描述 Datax从mongodb写入数据到psql报错如下 org.postgresql.util.PSQLException: ERROR: invalid bytesequence for encoding "UTF8": 0x002.原因分析 此为psql独有的错误&#xff0c;不能对特殊字符’/u0000’,进行序列化&#xff0c;需要将此特殊字符替…

webrtc中的接口代理框架

文章目录 接口代理框架Proxy体系类结构导出接口 webrtc的实际运用PeerConnectionFactoyPeerConnection使用 接口代理框架 webrtc体系庞大&#xff0c;模块化极好&#xff0c;大多数模块都可以独立使用。模块提供接口&#xff0c;外部代码通过接口来使用模块功能。 在webrtc中通…

uni-app 前后端调用实例 基于Springboot

锋哥原创的uni-app视频教程&#xff1a; 2023版uniapp从入门到上天视频教程(Java后端无废话版)&#xff0c;火爆更新中..._哔哩哔哩_bilibili2023版uniapp从入门到上天视频教程(Java后端无废话版)&#xff0c;火爆更新中...共计23条视频&#xff0c;包括&#xff1a;第1讲 uni…

《PCI Express体系结构导读》随记 —— 第I篇 第2章 PCI总线的桥与配置(1)

前言中曾提到&#xff1a;本章重点介绍PCI桥。 在PCI体系结构中含有两类桥&#xff1a;一类是HOST主桥&#xff1b;另一类是PCI桥。在每一个PCI设备中&#xff08;包括PCI桥&#xff09;&#xff0c;都含有一个配置空间。这个配置空间由HOST主桥管理&#xff0c;而PCI桥可以转…

cfa一级考生复习经验分享系列(十五)

备考背景&#xff1a; 本科211石油理科背景&#xff1b;无金融方面专业知识及工作经验&#xff1b;在职期间备考&#xff1b;有效备考时间2个月&#xff1b;12月一级考试10A。 复习进度及教材选择 首先说明&#xff0c;关于教材的经验分享针对非金融背景考生。 第一阶段&#x…

Java EE Servlet之Cookie 和 Session

文章目录 1. Cookie 和 Session1.1 Cookie1.2 理解会话机制 (Session)1.2.1 核心方法 2. 用户登录2.1 准备工作2.2 登录页面2.3 写一个 Servlet 处理上述登录请求2.4 实现登录后的主页 3. 总结 1. Cookie 和 Session 1.1 Cookie cookie 是 http 请求 header 中的一个属性 浏…

[枚举涂块]画家问题

画家问题 题目描述 有一个正方形的墙&#xff0c;由N*N个正方形的砖组成&#xff0c;其中一些砖是白色的&#xff0c;另外一些砖是黄色的。Bob是个画家&#xff0c;想把全部的砖都涂成黄色。但他的画笔不好使。当他用画笔涂画第(i, j)个位置的砖时&#xff0c; 位置(i-1, j)、…