使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderclass Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return x#导入数据
def get_data_loader(is_train):#张量,多维数组to_tensor = transforms.Compose([transforms.ToTensor()])# 下载数据集 下载目录data_set = MNIST("", is_train, transform=to_tensor, download=True)#一个批次15张,顺序打乱return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)#评估准确率
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():#按批次取数据for (x, y) in test_data:#计算神经网络预测值outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):#比较预测结果和测试集结果if torch.argmax(output) == y[i]:#统计正确预测结果数n_correct += 1#统计全部预测结果n_total += 1#返回准确率=正确/全部的return n_correct / n_totaldef main():#加载训练集train_data = get_data_loader(is_train=True)#加载测试集test_data = get_data_loader(is_train=False)#初始化神经网络net = Net()#打印测试网络的准确率 0.1print("initial accuracy:", evaluate(test_data, net))#训练神经网络optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#重复利用数据集 2次for epoch in range(100):for (x, y) in train_data:#初始化 固定写法net.zero_grad()#正向传播output = net.forward(x.view(-1, 28 * 28))#计算差值loss = torch.nn.functional.nll_loss(output, y)#反向误差传播loss.backward()#优化网络参数optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))# #使用3张图片进行预测# for (n, (x, _)) in enumerate(test_data):#     if n > 3:#         break#     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))#     plt.figure(n)#     plt.imshow(x[0].view(28, 28))#     plt.title("prediction: " + str(int(predict)))# plt.show()image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakpredict = torch.argmax(net.forward(x.view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].permute(1, 2, 0))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = MNIST("", is_train, transform=to_tensor, download=True)return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:x, y = x.to(device), y.to(device)outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):if torch.argmax(output.cpu()) == y[i].cpu():n_correct += 1n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)net = Net().to(device)print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)for epoch in range(100):for (x, y) in train_data:x, y = x.to(device), y.to(device)net.zero_grad()output = net.forward(x.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(output, y)loss.backward()optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakx = x.to(device)predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())plt.figure(n)plt.imshow(x[0].permute(1, 2, 0).cpu())plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155536.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

移动云CNP产品介绍

整体介绍 磐舟devops的核心功能是项目管理和CI流程实现。CD能力也是集成的外部开源产品argoCD。所以 磐舟并不以CD能力见长。一般推荐试用磐舟完成CI,然后试用移动云CNP产品完成CD部署工作。 移动云原生技术平台CNP是面向多云多集群场景的应用管理平台。平台以应用…

Linux—简介安装常用命令系统中软件安装项目部署

目录 1. 前言1.1 什么是Linux1.2 为什么要学Linux1.3 学完Linux能干什么 2. Linux简介2.1 主流操作系统2.2 Linux发展历史2.3 Linux系统版本 3. Linux安装3.1 安装方式介绍3.2 安装VMware3.3 安装Linux3.4 网卡设置3.5 安装SSH连接工具3.5.1 SSH连接工具介绍3.5.2 FinalShell安…

大数据可视化是什么?

大数据可视化是将海量数据通过视觉方式呈现出来,以便于人们理解和分析数据的过程。它可以帮人们发现数据之间的关系、趋势和模式,并制定更明智的决策。大数据可视化通常通过图形、图表、地图和仪表盘等视觉元素来呈现数据。这些元素具有直观、易理解的特…

前端uniapp生成海报绘制canvas画布并且保存到相册【实战/带源码/最新】

目录 插件市场效果如下图注意使用my-share.vue插件文件如下图片hch-posterutilsindex.js draw-demo.vuehch-poster.vue 最后 插件市场 插件市场 效果如下图 注意 主要&#xff1a;使用my-share.vue和绘制canvas的hch-poster.vue这两个使用 使用my-share.vue <template&…

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测 目录 时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 x 基本介绍 1.Matlab实现LSTM-Adaboost时间序列预测…

SQL基础理论篇(八):视图

文章目录 简介创建视图修改视图删除视图总结参考文献 简介 视图&#xff0c;即VIEW&#xff0c;是SQL中的一个重要概念&#xff0c;它其实是一种虚拟表(非实体数据表&#xff0c;本身不存储数据)。 视图类似于编程中的函数&#xff0c;也可以理解成是一个访问数据的接口。 从…

数据分析思维与模型:群组分析法

群组分析法&#xff0c;也称为群体分析法或集群分析法&#xff0c;是一种研究方法&#xff0c;用于分析和理解群体内的动态、行为模式、意见、决策过程等。这种方法在社会科学、心理学、市场研究、组织行为学等领域有广泛应用。它可以帮助研究人员或组织更好地理解特定群体的特…

C# Onnx DIS高精度图像二类分割

目录 介绍 效果 模型信息 项目 代码 下载 介绍 github地址&#xff1a;https://github.com/xuebinqin/DIS This is the repo for our new project Highly Accurate Dichotomous Image Segmentation 对应的paper是ECCV2022的一篇文章《Highly Accurate Dichotomous Imag…

Windows + Syslog-ng 发送eventlog 到Splunk indexer

1: 背景: 装了window Splunk universal forwarder 的 window server 要把event log 送到linux 的splunk indexer 上,由于网络的原因,不能直接发送数据到splunk indexer的话,要利用跳板机来实现: 2:架构: 3: 先说明每个类型server 上的安装情况: Window server: 安装S…

Tomcat 9.0.54源码环境搭建

一. 问什么要学习tomcat tomcat是目前非常流行的web容器&#xff0c;其性能和稳定性也是非常出色的&#xff0c;学习其框架设计和底层的实现&#xff0c;不管是使用、性能调优&#xff0c;还是应用框架设计方面&#xff0c;肯定会有很大的帮助 二. 运行源码 1.下载源…

DeepMind 推出 OPRO 技术,可用于优化 ChatGPT 提示

本心、输入输出、结果 文章目录 DeepMind 推出 OPRO 技术&#xff0c;可用于优化 ChatGPT 提示前言消息摘要OPRO的工作原理DeepMind的研究相关链接花有重开日&#xff0c;人无再少年实践是检验真理的唯一标准 DeepMind 推出 OPRO 技术&#xff0c;可用于优化 ChatGPT 提示 编辑…

股票池(三)

3-股票池 文章目录 3-股票池一. 查询股票池支持的类型二. 查询目前股票池对应的股票信息三 查询股票池内距离今天类型最少/最多的股票数据四. 查询股票的池统计信息 一. 查询股票池支持的类型 接口描述: 接口地址:/StockApi/stockPool/listPoolType 请求方式&#xff1a;GET…

Figma 是什么软件?为什么能被Adobe收购

很多人一定早就听说过Figma的名字了。看到很多设计同行推荐&#xff0c;用了很久&#xff0c;疯狂的安利朋友用。是什么让这么多设计师放弃了FigmaSketch的魅力&#xff1f;下面的内容将详细分享一些与Figma相关的知识点&#xff0c;并介绍这个经常听到但不熟悉的工具。 Figma…

MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

MindSpore基础教程&#xff1a;使用 MindCV和 Gradio 创建一个图像分类应用 官方文档教程使用已经弃用的MindVision模块&#xff0c;本文是对官方文档的更新 在这篇博客中&#xff0c;我们将探索如何使用 MindSpore 框架和 Gradio 库来创建一个基于深度学习的图像分类应用。我…

股票基础数据(二)

二. 股票基础数据 文章目录 二. 股票基础数据一. 查询股票融资信息数据二. 查询所有的股票信息三. 查询所有的股票类型信息四. 根据类型查询所有的股票数据信息五. 查询股票当前的基本信息六. 查询股票的K线图, 返回对应的 base64 信息七. 展示股票的K线图数据, 对应的是数据信…

大模型的实践应用7-阿里的多版本通义千问Qwen大模型的快速应用与部署

大家好,我是微学AI,今天给大家介绍一下大模型的实践应用7-阿里的多版本通义千问Qwen大模型的快速应用与部署。阿里云开源了Qwen系列模型,即Qwen-7B和Qwen-14B,以及Qwen的聊天模型,即Qwen-7B-Chat和Qwen-14B-Chat。通义千问模型针对多达 3 万亿个 token 的多语言数据进行了…

LLM之Prompt(二):清华提出Prompt 对齐优化技术BPO

论文题目&#xff1a;《Black-Box Prompt Optimization: Aligning Large Language Models without Model Training》 论文链接&#xff1a;https://arxiv.org/abs/2311.04155 github地址&#xff1a;https://github.com/thu-coai/BPO BPO背景介绍 最近&#xff0c;大型语言模…

米哈游大数据云原生实践

云布道师 近年来&#xff0c;容器、微服务、Kubernetes 等各项云原生技术的日渐成熟&#xff0c;越来越多的公司开始选择拥抱云原生&#xff0c;并将企业应用部署运行在云原生之上。随着米哈游业务的高速发展&#xff0c;大数据离线数据存储量和计算任务量增长迅速&#xff0c…

中大型企业网搭建(毕设类型)

毕业设计类别 某大学网络规划与部署 目录 某大学网络规划与部署 第一章项目概述 1.1 项目背景 1.2 网络需求分析 第二章网络总体设计方案 2.1 网络整体架构 2.2 网络设计思路 第三章 网络技术应用 3.1 DHCP 3.2 MSTP 3.3 VRRP 3.4 OSPF 3.5 VLAN 3.6 NAT 3.7 WLAN 3…

78基于matlab的BiLSTM分类算法,输出迭代曲线,测试集和训练集分类结果和混淆矩阵

基于matlab的BiLSTM分类算法&#xff0c;输出迭代曲线&#xff0c;测试集和训练集分类结果和混淆矩阵&#xff0c;程序有详细注释&#xff0c;数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。