卷积神经网络(CNN):艺术作品识别

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keraswarnings.filterwarnings("ignore")#忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlibdata_dir = "./27-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 3776
batch_size = 16
img_height = 224
img_width  = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 3021 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 755 files for validation.
class_names = train_ds.class_names
print(class_names)
['Alfred_Sisley', 'Edgar_Degas', 'Francisco_Goya', 'Marc_Chagall', 'Pablo_Picasso', 'Paul_Gauguin', 'Peter_Paul_Rubens', 'Rembrandt', 'Titian', 'Vincent_van_Gogh']

2. 检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(16, 224, 224, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(2000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(2000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

在这里插入图片描述

三、构建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activation# Load pre-trained model
base_model = keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(img_width,img_height,3))for layer in base_model.layers:layer.trainable = True# Add layers at the end
X = base_model.output
X = Flatten()(X)X = Dense(512, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)X = Dense(16, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)output = Dense(len(class_names), activation='softmax')(X)model = Model(inputs=base_model.input, outputs=output)

四、编译

optimizer = tf.keras.optimizers.Adam(lr=1e-4)model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateSchedulerNO_EPOCHS = 15
PATIENCE  = 5
VERBOSE   = 1# 设置动态学习率
# annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)# 
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=VERBOSE,save_best_only=True,save_weights_only=True)
train_model  = model.fit(train_ds,epochs=NO_EPOCHS,verbose=1,validation_data=val_ds,callbacks=[earlystopper, checkpointer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']loss = train_model.history['loss']
val_loss = train_model.history['val_loss']epochs_range = range(len(acc))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metricsdef test_accuracy_report(model):print(metrics.classification_report(val_label, val_pre, target_names=class_names)) score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model)
											precision    recall  f1-score   supportAlfred_Sisley       0.76      0.98      0.86        53Edgar_Degas       0.89      0.94      0.92       132Francisco_Goya       0.89      0.69      0.77        70Marc_Chagall       0.85      0.94      0.89        48Pablo_Picasso       0.89      0.74      0.81        90Paul_Gauguin       0.94      0.84      0.89        57
Peter_Paul_Rubens       0.71      0.86      0.78        29Rembrandt       0.66      0.92      0.77        48Titian       0.90      0.72      0.80        65Vincent_van_Gogh       0.88      0.87      0.87       163accuracy                           0.85       755macro avg       0.84      0.85      0.84       755weighted avg       0.86      0.85      0.85       755Loss function: 0.5761227011680603, accuracy: 0.8490065932273865

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/196949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS开发员,月薪过万不是梦

最近爆出消息,安卓与鸿蒙将不再兼容!这意味着华为已经搭建了完整的鸿蒙生态,不再需要依赖于安卓生态。据统计,鸿蒙生态设备已经达到了7亿台,开发者人数也达到了220万人 此外,华为对鸿蒙系统的性能和体验有…

服务器感染了.halo勒索病毒,如何确保数据文件完整恢复?

尊敬的读者: 随着数字化的快速发展,网络安全威胁也愈演愈烈。其中,.halo勒索病毒是一种带有恶意目的的恶意软件,对用户的数据构成巨大威胁。本文将深入介绍.halo勒索病毒的特征,探讨如何有效恢复被其加密的数据&#…

spring boot配置文件格式 ${}和@@

${}和都是springboot引用属性变量的方式&#xff0c;具体区别与用法&#xff1a; 1、${}常用于pom.xml&#xff0c;和 src/main/resources/application.properties等默认配置文件的属性变量引用。 语法为&#xff1a;field_name${field_value} pom.xml示例&#xff1a; <…

Kotlin学习之04

集合的变换操作 filter&#xff1a;保留满足条件的元素 map&#xff1a;集合中所有元素映射到其他元素构成新集合&#xff08;就是转换每个元素&#xff0c;然后再组成一个新的结果&#xff09; flatMap&#xff1a;集合中所有元素映射到新集合并合并这些集合得到新集合&…

语音识别从入门到精通——1-基本原理解释

文章目录 语音识别算法1. 语音识别简介1.1 **语音识别**1.1.1 自动语音识别1.1.2 应用 1.2 语音识别流程1.2.1 预处理1.2.2 语音检测和断句1.2.3 音频场景分析1.2.4 识别引擎(**语音识别的模型**)1. 传统语音识别模型2. 端到端的语音识别模型基于Transformer的ASR模型基于CNN的…

unity学习笔记18

模型文件属性简介 1.动画类型&#xff1a;一共有四种&#xff1a;无 表示没有动画&#xff0c;旧版 就表示这个模型文件里面的动画片段可以用animation组件来播放的&#xff0c;最后两个 ”泛型“和“人形”都是animator组件来播放的。区别是泛型支持所有类型的动画播放&#x…

【开题报告】基于SpringBoot的献爱心公益平台的设计与实现

1.研究背景 随着社会的进步和发展&#xff0c;公益事业在社会中扮演着越来越重要的角色。公益活动能够帮助弱势群体解决问题&#xff0c;改善社会环境&#xff0c;推动社会进步。然而&#xff0c;传统的公益活动组织和管理方式存在一些问题&#xff0c;如信息不透明、资源分散…

CPP-SCNUOJ-Problem P29. [算法课指针] 颜色分类,小白偏题超简单方法

Problem P29. [算法课指针] 颜色分类 给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums &#xff0c;原地对它们进行排序&#xff0c;使得相同颜色的元素相邻&#xff0c;并按照红色、白色、蓝色顺序排列。 我们使用整数 0、 1 和 2 分别表示红色、白色和蓝色。 输入 …

浅析Hotspot的经典7种垃圾收集器原理特点与组合搭配

# 浅析Hotspot的经典7种垃圾收集器原理特点与组合搭配 HotSpot共有7种垃圾收集器&#xff0c;3个新生代垃圾收集器&#xff0c;3个老年代垃圾收集器&#xff0c;以及G1&#xff0c;一共构成7种可供选择的垃圾收集器组合。 新生代与老年代垃圾收集器之间形成6种组合&#xff0c…

Java读取邮件并生成邮件文件eml

1.JavaMail的关键对象 Properties&#xff1a;属性对象 Properties props new Properties(); props.put("mail.smtp.host", "smtp.sina.com.cn"); props.put("mail.smtp.auth", "true");针对不同的的邮件协议&#xff0c;JavaMail规…

Tecplot绘制涡结构(Q准则)

文章目录 目的步骤1步骤2步骤3步骤4步骤5步骤6结果 目的 Tecplot绘制涡结构(Q准则判别)并用温度进行染色 Q准则计算公式 步骤1 步骤2 步骤3 步骤4 步骤5 步骤6 结果

C#的方法使用

为何使用方法&#xff1a; 在C#方法是一组执行特定任务的语句的组合。使用方法可以提高代码的可重用性和模块化。 以下是在C#中使用方法的步骤&#xff1a; 1. 方法的定义&#xff1a; 使用 method 关键字来定义一个方法&#xff0c;然后指定方法的访问修饰符&#xff08;如 …

鸿蒙4.0开发笔记之ArkTS装饰器语法基础之发布者订阅者模式@Provide和@Consume(十三)

1、定义 在鸿蒙系统的官方语言ArkTS中&#xff0c;有一套类似于发布者和订阅的模式&#xff0c;使用Provide、Consume两个装饰器来实现。 Provide、Consume&#xff1a;Provide/Consume装饰的变量用于跨组件层级&#xff08;多层组件&#xff09;同步状态变量&#xff0c;可以…

5.【自动驾驶与机器人中的SLAM技术】2D点云的scan matching算法 和 检测退化场景的思路

目录 1. 基于优化的点到点/线的配准2. 对似然场图像进行插值&#xff0c;提高匹配精度3. 对二维激光点云中会对SLAM功能产生退化场景的检测4. 在诸如扫地机器人等这样基于2D激光雷达导航的机器人&#xff0c;如何处理悬空/低矮物体5. 也欢迎大家来我的读书号--过千帆&#xff0…

Android wifi 框架以及Enable流程

Android P相比于Android O的变化 多了WifiStateMachinePrime&#xff08;状态机的前处理机制&#xff09;&#xff0c;wifiService的相关cmd 不再是直接send 给WifiStateMachine&#xff0c;而是被送到WifiStateMachinePrime先进行处理后&#xff0c;再送往WifiStateMachine也…

Java微信支付对帐,微信账单下载并读取到实体Bean,并保存至数据库

最近公司的项目需要微信对帐功能&#xff0c;这里展示了简单的微信账单下载并读取到数据库方法&#xff0c;有问题或者更好的想法的可以在评论区交流哟。 一、依赖 <!-- 微信支付 --> <dependency><groupId>com.github.wechatpay-apiv3</groupId><…

【qml入门教程系列】:qml property使用介绍

作者:令狐掌门 技术交流QQ群:675120140 博客地址:https://mingshiqiang.blog.csdn.net/ 文章目录 属性的定义property基本用法属性变更事件通知属性绑定属性别名只读属性默认属性 default property访问和修改属性方式1:使用setProperty方法方式2:使用QQmlContext设置属性自定…

全网最新最全的自动化测试教程:python+pytest接口自动化-测试函数、测试类/测试方法的封装

前言 在pythonpytest 接口自动化系列中&#xff0c;我们之前的文章基本都没有将代码进行封装&#xff0c;但实际编写自动化测试脚本中&#xff0c;我们都需要将测试代码进行封装&#xff0c;才能被测试框架识别执行。 例如单个接口的请求代码如下&#xff1a; import reques…

深入理解JVM虚拟机第二十七篇:详解JVM当中InvokeDynamic字节码指令,Java是动态类型语言么?

😉😉 学习交流群: ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚📚 工作微信:BigTreeJava 拉你进微信群,免费领取! 🍎🍎4:本文章内容出自上述:Sp…

YOLO5Face算法解读

论文&#xff1a;YOLO5Face: Why Reinventing a Face Detector 链接&#xff1a;https://arxiv.org/abs/2105.12931v1 机构&#xff1a;深圳神目科技&LinkSprite Technologies&#xff08;美国&#xff09; 开源代码&#xff1a;https://github.com/deepcam-cn/yolov5-face…