六、项目实战---识别猫和狗

一、准备数据集

kagglecatsanddogs网上一搜一大堆,这里我就不上传了,需要的话可以私信
在这里插入图片描述
导包

import os
import zipfile
import random
import shutil
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from shutil import copyfile

猫和狗的照片各12500张

print(len(os.listdir('./temp/cats/')))
print(len(os.listdir('./temp/dogs/')))
"""
12500
12500
"""

生成训练数据文件夹和测试数据文件夹

import os
import zipfile
import random
import shutil
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from shutil import copyfiledef create_dir(file_dir):if os.path.exists(file_dir):print("True")shutil.rmtree(file_dir)#删除再创建os.makedirs(file_dir)else:os.makedirs(file_dir)cat_source_dir = "./temp/cats/"
train_cats_dir = "./temp/train/cats/"
test_cats_dir = "./temp/test/cats/"dot_source_dir = "./temp/dogs/"
train_dogs_dir = "./temp/train/dogs/"
test_dogs_dir = "./temp/test/dogs/"create_dir(train_cats_dir)#创建猫的训练集文件夹
create_dir(test_cats_dir)#创建猫的测试集文件夹
create_dir(train_dogs_dir)#创建狗的训练集文件夹
create_dir(test_dogs_dir)#创建狗的测试集文件夹"""
True
True
True
True
"""

在这里插入图片描述
将总的猫狗图像按9:1分成训练集和测试集,猫和狗各12500张
最终temp/train/catstemp/train/dogs两个文件夹下各12500 * 0.9=11250张
temp/test/catstemp/test/dogs这两个文件夹下各12500 * 0.1=1250张
cats和dogs为总共的猫狗图像
test和train为准备的数据集文件

import os
import zipfile
import random
import shutil
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from shutil import copyfiledef split_data(source,train,test,split_size):files = []for filename in os.listdir(source):file = source + filenameif os.path.getsize(file)>0:files.append(filename)else:print(filename + "is zero file,please ignoring")train_length = int(len(files)*split_size)test_length = int(len(files)-train_length)shuffled_set = random.sample(files,len(files))train_set = shuffled_set[0:train_length]test_set = shuffled_set[-test_length:]for filename in train_set:this_file = source + filenamedestination = train + filenamecopyfile(this_file,destination)for filename in test_set:this_file = source + filenamedestination = test + filenamecopyfile(this_file,destination)cat_source_dir = "./temp/cats/"
train_cats_dir = "./temp/train/cats/"
test_cats_dir = "./temp/test/cats/"dot_source_dir = "./temp/dogs/"
train_dogs_dir = "./temp/train/dogs/"
test_dogs_dir = "./temp/test/dogs/"split_size = 0.9
split_data(cat_source_dir,train_cats_dir,test_cats_dir,split_size)
split_data(dog_source_dir,train_dogs_dir,test_dogs_dir,split_size)

二、模型的搭建和训练

先对数据进行归一化操作,预处理进行优化一下

import os
import zipfile
import random
import shutil
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from shutil import copyfiletrain_dir = "./temp/train/"
train_datagen = ImageDataGenerator(rescale=1.0/255.0)#优化网络,先进行归一化操作
train_generator = train_datagen.flow_from_directory(train_dir,batch_size=100,class_mode='binary',target_size=(150,150))#二分类,训练样本的输入的要一致validation_dir = "./temp/test/"
validation_datagen = ImageDataGenerator(rescale=1.0/255.0)
validation_generator = validation_datagen.flow_from_directory(validation_dir,batch_size=100,class_mode='binary',target_size=(150,150))
"""
Found 22500 images belonging to 2 classes.
Found 2500 images belonging to 2 classes.
"""

搭建模型架构

model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(16,(3,3),activation='relu',input_shape=(150,150,3)),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(32,(3,3),activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(64,(3,3),activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(512,activation='relu'),tf.keras.layers.Dense(1,activation='sigmoid')
])
model.compile(optimizer=RMSprop(lr=0.001),loss='binary_crossentropy',metrics=['acc'])

训练模型
225:因为数据一共22500张,猫和狗各12500张,其对于训练集个11250张,故训练集共22500张,在预处理第一段代码中,batch_size=100设置了一批100个,故总共应该有225批
epochs=2:两轮,也就是所有的样本全部训练一次
每轮包含225批,每一批有100张样本

history = model.fit_generator(train_generator,epochs=2,#进行2轮训练,每轮255批verbose=1,#要不记录每次训练的日志,1表示记录validation_data=validation_generator)"""
Instructions for updating:
Use tf.cast instead.
Epoch 1/2
131/225 [================>.............] - ETA: 2:03 - loss: 0.7204 - acc: 0.6093
"""

history是模型运行过程的结果

三、分析训练结果

import matplotlib.image as mpimg
import matplotlib.pyplot as pltacc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))

epoch太少了,导致是直线,多训练几轮实际应该是折线图
准确率

plt.plot(epochs,acc,'r',"training accuracy")
plt.plot(epochs,val_acc,'b',"validation accuracy")
plt.title("training and validation accuracy")
plt.figure()

在这里插入图片描述
损失值

plt.plot(epochs,loss,'r',"training loss")
plt.plot(epochs,val_loss,'b',"validation loss")
plt.figure()

在这里插入图片描述

四、模型的使用验证

import numpy as np
from google.colab import files
from tensorflow.keras.preprocessing import imageuploaded = files.upload()
for fn in uploaded.keys():path = 'G:/Juptyer_workspace/Tensorflow_mooc/sjj/test/' + fn#该路径为要用模型测试的路径img = image.load_img(path,target_size=(150,150))x = image.img_to_array(img)#多维数组x = np.expand_dims(x,axis=0)#拉伸images = np.vstack([x])#水平方向拉直classes = model.predict(images,batch_size=10)print(classes[0])if classes[0]>0.5:print(fn + "it is a dog")else:print(fn + "it is a cat")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/377632.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

修改shell终端提示信息

PS1:就是用户平时的提示符。PS2:第一行没输完,等待第二行输入的提示符。公共设置位置:/etc/profile echo $PS1可以看到当前提示符设置例如:显示绿色,并添加时间和shell版本export PS1"\[\e[32m\][\uyou are right…

java 字谜_计算字谜的出现次数

java 字谜Problem statement: 问题陈述: Given a string S and a word C, return the count of the occurrences of anagrams of the word in the text. Both string and word are in lowercase letter. 给定一个字符串S和一个单词C ,返回该单词在文本…

Origin绘制热重TG和微分热重DTG曲线

一、导入数据 二、传到Origin中 三、热重TG曲线 temp为横坐标、mass为纵坐标 绘制折线图 再稍微更改下格式 字体加粗,Times New Roman 曲线宽度设置为2 横纵坐标数值格式为Times New Roman 根据实际情况改下横纵坐标起始结束位置 四、微分热重DTG曲线 点击曲线…

【嵌入式系统复习】嵌入式网络与协议栈

目录开放式系统互连模型总线通信的报文组形式以及传递方式报文组形式报文传递方式网络分配与调度嵌入式TCP/IP蓝牙技术蓝牙的节能状态纠错方案蓝牙协议栈开放式系统互连模型 ISO/OSI七层模型展示了网络结构与各层的功能。 应用层: 提供了终端用户程序和网络之间的应…

代码兼容、技巧

代码兼容、技巧 前端开发中,一个头疼的事,就是代码的不兼容,这里贴出自己在前端开发中的一些解决经验。除了其浏览器本身的BUG外,不建议使用CSS hack来解决兼容性问题的。 IE和FF下对”li“的的高度解析不同 可以不定义高度&#…

Windows Phone 7 自定义事件

在Windows Phone的应用开发里面,对于事件这种东西我们可以随处可见,系统本来就已经封装好了各种各样的事件机制,如按钮的单击事件等等的。在实际的开发中,我们需要自己去给相关的类自定义一些事件来满足业务的要求,特别…

getcwd函数_PHP getcwd()函数与示例

getcwd函数PHP getcwd()函数 (PHP getcwd() function) The full form of getcwd is "Get Current Working Directory", the function getcwd() is used to get the name of the current working directory, it does not accept any parameter and returns the curren…

十四、数据库的导出和导入的两种方法

一、以SQL脚本格式导出(推荐) 导出 右击需要导出的数据库,任务—>生成脚本 下一步 选择要导出的数据库,下一步 内容根据需求修改,没啥需求直接下一步 勾选 表 勾选需要导出的数据库中的表 选择脚本保存的路…

Apache中 RewriteCond 规则参数介绍

RewriteCond就像我们程序中的if语句一样,表示如果符合某个或某几个条件则执行RewriteCond下面紧邻的RewriteRule语句,这就是RewriteCond最原始、基础的功能,为了方便理解,下面来看看几个例子。RewriteEngine onRewriteCond %{HTT…

【C++grammar】文件I/O流的基本用法

目录1、输入输出类介绍1.C/C文件操作对比2.什么是流?3.C I/O流类层次4.带缓冲的输入输出5.gcc编译器cin.in_avail()2、向文件写入数据1.写文件小练习2.如何将信息同时输出到文件和屏幕?3、从文件读数据1.检测文件是否成功打开2.检测是否已到文件末尾3.读…

作业2 分支循环结构

书本第39页 习题2 1.输入2个整数num1和num2.计算并输出它们的和&#xff0c;差&#xff0c;积&#xff0c;商&#xff0c;余数。 //输入2个整数num1和num2.计算并输出它们的和&#xff0c;差&#xff0c;积&#xff0c;商&#xff0c;余数。//#include<stdio.h> int main…

求一个序列中最大的子序列_最大的斐波那契子序列

求一个序列中最大的子序列Problem statement: 问题陈述&#xff1a; Given an array with positive number the task to find the largest subsequence from array that contain elements which are Fibonacci numbers. 给定一个具有正数的数组&#xff0c;任务是从包含菲波纳…

十三、系统优化

系统整体框架图 程序运行进入纺织面料库存管理系统主页面 用户子系统功能演示&#xff1a; 1&#xff0c;点击用户登录进入用户登录页面&#xff0c;可以注册和找回密码 2&#xff0c;注册新用户&#xff0c;账号、密码、性别、手机号均有限制&#xff0c;用户注册需要按指定…

时间工具类[DateUtil]

View Code 1 package com.ly.util;2 3 import java.text.DateFormat;4 import java.text.ParseException;5 import java.text.SimpleDateFormat;6 import java.util.Calendar;7 import java.util.Date;8 9 /**10 * 11 * 功能描述12 * 13 * authorAdministrator14 * Date Jul 19…

JQuery delegate多次绑定的解决办法

我用delegate来控制分页&#xff0c;查询的时候会造成多次绑定 //前一页、后一页触发 1 $("body").delegate("#tableFoot a:not(a.btn)", "click", function () { 2 _options.page $(this).attr("page"); 3 loadTmpl(_option…

leetcode 45. 跳跃游戏 II 思考分析

题目 给定一个非负整数数组&#xff0c;你最初位于数组的第一个位置。 数组中的每个元素代表你在该位置可以跳跃的最大长度。 你的目标是使用最少的跳跃次数到达数组的最后一个位置。 示例: 输入: [2,3,1,1,4] 输出: 2 解释: 跳到最后一个位置的最小跳跃数是 2。 从下标为 …

C程序实现冒泡排序

Bubble Sort is a simple, stable, and in-place sorting algorithm. 气泡排序是一种简单&#xff0c;稳定且就地的排序算法。 A stable sorting algorithm is the one where two keys having equal values appear in the same order in the sorted output array as it is pre…

一、爬虫基本概念

一、爬虫根据使用场景分类 爬虫&#xff1a; 通过编写程序&#xff0c;模拟浏览器上网&#xff0c;让其去互联网上抓取数据的过程。 ① 通用爬虫&#xff1a;抓取系统重要的组成部分&#xff0c;抓取的是一整张页面的数据 ② 聚焦爬虫&#xff1a;建立在通用爬虫的基础之上&am…

经营你的iOS应用日志(二):异常日志

如果你去4S店修车&#xff0c;给小工说你的车哪天怎么样怎么样了&#xff0c;小工有可能会立即搬出一台电脑&#xff0c;插上行车电脑把日志打出来&#xff0c;然后告诉你你的车发生过什么故障。汽车尚且如此&#xff0c;何况移动互联网应用呢。 本文第一篇&#xff1a;经营你的…

Discuz 升级X3问题汇总整理

最近一段时间公司的社区垃圾帖数量陡然上涨&#xff0c;以至于社区首页的推荐版块满满都是垃圾帖的身影&#xff0c;为了进一步解决垃圾帖问题我们整整花了1天时间删垃圾贴&#xff0c;清除不良用户&#xff0c;删的手都酸了&#xff0c;可见垃圾帖的数量之多&#xff01;可耻的…