【快捷测试模型是否可以跑通】设置一张图片的张量形式,送入自己写的模型进行测试

文章目录

  • 1.


1.

import torch.nn as nn
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch.nn.functional as Fclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class PPM(nn.Module):def __init__(self, pooling_sizes=(1, 3, 5)):super().__init__()self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size, size)) for size in pooling_sizes])def forward(self, feat):b, c, h, w = feat.shapeoutput = [layer(feat).view(b, c, -1) for layer in self.layer]output = torch.cat(output, dim=-1)return output# Efficient self attention
class ESA_layer(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)self.ppm = PPM(pooling_sizes=(1, 3, 5))self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):# input x (b, c, h, w)b, c, h, w = x.shapeq, k, v = self.to_qkv(x).chunk(3, dim=1)  # q/k/v shape: (b, inner_dim, h, w)q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)  # q shape: (b, head, n_q, d)k, v = self.ppm(k), self.ppm(v)  # k/v shape: (b, inner_dim, n_kv)k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)  # k shape: (b, head, n_kv, d)v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)  # v shape: (b, head, n_kv, d)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # shape: (b, head, n_q, n_kv)attn = self.attend(dots)out = torch.matmul(attn, v)  # shape: (b, head, n_q, d)out = rearrange(out, 'b head n d -> b n (head d)')return self.to_out(out)class ESA_blcok(nn.Module):def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):super().__init__()self.ESAlayer = ESA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))def forward(self, x):b, c, h, w = x.shapeout = rearrange(x, 'b c h w -> b (h w) c')out = self.ESAlayer(x) + outout = self.ff(out) + outout = rearrange(out, 'b (h w) c -> b c h w', h=h)return out+x# return outdef MaskAveragePooling(x, mask):mask = torch.sigmoid(mask)b, c, h, w = x.shapeeps = 0.0005x_mask = x * maskh, w = x.shape[2], x.shape[3]area = F.avg_pool2d(mask, (h, w)) * h * w + epsx_feat = F.avg_pool2d(x_mask, (h, w)) * h * w / areax_feat = x_feat.view(b, c, -1)return x_feat# Lesion-aware Cross Attention
class LCA_layer(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x, mask):# input x (b, c, h, w)b, c, h, w = x.shapeq, k, v = self.to_qkv(x).chunk(3, dim=1)  # q/k/v shape: (b, inner_dim, h, w)q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)  # q shape: (b, head, n_q, d)k, v = MaskAveragePooling(k, mask), MaskAveragePooling(v, mask)  # k/v shape: (b, inner_dim, 1)k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)  # k shape: (b, head, 1, d)v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)  # v shape: (b, head, 1, d)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # shape: (b, head, n_q, n_kv)attn = self.attend(dots)out = torch.matmul(attn, v)  # shape: (b, head, n_q, d)out = rearrange(out, 'b head n d -> b n (head d)')return self.to_out(out)class LCA_blcok(nn.Module):def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):super().__init__()self.LCAlayer = LCA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))def forward(self, x, mask):b, c, h, w = x.shapeout = rearrange(x, 'b c h w -> b (h w) c')out = self.LCAlayer(x, mask) + outout = self.ff(out) + outout = rearrange(out, 'b (h w) c -> b c h w', h=h)return out# test
if __name__ == '__main__':x = torch.rand((4, 3, 320, 320))mask = torch.rand(4, 1, 320, 320)lca = LCA_blcok(dim=3)esa = ESA_blcok(dim=3)print(lca(x, mask).shape)print(esa(x).shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/113089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

excel怎么固定前几行前几列不滚动?

在Excel中,如果你想固定前几行或前几列不滚动,可以通过以下几种方法来实现。详细的介绍如下: **固定前几行不滚动:** 1. 选择需要固定的行数。例如,如果你想要固定前3行,应该选中第4行的单元格。 2. 在E…

C++ 友元函数和友元类

前言 在本文中,您将学习在C 中创建友元函数和友元类,并在程序中有效地使用它们。OOP的重要概念之一是数据隐藏,即非成员函数无法访问对象的私有或受保护的数据。但是,有时这种限制可能迫使程序员编写冗长而复杂的代码。因此&#…

leetcode_171Excel表列序号

1. 题意 把excel中列序号字符串转换为10进制数。 Excel表列序号 2. 题解 26进制转10进制 class Solution { public:int titleToNumber(string columnTitle) {int sz columnTitle.size();int ans 0;int base 1;for ( int i sz - 1; ~i; --i){int v columnTitle[i] - A …

使用 ClickHouse 深入了解 Apache Parquet (一)

​ 【squids.cn】 全网zui低价RDS,免费的迁移工具DBMotion、数据库备份工具DBTwin、SQL开发工具等 自2013年作为Hadoop的列存储发布以来,Parquet几乎已经成为一种无处不在的文件交换格式,它提供了高效的存储和检索。这种采纳使其成为更近期的…

JUC并发编程——Volatile详解(基于狂神说的学习笔记)

Volatile Volatile 是Java虚拟机提供的轻量级的同步机制 1、保证可见性 public class JMMDemo {// 在num前添加关键字volatile,保证num在所有线程可见,即修改就被通知private volatile static int num 0;public static void main(String[] args) thr…

数字电路学习

资料 元器件 电流、电压、电阻、电容、电感、保险丝、熔断器、接插件、蜂鸣器、继电器、三极管、mos管、 型号、特性、参数 数据手册 立创商城:https://www.szlcsc.com/?cZH 华秋商城:https://www.hqchip.com/ 公式 欧姆定律 IU/R 仿真软件 mu…

Crypto(5)2023xctf ezCrypto(待补)

下载地址: https://adworld.xctf.org.cn/match/list?event_hasha37c4ee0-1808-11ee-ab28-000c29bc20bf 题目代码分析: #这两行导入了Python标准库中的 random 和 string 模块,用于生成随机数和处理字符串 import random import stringcha…

【六:pytest框架介绍】

常见的请求对象requests.get()requests.post()requests.delete()requests.put()requests.request()常见的响应对象reprequests.request()//返回字符串格式数据print(req.text)//返回字节格式数据print(req.content)//返回字典格式数据print(req.json)#状态码print(req.status_c…

LLMs之RAG:利用langchain实现RAG应用五大思路步骤—基于langchain使用LLMs(ChatGPT)构建一个问题回答文档的应用程序实战代码

LLMs之RAG:利用langchain实现RAG应用五大思路步骤—基于langchain使用LLMs(ChatGPT)构建一个问题回答文档的应用程序实战代码 目录 相关文章

基于STM32设计的小龙虾养殖系统(带手机APP)

一、项目介绍 随着人们对健康生活需求的提高,小龙虾逐渐成为现代消费者餐桌上的一道风味佳肴,并且市场需求不断扩大。然而,小龙虾的养殖需要注意许多因素,其中最重要的就是水质条件。水质不良会导致小龙虾死亡率增加,降低养殖效益。因此,为了保证小龙虾的健康生长,必须…

神经网络的发展历史

神经网络的发展历史可以追溯到上世纪的数学理论和生物学研究。以下是神经网络发展史的详细概述: 早期的神经元模型: 1943年,Warren McCulloch和Walter Pitts提出了一种神经元模型,被称为MCP神经元模型,它模拟了生物神经…

v-model修饰符 .lazy .number .trim

1、v-model.lazy“xxx” 默认情况下,v-model它是在每次输入数据时触发input事件来更新数据的 使用 .lazy 修饰符后,当改变数据失去焦点-触发change事件来进行更新数据 2、v-model.number"xxx" 它会自动将输入的值自动转成number 类型&#x…

使用高防服务器有什么好处?103.216.155.x

为什么建议租用高防服务器 第一,高防服务器由于业务的特殊性,本身机器的配置要求高,服务器的价格相比普通的贵,而且,机器还有维护费、托管费等,这会让运营的成本上升。 第二,租用高防服务器&a…

GC overhead limit exceeded问题

1.问题现象 程序包运行时候发生了java.lang.OutOfMemoryError: GC overhead limit exceeded异常, 详细信息如下 org.apache.ibatis.exceptions.PersistenceException: ### Error querying database. Cause: org.jboss.util.NestedSQLException: Error; - nested t…

ELK之LogStash插件grok和geoip的配置使用

本文针对LogStash常用插件grok和geoip的使用进行说明: 一、使用grok输出结构化数据 编辑 first-pipeline.conf 文件,修改为如下内容: input{#stdin{type > stdin}file {# 读取文件的路径path > ["/tmp/access.log"]start_…

【斗罗二】冰帝两次险些杀死雨浩,天梦哥求助伊老遭拒绝,霍云儿现身救儿子

Hello,小伙伴们,我是小郑继续为大家深度解析绝世唐门。 斗罗大陆动画第二部绝世唐门已经更新了,霍雨浩与冰帝完美融合,成功觉醒了第二武魂,霍挂的时代正式到来。只是在整个第19集中,官方做了大量的改编,不但…

Ubuntu 20.04 上安装和配置 neo4j

1. 进入要安装neo4j的ubuntu环境。 2. 添加Debian资源库。 java 1.8.xx版本对应neo4j 3.xx版本(jdk 11版本对应neo4j 4.xx版本): (1)wget -O - https://debian.neo4j.com/neotechnology.gpg.key | sudo apt-key add…

Yolov8-pose关键点检测:模型轻量化创新 |多尺度空洞注意力(MSDA)结合C2f | 中科院一区顶刊 DilateFormer 2023.9

💡💡💡本文解决什么问题:多尺度空洞注意力(MSDA)采用多头的设计,在不同的头部使用不同的空洞率执行滑动窗口膨胀注意力(SWDA),全网独家首发,创新力度十足,适合科研 1)与C2f结合; MSDA | GFLOPs从9.6降低至8.5, mAP50从0.921降低至0.909,mAP50-95从0.697提…

uniapp缓存对象数组

需求:使用uniapp,模拟key(表名)增删改查对象数组,每个key可以单独操作,并模拟面对对象对应表,每个key对应的baseInstance 类似一个操作类,当然如果你场景比较简单,可以改…

AC修炼计划(AtCoder Regular Contest 167)

传送门:AtCoder Regular Contest 167 - AtCoder 再次感谢樱雪喵大佬的题解,讲的很详细,Orz。 大佬的博客链接如下:Atcoder Regular Contest 167 - 樱雪喵 - 博客园 (cnblogs.com) 第一题很签到,就省略掉了。 第二题…