PyTorch面部表情识别项目实战

新书速览|PyTorch深度学习与企业级项目实战-CSDN博客

本书案例比较丰富、比较完整,可以用于课题研究、毕业论文素材,值得大家收藏。

人脸表情是人类信息交流的重要方式,它所包含的人体行为信息与人的情感状态、精神状态、健康状态等有着极为密切的关联。因此,通过对人脸表情的识别可以获得很多有价值的信息,从而分析人类的心理活动和精神状态,并为各种机器视觉和人工智能控制系统的应用提供解决方案。 所以本项目在研究人脸面部表情识别的过程中,借助人工智能算法的优势开展基于深度神经网络的图像分类实验。

借助MobileNetv3模型进行迁移学习,经过足够多次的迭代,分类准确率可以达到90%。这里使用的MMAFEDB数据集包含128 000MMA面部表情图像数据集 MMAFEDB数据集包含用于、验证和测试的目录。 每个目录包含对应7个面部表情类别的7个子目录。

MMAFEDB数据集数据说明如下:

  1. Angry:愤怒。
  2. Disgust:厌恶。
  3. Fear:恐惧。
  4. Happy:快乐。
  5. Neutral:中性。
  6. Sad:悲伤。
  7. Surprise:惊讶。

MMAFEDB数据集数据来源如下:

https://www.kaggle.com/mahmoudima/mma-facial-expression?select=MMAFEDB

相对重量级网络而言,轻量级网络的特点是参数少、计算量小、推理时间短,更适用于存储空间和功耗受限的场景,例如移动端嵌入式设备等边缘计算设备。因此,轻量级网络受到了广泛的关注,其中MobileNet可谓是其中的佼佼者。MobileNetV3经过了V1和V2前两代的积累,性能和速度都表现优异,受到学术界和工业界的追捧,无疑是轻量级网络的“抗把子”。

本项目这里加载预训练的MobileNetv3模型,由于预训练的模型与我们的任务需要不一样,因此需要修改最后的全连接层,将输出维度修改为我们的任务要求中的7个分类(7种面部表情)。 但是需要注意冻结其他层的参数,防止训练过程中将其改动,然后训练微调最后一层即可。

由于我们的任务为多分类问题,因此损失函数需要使用交叉熵损失函数(Cross Entropy),但是这里没有采用框架自带的损失函数,而是自己实现了一个损失函数。虽说大多数情况下,框架自带的损失函数就能够满足我们的需求,但是对于一些特定任务是无法满足的,需要我们进行自定义。自定义函数需要继承 nn.Module 子类,然后定义好参数和所需的变量,在forward方法中编写计算损失函数的过程,然后PyTorch会自动计算反向传播需要的梯度,不需要我们自己进行计算。

###########expression_on_face.py#####################
import torchvision
from torch import nn
import numpy as np
import os
import pickle
import torchfrom torchvision import transforms, datasets
import torchvision.models as models
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as pltepochs = 10
lr = 0.03
batch_size = 32
image_path = './The_expression_on_his_face/train'
save_path = './chk/expression_model.pkl'device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 1.数据转换
data_transform = {# 训练中的数据增强和归一化'train': transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪transforms.RandomHorizontalFlip(),  # 左右翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值方差归一化])
}# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path),transform=data_transform['train'])# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size,True)print('using {} images for training.'.format(len(train_dataset)))# 4.建立分类标签与索引的关系
cloth_list = train_dataset.class_to_idx
class_dict = {}
for key, val in cloth_list.items():class_dict[val] = key
with open('class_dict.pk', 'wb') as f:pickle.dump(class_dict, f)print(class_dict.values())# 自定义损失函数,需要在forward中定义过程
class MyLoss(nn.Module):def __init__(self):super(MyLoss, self).__init__()# 参数为传入的预测值和真实值,返回所有样本的损失值# 自己只需定义计算过程,反向传播PyTroch会自动记录def forward(self, pred, label):# pred:[32, 4] label:[32, 1] 第一维度是样本数exp = torch.exp(pred)tmp1 = exp.gather(1, label.unsqueeze(-1)).squeeze()tmp2 = exp.sum(1)softmax = tmp1 / tmp2log = -torch.log(softmax)return log.mean()# 5.加载MobileNetv3模型
#加载预训练好的MobileNetv3模型
model = torchvision.models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
# 冻结模型参数
for param in model.parameters():param.requires_grad = False# 修改最后一层的全连接层
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 7)# 将模型加载到cpu中
model = model.to('cpu')# criterion = nn.CrossEntropyLoss() # 损失函数
criterion = MyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 优化器# 6.模型训练
best_acc = 0  			# 最优精确率
best_model = None  		# 最优模型参数for epoch in range(epochs):model.train()running_loss = 0  	# 损失epoch_acc = 0  		# 每个epoch的准确率epoch_acc_count = 0 	# 每个epoch训练的样本数train_count = 0  	# 用于计算总的样本数,方便求准确率train_bar = tqdm(train_loader)for data in train_bar:images, labels = dataoptimizer.zero_grad()output = model(images.to(device))loss = criterion(output, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# 计算每个epoch正确的个数epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()train_count += len(images)# 每个epoch对应的准确率epoch_acc = epoch_acc_count / train_count# 打印信息print("【EPOCH: 】%s" % str(epoch + 1))print("训练损失为%s" % str(running_loss))print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')if epoch_acc > best_acc:best_acc = epoch_accbest_model = model.state_dict()# 在训练结束保存最优的模型参数if epoch == epochs - 1:# 保存模型torch.save(best_model, save_path)print('Finished Training')# 加载索引与标签映射字典
with open('class_dict.pk', 'rb') as f:class_dict = pickle.load(f)# 数据变换
data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor()])# 图片路径
img_path = r'./The_expression_on_his_face/test.jpg'# 打开图像
#为了避免通道数不匹配,使用灰度图像(1通道),使用RGB图像(3通道)
#解决方式:加载图像时,做一下转换
img = Image.open(img_path).convert('RGB')# 对图像进行变换
img = data_transform(img)plt.imshow(img.permute(1, 2, 0))
plt.show()# 将图像升维,增加batch_size维度
img = torch.unsqueeze(img, dim=0)# 获取预测结果
pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)

代码运行结果如下:

using 851 images for training.
dict_values(['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'])
train epoch[1/10] loss:1.943: 100%|██████████| 27/27 [00:10<00:00,  2.65it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】1
训练损失为54.592151284217834
训练精度为23.14%
train epoch[2/10] loss:1.906: 100%|██████████| 27/27 [00:10<00:00,  2.50it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】2
训练损失为51.95514786243439
训练精度为27.96%
train epoch[3/10] loss:1.873: 100%|██████████| 27/27 [00:10<00:00,  2.68it/s]
【EPOCH: 】3
训练损失为54.413649916648865
训练精度为29.96%
train epoch[4/10] loss:1.508: 100%|██████████| 27/27 [00:10<00:00,  2.60it/s]
【EPOCH: 】4
训练损失为51.14111852645874
训练精度为30.66%
train epoch[5/10] loss:1.816: 100%|██████████| 27/27 [00:13<00:00,  2.05it/s]
【EPOCH: 】5
训练损失为52.17003357410431
训练精度为32.07%
train epoch[6/10] loss:1.833: 100%|██████████| 27/27 [00:11<00:00,  2.31it/s]0%|          | 0/27 [00:00<?, ?it/s]【EPOCH: 】6
训练损失为51.988134145736694
训练精度为31.37%
train epoch[7/10] loss:1.907: 100%|██████████| 27/27 [00:10<00:00,  2.49it/s]
【EPOCH: 】7
训练损失为51.65321123600006
训练精度为32.54%
train epoch[8/10] loss:1.993: 100%|██████████| 27/27 [00:11<00:00,  2.40it/s]
【EPOCH: 】8
训练损失为51.17294144630432
训练精度为33.72%
train epoch[9/10] loss:1.682: 100%|██████████| 27/27 [00:13<00:00,  2.02it/s]
【EPOCH: 】9
训练损失为52.21281313896179
训练精度为29.49%
train epoch[10/10] loss:1.926: 100%|██████████| 27/27 [00:15<00:00,  1.75it/s]
【EPOCH: 】10
训练损失为50.530142426490784
训练精度为32.43%
Finished Training
【预测结果分类】:angry

注意,这里只是列出计算10个周期,如果需要提高训练精度,需要增加训练周期100个以及扩大训练样本量。读者可以自行尝试。

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/47224.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于Ubuntu22.04中的Command ‘vim‘ not found, but can be installed with:

前言 在Ubuntu终端编辑文本内容时需要利用vim&#xff0c;但新安装的虚拟机中并未配置vim&#xff0c;本文记录了vim的安装过程。 打开终端后&#xff0c;在home目录中输入 vim test.txt但提示报错&#xff0c;提示我们没有找到vim&#xff0c;需要通过以下命令进行安装&…

yearrecord——一个类似痕迹墙的React数据展示组件

介绍一下自己做的一个类似于力扣个人主页提交记录和GitHub主页贡献记录的React组件。 下图分别是力扣个人主页提交记录和GitHub个人主页的贡献记录&#xff0c;像这样类似痕迹墙的形式可以比较直观且高效得展示一段时间内得数据记录。 然而要从0实现这个功能还是有一些麻烦得…

等保-Linux等保测评

等保-Linux等保测评 1.查看相应文件&#xff0c;账户xiaoming的密码设定多久过期 rootdengbap:~# chage -l xiaoming Last password change : password must be changed Password expires : pass…

mysql5.7版本字符集编码

默认character_set_databaselatin1 当你字段插入中文值的时候&#xff0c;会报错。 所以修改为了character_set_databaseutf8既可以。 character_set_server他的范围更大&#xff0c;属于服务器级别。

LeetCode 852, 20, 51

目录 852. 山脉数组的峰顶索引题目链接标签二分思路代码 三分思路代码 20. 有效的括号题目链接标签思路代码 51. N 皇后题目链接标签思路回溯如何保证皇后之间无法互相攻击 代码 852. 山脉数组的峰顶索引 题目链接 852. 山脉数组的峰顶索引 标签 数组 二分查找 二分 思路…

逍遥模拟器安装Magisk和EDXPosed教程

资源下载&#xff1a; 逍遥模拟器安装Magisk和EDXPosed教程 - 多开鸭资源下载&#xff1a; MagiskEDXP教程文件 单独的逍遥模拟器使用的版本EDXPosed打包下载&#xff08;下载之后解压出来一共4个文件&#xff09;&#xff1a; 如果要按本教程安装就务必使用这里的安装包&…

爬虫(一)——爬取快手无水印视频

前言 最近对爬虫比较感兴趣&#xff0c;于是浅浅学习了一些关于爬虫的知识。爬虫可以实现很多功能&#xff0c;非常有意思&#xff0c;在这里也分享给大家。由于爬虫能实现的功能太多&#xff0c;而且具体的实现方式也有所不同&#xff0c;所以这里开辟了一个新的系列——爬虫…

用AI生成Springboot单元测试代码太香了

你好&#xff0c;我是柳岸花开。 在当今软件开发过程中&#xff0c;单元测试已经成为保证代码质量的重要环节。然而&#xff0c;编写单元测试代码却常常让开发者头疼。幸运的是&#xff0c;随着AI技术的发展&#xff0c;我们可以利用AI工具来自动生成单元测试代码&#xff0c;极…

基于单片机的停车场车位管理系统设计

1.简介 停车场车位管理系统是日常中随处可见的一种智能化车位管理技术&#xff0c;使用该技术可以提高车位管理效率&#xff0c;从而减轻人员车位管理工作负荷。本系统集成车牌识别、自动放行、自助缴费等技术&#xff0c;并且具备车位占用状态实时监测与车位数量实时统计、查询…

Java SpringAOP简介

简介 官方介绍&#xff1a; SpringAOP的全称是&#xff08;Aspect Oriented Programming&#xff09;中文翻译过来是面向切面编程&#xff0c;AOP是OOP的延续&#xff0c;是软件开发中的一个热点&#xff0c;也是Spring框架中的一个重要内容&#xff0c;是函数式编程的一种衍生…

SpringBatch文件读写ItemWriter,ItemReader使用详解

SpringBatch文件读写ItemWriter&#xff0c;ItemReader使用详解 1. ItemReaders 和 ItemWriters1.1. ItemReader1.2. ItemWriter1.3. ItemProcessor 2.FlatFileItemReader 和 FlatFileItemWriter2.1.平面文件2.1.1. FieldSet 2.2. FlatFileItemReader2.3. FlatFileItemWriter 3…

AI 绘画|Midjourney设计Logo提示词

你是否已经看过许多别人分享的 MJ 咒语&#xff0c;却仍无法按照自己的想法画图&#xff1f;通过学习 MJ 的提示词逻辑后&#xff0c;你将能够更好地理解并创作自己的“咒语”。本文将详细拆解使用 MJ 设计 Logo 的逻辑&#xff0c;让你在阅读后即可轻松上手&#xff0c;制作出…

打包一个自己的Vivado IP核

写在前面 模块复用是逻辑设计人员必须掌握的一个基本功&#xff0c;通过将成熟模块打包成IP核&#xff0c;可实现重复利用&#xff0c;避免重复造轮子&#xff0c;大幅提高我们的开发效率。 接下来将之前设计的串口接收模块和串口发送模块打包成IP核&#xff0c;再分别调用…

【深度学习】FaceChain-SuDe,免训练,AI换脸

https://arxiv.org/abs/2403.06775 FaceChain-SuDe: Building Derived Class to Inherit Category Attributes for One-shot Subject-Driven Generation 摘要 最近&#xff0c;基于主体驱动的生成技术由于其个性化文本到图像生成的能力&#xff0c;受到了广泛关注。典型的研…

深度学习入门——神经网络

前言 神经网络可以帮助自动化设定权重 具体地讲&#xff0c;神经网络的一个重要性质是它可以自动地从数据中学习到合适的权重参数 从感知机到神经网络 神经网络的例子 中间层aka隐藏层 复习感知机 偏置b 并没有被画出来。如果要明确地表示出b&#xff0c;可以像图3-3那样做…

Large Language Model系列之一:语言模型与表征学习(Language Models and Representation Learning)

语言模型与表征学习&#xff08;Language Models and Representation Learning&#xff09; 1 语言模型 N-Gram模型 from collections import defaultdictsentences [The swift fox jumps over the lazy dog.,The swift river flows under the ancient bridge.,The swift br…

华为1000人校园实验记录

在这里插入代码片1000人校园区网设计 1、配置Eth-trunk实现链路冗余 vlan 900 管理WLAN #接入SW8 操作&#xff1a;sys undo in en sysname JR-SW8 int Eth-Trunk 1 mode lacp-static trunkport g0/0/1 0/0/2 port link-type trunk port trunk allow-pass vlan 200 900 qu vla…

模拟器小程序/APP抓包(Reqable+MUMU模拟器)

一、使用adb连接上MUMU模拟器 打开多开器点击ADB图标 连接模拟器端口&#xff1a; adb connect 127.0.0.1:16384列出已连接的设备&#xff1a; adb devices正常会显示MuMu的设备已连接 二、下载Reqable 1.下载链接&#xff1a;客户端下载 | Reqable 2.文档链接&#xff1a;…

redis基本类型和订阅

redis-cli -h <host> -p <port> -a <password> 其中&#xff0c;< host>是Redis服务器的主机名或IP地址&#xff0c;< port>是Redis服务器的端口号&#xff0c;< password>是Redis服务器的密码&#xff08;如果有的话&#xff09;。 set …

LLM基础模型系列:Prompt-Tuning

------->更多内容&#xff0c;请移步“鲁班秘笈”&#xff01;&#xff01;<------ 大型预训练语言模型的规模不断扩大&#xff0c;在许多自然语言处理 &#xff08;NLP&#xff09; 基准测试中取得了最先进的结果。自GPT和BERT开发以来&#xff0c;标准做法一直是在下游…