政安晨:【Keras机器学习示例演绎】(二十六)—— 图像相似性搜索的度量学习

目录

概述

设置

数据集

嵌入模型

测试


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:在 CIFAR-10 图像上使用相似度量学习的示例。

概述


度量学习旨在训练能将输入嵌入高维空间的模型,从而使训练方案所定义的 "相似 "输入彼此靠近。这些模型一经训练,就能为下游系统生成对这种相似性有用的嵌入模型,例如作为搜索的排名信号,或作为另一种监督问题的预训练嵌入模型。

设置


将 Keras 后端设置为 tensorflow。

import osos.environ["KERAS_BACKEND"] = "tensorflow"import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from collections import defaultdict
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
import keras
from keras import layers

数据集


在本示例中,我们将使用 CIFAR-10 数据集。

from keras.datasets import cifar10(x_train, y_train), (x_test, y_test) = cifar10.load_data()x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)

为了了解数据集,我们可以将 25 个随机例子组成的网格可视化。

height_width = 32def show_collage(examples):box_size = height_width + 2num_rows, num_cols = examples.shape[:2]collage = Image.new(mode="RGB",size=(num_cols * box_size, num_rows * box_size),color=(250, 250, 250),)for row_idx in range(num_rows):for col_idx in range(num_cols):array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)collage.paste(Image.fromarray(array), (col_idx * box_size, row_idx * box_size))# Double size for visualisation.collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))return collage# Show a collage of 5x5 random images.
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
examples = x_train[sample_idxs]
show_collage(examples)

度量学习提供的训练数据并不是明确的(X,y)对,而是使用以我们想要表达的相似性方式相关的多个实例。

在我们的例子中,我们将使用同一类别的实例来表示相似性;单个训练实例将不是一幅图像,而是同一类别的一对图像。

在提及这对图像时,我们将使用常见的度量学习名称:锚图像(随机选择的图像)和正图像(随机选择的另一张同类图像)。

为此,我们需要建立一种从类到该类实例的查询形式。在生成用于训练的数据时,我们将从该查找表中采样。

class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):class_idx_to_train_idxs[y].append(y_train_idx)class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):class_idx_to_test_idxs[y].append(y_test_idx)

在本例中,我们使用的是最简单的训练方法;一个批次将由分布在各个类别中的(锚、正)对组成。

学习的目标是使锚和正对在批次中更接近、更远离其他实例。

在这种情况下,批次大小将由类的数量决定;对于 CIFAR-10,类的数量为 10。

num_classes = 10class AnchorPositivePairs(keras.utils.Sequence):def __init__(self, num_batches):super().__init__()self.num_batches = num_batchesdef __len__(self):return self.num_batchesdef __getitem__(self, _idx):x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)for class_idx in range(num_classes):examples_for_class = class_idx_to_train_idxs[class_idx]anchor_idx = random.choice(examples_for_class)positive_idx = random.choice(examples_for_class)while positive_idx == anchor_idx:positive_idx = random.choice(examples_for_class)x[0, class_idx] = x_train[anchor_idx]x[1, class_idx] = x_train[positive_idx]return x

我们可以用另一张拼贴图来直观地展示一批结果。上排显示从 10 个类别中随机选择的锚点,下排显示相应的 10 个阳性锚点。

examples = next(iter(AnchorPositivePairs(num_batches=1)))show_collage(examples)

嵌入模型


我们定义了一个带有 train_step 的自定义模型,它首先嵌入锚点和正点,然后使用它们的成对点乘作为 softmax 的对数。

class EmbeddingModel(keras.Model):def train_step(self, data):# Note: Workaround for open issue, to be removed.if isinstance(data, tuple):data = data[0]anchors, positives = data[0], data[1]with tf.GradientTape() as tape:# Run both anchors and positives through model.anchor_embeddings = self(anchors, training=True)positive_embeddings = self(positives, training=True)# Calculate cosine similarity between anchors and positives. As they have# been normalised this is just the pair wise dot products.similarities = keras.ops.einsum("ae,pe->ap", anchor_embeddings, positive_embeddings)# Since we intend to use these as logits we scale them by a temperature.# This value would normally be chosen as a hyper parameter.temperature = 0.2similarities /= temperature# We use these similarities as logits for a softmax. The labels for# this call are just the sequence [0, 1, 2, ..., num_classes] since we# want the main diagonal values, which correspond to the anchor/positive# pairs, to be high. This loss will move embeddings for the# anchor/positive pairs together and move all other pairs apart.sparse_labels = keras.ops.arange(num_classes)loss = self.compute_loss(y=sparse_labels, y_pred=similarities)# Calculate gradients and apply via optimizer.gradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Update and return metrics (specifically the one for the loss value).for metric in self.metrics:# Calling `self.compile` will by default add a [`keras.metrics.Mean`](/api/metrics/metrics_wrappers#mean-class) lossif metric.name == "loss":metric.update_state(loss)else:metric.update_state(sparse_labels, similarities)return {m.name: m.result() for m in self.metrics}

接下来,我们将介绍从图像映射到嵌入空间的结构。该模型由一系列 2d 卷积组成,然后进行全局池化,最后线性投影到嵌入空间。按照度量学习的常见方法,我们对嵌入空间进行归一化处理,以便使用简单的点积来衡量相似性。为了简单起见,我们有意缩小了模型的规模。

inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = layers.UnitNormalization()(embeddings)model = EmbeddingModel(inputs, embeddings)

最后,我们运行训练。在 Google Colab GPU 实例上,这大约需要一分钟。

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)plt.plot(history.history["loss"])
plt.show()
Epoch 1/2077/1000 ━[37m━━━━━━━━━━━━━━━━━━━  1s 2ms/step - loss: 2.2962WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700589927.295343 3724442 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.1000/1000 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step - loss: 2.2504
Epoch 2/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.1068
Epoch 3/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0646
Epoch 4/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0210
Epoch 5/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9857
Epoch 6/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9543
Epoch 7/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9175
Epoch 8/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8740
Epoch 9/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8474
Epoch 10/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8380
Epoch 11/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8146
Epoch 12/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7658
Epoch 13/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7512
Epoch 14/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7671
Epoch 15/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7245
Epoch 16/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7001
Epoch 17/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7099
Epoch 18/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6775
Epoch 19/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6547
Epoch 20/201000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6356

测试


我们可以将该模型应用于测试集,并考虑嵌入空间中的近邻,从而检验该模型的质量。

首先,我们嵌入测试集并计算所有近邻。回想一下,由于嵌入是单位长度的,我们可以通过点积计算余弦相似度。

near_neighbours_per_example = 10embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step

为了直观地检验这些嵌入,我们可以为 5 个随机例子建立一个近邻拼贴图。下图的第一列是随机选取的图像,随后的 10 列按相似度排序显示了近邻图像。

num_collage_examples = 5examples = np.empty((num_collage_examples,near_neighbours_per_example + 1,height_width,height_width,3,),dtype=np.float32,
)
for row_idx in range(num_collage_examples):examples[row_idx, 0] = x_test[row_idx]anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])for col_idx, nn_idx in enumerate(anchor_near_neighbours):examples[row_idx, col_idx + 1] = x_test[nn_idx]show_collage(examples)

我们还可以通过混淆矩阵来考虑近邻的正确性,从而对性能进行量化。

让我们从 10 个类别中各抽取 10 个例子,并将它们的近邻视为一种预测形式;也就是说,该例子及其近邻是否属于同一类别?

我们观察到,每个动物类别的表现一般都很好,与其他动物类别混淆的情况最多。车辆类别也遵循同样的模式。

confusion_matrix = np.zeros((num_classes, num_classes))# For each class.
for class_idx in range(num_classes):# Consider 10 examples.example_idxs = class_idx_to_test_idxs[class_idx][:10]for y_test_idx in example_idxs:# And count the classes of its near neighbours.for nn_idx in near_neighbours[y_test_idx][:-1]:nn_class_idx = y_test[nn_idx]confusion_matrix[class_idx, nn_class_idx] += 1# Display a confusion matrix.
labels = ["Airplane","Automobile","Bird","Cat","Deer","Dog","Frog","Horse","Ship","Truck",
]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/5926.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Pytorch深度学习——多层感知机

本文章来源于对李沐动手深度学习代码以及原理的理解,并且由于李沐老师的代码能力很强,以及视频中讲解代码的部分较少,所以这里将代码进行尽量逐行详细解释 并且由于pytorch的语法有些小伙伴可能并不熟悉,所以我们会采用逐行解释小…

Word域代码学习(简单使用)-【SEQ】

Word域代码学习(简单使用)-【SEQ】 快捷键 序号快捷键操作1 Ctrl F9 插入域代码花括号2 F9 显示域代码结果3 Shift F9 切换为域代码4 Windows Alt F9 切换全部域代码 域代码说明 域代码不区分大小写在word中,依次选择插入➡文档部件➡域即可选择插入…

Linux 学习 --- 编辑 vi 命令

1、vi 基本概念(了解) 基本上 vi 可以分为三种状态,分别是命令模式 (command mode)、插入模式 (Insert mode) 和底行模式 (last line mode),各模式的功能区分如下: 命令行模式 command mode)  控制屏幕光标的移动&a…

PotatoPie 4.0 实验教程(31) —— FPGA实现摄像头图像高斯滤波

什么是高斯滤波 高斯滤波是一种常见的图像处理技术,用于去除图像中的噪声和平滑图像。它的原理基于统计学中的高斯分布(也称为正态分布)。 在高斯滤波中,一个二维的高斯核函数被用来对图像中的每个像素进行加权平均。这个高斯核…

jvm 马士兵 01

01.JVM是什么 JVM是一个跨平台的标准 JVM只识别class文件,符合JVM规范的class文件都可以被识别

AI智能名片商城小程序:引领企业迈向第三增长极

随着数字化浪潮的席卷,私域流量的重要性逐渐凸显,为企业增长提供了全新的动力。在这一背景下,AI智能名片商城系统崭露头角,以其独特的优势,引领企业迈向第三增长极。 私域流量的兴起,为企业打开了一扇新的销…

【codeforces】Immobile Knight

Immobile Knight 我感觉自己不太适合写codeforces,简单题也比较考验思维,当时这题看了半天以为是搜索,写了20分钟暴力交了,还好对的,20个人19个人5分钟不到速通第一题,唯留我一人在第一题凌乱。下来看看这…

深度学习中的归一化:BN,LN,IN,GN的优缺点

目录 深度学习中归一化的作用常见归一化的优缺点 深度学习中归一化的作用 加速训练过程 归一化可以加速深度学习模型的训练过程。通过调整输入数据的尺度,归一化有助于改善优化算法的收敛速度。这是因为归一化后的数据具有相似的尺度,使得梯度下降等优化…

密码学基础练习五道 RSA、elgamal、elgamal数字签名、DSA数字签名、有限域(GF)上的四则运算

1.RSA #include <stdlib.h>#include <stdio.h>#include <string.h>#include <math.h>#include <time.h>#define PRIME_MAX 200 //生成素数范围#define EXPONENT_MAX 200 //生成指数e范围#define Element_Max 127 //加密单元的…

dockerfile 搭建lamp 实验模拟

一 实验目的 二 实验 环境 1, 实验环境 192.168.217.88一台机器安装docker 并做mysql nginx php 三台容器 2&#xff0c; 大致框架 3&#xff0c; php php:Nginx服务器不能处理动态页面&#xff0c;需要由 Nginx 把动态请求交给 php-fpm 进程进行解析 php有三…

LT6911UXB HDMI2.0 至四端口 MIPI DSI/CSI,带音频 龙迅方案

1. 描述LT6911UXB 是一款高性能 HDMI2.0 至 MIPI DSI/CSI 转换器&#xff0c;适用于 VR、智能手机和显示应用。HDMI2.0 输入支持高达 6Gbps 的数据速率&#xff0c;可为4k60Hz视频提供足够的带宽。此外&#xff0c;数据解密还支持 HDCP2.2。对于 MIPI DSI / CSI 输出&#xff0…

van-cascader(vant2)异步加载的bug

问题描述&#xff1a;由于一次性返回所有的级联数据的话&#xff0c;数据量太大&#xff0c;接口响应时间太久&#xff0c;因此采用了异步加载的方案&#xff0c;看了vant的官方示例代码&#xff0c;照着改了下&#xff0c;很轻松地实现了功能。正当我感叹世界如此美好的时候&a…

【C++ —— 多态】

C —— 多态 多态的概念多态的定义和实现多态的构成条件虚函数虚函数的重写虚函数重写的两个例外协变&#xff1a;析构函数的重写 C11 override和final重载、覆盖(重写)、隐藏(重定义)的对比 抽象类概念接口继承和实现继承 多态的继承虚函数表多态的原理动态绑定和静态绑定 单继…

数据库(MySQL)基础:多表查询(一)

一、多表关系 概述 项目开发中&#xff0c;在进行数据库表结构设计时&#xff0c;会根据业务需求及业务模块之间的关系&#xff0c;分析并设计表结构&#xff0c;由于业务之间相互关联&#xff0c;所以各个表结构之间也存在着各种联系&#xff0c;基本上分为三种&#xff1a;…

OceanBase开发者大会实录-陈文光:AI时代需要怎样的数据处理技术?

本文来自2024 OceanBase开发者大会&#xff0c;清华大学教授、蚂蚁技术研究院院长陈文光的演讲实录—《AI 时代的数据处理技术》。完整视频回看&#xff0c;请点击这里&#xff1e;> 大家好&#xff0c;我是清华大学、蚂蚁技术研究院陈文光&#xff0c;今天为大家带来《AI 时…

【C语言】atoi和atof函数的使用

人生应该树立目标&#xff0c;否则你的精力会白白浪费。&#x1f493;&#x1f493;&#x1f493; 目录 •&#x1f319;知识回顾 &#x1f34b;知识点一&#xff1a;atoi函数的使用和实现 • &#x1f330;1.函数介绍 • &#x1f330;2.代码演示 • &#x1f330;3.atoi函数的…

Flask框架进阶-Flask流式输出和受访配置--纯净详解版

Flask流式输出&#x1f680; 在工作的项目当中遇到了一种情况&#xff0c;当前端页面需要对某个展示信息进行批量更新&#xff0c;如果直接将全部的数据算完之后&#xff0c;再返回更新&#xff0c;则会导致&#xff0c;前端点击刷新之后等待时间过长&#xff0c;开始考虑到用进…

liceo靶机复现

liceo-hackmyvm 靶机地址&#xff1a;https://hackmyvm.eu/machines/machine.php?vmLiceo 本机环境&#xff1a;NAT模式下&#xff0c;使用VirtualBox 信息收集&#xff1a; 首先局域网内探测靶机IP 发现IP为10.0.2.4 开启nmap扫描一下看看开了什么端口 扫描期间看一下web页…

[蓝桥杯2024]-PWN:fd解析(命令符转义,标准输出重定向,利用system(‘$0‘)获取shell权限)

查看保护 查看ida 这里有一次栈溢出&#xff0c;并且题目给了我们system函数。 这里的知识点没有那么复杂 方法一&#xff08;命令转义&#xff09;&#xff1a; 完整exp&#xff1a; from pwn import* pprocess(./pwn) pop_rdi0x400933 info0x601090 system0x400778payloa…

78、贪心-跳跃游戏

思路 方法1: canJump01 - 使用递归&#xff08;回溯法&#xff09; 这个方法是通过递归实现的&#xff0c;它从数组的第一个位置开始&#xff0c;尝试所有可能的跳跃步数&#xff0c;直到达到数组的最后一个位置或遍历完所有的可能性。 思路&#xff1a; 如果数组为空或者长…