ARJ_DenseNet BMR模型训练

废话不多数,模型训练代码

densenet_arj_BMR.py

import timefrom tensorflow.keras.applications.xception import Xception
from tensorflow.keras.applications.densenet import DenseNet169
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras as keras
from arj_t.plt_graph import show_graph
from common_para import train_dir, val_dir, station, EPOCHS_1,EPOCHS_2, batch_size, CLASS_WEIGHT, classesinput_shape = (224, 224)
date_ = time.strftime('%Y%m%d', time.localtime())
cpkt_path = f'./ckpt/ARJ_Densenet_ckpt{station}_20231017-1.h5'
model_path = f'./ckpt/ARJ_Densenet_MODEL{station}_{date_}.h5'class ArjDensenetModel(object):def __init__(self):self.base_model = DenseNet169(weights='imagenet', include_top=False)# 泛化能力不行,进行图像增强测试self.train_gen = ImageDataGenerator(rescale=1.0 / 255.0,# rotation_range=45,# width_shift_range=0.2,# height_shift_range=0.2,# brightness_range=(0, 0.3),# shear_range=0.2, # 浮点数。剪切强度(以弧度逆时针方向剪切角度)# zoom_range=[0.5, 1.5],  # 小于1.0的缩放将放大图像,大于1.0的缩放将缩小图像。# horizontal_flip=True,# vertical_flip=True,# fill_mode='constant',# cval=0)# self.train_gen = ImageDataGenerator(rescale=1.0 / 255.0)self.val_gen = ImageDataGenerator(rescale=1.0 / 255.0)# 获取本地训练和验证图片,生成generatordef get_local_data(self):self.train_gen = self.train_gen.flow_from_directory(directory=train_dir,target_size=input_shape,batch_size=batch_size,class_mode='binary',  # binary 改为 categoricalshuffle=True,# save_to_dir=r'D:\AOI Gray Image-OA\dataset\BMR\train_trans2',# save_format='jpg',# save_prefix='trans_')self.val_gen = self.val_gen.flow_from_directory(directory=val_dir,target_size=input_shape,batch_size=batch_size,class_mode='binary',  # binary 改为 categorical 2022/5/15shuffle=True)return Nonedef refine_basemode(self):"""获取VGG16 basemode只获取全连接层以前的卷积和池化层,并进行参数冻结,也就是使用原有训练好的参数自主增加隐藏层和全连接层进行训练,获得目标模型:return:"""# 获取除全连接层以外的层数,no-top modelx = self.base_model.outputs[0]# 加入全局池化、隐藏层、全连接层x = keras.layers.GlobalAveragePooling2D()(x)x = keras.layers.Dense(2048, activation='relu')(x)# x = keras.layers.BatchNormalization()(x)x = keras.layers.Dense(1024, activation='relu')(x)out = keras.layers.Dense(2, activation='softmax')(x)# 生成新的模型new_model = keras.models.Model(inputs=self.base_model.inputs, outputs=out)# 冻结vgg模型原有参数self.freeze_base_model()# 对new_model进行编译# 学习效果不佳,初始学习率加大尝试# 初始学习率0.01->0.001opt = keras.optimizers.Adam(learning_rate=0.001)new_model.compile(# optimizer=opt,  # 优化器# # 因为class_mode使用了categorical, 此时返回one-hot编码标签# # 那么这里就需要使用categorical_crossentropy,多类对数交叉熵损失计算# # 如果class_mode使用binary, 此时返回1D的二值标签,loss就需要使用sparse_categorical_crossentropy# loss='sparse_categorical_crossentropy',  # 使用交叉熵损失函数 分类# metrics=['accuracy']# binary_crossentropy与sigmoid联合使用二分类# categorical_crossentropy与softmax联合使用optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])return new_model# 冻结模型训练层数def freeze_base_model(self):for layer in self.base_model.layers:layer.trainable = Falsereturn None# # 对new_model进行trainingdef fit(self, model):# 获取本地数据self.get_local_data()# 定义checkpointckpt = keras.callbacks.ModelCheckpoint(filepath=cpkt_path,monitor='val_accuracy',save_freq='epoch',save_weights_only=True,save_best_only=True)# 早停法用起来el1 = keras.callbacks.EarlyStopping(monitor='val_accuracy',patience=15,verbose=2,mode='auto')# 定义学习率缩小规则rc1 = keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy',factor=0.1,  # 学习率缩小倍数 new_lr = lr*factorpatience=5,  # 耐心吗,5次迭代不增加就缩小学习率mode='auto',verbose=1,  # 1代表更新信息,0代表不更新# epsilon=0.0001,  # 确认是否进入平原区min_lr=0,cooldown=0)# 模型训练# 加入class_weight权重# 暂时注释。his1 = model.fit(self.train_gen, validation_data=self.val_gen,epochs=EPOCHS_1, callbacks=[ckpt, rc1, el1])# his1 = model.fit(self.train_gen, validation_data=self.val_gen,#                  epochs=EPOCHS_1, callbacks=[ckpt, rc1, el1], class_weight=CLASS_WEIGHT)print('first step end')# 解冻所有layer,进行参数微调for layer in model.layers:layer.trainable = True# 早停法用起来el2 = keras.callbacks.EarlyStopping(monitor='val_accuracy',patience=11,verbose=2,mode='auto')# 定义学习率缩小规则rc2 = keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy',factor=0.1,  # 学习率缩小倍数 new_lr = lr*factorpatience=5,  # 耐心吗,5次迭代不增加就缩小学习率mode='auto',verbose=1,  # 1代表更新信息,0代表不更新# epsilon=0.0001,  # 确认是否进入平原区min_lr=0,cooldown=0)opt = keras.optimizers.Adam(learning_rate=0.001)model.compile(optimizer=opt,loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 模型训练# model.load_weights(cpkt_path)his2 = model.fit(self.train_gen, validation_data=self.val_gen,epochs=EPOCHS_2, callbacks=[ckpt, rc2, el2], class_weight={0: 1, 1: 1.5})# # 模型训练# his2 = model.fit(self.train_gen, validation_data=self.val_gen,#                  epochs=EPOCHS_2, callbacks=[ckpt, rc2, el2], class_weight={0: 1, 1: 2, 2: 3})print('END STEP')return his1, his2if __name__ == '__main__':arj_model = ArjDensenetModel()model = arj_model.refine_basemode()his1, his2 = arj_model.fit(model)# # 保存模型# model.save(model_path)show_graph(his1)show_graph(his2)

common_para.py代码

train_dir = r"D:\new_data\BMR_TRAIN\train"
val_dir = r"D:\new_data\BMR_TRAIN\validate"
station = '_ALL_BMR'
batch_size = 32
EPOCHS_1 = 10
EPOCHS_2 = 40
CLASS_WEIGHT = {0: 1., 1: 1., 2: 1.}
threshold_value = 0
classes = 2

模型预测代码 

BMR_IPS_135K_predict.py
import osimport numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_arrayimport densenet_arj_BMR
import inceptionRestnet_arj_t
import resnet101_arj_BMR
import xception_arj_BMRMODEL_NAME = 'densenet'val_path = r'D:\AOI Gray Image-OA\dataset\case1\135K-ISR-IPS\validate'
# val_path = r'D:\AOI Gray Image-OA\dataset\error\W_to_G'
other_path = r'D:\AOI Gray Image-OA\AOI IMAGE-20220513\A6Q\ISR\复判后-G'
test_path = r'D:\new_data\BMR表外测试\BMR\A1A\P'
# TARGET_SIZE = (299, 299)
# DEFECT_TYPE = 'P'
# error_path = fr"D:\AOI Gray Image-OA\dataset\BMR\{MODEL_NAME}"def get_ckptpath_model():# arj_model = tensorflow.keras.models.Model()ckpt_path = ''# target_size = (224, 224)if MODEL_NAME == 'xception':ckpt_path = xception_arj_BMR.cpkt_patharj_model = xception_arj_BMR.ArjResnet101Model()target_size = xception_arj_BMR.input_shapeelif MODEL_NAME == 'inceptionRestnet':ckpt_path = inceptionRestnet_arj_t.cpkt_patharj_model = inceptionRestnet_arj_t.ArjInceptionRestnetModel()target_size = inceptionRestnet_arj_t.input_shapeelif MODEL_NAME == 'densenet':ckpt_path = densenet_arj_BMR.cpkt_patharj_model = densenet_arj_BMR.ArjDensenetModel()target_size = densenet_arj_BMR.input_shapeelif MODEL_NAME == 'resnet101':ckpt_path = resnet101_arj_BMR.cpkt_patharj_model = resnet101_arj_BMR.ArjResnet101Model()target_size = resnet101_arj_BMR.input_shapereturn ckpt_path, arj_model, target_size# 获取想要预测的图片绝对路径,包含文件名
def get_img_paths(defect_type, path):img_path = os.path.join(path, defect_type)img_paths = []for root, dirs, files in os.walk(img_path):for file in files:# print(file[-3:])if file[-3:] == 'jpg':img_paths.append(os.path.join(root, file))return img_pathsdef bmr_ips_predict(img_paths, error_path, defect_type='G'):ckpt_path, arj_model, input_shape = get_ckptpath_model()model = arj_model.refine_basemode()print(ckpt_path)model.load_weights(ckpt_path)print(model.summary())predict_dict = {0: 'G', 1: 'P', 2: 'W'}# 加载图片,预测white_cnt = 0good_cnt = 0repair_cnt = 0threshold_ls = []for img_path in img_paths:img_arr = load_img(img_path, target_size=input_shape)img = img_arr# print(img_path)# 转化为矩阵img_arr = img_to_array(img_arr)# print(img.shape)# 归一化# img_arr = preprocess_input(img_arr)img_arr /= 255.# print(type(img_arr))# img_arr = preprocess_input(img_arr)# img_arr /= 127.5# img_arr -= 1.# 形状修改img_arr = img_arr.reshape(1, img_arr.shape[0], img_arr.shape[1], img_arr.shape[2])# print(img.shape)# print(img_arr)y_predict = model.predict(img_arr)index = np.argmax(y_predict)# 加入阈值threshold = y_predict[0][index]# print(img_path.split('\\')[-1])# print(y_predict[0], ' >> ', threshold)# threshold_ls.append(threshold)# print(y_predict)y_predict = predict_dict[index]# print(index)# print(y_predict)# if index == 0:#     good_cnt += 1# else:#     repair_cnt += 1# 保存判错的图片# 预测结果G# save_img_name = str(round(threshold,2))+'_'+img_path.split('\\')[-1]save_img_name = img_path.split('\\')[-1]if index == 0:# 加入阈值调节判G能力if threshold > 0:good_cnt += 1# print(good_cnt)# print(img_path[-10:])# 如果原本P文件夹if defect_type == 'P':threshold_ls.append(threshold)img.save(os.path.join(error_path, 'AI_P_TO_G', save_img_name))os.remove(img_path)# 如果原本W文件夹if defect_type == 'W':img.save(os.path.join(error_path, 'AI_W_TO_G', save_img_name))os.remove(img_path)else:repair_cnt += 1elif index == 1:repair_cnt += 1if defect_type == 'G':threshold_ls.append(threshold)img.save(os.path.join(error_path, 'AI_G_TO_P', save_img_name))os.remove(img_path)if defect_type == 'W':img.save(os.path.join(error_path, 'AI_W_TO_P', save_img_name))os.remove(img_path)elif index == 2:white_cnt += 1if defect_type == 'G':img.save(os.path.join(error_path, 'AI_G_TO_W', save_img_name))os.remove(img_path)if defect_type == 'P':img.save(os.path.join(error_path, 'AI_P_TO_W', save_img_name))os.remove(img_path)else:print('还有第四种可能??!!')# print(y_predict)# print('**************************')# pd.DataFrame(data=threshold_ls).to_csv('./threshold.csv', encoding='utf-8')print(threshold_ls)print('good_cnt :  %d' % good_cnt)print('repair_cnt :  %d' % repair_cnt)# print('white_cnt :  %d' % white_cnt)# if __name__ == '__main__':
#     paths = get_img_paths(DEFECT_TYPE, test_path)
#     bmr_ips_predict(paths,error_path)

模型总预测代码

predict_all.py

import BMR_IPS_135K_predict# 此程序用来进行所有模型预测2023/10/17img_path = r'D:\new_data\BMR_TRAIN\test\WHITE'DEFECT_TYPE = 'P'paths = BMR_IPS_135K_predict.get_img_paths(DEFECT_TYPE, img_path)BMR_IPS_135K_predict.bmr_ips_predict(paths, img_path, DEFECT_TYPE)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/110634.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑桌面记事本便签软件哪个好?

很多人的电脑或者手机上都离不开一款好用的便签软件,使用便签软件可以帮助大家记事,提醒大家按时完成各项任务,但是自带的记事本便签软件不论从外观还是功能方面都有一定的欠缺,在使用过程中很容易耽误事情。 功能全面外观好看的…

HTTP 协议的基本格式(部分)

要想了解HTTP,得先知道什么是HTTP,那么HTTP是什么呢?HTTP (全称为 "超文本传输协议") 是一种应用非常广泛的 应用层协议。那什么是超文本呢?那就是除了文本,还有图片,声音,视频等。 …

『C++成长记』C++入门——命名空间缺省参数

🔥博客主页:小王又困了 📚系列专栏:C 🌟人之为学,不日近则日退 ❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、C的认识 📒1.1什么是C 📒1.2C的发展 二、C关键字 三…

论文阅读之《Kindling the Darkness: A Practical Low-light Image Enhancer》

目录 摘要 介绍 已有方法回顾 普通方法 基于亮度的方法 基于深度学习的方法 基于图像去噪的方法 提出的方法 2.1 Layer Decomposition Net 2.2 Reflectance Restoration Net 2.3 Illumination Adjustment Net 实验结果 总结 Kindling the Darkness: A Practical L…

轮转数组------题解报告

题目:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 题解: 如果直接暴力双循环会时间超限,所以我选择了一个空间复杂度比较高的方法。直接再创建一个数组,然后对应位置替换,最后把值赋给原…

Spring framework Day24:定时任务

前言 在我们的日常生活和工作中,时间管理是一项至关重要的技能。随着各种复杂任务的增加和时间压力的不断增加,如何更好地分配和利用时间成为了一项迫切需要解决的问题。在这样的背景下,定时任务成为了一种非常有效的解决方案。 定时任务&a…

Kubernetes技术与架构-服务

从软件系统架构设计分层的角度看,Kubernetes的Service是基于Pod的上层,业务应用部署在Pod中,使用Service绑定Pod部署的应用,Service可以对外或者对上层提供服务,当Pod集群在系统调度的过程中发生弹性伸缩的时候&#x…

Python中的With ...as... 作用

Python中的with … as …作用: 1、通过with语句可以得到一个上下文管理器 2、执行对象 3、加载__enter__方法 4、加载__exit__方法 5、执行__enter__方法 6、as 可以得到enter的返回值 7、拿到对象执行相关操作 8、执行完了之后调用__exit__方法 9、如果遇到异常&a…

1024程序员节特辑 | ELK+ 用户画像构建个性化推荐引擎,智能实现“千人千面”

专栏集锦,赶紧收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https://blog.…

易点易动上线招标管理模块:提升企业高效招标管理的解决方案

在当今竞争激烈的商业环境下,招标管理对于企业的成功至关重要。为了帮助企业实现高效的招标管理,易点易动固定资产管理系统上线了全新的招标管理模块。该模块涵盖了供应商资质审核、采购询价单、重新报价单、招标结果单、招标作废单等功能,为…

python爬虫采集企查查数据

企查查,一个查询企业信息的网站,这个网站也是网络爬虫选择采集的对象,这个网站反爬提别厉害,没有一定的爬虫技术,是无法采集成功的。 网络爬虫从企查查采集企业信息,如果想要看到完成的企业信息就需要登录后…

Python 爬虫实战之爬拼多多商品并做数据分析

Python爬虫可以用来抓取拼多多商品数据,并对这些数据进行数据分析。以下是一个简单的示例,演示如何使用Python爬取拼多多商品数据并进行数据分析。 首先,需要使用Python的requests库和BeautifulSoup库来抓取拼多多商品页面。以下是一个简单的…

10月份stable diffusion animatediff等插件使用指南,又来更新了

插件一直会更新,包含了基本市面上流行的90%插件,好用的插件更是不会错过,往期插件请看往期文章,如果你没有时间一直关注sd更新的进展,请关注我,一个月用几个小时看一下我的文章,最短时间跟进sd。…

【微服务】spring webflux响应式编程使用详解

目录 一、webflux介绍 1.1 什么是webflux 1.2 什么是响应式编程 1.3 webflux特点 二、Java9中响应式编程 2.1 定义事件流源 2.2 实现订阅者 三、Spring Webflux介绍 四、Reactor 介绍 五、Reactor 常用API操作 5.1 Flux 创建流操作API 5.2 Flux响应流的订阅 5.3 Fl…

Mybatis对数据库进行增删查改以及单元测试

这篇写的草率了,是好几天前学到,以后用来自己复习 UserInfo import lombok.Data;Data public class UserInfo {private int id;private String name;private int age;private String email;//LocalDateTime可用于接收 时间}Mapper UserMapper pack…

软考 系统架构设计师系列知识点之软件构件(1)

所属章节: 第2章. 计算机系统基础知识 第3节. 计算机软件 2.3.7 软件构件 1. 概述 构件又称为组件,是一个自包容、可复用的程序集。构建是一个程序集、或者说是一组程序的集合。这个集合可能会以各种方式体现出来,如源程序或二进制代码。这…

2023年中国多功能折叠刀产量、销量及市场规模分析[图]

多功能折叠刀是一种集多种功能于一身的刀具,通常包括切割、开瓶、剥皮、锯木等功能,可以通过折叠和展开的方式来实现不同的功能,具有便携、多用途、安全等特点,广泛应用于户外探险、露营、自驾旅行等场景。 多功能折叠刀行业分类…

Simian使用方法

1.下载 链接1:官网下载 链接2:压缩包 2.操作 1.双击exe启动 2.打开控制台,winR 输入cmd 3.输入操作语句 G:\1111\simian-2.5.10\bin\simian-2.5.10.exe -includes"G:\1111\test\*.cpp" -threshold3 > output.txt G:\1111\si…

利用TypeScript 和 jsdom 库实现自动化抓取数据

以下是一个使用 TypeScript 和 jsdom 库的下载器程序,用于下载zhihu的内容。此程序使用了 duoip.cn/get_proxy 这段代码。 import { JSDOM } from jsdom; import { getProxy } from https://www.duoip.cn/get_proxy;const zhihuUrl https://www.zhihu.com;(async (…

璞华科技再次赋能,助力成都市温江区“码”上维权不烦“薪” !

科技赋能护“薪”行动 “码”上维权不烦“薪” 为保障劳动者工资收入的合法权益,提升人社部门智能化咨询服务能力,2023年10月17日,成都市温江区人力资源和社会保障局发布“码上护薪”小程序,助力劳动者“码”上维权不烦”薪”。…