pytorch异常——RuntimeError:Given groups=1, weight of size..., expected of...

文章目录

    • 省流
    • 异常报错
    • 异常截图
    • 异常代码
    • 原因解释
    • 修正代码
    • 执行结果

省流

  • nn.Conv2d 需要的输入张量格式为 (batch_size, channels, height, width),但您的示例输入张量 x 是 (batch_size, height, width, channels)。因此,需要对输入张量进行转置。

  • 注意,TensorFlow使用"NHWC"(批次、高度、宽度、通道)格式,而PyTorch使用"NCHW"(批次、通道、高度、宽度)格式

异常报错

RuntimeError: Given groups=1, weight of size [16, 3, 2, 3], 
expected input[8, 65, 66, 3] to have 3 channels, 
but got 65 channels instead

异常截图

在这里插入图片描述

异常代码

def down_shifted_conv2d(x , num_filters , filters_size = [2,3],stride = 1, **kwargs):batch_size,H,W,channels = x.shapepadding = (0,0,int(((filters_size[1]) - 1) / 2 ) , int((int(filters_size[1]) - 1) / 2),int(filters_size[0]) - 1 , 0,0,0)x_paded = nn.functional.pad(x, padding)print(x_paded.shape)conv_layer = nn.Conv2d(in_channels=channels, out_channels=num_filters, kernel_size=filters_size,stride=stride, **kwargs)return conv_layer(x_paded)
# Example usage
x = torch.randn(8, 64, 64, 3)  # Example input with batch size 8, height and width 64, and 3 channels
num_filters = 16
output = down_shifted_conv2d(x, num_filters)
print(output.shape)

原因解释

  • 在pytorch中,“nn.Conv2d”需要输入的张量格式为(batch_size,channels,height,width),原图输入的x的格式是(batch_size,height ,weight,channel)所以需要对tensor进行转置。

  • 矩阵交换维度的函数permute,按照编号,将新的顺序填好即可

def down_shifted_conv2d(x , num_filters , filters_size = [2,3], stride = 1, **kwargs):batch_size, H, W, channels = x.shape# Transpose the input tensor to (batch_size, channels, height, width)x = x.permute(0, 3, 1, 2)# Paddingpadding = (int((filters_size[1] - 1) / 2), int((filters_size[1] - 1) / 2),filters_size[0] - 1, 0)x_paded = F.pad(x, padding)

修正代码

def down_shifted_conv2d(x , num_filters , filters_size = [2,3],stride = 1, **kwargs):batch_size,H,W,channels = x.shape# 按照顺序对4个维度分别进行填充padding = (0,0,int(((filters_size[1]) - 1) / 2 ) , int((int(filters_size[1]) - 1) / 2),int(filters_size[0]) - 1 , 0,0,0)x_paded = nn.functional.pad(x, padding)x_paded = x_paded.permute(0,3,1,2)# 进行卷积conv_layer = nn.Conv2d(in_channels=channels, out_channels=num_filters, kernel_size=filters_size,stride=stride, **kwargs)return conv_layer(x_paded)
# Example usage
x = torch.randn(8, 64, 64, 3)  
num_filters = 16
output = down_shifted_conv2d(x, num_filters)
print(output.shape)

执行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/65160.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LLM学习笔记(1)

学习链接 ChatGPT Prompt Engineering for Developers - DeepLearning.AI 一、prompt engineering for developer 1、原则 prompting principles and iterative pattern 2、用于summarize 环境与helper functions import openai import osfrom dotenv import load_dotenv…

[C++] STL_list常用接口的模拟实现

文章目录 1、list的介绍与使用1.1 list的介绍1.2 list的使用 2、list迭代器3、list的构造4、list常用接口的实现4.1 list capacity4.2 插入删除、交换、清理4.2.1 insert任意位置插入4.2.2 push_front头插4.2.3 push_back尾插4.2.4 erase任意位置删除4.2.5 pop_front头删4.2.6 …

Redis之管道解读

目录 基本介绍 使用例子 管道对比 管道与原生批量命令对比 管道与事务对比 使用pipeline注意事项 基准测试 基本介绍 Redis是一种基于客户端-服务端模型以及请求/响应协议的TCP服务器。 这意味着请求通常按如下步骤处理: 客户端发送一个请求到服务器&am…

java 八股文 基础 每天笔记随机刷

Component 和 PostConstruct 搭配使用 被Component注解标识的类在应用程序启动时会被实例化,并由Spring容器进行管理。PostConstruct是一个Java注解,用于标记一个方法在类被实例化后自动执行。该方法必须是非静态的,没有参数,且不…

数字货币量化交易平台

数字货币量化交易平台是近年来金融科技领域迅速崛起的一种创新型交易方式。它通过应用数学模型和算法策略,实现对数字货币市场的自动交易和风险控制。然而,要在这个竞争激烈的领域中脱颖而出,一个数字货币量化交易平台需要具备足够的专业性&a…

正中优配:A股早盘三大股指微涨 华为概念表现活跃

周三(8月30日),到上午收盘,三大股指团体收涨。其间上证指数涨0.06%,报3137.72点;深证成指和创业板指别离涨0.33%、0.12%;沪深两市合计成交额6423.91亿元,总体来看,两市个…

java-数组

数组静态初始化写法: //静态初始化数组 int[] age new int[] {7,18,19}; double[] scores new double[]{67.5,77.8,94.2,99};//静态初始化数组简化写法 int[] age1 {7,18,19}; double[] scores2 {67.5,77.8,94.2,99};数组在内存中定义方式: 1.在内…

opencv的haarcascade_frontalface_default.xml等文件

文章目录 GitHub下载在安装好的OpenCV文件夹下寻找opencv-python中获取 GitHub下载 下载地址:https://github.com/opencv/opencv/tree/master/data/haarcascades 在安装好的OpenCV文件夹下寻找 路径如下: 你安装的opencv路径\OpenCV\opencv\build\et…

ELK安装、部署、调试(一)设计规划及准备

一、整体规划如图: 【filebeat】 需要收集日志的服务器,安装filebeat软件,用于收集日志。logstash也可以收集日志,但是占用的系统资源过大,所以使用了filebeat来收集日志。 【kafka】 接收filebeat的日志&#xff…

Can‘t connect to local MySQL server through socket ‘/tmp/mysql.sock‘

最近在用django框架开发后端时,在运行 $python manage.py makemigrations 命令时,报了以上错误,错误显示连接mysql数据库失败,查看了mysql数据库初始化配置文件my.cnf,我的mysql.sock文件存放路径配置在了/usr/local…

查看GPU占用率

如何监控NVIDIA GPU 的运行状态和使用情况_nvidia 85c_LiBiGo的博客-CSDN博客设备跟踪和管理正成为机器学习工程的中心焦点。这个任务的核心是在模型训练过程中跟踪和报告gpu的使用效率。有效的GPU监控可以帮助我们配置一些非常重要的超参数,例如批大小,…

java八股文面试[数据库]——MySQL索引的数据结构

知识点: 【2023年面试】mysql索引的基本原理_哔哩哔哩_bilibili 【2023年面试】mysql索引结构有哪些,各自的优劣是什么_哔哩哔哩_bilibili

【MySQL学习笔记】(七)内置函数

内置函数 日期函数示例案例-1案例-2 字符串函数示例 数学函数其他函数 日期函数 示例 获得当前年月日 mysql> select current_date(); ---------------- | current_date() | ---------------- | 2023-09-03 | ---------------- 1 row in set (0.00 sec)获得当前时分秒…

java 批量下载将多个文件(minio中存储)压缩成一个zip包

我的需求是将minio中存储的文件按照查询条件查询出来统一压成一个zip包然后下载下来。 思路:针对这个需求,其实可以有多个思路,不过也大同小异,一般都是后端返回流文件前端再处理下载,也有少数是压缩成zip包之后直接给…

C++算法 —— 动态规划(1)斐波那契数列模型(包含动规思路总介绍)

文章目录 1、动规思路简介2、第N个泰波那契数列3、三步问题4、使用最小花费爬楼梯5、解码方法 1、动规思路简介 动规的思路有五个步骤,且最好画图来理解细节,不要怕麻烦。当你开始画图,仔细阅读题时,学习中的沉浸感就体验到了。 …

Linux常用命令——cupsdisable命令

在线Linux命令查询工具 cupsdisable 停止指定的打印机 补充说明 cupsdisable命令用于停止指定的打印机。 语法 cupsdisable(选项)(参数)选项 -E:当连接到服务器时强制使用加密; -U:指定连接服务器时使用的用户名; -u&#…

git的常用命令

初始化git,以及如何提交代码 1、配置用户信息 git config --global user.name zhangsan # 设置用户签名 git config --global user.email zhangsanqq.com # 设置用户邮箱(不会验证,可以不存在)1.1、查看是否已经添加用户配置 在…

长城网络靶场,第一题笔记

黑客使用了哪款扫描工具对论坛进行了扫描?(小写简称) 第一关,第三小题的答案是awvs 思路是先统计查询 然后过滤ip检查流量 过滤语句:tcp and ip.addr ip 114.240179.133没有 第二个101.36.79.67 之后找到了一个…

可扩展的Blender插件开发汇总

成熟的 Blender 3D 插件是令人惊奇的事情。作为 Python 和 Blender 的新手,我经常发现自己被社区中的人们创造的强大的东西弄得目瞪口呆。坦率地说,其中一些包看起来有点神奇,当自我怀疑或冒名顶替综合症的唠叨声音被打破时,很容易想到“如果有人能做出可以做xxx的东西就好…

AI:06-基于OpenCV的二维码识别技术的研究

二维码作为一种广泛应用于信息传递和识别的技术,具有识别速度快、容错率高等优点。本文探讨如何利用OpenCV库实现二维码的快速、准确识别,通过多处代码实例展示技术深度。 二维码作为一种矩阵型的条码,广泛应用于各个领域,如商品追溯、移动支付、活动签到等。二维码的快速…