PyTorch面部表情识别项目实战

新书速览|PyTorch深度学习与企业级项目实战-CSDN博客

本书案例比较丰富、比较完整,可以用于课题研究、毕业论文素材,值得大家收藏。

人脸表情是人类信息交流的重要方式,它所包含的人体行为信息与人的情感状态、精神状态、健康状态等有着极为密切的关联。因此,通过对人脸表情的识别可以获得很多有价值的信息,从而分析人类的心理活动和精神状态,并为各种机器视觉和人工智能控制系统的应用提供解决方案。 所以本项目在研究人脸面部表情识别的过程中,借助人工智能算法的优势开展基于深度神经网络的图像分类实验。

借助MobileNetv3模型进行迁移学习,经过足够多次的迭代,分类准确率可以达到90%。这里使用的MMAFEDB数据集包含128 000MMA面部表情图像数据集 MMAFEDB数据集包含用于、验证和测试的目录。 每个目录包含对应7个面部表情类别的7个子目录。

MMAFEDB数据集数据说明如下:

  1. Angry:愤怒。
  2. Disgust:厌恶。
  3. Fear:恐惧。
  4. Happy:快乐。
  5. Neutral:中性。
  6. Sad:悲伤。
  7. Surprise:惊讶。

MMAFEDB数据集数据来源如下:

https://www.kaggle.com/mahmoudima/mma-facial-expression?select=MMAFEDB

相对重量级网络而言,轻量级网络的特点是参数少、计算量小、推理时间短,更适用于存储空间和功耗受限的场景,例如移动端嵌入式设备等边缘计算设备。因此,轻量级网络受到了广泛的关注,其中MobileNet可谓是其中的佼佼者。MobileNetV3经过了V1和V2前两代的积累,性能和速度都表现优异,受到学术界和工业界的追捧,无疑是轻量级网络的“抗把子”。

本项目这里加载预训练的MobileNetv3模型,由于预训练的模型与我们的任务需要不一样,因此需要修改最后的全连接层,将输出维度修改为我们的任务要求中的7个分类(7种面部表情)。 但是需要注意冻结其他层的参数,防止训练过程中将其改动,然后训练微调最后一层即可。

由于我们的任务为多分类问题,因此损失函数需要使用交叉熵损失函数(Cross Entropy),但是这里没有采用框架自带的损失函数,而是自己实现了一个损失函数。虽说大多数情况下,框架自带的损失函数就能够满足我们的需求,但是对于一些特定任务是无法满足的,需要我们进行自定义。自定义函数需要继承 nn.Module 子类,然后定义好参数和所需的变量,在forward方法中编写计算损失函数的过程,然后PyTorch会自动计算反向传播需要的梯度,不需要我们自己进行计算。

###########expression_on_face.py#####################
import torchvision
from torch import nn
import numpy as np
import os
import pickle
import torchfrom torchvision import transforms, datasets
import torchvision.models as models
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as pltepochs = 10
lr = 0.03
batch_size = 32
image_path = './The_expression_on_his_face/train'
save_path = './chk/expression_model.pkl'device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 1.数据转换
data_transform = {# 训练中的数据增强和归一化'train': transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪transforms.RandomHorizontalFlip(),  # 左右翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值方差归一化])
}# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path),transform=data_transform['train'])# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size,True)print('using {} images for training.'.format(len(train_dataset)))# 4.建立分类标签与索引的关系
cloth_list = train_dataset.class_to_idx
class_dict = {}
for key, val in cloth_list.items():class_dict[val] = key
with open('class_dict.pk', 'wb') as f:pickle.dump(class_dict, f)print(class_dict.values())# 自定义损失函数,需要在forward中定义过程
class MyLoss(nn.Module):def __init__(self):super(MyLoss, self).__init__()# 参数为传入的预测值和真实值,返回所有样本的损失值# 自己只需定义计算过程,反向传播PyTroch会自动记录def forward(self, pred, label):# pred:[32, 4] label:[32, 1] 第一维度是样本数exp = torch.exp(pred)tmp1 = exp.gather(1, label.unsqueeze(-1)).squeeze()tmp2 = exp.sum(1)softmax = tmp1 / tmp2log = -torch.log(softmax)return log.mean()# 5.加载MobileNetv3模型
#加载预训练好的MobileNetv3模型
model = torchvision.models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
# 冻结模型参数
for param in model.parameters():param.requires_grad = False# 修改最后一层的全连接层
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 7)# 将模型加载到cpu中
model = model.to('cpu')# criterion = nn.CrossEntropyLoss() # 损失函数
criterion = MyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 优化器# 6.模型训练
best_acc = 0  			# 最优精确率
best_model = None  		# 最优模型参数for epoch in range(epochs):model.train()running_loss = 0  	# 损失epoch_acc = 0  		# 每个epoch的准确率epoch_acc_count = 0 	# 每个epoch训练的样本数train_count = 0  	# 用于计算总的样本数,方便求准确率train_bar = tqdm(train_loader)for data in train_bar:images, labels = dataoptimizer.zero_grad()output = model(images.to(device))loss = criterion(output, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# 计算每个epoch正确的个数epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()train_count += len(images)# 每个epoch对应的准确率epoch_acc = epoch_acc_count / train_count# 打印信息print("【EPOCH: 】%s" % str(epoch + 1))print("训练损失为%s" % str(running_loss))print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')if epoch_acc > best_acc:best_acc = epoch_accbest_model = model.state_dict()# 在训练结束保存最优的模型参数if epoch == epochs - 1:# 保存模型torch.save(best_model, save_path)print('Finished Training')# 加载索引与标签映射字典
with open('class_dict.pk', 'rb') as f:class_dict = pickle.load(f)# 数据变换
data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor()])# 图片路径
img_path = r'./The_expression_on_his_face/test.jpg'# 打开图像
#为了避免通道数不匹配,使用灰度图像(1通道),使用RGB图像(3通道)
#解决方式:加载图像时,做一下转换
img = Image.open(img_path).convert('RGB')# 对图像进行变换
img = data_transform(img)plt.imshow(img.permute(1, 2, 0))
plt.show()# 将图像升维,增加batch_size维度
img = torch.unsqueeze(img, dim=0)# 获取预测结果
pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)

代码运行结果如下:

using 851 images for training.
dict_values(['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'])
train epoch[1/10] loss:1.943: 100%|██████████| 27/27 [00:10<00:00,  2.65it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】1
训练损失为54.592151284217834
训练精度为23.14%
train epoch[2/10] loss:1.906: 100%|██████████| 27/27 [00:10<00:00,  2.50it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】2
训练损失为51.95514786243439
训练精度为27.96%
train epoch[3/10] loss:1.873: 100%|██████████| 27/27 [00:10<00:00,  2.68it/s]
【EPOCH: 】3
训练损失为54.413649916648865
训练精度为29.96%
train epoch[4/10] loss:1.508: 100%|██████████| 27/27 [00:10<00:00,  2.60it/s]
【EPOCH: 】4
训练损失为51.14111852645874
训练精度为30.66%
train epoch[5/10] loss:1.816: 100%|██████████| 27/27 [00:13<00:00,  2.05it/s]
【EPOCH: 】5
训练损失为52.17003357410431
训练精度为32.07%
train epoch[6/10] loss:1.833: 100%|██████████| 27/27 [00:11<00:00,  2.31it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】6
训练损失为51.988134145736694
训练精度为31.37%
train epoch[7/10] loss:1.907: 100%|██████████| 27/27 [00:10<00:00,  2.49it/s]
【EPOCH: 】7
训练损失为51.65321123600006
训练精度为32.54%
train epoch[8/10] loss:1.993: 100%|██████████| 27/27 [00:11<00:00,  2.40it/s]
【EPOCH: 】8
训练损失为51.17294144630432
训练精度为33.72%
train epoch[9/10] loss:1.682: 100%|██████████| 27/27 [00:13<00:00,  2.02it/s]
【EPOCH: 】9
训练损失为52.21281313896179
训练精度为29.49%
train epoch[10/10] loss:1.926: 100%|██████████| 27/27 [00:15<00:00,  1.75it/s]
【EPOCH: 】10
训练损失为50.530142426490784
训练精度为32.43%
Finished Training
【预测结果分类】:angry

注意,这里只是列出计算10个周期,如果需要提高训练精度,需要增加训练周期100个以及扩大训练样本量。读者可以自行尝试。

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/47224.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker自建私有仓库遇到https问题

记录一下自己在自建Docker仓库的时候遇到的一个报错 问题 docker push registry:5000/library/centos:7 The push refers to repository [registry:5000/library/centos] Get "https://registry:5000/v2/": http: server gave HTTP response to HTTPS client解决办…

关于Ubuntu22.04中的Command ‘vim‘ not found, but can be installed with:

前言 在Ubuntu终端编辑文本内容时需要利用vim&#xff0c;但新安装的虚拟机中并未配置vim&#xff0c;本文记录了vim的安装过程。 打开终端后&#xff0c;在home目录中输入 vim test.txt但提示报错&#xff0c;提示我们没有找到vim&#xff0c;需要通过以下命令进行安装&…

yearrecord——一个类似痕迹墙的React数据展示组件

介绍一下自己做的一个类似于力扣个人主页提交记录和GitHub主页贡献记录的React组件。 下图分别是力扣个人主页提交记录和GitHub个人主页的贡献记录&#xff0c;像这样类似痕迹墙的形式可以比较直观且高效得展示一段时间内得数据记录。 然而要从0实现这个功能还是有一些麻烦得…

vue搜索框过滤--- computed、watch区别

vue组件选项&#xff08;component options&#xff09; 1. computed&#xff08;计算属性&#xff09; 用途&#xff1a;computed属性用于声明性地描述一些依赖其它响应式属性的数据。当依赖的响应式属性变化时&#xff0c;计算属性会自动重新求值。缓存&#xff1a;计算属性…

等保-Linux等保测评

等保-Linux等保测评 1.查看相应文件&#xff0c;账户xiaoming的密码设定多久过期 rootdengbap:~# chage -l xiaoming Last password change : password must be changed Password expires : pass…

数据库管理-第221期 Oracle的高可用-04(20240717)

数据库管理221期 2024-07-17 数据库管理-第221期 Oracle的高可用-04&#xff08;20240717&#xff09;1 ADG2 连接配置2.1 TNS2.2 JDBC2.3 JAVA连接池2.3.1 Oracle UCP2.3.2 应用连接池基础配置 总结 数据库管理-第221期 Oracle的高可用-04&#xff08;20240717&#xff09; 作…

mysql5.7版本字符集编码

默认character_set_databaselatin1 当你字段插入中文值的时候&#xff0c;会报错。 所以修改为了character_set_databaseutf8既可以。 character_set_server他的范围更大&#xff0c;属于服务器级别。

自然语言处理NLP--文本相似度面试题

自然语言处理NLP--文本相似度面试题 问题 1: 什么是文本相似度&#xff0c;如何在搜索系统中应用&#xff1f;问题 2: 如何使用TF-IDF进行文本相似度计算&#xff1f;问题 3: 使用Word2Vec进行文本相似度计算的过程是怎样的&#xff1f;问题 4: BERT如何用于文本相似度计算&…

LeetCode 852, 20, 51

目录 852. 山脉数组的峰顶索引题目链接标签二分思路代码 三分思路代码 20. 有效的括号题目链接标签思路代码 51. N 皇后题目链接标签思路回溯如何保证皇后之间无法互相攻击 代码 852. 山脉数组的峰顶索引 题目链接 852. 山脉数组的峰顶索引 标签 数组 二分查找 二分 思路…

网络安全-网络安全及其防护措施6

26. 访问控制列表&#xff08;ACL&#xff09; ACL的定义和作用 访问控制列表&#xff08;ACL&#xff09;是一种网络安全机制&#xff0c;用于控制网络设备上的数据包流量。通过ACL&#xff0c;可以定义允许或拒绝的流量&#xff0c;增强网络的安全性和管理效率。ACL通过在路…

逍遥模拟器安装Magisk和EDXPosed教程

资源下载&#xff1a; 逍遥模拟器安装Magisk和EDXPosed教程 - 多开鸭资源下载&#xff1a; MagiskEDXP教程文件 单独的逍遥模拟器使用的版本EDXPosed打包下载&#xff08;下载之后解压出来一共4个文件&#xff09;&#xff1a; 如果要按本教程安装就务必使用这里的安装包&…

翁恺-C语言程序设计-10-0. 说反话

10-0. 说反话 给定一句英语&#xff0c;要求你编写程序&#xff0c;将句中所有单词的顺序颠倒输出。 输入格式&#xff1a;测试输入包含一个测试用例&#xff0c;在一行内给出总长度不超过80的字符串。字符串由若干单词和若干空格组成&#xff0c;其中单词是由英文字母&#…

爬虫(一)——爬取快手无水印视频

前言 最近对爬虫比较感兴趣&#xff0c;于是浅浅学习了一些关于爬虫的知识。爬虫可以实现很多功能&#xff0c;非常有意思&#xff0c;在这里也分享给大家。由于爬虫能实现的功能太多&#xff0c;而且具体的实现方式也有所不同&#xff0c;所以这里开辟了一个新的系列——爬虫…

记录贴-芋道源码-环境搭建

文字讲解 链接: 芋道源码-环境搭建&#xff08;一&#xff09;后端 链接: 芋道源码-环境搭建&#xff08;二&#xff09;前端 链接: 基于FastGPT和芋道源码挑战一句话生成代码 视频讲解 链接: 芋道源码零基础启动教程&#xff08;上&#xff09; 链接: 芋道源码零基础启动教程…

bs4取值技巧的详细介绍

1. 基本取值方法&#xff1a; find()&#xff1a; 查找第一个匹配的标签。soup.find(h1) # 查找第一个<h1>标签find_all()&#xff1a; 查找所有匹配的标签。soup.find_all(a) # 查找所有<a>标签select()&#xff1a; 使用CSS选择器查找标签。soup.select(.item…

进阶篇:如何使用 Stable Diffusion 优化神经网络训练

进阶篇&#xff1a;如何使用 Stable Diffusion 优化神经网络训练 一、引言 随着深度学习的发展&#xff0c;神经网络模型在各个领域取得了显著的成果。然而&#xff0c;在训练复杂神经网络时&#xff0c;模型的稳定性和优化问题始终是一个挑战。Stable Diffusion&#xff08;…

用AI生成Springboot单元测试代码太香了

你好&#xff0c;我是柳岸花开。 在当今软件开发过程中&#xff0c;单元测试已经成为保证代码质量的重要环节。然而&#xff0c;编写单元测试代码却常常让开发者头疼。幸运的是&#xff0c;随着AI技术的发展&#xff0c;我们可以利用AI工具来自动生成单元测试代码&#xff0c;极…

基于单片机的停车场车位管理系统设计

1.简介 停车场车位管理系统是日常中随处可见的一种智能化车位管理技术&#xff0c;使用该技术可以提高车位管理效率&#xff0c;从而减轻人员车位管理工作负荷。本系统集成车牌识别、自动放行、自助缴费等技术&#xff0c;并且具备车位占用状态实时监测与车位数量实时统计、查询…

SQL进阶--条件分支

一、问题引入 在SQL中&#xff0c;虽然不像某些编程语言&#xff08;如C、Java或Python&#xff09;那样直接支持if-else这样的条件分支语句&#xff0c;但它提供了几种方式来实现条件逻辑&#xff0c;这些方式主要通过CASE语句、IF()函数&#xff08;在某些数据库如MySQL中&a…

C# - WINFORM - 控件树遍历与特定控件操作方案概述

1.全局控件遍历 实现了一个通用函数EnumerateAllControls, 它可以遍历指定窗体或容器内的所有控件&#xff0c;打印出每个控件的名称和类型。 private void EnumerateAllControls(Control parent) {foreach (Control control in parent.Controls){Console.WriteLine($"C…