政安晨:【Keras机器学习实践要点】(二十八)—— 使用Reptile进行小样本学习

目录

介绍

定义超参数

准备数据

可视化数据集中的一些示例

建立模型

训练模型

可视化结果


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 Reptile 对 Omniglot 数据集进行少量分类。

介绍

Reptile算法是由OpenAI开发的,用于执行与模型无关的元学习。具体而言,该算法旨在通过最小的训练量(少样本学习)迅速学习执行新任务。该算法通过使用在一个小批量的从未见过的数据上训练得到的权重与训练前的模型权重之间的差别进行随机梯度下降,经过一定数量的元迭代来工作。

import osos.environ["KERAS_BACKEND"] = "tensorflow"import keras
from keras import layersimport matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
import tensorflow_datasets as tfds

定义超参数

learning_rate = 0.003
meta_step_size = 0.25inner_batch_size = 25
eval_batch_size = 25meta_iters = 2000
eval_iters = 5
inner_iters = 4eval_interval = 1
train_shots = 20
shots = 5
classes = 5

准备数据

Omniglot 数据集由来自 50 种不同字母的 1623 个字符组成,每个字符有 20 个示例。

每个字符的 20 个样本是通过亚马逊的 Mechanical Turk 在线抽取的。

在少量学习任务中,从 n 个随机选择的类别中随机抽取 k 个样本(或 "镜头")。

这 n 个数值用于创建一组新的临时标签,以测试模型在少量示例的情况下学习新任务的能力。

换句话说,如果你在 5 个类别上进行训练,你的新类别标签将是 0、1、2、3 或 4。Omniglot 是完成这项任务的绝佳数据集,因为它有许多不同的类别可供选择,而且每个类别都有合理数量的样本。

class Dataset:# This class will facilitate the creation of a few-shot dataset# from the Omniglot dataset that can be sampled from quickly while also# allowing to create new labels at the same time.def __init__(self, training):# Download the tfrecord files containing the omniglot data and convert to a# dataset.split = "train" if training else "test"ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)# Iterate over the dataset to get each individual image and its class,# and put that data into a dictionary.self.data = {}def extraction(image, label):# This function will shrink the Omniglot images to the desired size,# scale pixel values and convert the RGB image to grayscaleimage = tf.image.convert_image_dtype(image, tf.float32)image = tf.image.rgb_to_grayscale(image)image = tf.image.resize(image, [28, 28])return image, labelfor image, label in ds.map(extraction):image = image.numpy()label = str(label.numpy())if label not in self.data:self.data[label] = []self.data[label].append(image)self.labels = list(self.data.keys())def get_mini_dataset(self, batch_size, repetitions, shots, num_classes, split=False):temp_labels = np.zeros(shape=(num_classes * shots))temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))if split:test_labels = np.zeros(shape=(num_classes))test_images = np.zeros(shape=(num_classes, 28, 28, 1))# Get a random subset of labels from the entire label set.label_subset = random.choices(self.labels, k=num_classes)for class_idx, class_obj in enumerate(label_subset):# Use enumerated index value as a temporary label for mini-batch in# few shot learning.temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx# If creating a split dataset for testing, select an extra sample from each# label to create the test dataset.if split:test_labels[class_idx] = class_idximages_to_split = random.choices(self.data[label_subset[class_idx]], k=shots + 1)test_images[class_idx] = images_to_split[-1]temp_images[class_idx * shots : (class_idx + 1) * shots] = images_to_split[:-1]else:# For each index in the randomly selected label_subset, sample the# necessary number of images.temp_images[class_idx * shots : (class_idx + 1) * shots] = random.choices(self.data[label_subset[class_idx]], k=shots)dataset = tf.data.Dataset.from_tensor_slices((temp_images.astype(np.float32), temp_labels.astype(np.int32)))dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)if split:return dataset, test_images, test_labelsreturn datasetimport urllib3urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

演绎展示:

Downloading and preparing dataset 17.95 MiB (download: 17.95 MiB, generated: Unknown size, total: 17.95 MiB) to /home/fchollet/tensorflow_datasets/omniglot/3.0.0...Dl Completed...: 0 url [00:00, ? url/s]Dl Size...: 0 MiB [00:00, ? MiB/s]Extraction completed...: 0 file [00:00, ? file/s]Generating splits...:   0%|          | 0/4 [00:00<?, ? splits/s]Generating train examples...:   0%|          | 0/19280 [00:00<?, ? examples/s]Shuffling /home/fchollet/tensorflow_datasets/omniglot/3.0.0.incomplete1MPXME/omniglot-train.tfrecord*...:   0%…Generating test examples...:   0%|          | 0/13180 [00:00<?, ? examples/s]Shuffling /home/fchollet/tensorflow_datasets/omniglot/3.0.0.incomplete1MPXME/omniglot-test.tfrecord*...:   0%|…Generating small1 examples...:   0%|          | 0/2720 [00:00<?, ? examples/s]Shuffling /home/fchollet/tensorflow_datasets/omniglot/3.0.0.incomplete1MPXME/omniglot-small1.tfrecord*...:   0…Generating small2 examples...:   0%|          | 0/3120 [00:00<?, ? examples/s]Shuffling /home/fchollet/tensorflow_datasets/omniglot/3.0.0.incomplete1MPXME/omniglot-small2.tfrecord*...:   0…Dataset omniglot downloaded and prepared to /home/fchollet/tensorflow_datasets/omniglot/3.0.0. Subsequent calls will reuse this data.

可视化数据集中的一些示例

_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))sample_keys = list(train_dataset.data.keys())for a in range(5):for b in range(5):temp_image = train_dataset.data[sample_keys[a]][b]temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)temp_image *= 255temp_image = np.clip(temp_image, 0, 255).astype("uint8")if b == 2:axarr[a, b].set_title("Class : " + sample_keys[a])axarr[a, b].imshow(temp_image, cmap="gray")axarr[a, b].xaxis.set_visible(False)axarr[a, b].yaxis.set_visible(False)
plt.show()

演绎展示:

建立模型

def conv_bn(x):x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)x = layers.BatchNormalization()(x)return layers.ReLU()(x)inputs = layers.Input(shape=(28, 28, 1))
x = conv_bn(inputs)
x = conv_bn(x)
x = conv_bn(x)
x = conv_bn(x)
x = layers.Flatten()(x)
outputs = layers.Dense(classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile()
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)

训练模型

training = []
testing = []
for meta_iter in range(meta_iters):frac_done = meta_iter / meta_iterscur_meta_step_size = (1 - frac_done) * meta_step_size# Temporarily save the weights from the model.old_vars = model.get_weights()# Get a sample from the full dataset.mini_dataset = train_dataset.get_mini_dataset(inner_batch_size, inner_iters, train_shots, classes)for images, labels in mini_dataset:with tf.GradientTape() as tape:preds = model(images)loss = keras.losses.sparse_categorical_crossentropy(labels, preds)grads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))new_vars = model.get_weights()# Perform SGD for the meta step.for var in range(len(new_vars)):new_vars[var] = old_vars[var] + ((new_vars[var] - old_vars[var]) * cur_meta_step_size)# After the meta-learning step, reload the newly-trained weights into the model.model.set_weights(new_vars)# Evaluation loopif meta_iter % eval_interval == 0:accuracies = []for dataset in (train_dataset, test_dataset):# Sample a mini dataset from the full dataset.train_set, test_images, test_labels = dataset.get_mini_dataset(eval_batch_size, eval_iters, shots, classes, split=True)old_vars = model.get_weights()# Train on the samples and get the resulting accuracies.for images, labels in train_set:with tf.GradientTape() as tape:preds = model(images)loss = keras.losses.sparse_categorical_crossentropy(labels, preds)grads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))test_preds = model.predict(test_images)test_preds = tf.argmax(test_preds).numpy()num_correct = (test_preds == test_labels).sum()# Reset the weights after getting the evaluation accuracies.model.set_weights(old_vars)accuracies.append(num_correct / classes)training.append(accuracies[0])testing.append(accuracies[1])if meta_iter % 100 == 0:print("batch %d: train=%f test=%f" % (meta_iter, accuracies[0], accuracies[1]))

演绎展示:

batch 0: train=0.600000 test=0.200000
batch 100: train=0.800000 test=0.200000
batch 200: train=1.000000 test=1.000000
batch 300: train=1.000000 test=0.800000
batch 400: train=1.000000 test=0.600000
batch 500: train=1.000000 test=1.000000
batch 600: train=1.000000 test=0.600000
batch 700: train=1.000000 test=1.000000
batch 800: train=1.000000 test=0.800000
batch 900: train=0.800000 test=0.600000
batch 1000: train=1.000000 test=0.600000
batch 1100: train=1.000000 test=1.000000
batch 1200: train=1.000000 test=1.000000
batch 1300: train=0.600000 test=1.000000
batch 1400: train=1.000000 test=0.600000
batch 1500: train=1.000000 test=1.000000
batch 1600: train=0.800000 test=1.000000
batch 1700: train=0.800000 test=1.000000
batch 1800: train=0.800000 test=1.000000
batch 1900: train=1.000000 test=1.000000

可视化结果

# First, some preprocessing to smooth the training and testing arrays for display.
window_length = 100
train_s = np.r_[training[window_length - 1 : 0 : -1],training,training[-1:-window_length:-1],
]
test_s = np.r_[testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
]
w = np.hamming(window_length)
train_y = np.convolve(w / w.sum(), train_s, mode="valid")
test_y = np.convolve(w / w.sum(), test_s, mode="valid")# Display the training accuracies.
x = np.arange(0, len(test_y), 1)
plt.plot(x, test_y, x, train_y)
plt.legend(["test", "train"])
plt.grid()train_set, test_images, test_labels = dataset.get_mini_dataset(eval_batch_size, eval_iters, shots, classes, split=True
)
for images, labels in train_set:with tf.GradientTape() as tape:preds = model(images)loss = keras.losses.sparse_categorical_crossentropy(labels, preds)grads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))
test_preds = model.predict(test_images)
test_preds = tf.argmax(test_preds).numpy()_, axarr = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))sample_keys = list(train_dataset.data.keys())for i, ax in zip(range(5), axarr):temp_image = np.stack((test_images[i, :, :, 0],) * 3, axis=2)temp_image *= 255temp_image = np.clip(temp_image, 0, 255).astype("uint8")ax.set_title("Label : {}, Prediction : {}".format(int(test_labels[i]), test_preds[i]))ax.imshow(temp_image, cmap="gray")ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)
plt.show()


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/812151.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库练习

insert into employee(empno,ename,job,mgr,hiredate,sal,comm,deptno) values (1001,甘宁,文员,1013,2000-12-17,8000.00,NULL,20), (1002,黛绮丝,销售员,1006,2001-02-20,16000.00,3000.00,30), (1003,殷天正,销售员,1006,2001-02-22,12500.00,5000.00,30), (1004,刘备,经理,…

kimichat使用技巧:用语音对话聊天

kimichat之前是只能用文字聊天的&#xff0c;不过最近推出了语音新功能&#xff0c;也可以用语音畅快的对话聊天了。 这个功能目前支持手机app版本&#xff0c;所以首先要在手机上下载安装kimi智能助手。已经安装的&#xff0c;要点击检查更新&#xff0c;更新到最新的版本。 …

Ubuntu 20.04 设置开启 root 远程登录连接

Ubuntu默认不设置 root 帐户和密码 Ubuntu默认不设置 root 帐户和密码 Ubuntu默认不设置 root 帐户和密码 如有需要&#xff0c;可在设置中开启允许 root 用户登录。具体操作步骤如下&#xff1a; 操作步骤 1、首先使用普通用户登录 2、设置root密码 macw:~$ sudo passwd …

Llama2模型本地部署(Mac M1 16G)

环境准备 环境&#xff1a;Mac M1 16G、Conda Conda创建环境配置 使用Anaconda-Navigator创建python 3.8环境 切换到新建的conda环境&#xff1a; conda activate llama38 llama.cpp 找一个目录&#xff0c;下载llama.cpp git clone https://github.com/ggerganov/llama.…

读所罗门的密码笔记18_大宪章

1. 大宪章 1.1. 1215年会议开启了一个艰难的谈判过程&#xff0c;充满了紧张和对权力与道德权威的争夺 1.1.1. 这部宪章会赋予各方一系列的权力&#xff0c;对国王的自由裁量权进行制衡 1.2. 《大宪章》还需要300多年的时间和多次迭代&#xff0c;才能成为财产权、公平税收、…

STM32 DCMI 的带宽与性能介绍

1. 引言 随着市场对更高图像质量的需求不断增加&#xff0c;成像技术持续发展&#xff0c;各种新兴技术&#xff08;例如3D、计算、运动和红外线&#xff09;的不断涌现。如今的成像应用对高质量、易用性、能耗效率、高集成度、快速上市和成本效益提出了全面要求。为了满足这些…

【算法一则】做算法学数据结构 - 简化路径 - 【栈】

目录 题目栈代码题解 题目 给你一个字符串 path &#xff0c;表示指向某一文件或目录的 Unix 风格 绝对路径 &#xff08;以 ‘/’ 开头&#xff09;&#xff0c;请你将其转化为更加简洁的规范路径。 在 Unix 风格的文件系统中&#xff0c;一个点&#xff08;.&#xff09;表…

Cesium 无人机航线规划

鉴于大疆司空平台和大疆无人机app高度绑定&#xff0c;导致很多东西没办法定制化。 从去年的时候就打算仿大疆开发一套完整的平台&#xff0c;包括无人机app以及仿司空2的管理平台&#xff0c;集航线规划、任务派发、实时图像、无人机管理等功能的平台。 当前阶段主要实现了&…

突破编程_前端_SVG(circle 圆形)

1 circle 元素的基本属性和用法 SVG 的 <circle> 元素用于在SVG文档中绘制圆形。它具有几个基本属性&#xff0c;允许定义圆形的大小、位置、填充颜色和边框样式。以下是 <circle> 元素的基本属性及其详细解释&#xff1a; 1.1 cx 和 cy 描述&#xff1a;这两个…

记录一次Java中使用P12证书访问https,nginx返回403的问题

目录 1、先使用浏览器导入证书访问&#xff0c;测试证书和密钥是否正确2、编写初始java代码3、结果响应 403 Forbidden4、解决方案 1、先使用浏览器导入证书访问&#xff0c;测试证书和密钥是否正确 成功返回&#xff0c;说明p12证书和密钥是没问题的。 2、编写初始java代码 …

Harmony鸿蒙南向外设驱动开发-Codec

功能简介 OpenHarmony Codec HDI&#xff08;Hardware Device Interface&#xff09;驱动框架基于OpenMax实现了视频硬件编解码驱动&#xff0c;提供Codec基础能力接口给上层媒体服务调用&#xff0c;包括获取组件编解码能力、创建组件、参数设置、数据的轮转和控制、以及销毁…

oracle创建整个数据库的只读账户

在源用户readonly 下创建只读用户 reader readonly 的表空间为AA 一、创建只读用户 create user reader identified by 密码 default tablespace AA; 二、授权 grant connect to reader ; 三、获取原账号readonly 的查询权限 select grant select on ||owner||.||object…

【面试题】redis在工作中的使用场景有哪些?

前言&#xff1a;在实际工作中&#xff0c;Redis作为一种高性能的内存数据库和缓存系统&#xff0c;可以应用于多种场景&#xff0c;同时在面试过程中也经常被问到类似的问题&#xff0c;我们经常会被问的一脸懵逼&#xff0c;那今天我们就来总结一下redis的一些使用场景。 数据…

实战解析:SpringBoot AOP与Redis结合实现延时双删功能

目录 一、业务场景 1、此时存在的问题 2、解决方案 3、为何要延时500毫秒&#xff1f; 4、为何要两次删除缓存&#xff1f; 二、代码实践 1、引入Redis和SpringBoot AOP依赖 2、编写自定义aop注解和切面 3、application.yml 4、user.sql脚本 5、UserController 6、U…

基于ssm微信小程序的医院挂号预约系统

采用技术 基于ssm微信小程序的医院挂号预约系统的设计与实现~ 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringMVCMyBatis 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 页面展示效果 用户管理 医院管理 医生管理 公告资讯管理 科室信息管…

IMU状态预积分的雅克比矩阵

IMU状态预积分的雅克比矩阵 预积分的雅克比矩阵 预积分的雅克比矩阵 最后讨论预积分相对状态变量的雅克比矩阵。由于预积分测量已经归纳了IMU在短时间内的读数&#xff0c;因此残差相对于状态变量的雅克比矩阵推导则简单。 首先考虑旋转。 旋转与Ri,Rj和 b g , i b_{g,i} bg,i…

【拓展技术】——AutoDL服务器训练Pycharm使用注意点Pycharm配置AutoDL

一、AutoDL服务器模型训练 AutoDL是一个为研究人员、开发者和企业提供的平台&#xff0c;它致力于提供一个高效、可靠和易用的环境&#xff0c;以支持复杂的计算任务和AI模型的部署&#xff1a; 高效的并行计算资源&#xff1a;AutoDL拥有强大的计算集群和高性能的计算节点&a…

【QT入门】Qt自定义控件与样式设计之控件提升与自定义控件

【QT入门】Qt自定义控件与样式设计之控件提升与自定义控件 往期回顾 【QT入门】Qt自定义控件与样式设计之QProgressBar用法及qss-CSDN博客 【QT入门】 Qt自定义控件与样式设计之QSlider用法及qss-CSDN博客 【QT入门】Qt自定义控件与样式设计之qss的加载方式-CSDN博客 一、最终…

C++ 类和对象 上

目录 前言 什么是面向对象&#xff1f;什么是面向过程&#xff1f; 面向过程 面向对象 比较 类 引入 定义 实例化 类的大小 this指针 前言 今天我们来进入C类和对象的学习。相信大家一定听说过C语言是面向过程的语言&#xff0c;而C是面向对象的语言&#xff1f;那么他…

启明智显M系列--工业级HMI芯片选型表

本章主要介绍启明智显M系列HMI主控芯片&#xff1a; 纯国产自主&#xff0c; RISC-V 内核&#xff0c;配备强大的 2D 图形加速处理器、PNG/JPEG 解码引擎、H.264解码&#xff1b;工业宽温&#xff0c;提供全开源SDK&#xff1b;1秒快速开机启动的特性&#xff0c;极大地提高了…