把一个dataset的表放在另一个dataset里面_现在开始:用你的Mac训练和部署一个图像分类模型...

e5593ccfe5a769493288d74fcc57d179.png

可能有些同学学习机器学习的时候比较迷茫,不知道该怎么上手,看了很多经典书籍介绍的各种算法,但还是不知道怎么用它来解决问题,就算知道了,又发现需要准备环境、准备训练和部署的机器,啊,好麻烦。

今天,我来给大家介绍一种容易上手的方法,给你现成的样本和代码,按照步骤操作,就可以在自己的 Mac 上体验运用机器学习的全流程啦~~~

下面的 Demo, 最终的效果是给定一张图片,可以预测图片的类别。比如我们训练模型用的样本是猫啊狗啊,那模型能学到的认识的就是猫啊狗啊, 如果用的训练样本是按钮啊搜索框啊,那模型能学到的认识的就是这个按钮啊搜索框啊。

如果想了解用机器学习是怎么解决实际问题的,可以看这篇:如何使用深度学习识别UI界面组件?从问题定义、算法选型、样本准备、模型训练、模型评估、模型服务部署、到模型应用都有介绍。

环境准备

安装 Anaconda

下载地址: https://www.anaconda.com/products/individual

49bd417fe5db82befe869f9f12430f01.png

安装成功后,在终端命令行执行以下命令,使环境变量立即生效:

$ source ~/.bashrc

可以执行以下命令,查看环境变量

$ cat ~/.bashrc

可以看到 anaconda 的环境变量已经自动添加到 .bashrc 文件了

9b22435303d64d388c19e441e54b4bf8.png

执行以下命令:

$ conda list

可以看到 Anaconda 中有很多已经安装好的包,如果有使用到这些包的就不需要再安装了,python 环境也装好了。

05d33de37982d8d4ecf973e505fa000d.png

注意:如果安装失败,重新安装,在提示安装在哪里时,选择「更改安装位置」,安装位置选择其他地方不是用默认的,安装在哪里自己选择,可以放在「应用程序」下。

bf7235b765144b9d625757f2695d91e8.png

安装相关依赖

anaconda 中没有 keras、tensorflow 和 opencv-python, 需要单独安装。

$ pip install keras
$ pip install tensorflow
$ pip install opencv-python

样本准备

这里只准备了 4 个分类: button、keyboard、searchbar、switch, 每个分类 200 个左右的样本。

cd11b65d9fed71d34d961e1984c5efce.png

fed0e71aa3ce229f347bf1a67349b414.png

6f3ecc3a95afcdff6c0ec32609bf4fc7.png

6a5c654f37e6641fac1e2157c189d5fc.png

模型训练

开发训练逻辑

新建一个项目 train-project, 文件结构如下:

.
├── CNN_net.py
├── dataset
├── nn_train.py
└── utils_paths.py

入口文件代码如下,这里的逻辑是将准备好的样本输入给图像分类算法 SimpleVGGNet, 并设置一些训练参数,例如学习率、Epoch、Batch Size, 然后执行这段训练逻辑,最终得到一个模型文件。

# nn_train.py
from CNN_net import SimpleVGGNet
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_reportfrom keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import utils_paths
import matplotlib.pyplot as plt
from cv2 import cv2
import numpy as np
import argparse
import random
import pickleimport os# 读取数据和标签
print("------开始读取数据------")
data = []
labels = []# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images('./dataset')))
random.seed(42)
random.shuffle(imagePaths)image_size = 256
# 遍历读取数据
for imagePath in imagePaths:# 读取图像数据image = cv2.imread(imagePath)image = cv2.resize(image, (image_size, image_size))data.append(image)# 读取标签label = imagePath.split(os.path.sep)[-2]labels.append(label)data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)# 数据集切分
(trainX, testX, trainY, testY) = train_test_split(data,labels, test_size=0.25, random_state=42)# 转换标签为one-hot encoding格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)# 数据增强处理
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")# 建立卷积神经网络
model = SimpleVGGNet.build(width=256, height=256, depth=3,classes=len(lb.classes_))# 设置初始化超参数# 学习率
INIT_LR = 0.01
# Epoch  
# 这里设置 5 是为了能尽快训练完毕,可以设置高一点,比如 30
EPOCHS = 5   
# Batch Size
BS = 32# 损失函数,编译模型
print("------开始训练网络------")
opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])# 训练网络模型
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,epochs=EPOCHS
)# 测试
print("------测试网络------")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=lb.classes_))# 绘制结果曲线
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig('./output/cnn_plot.png')# 保存模型
print("------保存模型------")
model.save('./cnn.model.h5')
f = open('./cnn_lb.pickle', "wb")
f.write(pickle.dumps(lb))
f.close()

对于实际应用场景下,数据集很大,epoch 也会设置比较大,并在高性能的机器上训练。现在要在本机 Mac 上完成训练任务,我们只给了很少的样本来训练模型,epoch 也很小(为 5),当然这样模型的识别准确率也会很差,但我们此篇文章的目的是为了在本机完成一个机器学习的任务。

开始训练

执行以下命令开始训练:

$ python nn_train.py

训练过程日志如下:

a8e2e0c1e0a2832a062f2a38e445472a.png

训练结束后,在当前目录下会生成两个文件: 模型文件 cnn.model.h5 和 损失函数曲线 output/cnn_plot.png

669e2095810152ddd0a713229bd1aa1b.png

d0e4da4f127e323fa8b79b498e30e9aa.png

模型评估

现在,我们拿到了模型文件 cnn.model.h5, 可以写一个预测脚本,本地执行脚本预测一张图片的分类。

$ python predict.py
# predict.py
import allspark
import io
import numpy as np
import json
from PIL import Image
import requests
import threading
import cv2
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
import timemodel = load_model('./train/cnn.model.h5')
# pred的输入应该是一个images的数组,而且图片都已经转为numpy数组的形式
# pred = model.predict(['./validation/button/button-demoplus-20200216-16615.png'])#这个顺序一定要与label.json顺序相同,模型输出是一个数组,取最大值索引为预测值
Label = ["button","keyboard","searchbar","switch"]
testPath = "./test/button.png"images = []
image = cv2.imread(testPath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = cv2.resize(image,(256,256))
images.append(image)
images = np.asarray(images)pred = model.predict(images)print(pred)max_ = np.argmax(pred)
print('预测结果为:',Label[max_])

如果想要知道这个模型的准确率,也可以给模型输入一批带有已知分类的数据,通过模型预测后,将模型预测的分类与真实的分类比较,计算出准确率和召回率。

模型服务部署

开发模型服务

但在实际应用中,我们预测一张图片的类别, 是通过给定一张图片,请求一个 API 来拿到返回结果的。我们需要编写一个模型服务,然后部署到远端,拿到一个部署之后的模型服务 API。

现在,我们可以编写一个模型服务,然后在本地部署。

# 模型服务 app.py
import allspark
import io
import numpy as np
import json
from PIL import Image
import requests
import threading
import cv2
import tensorflow as tf
from tensorflow.keras.models import load_modelwith open('label.json') as f:mp = json.load(f)
labels = {value:key for key,value in mp.items()}def create_opencv_image_from_stringio(img_stream, cv2_img_flag=-1):img_stream.seek(0)img_array = np.asarray(bytearray(img_stream.read()), dtype=np.uint8)image_temp = cv2.imdecode(img_array, cv2_img_flag)if image_temp.shape[2] == 4:image_channel3 = cv2.cvtColor(image_temp, cv2.COLOR_BGRA2BGR)image_mask = image_temp[:,:,3] #.reshape(image_temp.shape[0],image_temp.shape[1], 1)image_mask = np.stack((image_mask, image_mask, image_mask), axis = 2)index_mask = np.where(image_mask == 0)image_channel3[index_mask[0], index_mask[1], index_mask[2]] = 255return image_channel3else:return image_tempdef get_string_io(origin_path):r = requests.get(origin_path, timeout=2)stringIo_content = io.BytesIO(r.content)return stringIo_contentdef handleReturn(pred, percent, msg_length):result = {"content":[]}argm = np.argsort(-pred, axis = 1)for i in range(msg_length):label = labels[argm[i, 0]]index = argm[i, 0]if(pred[i, index] > percent):confident = Trueelse:confident = Falseresult['content'].append({'isConfident': confident, 'label': label})return resultdef process(msg, model):msg_dict = json.loads(msg)percent = msg_dict['threshold']msg_dict = msg_dict['images']msg_length = len(msg_dict)desire_size = 256images = []for i in range(msg_length):image_temp = create_opencv_image_from_stringio(get_string_io(msg_dict[i]))image_temp = cv2.cvtColor(image_temp, cv2.COLOR_BGR2RGB)image = cv2.resize(image_temp, (256, 256))  images.append(image)images = np.asarray(images)pred = model.predict(images)return bytes(json.dumps(handleReturn(pred, percent, msg_length)) ,'utf-8')  def worker(srv, thread_id, model):while True:msg = srv.read()try:rsp = process(msg, model)srv.write(rsp)except Exception as e:srv.error(500,bytes('invalid data format', 'utf-8'))if __name__ == '__main__':desire_size = 256model = load_model('./cnn.model.h5')context = allspark.Context(4)queued = context.queued_service()workers = []for i in range(10):t = threading.Thread(target=worker, args=(queued, i, model))t.setDaemon(True)t.start()workers.append(t)for t in workers:t.join()

部署模型服务

模型服务编写完成后,在本地部署,需要安装环境。首先创建一个模型服务项目: deploy-project, 将 cnn.model.h5 拷贝到此项目中, 并在此项目下安装环境。

.
├── app.py
├── cnn.model.h5
└── label.json

安装环境

可以看下阿里云的模型服务部署文档:3、Python语言-3.2 构建开发环境-3.2.3 使用预构建的开发镜像(推荐)

安装 Docker

可以直接查看 Mac Docker 安装文档

# 用 Homebrew 安装 需要先现状 Homebrew: https://brew.sh
$ brew cask install docker

安装完之后,桌面上会出现 Docker 的图标。

创建 anaconda 的虚拟环境

# 使用conda创建python环境,目录需指定固定名字:ENV
$ conda create -p ENV python=3.7# 安装EAS python sdk
$ ENV/bin/pip install http://eas-data.oss-cn-shanghai.aliyuncs.com/sdk/allspark-0.9-py2.py3-none-any.whl# 安装其它依赖包
$ ENV/bin/pip install tensorflow keras opencv-python# 激活虚拟环境
$ conda activate ./ENV# 退出虚拟环境(不使用时)
$ conda deactivate

运行 Docker 环境

/Users/chang/Desktop/ml-test/deploy-project 换成自己的项目路径

sudo docker run -ti -v  /Users/chang/Desktop/ml-test/deploy-project:/home -p 8080:8080  
registry.cn-shanghai.aliyuncs.com/eas/eas-python-base-image:py3.6-allspark-0.8

本地部署

现在可以本地部署了,执行以下命令:

cd /home
./ENV/bin/python app.py

下面的日志可以看到部署成功。

d62a5a5badf8b8199db0623796e4be02.png

部署成功后,可以通过 localhost:8080/predict 访问模型服务了。

我们用 curl 命令来发一个 post 请求, 预测图片分类:

curl -X POST 'localhost:8080/predict' 
-H 'Content-Type: application/json' 
-d '{"images": ["https://img.alicdn.com/tfs/TB1W8K2MeH2gK0jSZJnXXaT1FXa-638-430.png"],"threshold": 0.5
}'

得到预测结果:

{"content": [{"isConfident": true, "label": "keyboard"}]}

完整代码

可以直接 clone 代码仓库:https://github.com/imgcook/ml-mac-classify

在安装好环境后,直接按以下命令运行。

# 1、训练模型
$ cd train-project
$ python nn_train.py# 生成模型文件:cnn.model.h5# 2、将模型文件拷贝到 deploy-project 中,部署模型服务
# 先安装模型服务运行环境
$ conda activate ./ENV
$ sudo docker run -ti -v  /Users/chang/Desktop/ml-test/deploy-project:/home -p 8080:8080  registry.cn-shanghai.aliyuncs.com/eas/eas-python-base-image:py3.6-allspark-0.8
$ cd /home
$ ./ENV/bin/python app.py# 得到模型服务 API: localhost:8080/predict# 3、访问模型服务
curl -X POST 'localhost:8080/predict' 
-H 'Content-Type: application/json' 
-d '{"images": ["https://img.alicdn.com/tfs/TB1W8K2MeH2gK0jSZJnXXaT1FXa-638-430.png"],"threshold": 0.5
}'

最后

好啦,总结一下这里使用深度学习的流程。我们选用了 SimpleVGGNet 作为图像分类算法(相当于一个函数),将准备好的数据传给这个函数,运行这个函数(学习数据集的特征和标签)得到一个输出,就是模型文件 model.h5。

18332d42a9f72191df8d6d2d7b5d1b1a.png

这个模型文件可以接收一张图片作为输入,并预测这张图片是什么,输出预测结果。但如果想要让模型可以在线上跑,需要写一个模型服务(API)并部署到线上以得到一个 HTTP API,我们可以在生产环境直接调用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/501351.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android studio的布局总结

首先我们的安卓的页面实际上的组件就是需要一些东西控制住他们 这些东西是什么呢,叫做布局管理器,一开始的开发中有常用的5款布局管理器 下面我来一一介绍其中的功能和一些关键字属性 1.relativelayout 相对布局管理器 属性:android:gra…

安装redis提示[test] error 2_安装PHP Redis扩展

安装PHP Redis扩展1、查看本机已经安装的Redis版本brew info redisredis: stable 5.0.7 (bottled), HEAD Persistent key-value database, with built-in net interface https://redis.io//usr/local/Cellar/redis/5.0.7 (13 files, 3.1MB) * Poured from bottle on 2020-02-19…

Android studio的监听器初学者要懂

首先,什么是监听器呢?监听器的作用是什么呢?我们如何去使用他? 1.什么是监听器:监听器的作用是什么呢? 监听器顾名思义,一旦你的按钮或者其他组件被你用例如鼠标点击,就会产生一个…

zincrby redis python_【Redis数据结构 序】使用redis-py操作Redis数据库

想要看更加舒服的排版、更加准时的推送关注公众号“不太灵光的程序员”每日八点有干货推送同时发布《【Redis数据结构 1序】1使用redis-py操作Redis数据库》本文依旧会对学习内容进行拆分,建议阅读时间基本保持10分钟内,想学习之前章节内容点击《你不了解…

Android studio的UI组件

1.文本框组件 掌管文字大小&#xff0c;文字来源&#xff0c;文字是否以行的形式显示&#xff0c;对齐方式居中 9patch图片拉伸不变形&#xff0c;需要放在drawable中 <TextViewandroid:layout_width"wrap_content"android:layout_height"wrap_content"…

visual paradigm 表示选择关系_知识获取的新挑战—远程监督关系抽取

本文主要介绍远程监督关系抽取任务上两个最新的工作。远程监督&#xff08;Distantly Supervised&#xff09;是关系抽取&#xff08;Relation Extraction&#xff09;的一种主要实现方法。关系抽取是指获得文本中的三元组&#xff08;triple&#xff09;&#xff0c;包括实体对…

Android studio的Activity详解

Activity就相当于我们的手机界面&#xff0c;里面包含着各个组件 Activity 的4种状态 运行状态&#xff1a;屏幕可视&#xff0c;且可以进行操作 暂停状态&#xff1a;返回退出的时候&#xff0c;询问是否退出运行状态&#xff0c;此时属于暂停状态 ------------------------…

arraylist 的扩容机制_每天都用ArrayList,你读过它的源码么?

作者&#xff1a;陌北有棵树&#xff0c;玩Java&#xff0c;架构师社区合伙人&#xff01;【一】关于扩容如果没有指定初始容量&#xff0c;则设置为10/** * Default initial capacity. */private static final int DEFAULT_CAPACITY 10;ArrayList的扩容比较简单&#xff0c;容…

JAVA入门级教学之(IDEA工具的快捷键和简单设置)

1.字体font file-->settings-->输入font-->设置字体样式以及字号大小 2.快速生成main方法 psvm 3.快速输出Systm.out.println(); sout 4.删除一行 ctrly 5.怎么运行 代码删右键run 或者点击右上角箭头 shiftf10(不同电脑可能不一样) 6.左侧窗口中的列表怎么展开…

java selenium_selenium 常见面试题以及答案(Java版)

1.怎么 判断元素是否存在&#xff1f;判断元素是否存在和是否出现不同&#xff0c; 判断是否存在意味着如果这个元素压根就不存在&#xff0c; 就会抛出NoSuchElementException这样就可以使用try catch&#xff0c;如果catch到NoSuchElementException 就返回false2.如何判断元素…

关于HTML的盒子的一些小问题

最近在开发的时候发现一个小问题&#xff0c;<DIV>我们很熟悉的一个盒子元素 关于他的描述 1.按照我们正常人的思维逻辑 编写好一个DIV盒子&#xff0c;然后再在盒子里面添加边框border、内边距padding、内容&#xff0c;这是我们的思维逻辑 但是DIV的编写会随着你添加…

语义网络分析图怎么做_怎么去分辨化工壶,光说可能大家还是会有疑惑,所以做了几个图...

网友们经常会拿一些壶出来&#xff0c;拍图给我看&#xff0c;问我会不会是化工壶&#xff0c;说到底&#xff0c;还是不放心自己手头上的紫砂壶&#xff0c;怕对自身健康造成影响&#xff0c;在这里&#xff0c;小编特地编辑这一段&#xff0c;教大家怎么去分辨化工壶&#xf…

CSS3特效之转化(transform)和过渡(transition)

CSS3特效之转化&#xff08;transform&#xff09;和过渡&#xff08;transition&#xff09; 在对动画深入之前&#xff0c;我们需要先了解它的一些特性&#xff0c;CSS3的转化&#xff08;transform&#xff09;和过渡&#xff08;transition&#xff09;。有人可能会有疑…

java如何保证redis设置过期时间的原子性_分布式锁用 Redis 还是 Zookeeper

在讨论这个问题之前&#xff0c;我们先来看一个业务场景&#xff1a;系统A是一个电商系统&#xff0c;目前是一台机器部署&#xff0c;系统中有一个用户下订单的接口&#xff0c;但是用户下订单之前一定要去检查一下库存&#xff0c;确保库存足够了才会给用户下单。由于系统有一…

转 安卓解决 IDEA 下 struts.xml 中 extends=“struts-default“ 报红的问题

解决 IDEA 下 struts.xml 中 extends"struts-default" 报红的问题 现象 在IDEA中配置struts.xml时 extends"struts-default" 报红&#xff0c;配置拦截器时属性无预选项提示&#xff0c;也爆红。 struts.xml本身的配置并没有错误。 解决办法 CtrlShiftAl…

系统新模块增加需要哪些步骤_人工智能之父的问题解决策略:模块化

最近主题阅读马文明斯基(Marvin Minsky) 和西摩佩珀特(Seymour Papert)两位人工智能大师&#xff0c;关于思维&#xff0c;关于教育的书籍。其中马文被称为「人工智能之父」。两人都非常重视过程模块化。复杂问题的解决需要系统性&#xff0c;也很少一次做对&#xff0c;要通过…

小白学Linux(一:开门见山)

目录 1.javaEE&#xff0c;先搭环境再敲码 2.Linux大数据 3.Python 4. Linux的学习方向 5. Linux的进阶段位 6.下面开始进入实际操作环节 第一步.安装虚拟机软件 第二步.在虚拟机里面安装一个别人开发好的Centos系统&#xff08;可以在此系统中写Linux指令&#xff0c;…

vb net 模拟 ctrl+c_8款优秀的.NET开发工具,收藏了

NET是一个重要的应用程序开发平台&#xff0c;因为它安全、稳定、易于学习和实现。今天小编给就给大家介绍8款优秀的.NET开发工具&#xff0c;有需要的小伙伴可以收藏转发哦。1、ChocolatyChocolaty是一个Windows软件包管理器&#xff0c;这个工具的重要之处在于&#xff0c;它…

卸载后以前拍的视频会删除吗_可立拍!苹果自己的视频编辑App是一个被忽视的好工具...

手机预装应用总是不如三方产品&#xff1f;看到这个问题&#xff0c;你是不是会下意识反驳&#xff1a;iPhone自带 app 就很好用啊&#xff01;的确如此&#xff0c;iPhone 的《Pages》《备忘录》&#xff0c;这些 app 的优秀表现改变了不少人「拿到新机就想卸载预装应用」的想…

解决:Linux中的CentOS 7的火狐浏览器不能访问服务器

今天安装CentOS 7的时候配置好环境&#xff0c;发现火狐不能连网 分析了一些可能是我的虚拟机网络配置没开&#xff0c;因此我总结了两个方法 1.检查虚拟机的编辑--》虚拟网络编辑器--》看看是否是NAT连接 2.搜索计算机的服务--》找到VMware DHCP Service和VMware NAT Servi…