Pytorch-ResNet-50 网络表情识别项目(深度学习)

ResNet-50 网络表情识别

    • 1. 导入依赖库
    • 2. 加载中文字体文件
    • 3. 设置图像尺寸和训练参数
    • 4. 数据增强和预处理
    • 5. 加载数据集
    • 6. 检查数据维度
    • 7. 定义ResNet50模型
    • 8. 初始化模型、损失函数和优化器
    • 9. 训练和测试函数
    • 10. 训练和测试模型
    • 11. 保存模型
    • 12. 评估数据保存和可视化
  • 原码

本项目采用的是FER-2013数据集加上博主的一些其他数据集整合的
FER-2013数据集链接如下
https://www.kaggle.com/datasets/msambare/fer2013

1. 导入依赖库

代码开始处导入了多个Python库,用于图像处理、数学运算、深度学习模型的构建和训练。

import cv2
import numpy as np
from PIL import ImageFont, ImageDraw
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
from PIL import Image

2. 加载中文字体文件

加载中文字体文件以便在图像上绘制中文标签。

font_path = "SourceHanSansSC-Bold.otf"
font = ImageFont.truetype(font_path, 30)

3. 设置图像尺寸和训练参数

定义了图像的目标尺寸、训练轮数和每批的样本数量。

img_size = 48
targetx = 48
targety = 48
epochs = 50   
batch_size = 64

4. 数据增强和预处理

定义了一个转换流程,包括调整图像大小、随机水平翻转和转换为张量。

transform = transforms.Compose([transforms.Resize((targetx, targety)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])

5. 加载数据集

使用ImageFolder加载训练和测试数据集,并通过DataLoader进行批量加载。

train_dataset = datasets.ImageFolder(root="./FER-2013/train" , transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.ImageFolder(root="./FER-2013/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

6. 检查数据维度

定义了一个函数check_data_dimensions来检查数据加载器返回的批次中图像和标签的维度。

def check_data_dimensions(loader):for images, labels in loader:print("Batch image size:", images.shape)print("Batch label size:", labels.shape)break

7. 定义ResNet50模型

创建了一个ResNet50模型,使用预训练权重,并替换最后的全连接层以适应表情识别的类别数。

class ResNet50Model(nn.Module):def __init__(self, num_classes=7):super(ResNet50Model, self).__init__()self.resnet50 = models.resnet50(pretrained=True)num_ftrs = self.resnet50.fc.in_featuresself.resnet50.fc = nn.Linear(num_ftrs, num_classes)def forward(self, x):return self.resnet50(x)

8. 初始化模型、损失函数和优化器

初始化了ResNet50模型,定义了损失函数和优化器,并根据GPU的可用性将它们移动到GPU。

model = ResNet50Model(num_classes=7)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)

9. 训练和测试函数

定义了训练和测试函数,用于迭代模型的训练和评估。

def train(model, train_loader, criterion, optimizer, device):# ...def test(model, test_loader, criterion, device):# ...

10. 训练和测试模型

在指定的轮数内迭代训练和测试模型,并打印每个epoch的损失和准确率。

for epoch in range(num_epochs):# 训练和测试过程...

11. 保存模型

训练完成后,保存ResNet50模型到文件。

torch.save(model.state_dict(), 'resnet50_final.pth')

12. 评估数据保存和可视化

将训练与测试的损失及准确率保存到.npy文件中,并使用matplotlib绘制损失和准确率图表。

np.save('train_losses.npy', train_losses)
# ...
plt.show()

原码

import cv2
import numpy as np
from PIL import ImageFont, ImageDraw
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
from PIL import Image
#%%
# 加载中文字体文件
font_path = "SourceHanSansSC-Bold.otf"
font = ImageFont.truetype(font_path, 30)
#%%
img_size = 48 #original size of the image
targetx = 48
targety = 48
epochs = 50   
batch_size = 64
# 数据增强和预处理
transform = transforms.Compose([transforms.Resize((targetx, targety)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])# 加载数据集
train_dataset = datasets.ImageFolder(root="./FER-2013/train" , transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.ImageFolder(root="./FER-2013/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#%%
# 检查数据维度
def check_data_dimensions(loader):for images, labels in loader:print("Batch image size:", images.shape)  # 批次中图像的维度print("Batch label size:", labels.shape)  # 批次中标签的维度break  # 只查看第一个批次即可# 查看训练数据和测试数据的维度
print("Training data dimensions:")
check_data_dimensions(train_loader)
print("\nTesting data dimensions:")
check_data_dimensions(test_loader)
#%%
import torch.nn as nn
import torchvision.models as modelsclass ResNet50Model(nn.Module):def __init__(self, num_classes=7):super(ResNet50Model, self).__init__()#加载预训练的ResNet-50模型self.resnet50 = models.resnet50(pretrained=True)#获取ResNet-50模型的最后一层全连接层的输入特征数量num_ftrs = self.resnet50.fc.in_features。num_ftrs = self.resnet50.fc.in_features#将ResNet-50模型的最后一层全连接层替换为一个新的全连接层,输出特征数量设置为num_classesself.resnet50.fc = nn.Linear(num_ftrs, num_classes)#forward方法定义了前向传播过程。# 在这个简单的类中,仅仅是调用self.resnet50(x),将输入x传递给预训练的ResNet-50模型进行前向传播。def forward(self, x):return self.resnet50(x)# Example usage
model = ResNet50Model(num_classes=7)
print(model)
#%%
# 初始化模型、损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=0.01)# 如果GPU可用,移动模型和损失函数到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
loss_fn.to(device)# 训练模型
num_epochs = 60train_losses = []
train_accuracies = []for epoch in range(num_epochs):model.train()total_loss = 0total_correct = 0total_samples = 0for data, targets in train_loader:# 将输入和标签移动到GPU(如果可用)data, targets = data.to(device), targets.to(device)# 前向传播outputs = model(data)loss = loss_fn(outputs, targets)# 零梯度optimizer.zero_grad()# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()# 计算准确率_, predicted = outputs.max(1)total_correct += predicted.eq(targets).sum().item()total_samples += targets.size(0)avg_loss = total_loss / len(train_loader)train_losses.append(avg_loss)# 计算准确率avg_accuracy = total_correct / total_samplestrain_accuracies.append(avg_accuracy)print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}')print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {avg_accuracy}')# 保存最后训练完的模型
torch.save(model.state_dict(), 'resnet50_final.pth')
print("最后训练完的模型已保存!")test_losses = []
test_accuracies = []
# 测试模型
for epoch in range(num_epochs):model.eval()total_loss = 0total_correct = 0total_samples = 0for data, targets in test_loader:data, targets = data.to(device), targets.to(device)outputs = model(data)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()total_loss += loss.item()# 计算准确率_, predicted = outputs.max(1)total_correct += predicted.eq(targets).sum().item()total_samples += targets.size(0)avg_loss = total_loss / len(test_loader)test_losses.append(avg_loss)# 计算准确率avg_accuracy = total_correct / total_samplestest_accuracies.append(avg_accuracy)print(f'Epoch {epoch+1}/{num_epochs}, test_Loss: {avg_loss}')print(f'Epoch {epoch+1}/{num_epochs}, test_Accuracy: {avg_accuracy}')# 评估数据保存到.npy文件
np.save('train_losses.npy', train_losses)
np.save('train_accuracies.npy', train_accuracies)
np.save('test_losses.npy', test_losses)
np.save('test_accuracies.npy', test_accuracies)# 绘制损失图
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()# 绘制准确率图
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy')
plt.title('Training Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34214.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

欧盟指控苹果应用商店规则非法压制竞争,面临巨额罚款风险

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Excel 宏录制与VBA编程 —— 14、使用VBA处理Excel事件

简介 若希望特定事件处理程序在触发特定事件时运行,可以为 Application 对象编写事件处理程序。 Application 对象的事件处理程序是全局的,这意味着只要 Microsoft Excel 处于打开状态,事件处理程序将在发生相应的事件时运行,而不…

计算机网络 交换机的基本配置

一、理论知识 1.三种模式: ①用户模式:当登录路由器后,系统自动进入用户EXEC命令模式。 例如: Router> 在用户模式状态下,用户只能查看路由器的连接状态和基本信息,访问其他网络和主机&#xff0c…

Dubbo 中查看动态生成的 class 文件

我们知道,在 Dubbo 框架中,对外发布服务时,会把每个服务提供者的实现类通过 Javassist 包装为一个 Wrapper 类,以减少反射调用开销。这个 Wrapper 是动态生成的,默认是不输出 class 文件的,如果想查看生成的…

数据库管理与数据库语句

数据库用户管理及高级sql语句 数据库管理 数据库用户管理 mysql权限表 在mysql中mysql库中的user表是最重要的权限表,记录允许连接到服务器的账号信息以及全局权限, 在mysql库中db和host表也是重要的权限表 db表中存储了用户对某个数据库的操作权限&…

Hyper-V 简介

Hyper-V 是微软开发的一种虚拟化技术,它允许在单个物理服务器上创建和运行多个虚拟机(VM),每个虚拟机都可以运行不同的操作系统和应用程序。Hyper-V 技术是 Windows Server 的一部分,并且也作为独立产品 Microsoft Hyp…

DataGrip 2024 mac/win版:让数据库管理更简单

JetBrains DataGrip 2024 是一款专为数据库开发者设计的集成开发环境(IDE),它凭借其卓越的性能和丰富的功能,为数据库管理提供了前所未有的便利。 DataGrip 2024 mac/win版获取 DataGrip 2024 支持几乎所有主流的关系型数据库管理系统,如 My…

一个角阀引起的思考和启示

给大家说一说,遇到的这个情况很有可能在你家也会出现。      发现马桶角阀处滴水,当时也没在意,擦干了几次之后,发现还是在滴,看了一下软管是新的,应该不会漏水,那就是角阀出问题了。    …

浅谈目标检测之YOLO(You Only Look Once)v1

简介:本文章要介绍的YOLOv1算法,它与之前的目标检测算法如R-CNN等不同,R-NN等目标检测算法是一种两阶段(two-stage)算法,步骤为先在图片上生成候选框,然后利用分类器对这些候选框进行逐一的判断…

记录一下MATLAB优化器出现的问题和解决

今天MATLAB优化器出了点问题。我想了想,决定解决一下,不然后面项目没有办法进行下去。 我忘了截图了。 具体来说,是出现了下面的问题。 Gurobi: Cplex: 在上次为了强化学习调整了Pytoch环境以后(不知道是不是这个原因&#…

Scala的Trait与Java的Interface:相似性与差异性深度解析

在面向对象编程中,接口(Interface)和特质(Trait)是实现代码复用和模块化设计的重要工具。Java和Scala作为两种流行的编程语言,它们对接口和特质有着不同的实现和理念。本文将深入探讨Scala中的Trait与Java中…

仓库管理系统09--修改用户密码

1、添加窗体 2、窗体布局控件 UI设计这块还是传统的表格布局&#xff0c;采用5行2列 3、创建viewmodel 4、前台UI绑定viewmodel 这里要注意属性绑定和命令绑定及命令绑定时传递的参数 <Window x:Class"West.StoreMgr.Windows.EditPasswordWindow"xmlns"http…

制造业工厂的管理到底有多难

一、引言 随着全球经济的不断发展&#xff0c;制造业作为实体经济的核心&#xff0c;对国家的经济增长起着至关重要的作用。然而&#xff0c;制造业工厂的管理却是一项复杂而艰巨的任务。本文将深入探讨制造业工厂管理所面临的挑战&#xff0c;并提出相应的应对策略。 二、制造…

TCP: 传输控制协议

TCP: 传输控制协议 TCP的服务TCP 的首部小结 本系列文章旨在巩固网络编程理论知识&#xff0c;后续将结合实际开展深入理解的文章。 TCP的服务 T C P和U D P都使用相同的网络层&#xff08;I P&#xff09;&#xff0c;T C P却向应用层提供与U D P完全不同的服务。 T C P提供一…

【已解决】Python报错:AttributeError: module ‘json‘ has no attribute ‘loads‘

&#x1f60e; 作者介绍&#xff1a;我是程序员行者孙&#xff0c;一个热爱分享技术的制能工人。计算机本硕&#xff0c;人工制能研究生。公众号&#xff1a;AI Sun&#xff0c;视频号&#xff1a;AI-行者Sun &#x1f388; 本文专栏&#xff1a;本文收录于《AI实战中的各种bug…

离散数学上机报告

一、 实验题目&#xff08;编程上机题&#xff09; &#xff08;1&#xff09; 从键盘分别对P、Q输入数据1、0&#xff0c;分别输出P∧Q、P∨Q、P→Q的逻辑结果值。 &#xff08;2&#xff09; 从键盘输入无向图的邻接矩阵&#xff0c;判断输出该图结点最大度数、最小度数。 &a…

synchronized关键字和ReentrantLock在不同jdk版本中性能哪个高?该怎么选择呢?

synchronized关键字和ReentrantLock在不同JDK版本中的性能差异经历了显著的变化。早期&#xff0c;在JDK 1.5及以前的版本中&#xff0c;ReentrantLock通常提供了更好的性能&#xff0c;主要是因为synchronized关键字的实现较为简单&#xff0c;没有太多的优化&#xff0c;导致…

图片如何去水印,分享4个小妙招,手把手教会你!

作为一个经常逛社区网站下载表情包、头像的人&#xff0c;遇到的一个大难题就是图片有水印。如何才能快速去除水印&#xff1f;询问了一圈身边朋友&#xff0c;搜集了各种资料&#xff0c;小编整理了4个超好用的方法。 如果大家和小编一样&#xff0c;能坐着就不站着&#xff0…

PHP 高频面试题

PHP 初级面试题及详细解答 1. 什么是 PHP&#xff0c;PHP 的全称是什么&#xff1f; 解答: PHP 是一种流行的开源脚本语言&#xff0c;特别适合用于 web 开发并可以嵌入 HTML。PHP 的全称是 “PHP: Hypertext Preprocessor”&#xff0c;它最初代表的是 “Personal Home Page…

Python简单实现自动识别并填加验证码

实现自动识别网页中的验证码并填写&#xff0c;需要结合使用网络爬虫技术、图像识别&#xff08;OCR&#xff09;&#xff0c;以及可能的浏览器自动化工具&#xff08;如Selenium&#xff09;。以下简单实现一下如何结合这些技术来实现这一目标&#xff1a; 步骤 1: 获取验证码…