使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderclass Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return x#导入数据
def get_data_loader(is_train):#张量,多维数组to_tensor = transforms.Compose([transforms.ToTensor()])# 下载数据集 下载目录data_set = MNIST("", is_train, transform=to_tensor, download=True)#一个批次15张,顺序打乱return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)#评估准确率
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():#按批次取数据for (x, y) in test_data:#计算神经网络预测值outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):#比较预测结果和测试集结果if torch.argmax(output) == y[i]:#统计正确预测结果数n_correct += 1#统计全部预测结果n_total += 1#返回准确率=正确/全部的return n_correct / n_totaldef main():#加载训练集train_data = get_data_loader(is_train=True)#加载测试集test_data = get_data_loader(is_train=False)#初始化神经网络net = Net()#打印测试网络的准确率 0.1print("initial accuracy:", evaluate(test_data, net))#训练神经网络optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#重复利用数据集 2次for epoch in range(100):for (x, y) in train_data:#初始化 固定写法net.zero_grad()#正向传播output = net.forward(x.view(-1, 28 * 28))#计算差值loss = torch.nn.functional.nll_loss(output, y)#反向误差传播loss.backward()#优化网络参数optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))# #使用3张图片进行预测# for (n, (x, _)) in enumerate(test_data):#     if n > 3:#         break#     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))#     plt.figure(n)#     plt.imshow(x[0].view(28, 28))#     plt.title("prediction: " + str(int(predict)))# plt.show()image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakpredict = torch.argmax(net.forward(x.view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].permute(1, 2, 0))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = MNIST("", is_train, transform=to_tensor, download=True)return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:x, y = x.to(device), y.to(device)outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):if torch.argmax(output.cpu()) == y[i].cpu():n_correct += 1n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)net = Net().to(device)print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)for epoch in range(100):for (x, y) in train_data:x, y = x.to(device), y.to(device)net.zero_grad()output = net.forward(x.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(output, y)loss.backward()optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakx = x.to(device)predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())plt.figure(n)plt.imshow(x[0].permute(1, 2, 0).cpu())plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155536.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

StringBuffer解析

StringBuffer是可变的字符序列,当一个StringBuffer被创建以后,通过StringBuffer提供append()、insert()、reverse()、setCharAt()、setLength()等方法可以改变这个字符串对象的字符序列。一旦通过StringBuffer生成了最终想要的字符串,就可以调…

移动云CNP产品介绍

整体介绍 磐舟devops的核心功能是项目管理和CI流程实现。CD能力也是集成的外部开源产品argoCD。所以 磐舟并不以CD能力见长。一般推荐试用磐舟完成CI,然后试用移动云CNP产品完成CD部署工作。 移动云原生技术平台CNP是面向多云多集群场景的应用管理平台。平台以应用…

Linux—简介安装常用命令系统中软件安装项目部署

目录 1. 前言1.1 什么是Linux1.2 为什么要学Linux1.3 学完Linux能干什么 2. Linux简介2.1 主流操作系统2.2 Linux发展历史2.3 Linux系统版本 3. Linux安装3.1 安装方式介绍3.2 安装VMware3.3 安装Linux3.4 网卡设置3.5 安装SSH连接工具3.5.1 SSH连接工具介绍3.5.2 FinalShell安…

大数据可视化是什么?

大数据可视化是将海量数据通过视觉方式呈现出来,以便于人们理解和分析数据的过程。它可以帮人们发现数据之间的关系、趋势和模式,并制定更明智的决策。大数据可视化通常通过图形、图表、地图和仪表盘等视觉元素来呈现数据。这些元素具有直观、易理解的特…

vscode自定义代码提示

vscode输入一段代码后,自动给代码提示,那么如何才能自定义呢? 点击左下角设置按钮,弹出一个框。选择点击“用户代码片段”,又弹出一个框提供了很多选项,其中包括了针对某些类型的文件(如&#x…

Python 检测网络是否连通

1 使用 urlib import urllib.requestdef test_internet_connection():url https://www.baidu.comtry:urllib.request.urlopen(url, timeout5)print("网络连接正常")except urllib.error.URLError as ex:print("网络连接异常:" str(ex))test_…

前端uniapp生成海报绘制canvas画布并且保存到相册【实战/带源码/最新】

目录 插件市场效果如下图注意使用my-share.vue插件文件如下图片hch-posterutilsindex.js draw-demo.vuehch-poster.vue 最后 插件市场 插件市场 效果如下图 注意 主要&#xff1a;使用my-share.vue和绘制canvas的hch-poster.vue这两个使用 使用my-share.vue <template&…

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测 目录 时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 x 基本介绍 1.Matlab实现LSTM-Adaboost时间序列预测…

SQL基础理论篇(八):视图

文章目录 简介创建视图修改视图删除视图总结参考文献 简介 视图&#xff0c;即VIEW&#xff0c;是SQL中的一个重要概念&#xff0c;它其实是一种虚拟表(非实体数据表&#xff0c;本身不存储数据)。 视图类似于编程中的函数&#xff0c;也可以理解成是一个访问数据的接口。 从…

数据分析思维与模型:群组分析法

群组分析法&#xff0c;也称为群体分析法或集群分析法&#xff0c;是一种研究方法&#xff0c;用于分析和理解群体内的动态、行为模式、意见、决策过程等。这种方法在社会科学、心理学、市场研究、组织行为学等领域有广泛应用。它可以帮助研究人员或组织更好地理解特定群体的特…

C# Onnx DIS高精度图像二类分割

目录 介绍 效果 模型信息 项目 代码 下载 介绍 github地址&#xff1a;https://github.com/xuebinqin/DIS This is the repo for our new project Highly Accurate Dichotomous Image Segmentation 对应的paper是ECCV2022的一篇文章《Highly Accurate Dichotomous Imag…

书摘:C 嵌入式系统设计模式 01

本书的原著为&#xff1a;《Design Patterns for Embedded Systems in C ——An Embedded Software Engineering Toolkit 》&#xff0c;讲解的是嵌入式系统设计模式&#xff0c;是一本不可多得的好书。 模式 &#xff0c;英文为 pattern&#xff0c;指的是在软件开发中&#…

Windows + Syslog-ng 发送eventlog 到Splunk indexer

1: 背景: 装了window Splunk universal forwarder 的 window server 要把event log 送到linux 的splunk indexer 上,由于网络的原因,不能直接发送数据到splunk indexer的话,要利用跳板机来实现: 2:架构: 3: 先说明每个类型server 上的安装情况: Window server: 安装S…

Axios七大特性

Axios是一个基于Promise的HTTP客户端&#xff0c;用于浏览器和Node.js环境中发起HTTP请求。它有许多强大的特性&#xff0c;下面将介绍Axios的七大特性。 1. 支持浏览器和Node.js Axios既可以在浏览器中使用&#xff0c;也可以在Node.js环境中使用&#xff0c;提供了统一的AP…

Tomcat 9.0.54源码环境搭建

一. 问什么要学习tomcat tomcat是目前非常流行的web容器&#xff0c;其性能和稳定性也是非常出色的&#xff0c;学习其框架设计和底层的实现&#xff0c;不管是使用、性能调优&#xff0c;还是应用框架设计方面&#xff0c;肯定会有很大的帮助 二. 运行源码 1.下载源…

DeepMind 推出 OPRO 技术,可用于优化 ChatGPT 提示

本心、输入输出、结果 文章目录 DeepMind 推出 OPRO 技术&#xff0c;可用于优化 ChatGPT 提示前言消息摘要OPRO的工作原理DeepMind的研究相关链接花有重开日&#xff0c;人无再少年实践是检验真理的唯一标准 DeepMind 推出 OPRO 技术&#xff0c;可用于优化 ChatGPT 提示 编辑…

股票池(三)

3-股票池 文章目录 3-股票池一. 查询股票池支持的类型二. 查询目前股票池对应的股票信息三 查询股票池内距离今天类型最少/最多的股票数据四. 查询股票的池统计信息 一. 查询股票池支持的类型 接口描述: 接口地址:/StockApi/stockPool/listPoolType 请求方式&#xff1a;GET…

Figma 是什么软件?为什么能被Adobe收购

很多人一定早就听说过Figma的名字了。看到很多设计同行推荐&#xff0c;用了很久&#xff0c;疯狂的安利朋友用。是什么让这么多设计师放弃了FigmaSketch的魅力&#xff1f;下面的内容将详细分享一些与Figma相关的知识点&#xff0c;并介绍这个经常听到但不熟悉的工具。 Figma…

Linux查看开机启动的服务

在Linux系统中&#xff0c;可以使用不同的命令和工具来查看开机启动的服务。以下是一些常用的方法&#xff1a; systemctl 命令&#xff1a; 使用 systemctl 命令可以查看系统中所有正在运行的服务以及它们的状态。 systemctl list-units --typeservice若要查看某个特定服务的…

Git介绍

Git是一个分布式版本控制系统&#xff0c;**用于跟踪文件和目录的变化&#xff0c;并协助多人协作开发项目。**它是由Linus Torvalds于2005年创建的&#xff0c;现在是一个开源工具&#xff0c;并且被广泛应用于软件开发领域。 以下是Git的一些关键特点和基本概念&#xff1a;…