【Image captioning】ruotianluo/self-critical.pytorch之1—数据集的加载与使用

【Image captioning】ruotianluo/self-critical.pytorch之1—数据集的加载与使用

作者:安静到无声 个人主页

数据加载程序示意图

数据集程序

使用方法

示例代码

#%%from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriterimport numpy as npimport time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' ##from six.moves import cPickle
import traceback
from collections import defaultdictimport captioning.utils.opts as opts
import captioning.models as models
from captioning.data.dataloader import *
import skimage.io
import captioning.utils.eval_utils as eval_utils
import captioning.utils.misc as utils
from captioning.utils.rewards import init_scorer, get_self_critical_reward
from captioning.modules.loss_wrapper import LossWrapperimport sys
sys.path.append("..")
import time
#%%opt = opts.parse_opt()
opt.input_json = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json'
opt.input_label_h5 = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5'
opt.input_fc_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc'
opt.input_att_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att'
opt.batch_size = 1
opt.train_only = 1opt.use_att = True
opt.use_att = True
opt.use_box = 0#%%
print(opt.input_json)
print(opt.batch_size)  #批量化为16
loader = DataLoader(opt)  # 数据加载
#打印字内容
#print(loader.get_vocab())  #返回字典
for i in range(2):data = loader.get_batch('train')print('———————————————————※输入的信息特征※——————————————————')  #[1,2048] 全连接特征print('全连接特征【fc_feats】的形状:', data['fc_feats'].shape)  #[1,2048] 全连接特征print('全连接特征【att_feats】的形状:', data['att_feats'].shape)  #[1,2048] 注意力特征print('att_masks', data['att_masks'])print('含有的信息infos:', data['infos'])  #infos [{'ix': 117986, 'id': 495956, 'file_path': 'train2014/COCO_train2014_000000495956.jpg'}]print('———————————————————※标签信息※——————————————————')  #[1,2048] 全连接特征print('labels', data['labels'])     #添加了一些0print('gts:', data['gts'])          #没有添加的原始句子print('masks', data['masks'])print('———————————————————※记录遍历的位置※——————————————————')  #[1,2048] 全连接特征print('bounds', data['bounds'])time.sleep(1)print(data.keys())

输出结果:

Hugginface transformers not installed; please visit https://github.com/huggingface/transformers
meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`
Warning: coco-caption not available
cider or coco-caption missing
/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
1
是否使用【注意力特征[use_fc]: True
是否使用【注意力特征[use_att]: True
是否在注意力特征中使用【检测框特征[use_box]: 0
DataLoader loading json file:  /home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
vocab size is  9487
DataLoader loading h5 file:  /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att data/cocotalk_box /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 82783 images to split train(训练集有多少图片)
assigned 5000 images to split val(验证集有多少图片)
assigned 5000 images to split test(测试集有多少图片)
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 60494, 'id': 46065, 'file_path': 'train2014/COCO_train2014_000000046065.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[   0,    1,  271,   17, 7068,   35,   98,    6,    1,  102,    3,912,    0,    0,    0,    0,    0,    0],[   0,  995, 2309, 2308,  609,    6,    1,  271,  119,  912,    0,0,    0,    0,    0,    0,    0,    0],[   0, 2309, 9487,  179,   98,    6,    1,   46,  271,    0,    0,0,    0,    0,    0,    0,    0,    0],[   0,  182,   35,  995, 7068,    6,    1,  271,    3,   60,  678,32,   14,   29,    0,    0,    0,    0],[   0,  995,  915,   17, 2309, 3130,    6,    1,   46,  271,    0,0,    0,    0,    0,    0,    0,    0]]])
gts: [array([[   1,  271,   17, 7068,   35,   98,    6,    1,  102,    3,  912,0,    0,    0,    0,    0],[ 995, 2309, 2308,  609,    6,    1,  271,  119,  912,    0,    0,0,    0,    0,    0,    0],[2309, 9487,  179,   98,    6,    1,   46,  271,    0,    0,    0,0,    0,    0,    0,    0],[ 182,   35,  995, 7068,    6,    1,  271,    3,   60,  678,   32,14,   29,    0,    0,    0],[ 995,  915,   17, 2309, 3130,    6,    1,   46,  271,    0,    0,0,    0,    0,    0,    0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 1, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 106440, 'id': 151264, 'file_path': 'train2014/COCO_train2014_000000151264.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[   0,    1,  230,    6,   14,  230,  237,   32, 1086,  627,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1, 6035,  230,   35,  274,  127,  225, 1598,  335,    1,940,    0,    0,    0,    0,    0,    0],[   0,    1,  230,   35,  900,   32,  307,  756,   61,  607,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1,  230,   35,   98,   79,    1,  230,  224,    0,    0,0,    0,    0,    0,    0,    0,    0],[   0,    1,   46, 1109,  230, 1596,  245,    1,  224,    0,    0,0,    0,    0,    0,    0,    0,    0]]])
gts: [array([[   1,  230,    6,   14,  230,  237,   32, 1086,  627,    0,    0,0,    0,    0,    0,    0],[   1, 6035,  230,   35,  274,  127,  225, 1598,  335,    1,  940,0,    0,    0,    0,    0],[   1,  230,   35,  900,   32,  307,  756,   61,  607,    0,    0,0,    0,    0,    0,    0],[   1,  230,   35,   98,   79,    1,  230,  224,    0,    0,    0,0,    0,    0,    0,    0],[   1,   46, 1109,  230, 1596,  245,    1,  224,    0,    0,    0,0,    0,    0,    0,    0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 2, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])

推荐专栏

🔥 手把手实现Image captioning

💯CNN模型压缩

💖模式识别与人工智能(程序与算法)

🔥FPGA—Verilog与Hls学习与实践

💯基于Pytorch的自然语言处理入门与实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/40477.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源,微信小程序 美食便签地图(FoodNoteMap)的设计与开发

目录 0 前言 1 美食便签地图简介 2 美食便签地图小程序端开发 2.1技术选型 2.2前端UI设计 2.3主页界面 2.4个人信息界面 2.5 添加美食界面 2.6美食便签界面 2.8 美食好友界面 2.9 美食圈子界面 2.10 子页面-店铺详情界面 2.11 后台数据缓存 2.12 订阅消息通知 2.1…

探索区块链世界:去中心化应用(DApp)的崭新前景

随着科技的不断发展,区块链技术逐渐引领着数字时代的潮流。在这个充满创新和变革的领域中,去中心化应用(DApp)成为了备受瞩目的焦点。DApp 不仅改变了传统应用程序的范式,还在金融、社交、游戏等多个领域展现出了广阔的…

GRPC 链接 NODE 和 GOLANG

GRPC 链接 NODE 和 GOLANG GRPC 了解 什么是GRPC gRPC 采用了 Protocol Buffers 作为数据序列化和反序列化的协议,可以更快速地传输数据,并支持多种编程语言的跨平台使用gRPC 提供“统一水平层”来对此类问题进行抽象化。 开发人员在本机平台中编写专…

打造专属照片分享平台:快速上手Piwigo网页搭建

文章目录 通过cpolar分享本地电脑上有趣的照片:部署piwigo网页前言1.Piwigo2. 使用phpstudy网页运行3. 创建网站4. 开始安装Piwogo 总结 🍀小结🍀 🎉博客主页:小智_x0___0x_ 🎉欢迎关注:&#x…

深度学习1:通过模型评价指标优化训练

P(Positive)表示预测为正样本,N(negative)表示预测为负样本,T(True)表示预测正确,F(False)表示预测错误。 TP:正样本预测正确的数量(正确检测) FP:负样本预测正确数量(误检测) TN…

卷积神经网络CNN

卷积神经网络CNN 1 应用领域1 检测任务2 分类和检索3 超分辨率重构4 医学任务5 无人驾驶6 人脸识别 2 卷积的作用3 卷积特征值计算方法4 得到特征图表示5 步长和卷积核大小对结果的影响1 步长2 卷积核 6 边缘填充方法7 特征图尺寸计算与参数共享8 池化层的作用9 整体网络架构10…

【GitLab私有仓库】如何在Linux上用Gitlab搭建自己的私有库并配置cpolar内网穿透?

文章目录 前言1. 下载Gitlab2. 安装Gitlab3. 启动Gitlab4. 安装cpolar5. 创建隧道配置访问地址6. 固定GitLab访问地址6.1 保留二级子域名6.2 配置二级子域名 7. 测试访问二级子域名 前言 GitLab 是一个用于仓库管理系统的开源项目,使用Git作为代码管理工具&#xf…

软考笔记——10.项目管理

进度管理 进度管理就是采用科学的方法,确定进度目标,编制进度计划和资源供应计划,进行进度控制,在与质量、成本目标协调的基础上,实现工期目标。 具体来说,包括以下过程: (1) 活动定义&#…

HLS实现FIR低通滤波器+System Generator仿真

硬件:ZYNQ7010 软件:MATLAB 2019b、Vivado 2017.4、HLS 2017.4、System Generator 2017.4 1、MATLAB设计低通滤波器 FPGA系统时钟 50MHz,也是采样频率。用 MATLAB 生成 1MHz 和 10MHz 的正弦波叠加的信号,并量化为 14bit 整数。把…

Web网页浏览器远程访问jupyter notebook服务器【内网穿透】

文章目录 前言1. Python环境安装2. Jupyter 安装3. 启动Jupyter Notebook4. 远程访问4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5. 固定公网地址 前言 Jupyter Notebook,它是一个交互式的数据科学和计算环境,支持多种编程语言,如…

Hyper-v导致Vmware window无法启动崩溃记录

最近有几次vmware启动window10直接崩溃情况,显示蓝屏报错。一开始没在意,以为是因为固态硬盘错了几个字节导致的? 但后来想想不对啊。vmware用了也有10来年了,稳得一笔,在仔细思考了一下后发现打不开的win10这三个虚拟…

设计模式之构建器(Builder)C++实现

1、构建器提出 在软件功能开发中,有时面临“一个复杂对象”的创建工作,该对象的每个功能接口由于需求的变化,会使每个功能接口发生变化,但是该对象使用每个功能实现一个接口的流程是稳定的。构建器就是解决该类现象的。构建就是定…

【Java】项目管理工具Maven的安装与使用

文章目录 1. Maven概述2. Maven的下载与安装2.1 下载2.2 安装 3. Maven仓库配置3.1 修改本地仓库配置3.2 修改远程仓库配置3.3 修改后的settings.xml 4. 使用Maven创建项目4.1 手工创建Java项目4.2 原型创建Java项目4.3 原型创建Web项目 5. Tomcat启动Web项目5.1 使用Tomcat插件…

【CTF-web】备份是个好习惯(查找备份文件、双写绕过、md5加密绕过)

题目链接:https://ctf.bugku.com/challenges/detail/id/83.html 经过扫描可以找到index.php.bak备份文件,下载下来后打开发现是index.php的原代码,如下图所示。 由代码可知我们要绕过md5加密,两数如果满足科学计数法的形式的话&a…

模型预测笔记(一):数据清洗及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)

模型预测 一、导入关键包二、如何载入、分析和保存文件三、修改缺失值3.1 众数3.2 平均值3.3 中位数3.4 0填充 四、修改异常值4.1 删除4.2 替换 五、数据绘图分析5.1 饼状图5.1.1 绘制某一特征的数值情况(二分类) 5.2 柱状图5.2.1 单特征与目标特征之间的…

OpenCV基本操作——算数操作

目录 图像的加法图像的混合 图像的加法 两个图像应该具有相同的大小和类型,或者第二个图像可以是标量值 注意:OpenCV加法和Numpy加法之间存在差异。OpenCV的加法是饱和操作,而Numpy添加的是模运算 import numpy as np import cv2 as cv imp…

[数据集][目标检测]钢材表面缺陷目标检测数据集VOC格式2279张10类别

数据集格式:Pascal VOC格式(不包含分割路径的txt文件和yolo格式的txt文件,仅仅包含jpg图片和对应的xml) 图片数量(jpg文件个数):2279 标注数量(xml文件个数):2279 标注类别数:10 标注类别名称:["yueyawan",&…

jenkins 连接服务器,提示Can‘t connect to server

在Jenkins 添加服务器时,提示 Cant connect to server,如图 搞了好久,不知道为什么不行~原来是行的,现在删了 新建一个也不行。

2023牛客暑期多校训练营8-C Clamped Sequence II

2023牛客暑期多校训练营8-C Clamped Sequence II https://ac.nowcoder.com/acm/contest/57362/C 文章目录 2023牛客暑期多校训练营8-C Clamped Sequence II题意解题思路代码 题意 解题思路 先考虑不加紧密度的情况,要支持单点修改,整体查询&#xff0…

[C++]笔记 - 知识点积累

一.运算符的优先级 一共15个级别 最高优先级 : () []最低优先级 :逗号表达式倒数第二低优先级 : 赋值和符合赋值(,,-...) ! >算术运算符 > 关系运算符 > && >> || >赋值运算符 二.数据类型转换 隐式类型转换 算数转换 char int long longlong flo…