机器学习-基于attention机制来实现对Image Caption图像描述实验

机器学习-基于attention机制来实现对Image Caption图像描述实验

实验目的

基于attention机制来实现对Image Caption图像描述

实验内容

1.了解一下RNN的Encoder-Decoder结构

在最原始的RNN结构中,输入序列和输出序列必须是严格等长的。但在机器翻译等任务中,源语言句子的长度和目标语言句子的长度往往不同,因此我们需要将原始序列映射为一个不同长度的序列。Encoder-Decoder模型就解决了这样一个长度不一致的映射问题。

1

2.模型架构训练

在Image Caption输入的图像代替了之前机器翻译中的输入的单词序列,图像是一系列的像素值,我们需要从使用图像特征提取常用的CNN从图像中提取出相应的视觉特征,然后使用Decoder将该特征解码成输出序列,下图是论文的网络结构,特征提取采用的是CNN,Decoder部分,将RNN换成了性能更好的LSTM,输入还是word embedding,每步的输出是单词表中所有单词的概率。

实验数据和程序清单

import json# 加载数据集标注
with open("annotations/captions_train2014.json", "r") as f:annotations = json.load(f)# 提取图像文件名和描述
image_path_to_caption = {}
for val in annotations["annotations"]:caption = f"<start> {val['caption']} <end>"image_path = "train2014/" + "COCO_train2014_" + "%012d.jpg" % (val["image_id"])if image_path in image_path_to_caption:image_path_to_caption[image_path].append(caption)else:image_path_to_caption[image_path] = [caption]image_paths = list(image_path_to_caption.keys())
归一化处理
import tensorflow as tfdef load_image(image_path):img = tf.io.read_file(image_path)img = tf.image.decode_jpeg(img, channels=3)img = tf.image.resize(img, (299, 299))img = tf.keras.applications.inception_v3.preprocess_input(img)
return img, image_path
模型构建
from tensorflow.keras.applications import InceptionV3encoder = InceptionV3(weights="imagenet", include_top=False)
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from tensorflow.keras.models import Modelembedding_dim = 256
vocab_size = 10000  # 您可以根据需要调整词汇表大小
max_length = 40  # 您可以根据需要调整最大描述长度# 解码器输入
input_caption = Input(shape=(max_length,))
embedding = Embedding(vocab_size, embedding_dim)(input_caption)
lstm_output = LSTM(256)(embedding)
output_caption = Dense(vocab_size, activation="softmax")(lstm_output)# 定义解码器模型
decoder = Model(inputs=input_caption, outputs=output_caption)
模型训练
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")def loss_function(real, pred):mask = tf.math.logical_not(tf.math.equal(real, 0))loss_ = loss_object(real, pred)mask = tf.cast(mask, dtype=loss_.dtype)loss_ *= maskreturn tf.reduce_mean(loss_)
@tf.function
def train_step(img_tensor, target):loss = 0hidden = decoder.reset_state(batch_size=target.shape[0])dec_input = tf.expand_dims([tokenizer.word_index["<start>"]] * target.shape[0], 1)with tf.GradientTape() as tape:features = encoder(img_tensor)for i in range(1, target.shape[1]):predictions = decoder([features, hidden, dec_input])loss += loss_function(target[:, i], predictions)dec_input = tf.expand_dims(target[:, i], 1)total_loss = loss / int(target.shape[1])trainable_variables = encoder.trainable_variables + decoder.trainable_variablesgradients = tape.gradient(loss, trainable_variables)optimizer.apply_gradients(zip(gradients, trainable_variables))return loss, total_loss
import timeepochs = 10
batch_size = 64
buffer_size = 1000dataset = tf.data.Dataset.from_tensor_slices((image_paths, captions))
dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)for epoch in range(epochs):start = time.time()total_loss = 0for (batch, (img_tensor, target)) in enumerate(dataset):batch_loss, t_loss = train_step(img_tensor, target)total_loss += t_lossif batch % 100 == 0:print(f"Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy() / int(target.shape[1]):.4f}")print(f"Epoch {epoch+1} Loss {total_loss/len(image_paths):.6f}")print(f"Time taken for 1 epoch: {time.time() - start:.2f} sec\n")

可视化:


import matplotlib.pyplot as plt
import numpy as npdef plot_attention(image_path, result, attention_plot):img = plt.imread(image_path)fig = plt.figure(figsize=(10, 10))len_result = len(result)for i in range(len_result):temp_att = np.resize(attention_plot[i], (8, 8))grid_size = max(np.ceil(len_result / 2), 2)ax = fig.add_subplot(grid_size, grid_size, i + 1)ax.set_title(result[i])imgplot = ax.imshow(img)ax.imshow(temp_att, cmap="gray", alpha=0.6, extent=imgplot.get_extent())plt.tight_layout()plt.show()plot_attention(image_path, result, attention_plot)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601729.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea中使用Lombok 失效,@Slf4j 找不到符号的解决办法

文章目录 一、前言二、问题排查和解决方案三、 其他解决方案3.1 另一种解决方案3.2 参考文章 一、前言 今天在一个多module工程中&#xff0c;新增了一个 springboot&#xff08;版本 2.2.4.RELEASE&#xff09; module&#xff0c;像往常一样&#xff0c;我引入了lombok依赖&…

selenium 用webdriver.Chrome 访问网页闪退解决方案

1.1.1. 解决方案&#xff1a; 1.1.1.1. 移动插件到谷歌的安装目录下 1.1.1.2. 设置环境变量 1.1.1.3. 重启电脑检查成功 解决时间&#xff1a;5min

58.网游逆向分析与插件开发-游戏增加自动化助手接口-游戏菜单文字资源读取的逆向分析

内容来源于&#xff1a;易道云信息技术研究院VIP课 之前的内容&#xff1a;接管游戏的自动药水设定功能-CSDN博客 码云地址&#xff08;master分支&#xff09;&#xff1a;https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号&#xff1a;34b9c1d43b512d0b4a3c395b…

Elasticsearch地理位置数据索引

地理位置数据索引 在 Elasticsearch 中&#xff0c;地理位置数据的索引涉及两种主要的字段类型&#xff1a;geo_point 和 geo_shape。这些字段类型允许 Elasticsearch 存储和查询地理空间数据&#xff0c;如坐标点、线和多边形。 geo_point Elasticsearch的geo_point字段类型…

Springboot支付宝沙箱支付(完整详细步骤)

Springboot支付宝沙箱支付&#xff08;完整详细步骤&#xff09; 网页操作步骤1.进入支付宝开发平台—沙箱环境2.点击沙箱进入沙箱环境3.进入沙箱&#xff0c;配置接口加签方式4.配置应用网关5.生成自己的密钥 IntelliJ IDEA 操作步骤1.导入依赖2.在 application.yml 里面进行配…

java基于SSM的毕业生就业管理系统+vue论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本毕业生就业管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信…

TypeScript接口、对象

目录 1、TypeScript 接口 1.1、实例 1.2、联合类型和接口 1.3、接口和数组 1.4、接口和继承 1.5、单继承实例 1.6、多继承实例 2、TypeScript 对象 2.2、对象实例 2.3、TypeScript类型模板 2.4、鸭子类型&#xff08;Duck typing&#xff09; 1、TypeScript 接口 接口…

问题 H: 取余运算

题目描述 输入b&#xff0c;p&#xff0c;k的值&#xff0c;求b^p mod k的值&#xff08;即b的p次方除以k的余数&#xff09;。其中b&#xff0c;p&#xff0c;k*k为32位整数。 输入 输入b&#xff0c;p&#xff0c;k的值 输出 输出b^p mod k的值 样例输入 2 10 9 样例…

Backtrader 文档学习-Strategy(中)

Backtrader 文档学习-Strategy&#xff08;中&#xff09; Strategy是BT的核心模块&#xff0c;需要慢慢摸索内部设计的流程&#xff0c;方法的用途&#xff0c;使用场景&#xff0c;还要一些常用的状态值。 本章主要介绍关于strategy的生存周期&#xff0c;notify_order方法&…

npm指令

1、npm install express&#xff1a;安装Node模块 安装完毕后会产生一个node_modules目录&#xff0c;其目录下就是安装的各个node模块。 2、npm view express&#xff1a;查看node模块的package.json文件夹 注意事项&#xff1a;如果想要查看package.json文件夹下某个标签的…

Mac启动时候出现禁止符号

Mac启动时候出现禁止符号 启动时候出现禁止符号,意味着 选定的启动磁盘 包含 Mac 操作系统&#xff0c;但它不是 您的 Mac 可以使用的 macOS 。您应该在这个磁盘上 重新安装 macOS 。 可以尝试以下苹果提供的方法&#xff1a; Mac启动时候出现禁止符号 不要轻易抹除磁盘&am…

指针和引用

指针使用操作符 "*" 和 "->"&#xff0c;引用使用操作符"."&#xff0c;他们具有相同的功能&#xff0c;都是间接的引用其他对象。 指针和引用的选择 在任何情况下都不能使用指向空值的引用&#xff0c;引用总是必须指向某些对象。如果你使…

鸡兔同笼问题加强版

描述 已知鸡和兔的总数量为 n,总腿数为 m。输入 n 和 m,依次输出鸡和兔的数目&#xff0c;如果无解&#xff0c;则输出 “No answer”(不要引号)。 输入描述 第一行输入一个数据 a,代表接下来共有几组数据&#xff0c;在接下来的 (a≤100000) a 行里&#xff0c;每行都有一个…

GhostscriptExample GS

1.导出图片 package cn.net.haotuo.pojo; import java.io.*;public class GhostscriptExample {public static void main(String[] args) { String gsPath "C:/Program Files/gs/gs9.54.0/bin/gswin64c.exe";String inputFilePath "C:\\Users\\Admini…

idea将本地编译好的代码上传到hub镜像仓库

第一步&#xff1a;编译打包本地的文件 package 第二步&#xff1a;执行docker bulid打包命令 docker build -t sunyuhua/algo-ability:1.0.0 .sunyuhuasunyuhua-HKF-WXX:~/workspace/shbgit/algo-ability$ docker build -t sunyuhua/algo-ability:1.0.0 . [] Building 141.…

C语言编译器(C语言编程软件)完全攻略

介绍常用C语言编译器的安装、配置和使用。 常用的C语言编译器&#xff08;编程软件&#xff09;介绍&#xff0c;同时附带下载地址、详细的安装教程和使用教程。我们还对比了不同C语言编译器&#xff08;C语言编程软件&#xff09;的优缺点&#xff0c;让初学者知道该如何选择…

差分电路原理以及为什么输出电压要偏移

我们在使用放大器芯片的时候&#xff0c;除了对放大器芯片本身应用外&#xff0c;通常还需要搭建一些外围电路来满足放大器芯片的使用条件&#xff0c;最终满足应用的功能&#xff0c;下面通过一个差分电路来熟悉这些应用。 差分运算放大电路&#xff0c;对共模信号得到有效抑…

微软的一些公开课,Python、机器学习、SQL、AI,全部免费

大家好&#xff0c;我是老章&#xff0c;刷X看到一位博主Alif Hossain⚡alifcoder总结了微软的一些公开课&#xff0c;全部免费&#xff0c;蛮不错的。感兴趣可以学一波&#xff0c;还能领徽章。 1. 机器学习简介 本课程是学习机器学习基础知识和用例的好方法。 → 11 个模块…

C# Image Caption

目录 介绍 效果 模型 decoder_fc_nsc.onnx encoder.onnx 项目 代码 下载 C# Image Caption 介绍 地址&#xff1a;https://github.com/ruotianluo/ImageCaptioning.pytorch I decide to sync up this repo and self-critical.pytorch. (The old master is in old ma…

实战演练 | Navicat 中编辑器设置的配置

Navicat 是一款功能强大的数据库管理工具&#xff0c;为开发人员和数据库管理员提供稳健的环境。其中&#xff0c;一个重要功能是 SQL 编辑器&#xff0c;用户可以在 SQL 编辑器中编写和执行 SQL 查询。Navicat 的编辑器设置可让用户自定义编辑器环境&#xff0c;以满足特定的团…