基于Keras的模型剪枝(Pruning)

目录

  • 设置
  • 在不修剪的情况下为 MNIST 训练模型
  • 评估基线测试准确性并保存模型以供以后使用
  • 预训练模型 Pruning
  • 根据 baseline 训练和评估模型
  • Create 3x smaller models
  • Create a 10x smaller model from combining pruning and quantization
  • See persistence of accuracy from TF to TFLite

设置

!pip install -q tensorflow-model-optimizationimport tempfile
import osimport tensorflow as tf
import numpy as npfrom tensorflow_model_optimization.python.core.keras.compat import keras%load_ext tensorboard

 

在不修剪的情况下为 MNIST 训练模型

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0# Define the model architecture.
model = keras.Sequential([keras.layers.InputLayer(input_shape=(28, 28)),keras.layers.Reshape(target_shape=(28, 28, 1)),keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Flatten(),keras.layers.Dense(10)
])# Train the digit classification model
model.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.fit(train_images,train_labels,epochs=4,validation_split=0.1,
)

 

评估基线测试准确性并保存模型以供以后使用

_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

 

预训练模型 Pruning

Start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.
动态剪枝:PolynomialDecay策略意味着剪枝率不是一个固定的数字,而是随着训练步骤的增加而动态调整。这种动态调整允许模型在训练初期保持更多的连接,随着模型对剪枝的适应,逐渐增加剪枝的强度,这样可以帮助模型保持一定的性能,同时实现模型大小和计算资源的优化。通过这种方式,模型在剪枝过程结束时,能达到较高的稀疏度,同时在训练的早期阶段避免了剪枝过多可能导致的信息损失。这种逐渐增加稀疏度的策略,允许网络在训练过程中逐步适应这些改变,从而可能达到更好的最终性能。

import tensorflow_model_optimization as tfmotprune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude # 移除权重中幅度最小的部分# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# end_step 用于定义PolynomialDecay策略中剪枝率达到最终稀疏度的时刻。这个计算确保了剪枝过程能够在指定的训练时长内完成。# Define model for pruning.
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0,  #模型一开始训练就剪枝end_step=end_step)
}model_for_pruning = prune_low_magnitude(model, **pruning_params) #模型定义# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model_for_pruning.summary()

 

根据 baseline 训练和评估模型

Fine tune with pruning for two epochs.

logdir = tempfile.mkdtemp()callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)#For this example, there is minimal loss in test accuracy after pruning, compared to the baseline._, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)# The logs show the progression of sparsity on a per-layer basis.#docs_infra: no_execute
%tensorboard --logdir={logdir}

 

Create 3x smaller models

上面只是用了标准的剪枝,实际应用还有两个方面可以继续压缩模型:

  • tfmot.sparsity.keras.strip_pruning 函数被用于去除模型中与修剪相关的所有临时变量(例如,修剪掩码),因为这些变量在训练之后用于推断不再需要,但会增加模型的大小。
  • 修剪操作通常会导致模型中许多权重变为零(这是通过将不重要的权重设为零来实现的)。序列化(保存到文件中)后的权重矩阵尺寸与修剪之前相同,尽管它包含了许多零值。标准压缩算法(如gzip)**可以识别这些冗余的零值,并通过仅存储非零信息来进一步压缩模型文件。
# First, create a compressible model for TensorFlow.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)# Then, create a compressible model for TFLite.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()_, pruned_tflite_file = tempfile.mkstemp('.tflite')with open(pruned_tflite_file, 'wb') as f:f.write(pruned_tflite_model)print('Saved pruned TFLite model to:', pruned_tflite_file)# Define a helper function to actually compress the models via gzip and measure the zipped size.def get_gzipped_model_size(file):# Returns size of gzipped model, in bytes.import osimport zipfile_, zipped_file = tempfile.mkstemp('.zip')with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:f.write(file)return os.path.getsize(zipped_file)print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

 

Create a 10x smaller model from combining pruning and quantization

除了剪枝,我们可以继续用量化技术(PTQ)来压缩模型。
converter.optimizations = [tf.lite.Optimize.DEFAULT] 这行API代码执行了量化操作。

You can apply post-training quantization to the pruned model for additional benefits.

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')with open(quantized_and_pruned_tflite_file, 'wb') as f:f.write(quantized_and_pruned_tflite_model)print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

 

See persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TF Lite model on the test dataset.

import numpy as npdef evaluate_model(interpreter):input_index = interpreter.get_input_details()[0]["index"]output_index = interpreter.get_output_details()[0]["index"]# Run predictions on ever y image in the "test" dataset.prediction_digits = []for i, test_image in enumerate(test_images):if i % 1000 == 0:print('Evaluated on {n} results so far.'.format(n=i))# Pre-processing: add batch dimension and convert to float32 to match with# the model's input data format.test_image = np.expand_dims(test_image, axis=0).astype(np.float32)interpreter.set_tensor(input_index, test_image)# Run inference.interpreter.invoke()# Post-processing: remove batch dimension and find the digit with highest# probability.output = interpreter.tensor(output_index)digit = np.argmax(output()[0])prediction_digits.append(digit)print('\n')# Compare prediction results with ground truth labels to calculate accuracy.prediction_digits = np.array(prediction_digits)accuracy = (prediction_digits == test_labels).mean()return accuracy# You evaluate the pruned and quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()test_accuracy = evaluate_model(interpreter)print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/750534.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ARP和DDOS攻击防御介绍

学习目标: 1. 如何利用ARP漏洞进行攻击? 2. 怎样有效地防御ARP攻击? 3. 如何应对DDOS攻击? ARP攻击如何产生? ARP如何进行有效防御? ARP基础工作原理: 交换机会根据mac地址表,进行转…

JAVA中类,方法,构造函数,对象区别

在JAVA中,类(Class)是一个蓝图或模板,用于描述如何创建对象。它定义了对象的属性和方法。 方法(Method)是类或对象中可执行的操作。它用于执行特定的任务或操作。方法定义在类中,可以被对象调用…

C++之priority_queue实现

闲话少说&#xff0c;代码起步&#xff01;&#xff01;&#xff01; #pragma oncenamespace cx {template<class T,class containervector<T>, class Compare less<T>>class priority_queue{public:void adjust_up(int child){Compare com;int parent (c…

pytorch 入门基础知识一(Pytorch 01)

一 深度学习基础相关 深度学习三个主要的方向&#xff1a;计算机视觉&#xff0c;自然语言&#xff0c;语音识别。 机器学习核心组件&#xff1a;1 数据集(data)&#xff0c;2 前向传播的model(net)&#xff0c;3 目标函数(loss)&#xff0c; 4 调整模型参数和优化函数的算法…

【STM32定时器(一)内部时钟定时与外部时钟 TIM小总结】

STM32 TIM详解 TIM介绍定时器类型基本定时器通用定时器高级定时器常用名词时序图预分频时序计数器时序图 定时器中断配置图定时器定时 代码调试代码案例1代码案例2 TIM介绍 定时器&#xff08;Timer&#xff09;是微控制器中的一个重要模块&#xff0c;用于生成定时和延时信号…

mybatis源码阅读系列(一)

源码下载 mybatis 初识mybatis MyBatis 是一个优秀的持久层框架&#xff0c;它支持定制化 SQL、存储过程以及高级映射。MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以使用简单的 XML 或注解用于配置和原始映射&#xff0c;将接口和 Java 的…

什么是软件开发?软件开发阶段划分是什么?并以LabVIEW为例进行说明

软件开发是一种创建、设计、编码、测试和维护应用程序、框架或其他软件组件的过程。它涉及从理解需求到设计、实现、测试、部署和最终维护的全过程。软件开发可以用来创建新的软件应用、系统软件、游戏、或开发网络应用等。 软件开发过程通常可以分为以下几个阶段&#xff1a;…

[蓝桥杯 2021 省 AB2] 完全平方数

题目链接 [蓝桥杯 2021 省 AB2] 完全平方数 题目描述 一个整数 a a a 是一个完全平方数&#xff0c;是指它是某一个整数的平方&#xff0c;即存在一个 整数 b b b&#xff0c;使得 a b 2 a b^2 ab2。 给定一个正整数 n n n&#xff0c;请找到最小的正整数 x x x&#…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的自动驾驶目标检测系统详解(深度学习+Python代码+PySide6界面+训练数据集)

摘要&#xff1a;开发自动驾驶目标检测系统对于提高车辆的安全性和智能化水平具有至关重要的作用。本篇博客详细介绍了如何运用深度学习构建一个自动驾驶目标检测系统&#xff0c;并提供了完整的实现代码。该系统基于强大的YOLOv8算法&#xff0c;并对比了YOLOv7、YOLOv6、YOLO…

相机与相机模型(针孔/鱼眼/全景相机)

本文旨在较为直观地介绍相机成像背后的数学模型&#xff0c;主要的章节组织如下&#xff1a; 第1章用最简单的针孔投影模型为例讲解一个三维点是如何映射到图像中的一个像素 第2章介绍除了针孔投影模型外其他一些经典投影模型&#xff0c;旨在让读者建立不同投影模型之间的建模…

RabbitMQ高级-高级特性

1.消息可靠性传递 在使用RabbitMQ的时候&#xff0c;作为消息发送方希望杜绝任何消息丢失或者投递失败场景。RabbitMQ为我们提供了两种方式来控制消息的投递可靠性模式 1.confirm 确认模式 确认模式是由exchange决定的 2.return 退回模式 回退模式是由routing…

LRC转SRT

最近看到一首很好的英文MTV原版&#xff0c;没又字幕&#xff0c;自己找字幕&#xff0c;只找到LRC&#xff0c;ffmpeg不支持LRC&#xff0c;网上在线转了SRT。 Subtitle Converter | Free tool | GoTranscript 然后用 ffmpeg 加字幕 ffmpeg -i LoveMeLikeYouDo.mp4 -vf sub…

nginx配置websocket

非加密的WebSocket连接。 #ws# 这是一个ws配置示例&#xff0c;表示使用非加密的WebSocket连接。 server { listen 8080; server_name example.com; location /websocket { proxy_pass http://backend-server; proxy_http_version 1.1; proxy_set_header Upgrade $http…

PyTorch学习笔记之基础函数篇(四)

文章目录 2.8 torch.logspace函数讲解2.9 torch.ones函数2.10 torch.rand函数2.11 torch.randn函数2.12 torch.zeros函数 2.8 torch.logspace函数讲解 torch.logspace 函数在 PyTorch 中用于生成一个在对数尺度上均匀分布的张量&#xff08;tensor&#xff09;。这意味着张量中…

【vue2源码】模版编译

文章目录 一、mount 基本流程二、执行 $mount 方法三、模版编译1、入口代码2、parse2.1 parseHTML2.2 parseText 3、generategenElement 函数 4、createCompileToFunctionFn 4、mountComponent 一、mount 基本流程 在执行 _init (new Vue时) 的方法中&#xff0c;调用了 vm.$m…

力扣热题100_矩阵_240_搜索二维矩阵 II

文章目录 题目链接解题思路解题代码 题目链接 240. 搜索二维矩阵 II 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性&#xff1a; 每行的元素从左到右升序排列。 每列的元素从上到下升序排列。 示例 1&#xff1a; 输入&#xf…

pytorch DDP模式下, 获取数据的的preftech + stream

直接上代码 - DDP forward if self.device_ids:if len(self.device_ids) 1:inputs, kwargs self.to_kwargs(inputs, kwargs, self.device_ids[0])output self.module(*inputs[0], **kwargs[0])else:inputs, kwargs self.scatter(inputs, kwargs, self.device_ids)outputs …

GAMES104-现代游戏引擎 1

主要学习重点还是面向就业&#xff0c;重点复习八股和算法 每天早上八点到九点用来学习这个课程 持续更新中... 第一节 游戏引擎导论 第二节 引擎架构分层

OLLAMA:如何像云端一样运行本地大语言模型

简介&#xff1a;揭开 OLLAMA 本地大语言模型的神秘面纱 您是否曾发现自己被云端语言模型的网络所缠绕&#xff0c;渴望获得更本地化、更具成本效益的解决方案&#xff1f;那么&#xff0c;您的探索到此结束。欢迎来到 OLLAMA 的世界&#xff0c;这个平台将彻底改变我们与大型…

橡胶工厂5G智能制造数字孪生可视化平台,推进橡胶工业数字化转型

橡胶5G智能制造工厂数字孪生可视化平台&#xff0c;推进橡胶工业数字化转型。随着信息技术的迅猛发展和智能制造的不断推进&#xff0c;数字化转型已成为制造业转型升级的重要方向。橡胶工业作为传统制造业的重要领域&#xff0c;正面临着产业升级和转型的迫切需求。橡胶5G智能…