python+pytorch人脸表情识别

概述

基于深度学习的人脸表情识别,数据集采用公开数据集fer2013,可直接运行,效果良好,可根据需求修改训练代码,自己训练模型。

详细

一、概述

本项目以PyTorch为框架,搭建卷积神经网络模型,训练后可直接调用py文件进行人脸检测与表情识别,默认开启摄像头实时检测识别。效果良好,可根据个人需求加以修改。

二、演示效果:

image.png

三、实现过程

1. 搭建网络

def __init__(self):super(FaceCNN, self).__init__()# 第一次卷积、池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),  # 卷积层# BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定nn.BatchNorm2d(num_features=64),  # 归一化nn.RReLU(inplace=True),  # 激活函数nn.MaxPool2d(kernel_size=2, stride=2),  # 最大值池化)# 第二次卷积、池化self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(num_features=128),nn.RReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)# 第三次卷积、池化self.conv3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(num_features=256),nn.RReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)# 参数初始化self.conv1.apply(gaussian_weights_init)self.conv2.apply(gaussian_weights_init)self.conv3.apply(gaussian_weights_init)# 全连接层self.fc = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=256 * 6 * 6, out_features=4096),nn.RReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(in_features=4096, out_features=1024),nn.RReLU(inplace=True),nn.Linear(in_features=1024, out_features=256),nn.RReLU(inplace=True),nn.Linear(in_features=256, out_features=7),)

2. 训练模型

# 载入数据并分割batch
train_loader = data.DataLoader(train_dataset, batch_size)
# 损失函数
loss_function = nn.CrossEntropyLoss()
# 学习率衰减
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
device = "cuda" if torch.cuda.is_available() else 'cpu'
# 构建模型
model = FaceCNN().to(device)
# 优化器
optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=wt_decay)
# 逐轮训练
for epoch in range(epochs):if (epoch + 1) % 10 == 0:learning_rate = learning_rate * 0.1# 记录损失值loss_rate = 0# scheduler.step() # 学习率衰减model.train()  # 模型训练for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播output = model.forward(images)# 误差计算loss_rate = loss_function(output, labels)# 误差的反向传播loss_rate.backward()# 更新参数optimizer.step()

3. 模型预测

with torch.no_grad():pred = model(face)probability = torch.nn.functional.softmax(pred, dim=1)probability = np.round(probability.cpu().detach().numpy(), 3)max_prob = np.max(probability)# print(max_prob)predicted = classes[torch.argmax(pred[0])]cv2.putText(img, predicted + " " + str(max_prob), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 255, 0), 1, cv2.LINE_AA)
cv2.imshow('frame', img)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/136634.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WebGL智慧城市软件项目

WebGL开发智慧城市项目时,需要考虑多个方面,包括技术、隐私、安全和可持续性。以下是一些需要注意的关键问题,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作。 1.隐私和数据安全…

[100天算法】-定长子串中元音的最大数目(day 67)

题目描述 给你字符串 s 和整数 k 。请返回字符串 s 中长度为 k 的单个子字符串中可能包含的最大元音字母数。英文中的 元音字母 为(a, e, i, o, u)。示例 1:输入:s "abciiidef", k 3 输出:3 解释&#xf…

Java 设计模式——解释器模式

目录 1.概述2.结构3.案例实现3.1.抽象表达式类3.2.终结表达式3.3.非终结表达式3.4.环境类3.5.测试 4.优缺点5.使用场景 1.概述 (1)如下图,设计一个软件用来进行加减计算。我们第一想法可能就是使用工具类,提供对应的加法和减法的…

el-table实现单选框+隐藏多选框+回显

0 效果 1 单选框 2 隐藏多选框 3 回显 回显数据要在el-table中添加两个属性

Django文件配置、request对象、连接MySQL、ORM

文章目录 Django静态文件及相关配置静态文件前言静态文件相关配置 form表单request对象request请求结果GET请求POST请求 pycharm连接数据库Django连接MySQLDjango ORM简介 Django静态文件及相关配置 在此篇博客我将以一个用户登录页面来引入相关知识 首先我们先编写一个html页面…

【JavaEESpring】Spring Web MVC⼊⻔

Spring Web MVC 1. 什么是 Spring Web MVC1.1 什么是 MVC ?1.2 是什么 Spring MVC? 2. 学习 Spring MVC2.1 建立连接2.2 请求2.3 响应 3. 相关代码链接 1. 什么是 Spring Web MVC 官⽅对于 Spring MVC 的描述是这样的: 1.1 什么是 MVC ? MVC 是 Model View C…

Java算法(六):模拟评委打分案例 方法封装抽离实现 程序的节流处理

Java算法(六) 评委打分 需求: 在编程竞赛中,有 6 个评委为参赛选手打分,分数为 0 - 100 的整数分。 选手的最后得分为:去掉一个最高分和一个最低分后 的 4个评委的平均值。 注意程序的节流 package c…

qt-C++笔记之Qt中的时间与定时器

qt-C笔记之Qt中的时间与定时器 code review! 文章目录 qt-C笔记之Qt中的时间与定时器一.Qt中的日期时间数据1.1.QTime:获取当前时间1.2.QDate:获取当前日期1.3.QDateTime:获取当前日期和时间1.4.QTime类详解1.5.QDate类详解1.6..QDateTime类…

17 Linux 中断

一、Linux 中断简介 1. Linux 中断 API 函数 ① 中断号 每个中断都有一个中断号,通过中断号可以区分出不同的中断。在 Linux 内核中使用一个 int 变量表示中断号。 ② request_irq 函数 在 Linux 中想要使用某个中断是需要申请的,request_irq 函数就是…

Docker 学习路线 13:部署容器

部署容器是使用Docker和容器化管理应用程序更高效、易于扩展和确保跨环境一致性性能的关键步骤。本主题将为您概述如何部署Docker容器以创建和运行应用程序。 概述 Docker容器是轻量级、可移植且自我包含的环境,可以运行应用程序及其依赖项。部署容器涉及启动、管…

【Python】数据分析案例:世界杯数据可视化

文章目录 前期数据准备导入数据 分析:世界杯中各队赢得的比赛数分析:先打或后打的比赛获胜次数分析:世界杯中的抛硬币决策分析:2022年T20世界杯的最高得分者分析:世界杯比赛最佳球员奖分析:最适合先击球或追…

苹果Ios系统app应用程序开发者如何获取IPA文件签名证书时需要注意什么?

今天呢想和大家介绍介绍苹果App开发者如何获取IPA文件签名证书的步骤和注意事项。对于苹果应用程序开发者而言,获取IPA文件签名证书是发布应用程序至App Store的重要步骤之一。签名证书能够确保应用程序的安全性和可信度,并使其能够在设备上正确运行。 …

VR全景技术,为养老院宣传推广带来全新变革

现如今,人口老龄化的现象加剧,养老服务行业也如雨后春笋般不断冒头,但是市面上各式的养老院被包装的五花八门,用户实际参访后却差强人意,如何更好的给父母挑选更为舒心的养老环境呢?可以利用720度VR全景技术…

iOS代码混淆----自动

先大致解释一下“编译"、"反编译": 编译:就是把千千万万行字符串(也叫代码,或者源文件),变成010101010101(机器码,也叫目标代码) 编译过程:预处理-编译-汇编-链接 我的脚本运行在预处理阶段。 反编…

什么是数据库?数据库有哪些基本分类和主要特点?

数据库是以某种有组织的方式存储的数据集合。本文从数据库的基本概念出发,详细解读了数据库的主要类别和基本特点,并就大模型时代备受瞩目的数据库类型——向量数据库进行了深度剖析,供大家在了解数据库领域的基本概念时起到一点参考作用。 …

计算机视觉驾驶行为识别应用简述

一、什么是计算机视觉识别? 计算机视觉识别是一种基于图像处理和机器学习的人工智能应用技术,可以用于多个场景。常见应用场景包括人脸识别、场景识别、OCR识别以及商品识别等。今天以咱们国产系统豌豆云为例,为大家梳理一下在车辆驾驶行为中…

Kafka -- 架构、分区、副本

1、Kafka的架构: 1、producer:消息的生产者 2、consumer:消息的消费者 3、broker:kafka集群的服务者,一个broker就是一个节点,主要是负责处理消息的读、写的请求和存储消息。在kafka cluster中包含很多的br…

雷达波形之一——LFM线性调频波形

文章目录 前言一、线性调频信号的形式1、原理2、时域表达式3、频域表达式 二、MATLAB 仿真1、涅菲尔积分①、MATLAB 源码②、仿真结果 2、LFM①、MATLAB 源码②、仿真结果1) 典型 LFM 波形,实部2) 典型 LFM 波形,虚部3) LFM 波形的典型谱 前言 线性调频…

亚马逊云科技海外服务器初体验

目录 前言亚马逊云科技海外服务器概述注册使用流程实例创建性能表现用户体验服务支持初体验总结 前言 随着云原生技术的飞速发展,越来越多的企业和开发者选择云服务器来作为自己的使用工具,云原生技术的发展也促进了云服务厂商的产品发展,所…

Java自学第6课:电商项目(2)

1 创建工具类并连接数据库 在工程src右键单击new,新建util包 再创建DBUtil类 数据库交互需要有数据库支持的包,这是官方给出的类库。 先声明1个代码块 // 静态代码块 只加载1次static{try {Class.forName("com.mysql.jdbc.Driver");} catch (…