机器学习-基于attention机制来实现对Image Caption图像描述实验

机器学习-基于attention机制来实现对Image Caption图像描述实验

实验目的

基于attention机制来实现对Image Caption图像描述

实验内容

1.了解一下RNN的Encoder-Decoder结构

在最原始的RNN结构中,输入序列和输出序列必须是严格等长的。但在机器翻译等任务中,源语言句子的长度和目标语言句子的长度往往不同,因此我们需要将原始序列映射为一个不同长度的序列。Encoder-Decoder模型就解决了这样一个长度不一致的映射问题。

1

2.模型架构训练

在Image Caption输入的图像代替了之前机器翻译中的输入的单词序列,图像是一系列的像素值,我们需要从使用图像特征提取常用的CNN从图像中提取出相应的视觉特征,然后使用Decoder将该特征解码成输出序列,下图是论文的网络结构,特征提取采用的是CNN,Decoder部分,将RNN换成了性能更好的LSTM,输入还是word embedding,每步的输出是单词表中所有单词的概率。

实验数据和程序清单

import json# 加载数据集标注
with open("annotations/captions_train2014.json", "r") as f:annotations = json.load(f)# 提取图像文件名和描述
image_path_to_caption = {}
for val in annotations["annotations"]:caption = f"<start> {val['caption']} <end>"image_path = "train2014/" + "COCO_train2014_" + "%012d.jpg" % (val["image_id"])if image_path in image_path_to_caption:image_path_to_caption[image_path].append(caption)else:image_path_to_caption[image_path] = [caption]image_paths = list(image_path_to_caption.keys())
归一化处理
import tensorflow as tfdef load_image(image_path):img = tf.io.read_file(image_path)img = tf.image.decode_jpeg(img, channels=3)img = tf.image.resize(img, (299, 299))img = tf.keras.applications.inception_v3.preprocess_input(img)
return img, image_path
模型构建
from tensorflow.keras.applications import InceptionV3encoder = InceptionV3(weights="imagenet", include_top=False)
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from tensorflow.keras.models import Modelembedding_dim = 256
vocab_size = 10000  # 您可以根据需要调整词汇表大小
max_length = 40  # 您可以根据需要调整最大描述长度# 解码器输入
input_caption = Input(shape=(max_length,))
embedding = Embedding(vocab_size, embedding_dim)(input_caption)
lstm_output = LSTM(256)(embedding)
output_caption = Dense(vocab_size, activation="softmax")(lstm_output)# 定义解码器模型
decoder = Model(inputs=input_caption, outputs=output_caption)
模型训练
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")def loss_function(real, pred):mask = tf.math.logical_not(tf.math.equal(real, 0))loss_ = loss_object(real, pred)mask = tf.cast(mask, dtype=loss_.dtype)loss_ *= maskreturn tf.reduce_mean(loss_)
@tf.function
def train_step(img_tensor, target):loss = 0hidden = decoder.reset_state(batch_size=target.shape[0])dec_input = tf.expand_dims([tokenizer.word_index["<start>"]] * target.shape[0], 1)with tf.GradientTape() as tape:features = encoder(img_tensor)for i in range(1, target.shape[1]):predictions = decoder([features, hidden, dec_input])loss += loss_function(target[:, i], predictions)dec_input = tf.expand_dims(target[:, i], 1)total_loss = loss / int(target.shape[1])trainable_variables = encoder.trainable_variables + decoder.trainable_variablesgradients = tape.gradient(loss, trainable_variables)optimizer.apply_gradients(zip(gradients, trainable_variables))return loss, total_loss
import timeepochs = 10
batch_size = 64
buffer_size = 1000dataset = tf.data.Dataset.from_tensor_slices((image_paths, captions))
dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)for epoch in range(epochs):start = time.time()total_loss = 0for (batch, (img_tensor, target)) in enumerate(dataset):batch_loss, t_loss = train_step(img_tensor, target)total_loss += t_lossif batch % 100 == 0:print(f"Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy() / int(target.shape[1]):.4f}")print(f"Epoch {epoch+1} Loss {total_loss/len(image_paths):.6f}")print(f"Time taken for 1 epoch: {time.time() - start:.2f} sec\n")

可视化:


import matplotlib.pyplot as plt
import numpy as npdef plot_attention(image_path, result, attention_plot):img = plt.imread(image_path)fig = plt.figure(figsize=(10, 10))len_result = len(result)for i in range(len_result):temp_att = np.resize(attention_plot[i], (8, 8))grid_size = max(np.ceil(len_result / 2), 2)ax = fig.add_subplot(grid_size, grid_size, i + 1)ax.set_title(result[i])imgplot = ax.imshow(img)ax.imshow(temp_att, cmap="gray", alpha=0.6, extent=imgplot.get_extent())plt.tight_layout()plt.show()plot_attention(image_path, result, attention_plot)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601729.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea中使用Lombok 失效,@Slf4j 找不到符号的解决办法

文章目录 一、前言二、问题排查和解决方案三、 其他解决方案3.1 另一种解决方案3.2 参考文章 一、前言 今天在一个多module工程中&#xff0c;新增了一个 springboot&#xff08;版本 2.2.4.RELEASE&#xff09; module&#xff0c;像往常一样&#xff0c;我引入了lombok依赖&…

selenium 用webdriver.Chrome 访问网页闪退解决方案

1.1.1. 解决方案&#xff1a; 1.1.1.1. 移动插件到谷歌的安装目录下 1.1.1.2. 设置环境变量 1.1.1.3. 重启电脑检查成功 解决时间&#xff1a;5min

58.网游逆向分析与插件开发-游戏增加自动化助手接口-游戏菜单文字资源读取的逆向分析

内容来源于&#xff1a;易道云信息技术研究院VIP课 之前的内容&#xff1a;接管游戏的自动药水设定功能-CSDN博客 码云地址&#xff08;master分支&#xff09;&#xff1a;https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号&#xff1a;34b9c1d43b512d0b4a3c395b…

Springboot支付宝沙箱支付(完整详细步骤)

Springboot支付宝沙箱支付&#xff08;完整详细步骤&#xff09; 网页操作步骤1.进入支付宝开发平台—沙箱环境2.点击沙箱进入沙箱环境3.进入沙箱&#xff0c;配置接口加签方式4.配置应用网关5.生成自己的密钥 IntelliJ IDEA 操作步骤1.导入依赖2.在 application.yml 里面进行配…

java基于SSM的毕业生就业管理系统+vue论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本毕业生就业管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信…

TypeScript接口、对象

目录 1、TypeScript 接口 1.1、实例 1.2、联合类型和接口 1.3、接口和数组 1.4、接口和继承 1.5、单继承实例 1.6、多继承实例 2、TypeScript 对象 2.2、对象实例 2.3、TypeScript类型模板 2.4、鸭子类型&#xff08;Duck typing&#xff09; 1、TypeScript 接口 接口…

Mac启动时候出现禁止符号

Mac启动时候出现禁止符号 启动时候出现禁止符号,意味着 选定的启动磁盘 包含 Mac 操作系统&#xff0c;但它不是 您的 Mac 可以使用的 macOS 。您应该在这个磁盘上 重新安装 macOS 。 可以尝试以下苹果提供的方法&#xff1a; Mac启动时候出现禁止符号 不要轻易抹除磁盘&am…

idea将本地编译好的代码上传到hub镜像仓库

第一步&#xff1a;编译打包本地的文件 package 第二步&#xff1a;执行docker bulid打包命令 docker build -t sunyuhua/algo-ability:1.0.0 .sunyuhuasunyuhua-HKF-WXX:~/workspace/shbgit/algo-ability$ docker build -t sunyuhua/algo-ability:1.0.0 . [] Building 141.…

C语言编译器(C语言编程软件)完全攻略

介绍常用C语言编译器的安装、配置和使用。 常用的C语言编译器&#xff08;编程软件&#xff09;介绍&#xff0c;同时附带下载地址、详细的安装教程和使用教程。我们还对比了不同C语言编译器&#xff08;C语言编程软件&#xff09;的优缺点&#xff0c;让初学者知道该如何选择…

差分电路原理以及为什么输出电压要偏移

我们在使用放大器芯片的时候&#xff0c;除了对放大器芯片本身应用外&#xff0c;通常还需要搭建一些外围电路来满足放大器芯片的使用条件&#xff0c;最终满足应用的功能&#xff0c;下面通过一个差分电路来熟悉这些应用。 差分运算放大电路&#xff0c;对共模信号得到有效抑…

C# Image Caption

目录 介绍 效果 模型 decoder_fc_nsc.onnx encoder.onnx 项目 代码 下载 C# Image Caption 介绍 地址&#xff1a;https://github.com/ruotianluo/ImageCaptioning.pytorch I decide to sync up this repo and self-critical.pytorch. (The old master is in old ma…

实战演练 | Navicat 中编辑器设置的配置

Navicat 是一款功能强大的数据库管理工具&#xff0c;为开发人员和数据库管理员提供稳健的环境。其中&#xff0c;一个重要功能是 SQL 编辑器&#xff0c;用户可以在 SQL 编辑器中编写和执行 SQL 查询。Navicat 的编辑器设置可让用户自定义编辑器环境&#xff0c;以满足特定的团…

ejs默认配置 造成原型链污染

文章目录 ejs默认配置 造成原型链污染漏洞背景漏洞分析漏洞利用 例题 [SEETF 2023]Express JavaScript Security ejs默认配置 造成原型链污染 参考文章 漏洞背景 EJS维护者对原型链污染的问题有着很好的理解&#xff0c;并使用非常安全的函数清理他们创建的每个对象 利用Re…

鸿蒙应用中的通知

目录 1、通知流程 2、发布通知 2.1、发布基础类型通知 2.1.1、接口说明 2.1.2、普通文本类型通知 2.1.3、长文本类型通知 2.1.4、多行文本类型通知 2.1.5、图片类型通知 2.2、发布进度条类型通知 2.2.1、接口说明 2.2.2、示例 2.3、为通知添加行为意图 2.3.1、接…

Python基础知识总结1-Python基础概念搞定这一篇就够了

时隔多年不用忘却了很多&#xff0c;再次进行python的汇总总结。好记性不如烂笔头&#xff01; PYTHON基础 Python简介python是什么&#xff1f;Python特点Python应用场景Python版本和兼容问题解决方案python程序基本格式 Python程序的构成代码的组织和缩进使用\行连接符 对象…

解决ChatGPT4.0无法上传文件

问题描述 ChatGPT4.0&#xff1a;上传文件时出错 解决方案&#xff1a; 仔细检查文件的编码格式&#xff0c;他似乎目前只能接受utf-8的编码&#xff0c;所以把文件的编码改为UTF-8即可成功上传

【十六】【动态规划】97. 交错字符串、712. 两个字符串的最小ASCII删除和、718. 最长重复子数组,三道题目深度解析

动态规划 动态规划就像是解决问题的一种策略&#xff0c;它可以帮助我们更高效地找到问题的解决方案。这个策略的核心思想就是将问题分解为一系列的小问题&#xff0c;并将每个小问题的解保存起来。这样&#xff0c;当我们需要解决原始问题的时候&#xff0c;我们就可以直接利…

Hadolint:Lint Dockerfile 的完整指南

想学习如何使用 Hadolint 对 Dockerfile 进行 lint 处理吗&#xff1f;这篇博文将向您展示如何操作。这是关于 Dockerfile linting 的完整指南。 通过对 Dockerfile 进行 lint 检查&#xff0c;您可以及早发现错误和问题&#xff0c;并确保它们遵循最佳实践。 什么是Hadolint…

坐标转换 | EXCEL中批量将经纬度坐标(EPSG:4326)转换为墨卡托坐标(EPSG:3857)

1 需求 坐标系概念&#xff1a; 经纬度坐标&#xff08;EPSG:4326&#xff09;&#xff1a;WGS84坐标系&#xff08;World Geodetic System 1984&#xff09;是一种用于地球表面点的经纬度坐标系。它是美国国防部于1984年建立的&#xff0c;用于将全球地图上的点定位&#xff0…

Vue-2、初识Vue

1、helloword小案列 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>初始Vue</title><!--引入vue--><script type"text/javascript" src"https://cdn.jsdelivr.n…