卷积神经网络(CNN):艺术作品识别

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keraswarnings.filterwarnings("ignore")#忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlibdata_dir = "./27-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 3776
batch_size = 16
img_height = 224
img_width  = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 3021 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 755 files for validation.
class_names = train_ds.class_names
print(class_names)
['Alfred_Sisley', 'Edgar_Degas', 'Francisco_Goya', 'Marc_Chagall', 'Pablo_Picasso', 'Paul_Gauguin', 'Peter_Paul_Rubens', 'Rembrandt', 'Titian', 'Vincent_van_Gogh']

2. 检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(16, 224, 224, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(2000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(2000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

在这里插入图片描述

三、构建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activation# Load pre-trained model
base_model = keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(img_width,img_height,3))for layer in base_model.layers:layer.trainable = True# Add layers at the end
X = base_model.output
X = Flatten()(X)X = Dense(512, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)X = Dense(16, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)output = Dense(len(class_names), activation='softmax')(X)model = Model(inputs=base_model.input, outputs=output)

四、编译

optimizer = tf.keras.optimizers.Adam(lr=1e-4)model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateSchedulerNO_EPOCHS = 15
PATIENCE  = 5
VERBOSE   = 1# 设置动态学习率
# annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)# 
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=VERBOSE,save_best_only=True,save_weights_only=True)
train_model  = model.fit(train_ds,epochs=NO_EPOCHS,verbose=1,validation_data=val_ds,callbacks=[earlystopper, checkpointer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']loss = train_model.history['loss']
val_loss = train_model.history['val_loss']epochs_range = range(len(acc))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metricsdef test_accuracy_report(model):print(metrics.classification_report(val_label, val_pre, target_names=class_names)) score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model)
											precision    recall  f1-score   supportAlfred_Sisley       0.76      0.98      0.86        53Edgar_Degas       0.89      0.94      0.92       132Francisco_Goya       0.89      0.69      0.77        70Marc_Chagall       0.85      0.94      0.89        48Pablo_Picasso       0.89      0.74      0.81        90Paul_Gauguin       0.94      0.84      0.89        57
Peter_Paul_Rubens       0.71      0.86      0.78        29Rembrandt       0.66      0.92      0.77        48Titian       0.90      0.72      0.80        65Vincent_van_Gogh       0.88      0.87      0.87       163accuracy                           0.85       755macro avg       0.84      0.85      0.84       755weighted avg       0.86      0.85      0.85       755Loss function: 0.5761227011680603, accuracy: 0.8490065932273865

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/196949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS开发员,月薪过万不是梦

最近爆出消息,安卓与鸿蒙将不再兼容!这意味着华为已经搭建了完整的鸿蒙生态,不再需要依赖于安卓生态。据统计,鸿蒙生态设备已经达到了7亿台,开发者人数也达到了220万人 此外,华为对鸿蒙系统的性能和体验有…

语音识别从入门到精通——1-基本原理解释

文章目录 语音识别算法1. 语音识别简介1.1 **语音识别**1.1.1 自动语音识别1.1.2 应用 1.2 语音识别流程1.2.1 预处理1.2.2 语音检测和断句1.2.3 音频场景分析1.2.4 识别引擎(**语音识别的模型**)1. 传统语音识别模型2. 端到端的语音识别模型基于Transformer的ASR模型基于CNN的…

unity学习笔记18

模型文件属性简介 1.动画类型:一共有四种:无 表示没有动画,旧版 就表示这个模型文件里面的动画片段可以用animation组件来播放的,最后两个 ”泛型“和“人形”都是animator组件来播放的。区别是泛型支持所有类型的动画播放&#x…

浅析Hotspot的经典7种垃圾收集器原理特点与组合搭配

# 浅析Hotspot的经典7种垃圾收集器原理特点与组合搭配 HotSpot共有7种垃圾收集器,3个新生代垃圾收集器,3个老年代垃圾收集器,以及G1,一共构成7种可供选择的垃圾收集器组合。 新生代与老年代垃圾收集器之间形成6种组合&#xff0c…

Tecplot绘制涡结构(Q准则)

文章目录 目的步骤1步骤2步骤3步骤4步骤5步骤6结果 目的 Tecplot绘制涡结构(Q准则判别)并用温度进行染色 Q准则计算公式 步骤1 步骤2 步骤3 步骤4 步骤5 步骤6 结果

鸿蒙4.0开发笔记之ArkTS装饰器语法基础之发布者订阅者模式@Provide和@Consume(十三)

1、定义 在鸿蒙系统的官方语言ArkTS中,有一套类似于发布者和订阅的模式,使用Provide、Consume两个装饰器来实现。 Provide、Consume:Provide/Consume装饰的变量用于跨组件层级(多层组件)同步状态变量,可以…

5.【自动驾驶与机器人中的SLAM技术】2D点云的scan matching算法 和 检测退化场景的思路

目录 1. 基于优化的点到点/线的配准2. 对似然场图像进行插值,提高匹配精度3. 对二维激光点云中会对SLAM功能产生退化场景的检测4. 在诸如扫地机器人等这样基于2D激光雷达导航的机器人,如何处理悬空/低矮物体5. 也欢迎大家来我的读书号--过千帆&#xff0…

Android wifi 框架以及Enable流程

Android P相比于Android O的变化 多了WifiStateMachinePrime(状态机的前处理机制),wifiService的相关cmd 不再是直接send 给WifiStateMachine,而是被送到WifiStateMachinePrime先进行处理后,再送往WifiStateMachine也…

Java微信支付对帐,微信账单下载并读取到实体Bean,并保存至数据库

最近公司的项目需要微信对帐功能&#xff0c;这里展示了简单的微信账单下载并读取到数据库方法&#xff0c;有问题或者更好的想法的可以在评论区交流哟。 一、依赖 <!-- 微信支付 --> <dependency><groupId>com.github.wechatpay-apiv3</groupId><…

全网最新最全的自动化测试教程:python+pytest接口自动化-测试函数、测试类/测试方法的封装

前言 在pythonpytest 接口自动化系列中&#xff0c;我们之前的文章基本都没有将代码进行封装&#xff0c;但实际编写自动化测试脚本中&#xff0c;我们都需要将测试代码进行封装&#xff0c;才能被测试框架识别执行。 例如单个接口的请求代码如下&#xff1a; import reques…

深入理解JVM虚拟机第二十七篇:详解JVM当中InvokeDynamic字节码指令,Java是动态类型语言么?

😉😉 学习交流群: ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚📚 工作微信:BigTreeJava 拉你进微信群,免费领取! 🍎🍎4:本文章内容出自上述:Sp…

YOLO5Face算法解读

论文&#xff1a;YOLO5Face: Why Reinventing a Face Detector 链接&#xff1a;https://arxiv.org/abs/2105.12931v1 机构&#xff1a;深圳神目科技&LinkSprite Technologies&#xff08;美国&#xff09; 开源代码&#xff1a;https://github.com/deepcam-cn/yolov5-face…

如何定位当生产环境CPU飙升的时候的问题

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、排查思路 二、预防CPU飙升 三、总结 前言 在当今的信息化时代&#xff0c;计算机系统在各行各业都发挥着重要的作用。然而&a…

DeepStream--测试PCB-Defect-Detection

GitHub - clintonoduor/PCB-Defect-Detection-using-Deepstream: PCB defect detection using deepstream & YoloV5我参考了了这个代码&#xff0c;作者基于YoloV5&#xff0c;训练一个电路板检测的模型&#xff0c;训练数据集来自https://robotics.pkusz.edu.cn/resources…

BearPi Std 板从入门到放弃 - 后天篇(1)(I2C1 读取 光照强度)

简介 基于 BearPi Std 板从入门到放弃 - 引气入体篇&#xff08;5&#xff09;(printf打印到串口), 通过I2C接口&#xff0c;读取光照强度并打印到串口; 开发板 &#xff1a; Bearpi Std(小熊派标准板) 主芯片: STM32L431RCT6 LED : PC13 \ 推挽输出即可 \ 高电平点亮 串口: U…

SpringBoot整合RocketMQ

SpringBoot整合RocketMQ 文章目录 SpringBoot整合RocketMQ下载安装SpringBoot整合RocketMQ导坐标改配置实现消息生产与消费 下载安装 教程地址&#xff1a;https://www.bilibili.com/video/BV15b4y1a7yG/?p132&spm_id_from333.1007.top_right_bar_window_history.content.…

11. 哈希冲突

上一节提到&#xff0c;通常情况下哈希函数的输入空间远大于输出空间&#xff0c;因此理论上哈希冲突是不可避免的。比如&#xff0c;输入空间为全体整数&#xff0c;输出空间为数组容量大小&#xff0c;则必然有多个整数映射至同一桶索引。 哈希冲突会导致查询结果错误&#…

大数据技术学习笔记(四)—— HDFS

目录 1 HDFS 概述1.1 HDFS 背景与定义1.2 HDFS 优缺点1.3 HDFS 组成架构1.4 HDFS 文件块大小 2 HDFS的shell操作2.1 上传2.2 下载2.3 HDFS直接操作 3 HDFS的客户端操作3.1 Windows 环境准备3.2 获取 HDFS 的客户端连接对象3.3 HDFS文件上传3.4 HDFS文件下载3.5 HDFS删除文件和目…

最强AI之风袭来,你爱了吗?

2017年&#xff0c;柯洁同阿尔法狗人机大战&#xff0c;AlphaGo以3比0大获全胜&#xff0c;一代英才泪洒当场...... 2019年&#xff0c;换脸哥视频“杨幂换朱茵”轰动全网&#xff0c;时至今日AI换脸仍热度只增不减&#xff1b; 2022年&#xff0c;ChatGPT一经发布便轰动全球&a…

【hacker送书第8期】Java从入门到精通(第7版)

第8期图书推荐 内容简介编辑推荐作者简介图书目录参与方式 内容简介 《Java从入门到精通&#xff08;第7版&#xff09;》从初学者角度出发&#xff0c;通过通俗易懂的语言、丰富多彩的实例&#xff0c;详细讲解了使用Java语言进行程序开发需要掌握的知识。全书分为4篇共24章&a…