SPGEMM_example解析

// 设备端并行求偏移数组 不判断当前列号是否出现过也添加进去
// 列号排序
// 然后计算
#include <hip/hip_runtime.h>// 核函数每个线程负责一行 计算当前行中有多少个元素 并先存入相应的偏移量数组中行号的+1位置 (不判断列号是否重复的版本)
__global__ void getRowNnz(const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B, int *dptr_offset_C, int m)
{int rowindex = threadIdx.x + blockDim.x * blockIdx.x;if (rowindex < m){int row_nnz = 0; // row_nnz记录当前第i行一共有多少个元素 初始设为0int A_begin = dptr_offset_A[rowindex];int A_end = dptr_offset_A[rowindex + 1];for (int jj = A_begin; jj < A_end; jj++){								 // jj为当前第rowindex行的非0元素在value数组与col数组中的起始位置int j = dptr_colindex_A[jj]; // j为当前A的第rowindex行中非0元素所处于的列号  然后找B中第j行的非0元素int B_begin = dptr_offset_B[j];int B_end = dptr_offset_B[j + 1];row_nnz += B_end - B_begin;}dptr_offset_C[rowindex + 1] = row_nnz; // 得到每行有多少个元素先存入相应的偏移量数组中行号的+1位置}
}// 核函数每个线程负责一行 标识当前元素的行号 并且对每行列索引数组中的相应区域进行排序
__global__ void SortAndRow(const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B,int *dptr_colindex_C, int *dptr_rowindex_C, int *dptr_offset_C, int m)
{int rowindex = threadIdx.x + blockDim.x * blockIdx.x;if (rowindex < m){// 找到当前行元素再col和val数组的下标范围int left = dptr_offset_C[rowindex];int right = dptr_offset_C[rowindex + 1];// 将遍历得到的列数组的列号存储到相应位置// 先设置初始插入的位置// 所有元素插入完成后再排序int pos = left; // 插入位置初始为left位置int i = left;int A_begin = dptr_offset_A[rowindex];int A_end = dptr_offset_A[rowindex + 1];for (int jj = A_begin; jj < A_end; jj++){								 // jj为当前第i行的非0元素在value数组与col数组中的起始位置int j = dptr_colindex_A[jj]; // j为当前A的第i行中非0元素所处于的列号  然后找B中第j行的非0元素int B_begin = dptr_offset_B[j];int B_end = dptr_offset_B[j + 1];for (int kk = B_begin; kk < B_end; kk++){								 // kk为当前B中第j行中的非0元素在value数组与col数组中的起始位置int k = dptr_colindex_B[kk]; // k为当前B的第j行中非0元素所处的列号 即最终结果元素所处的列号dptr_colindex_C[pos] = k;pos++;}}// 排序算法// // 插入排序// for (int i = left + 1; i < right; i++)// {// 	int key = dptr_colindex_C[i];// 	int j = i - 1;// 	while (j >= left && dptr_colindex_C[j] > key)// 	{// 		dptr_colindex_C[j + 1] = dptr_colindex_C[j];// 		j--;// 	}// 	dptr_colindex_C[j + 1] = key;// }// 希尔排序int n = right - left;int p, q, gap;for (gap = n / 2; gap > 0; gap /= 2){for (p = 0; p < gap; p++){for (q = p + gap + left; q < n + left; q += gap){if (dptr_colindex_C[q] < dptr_colindex_C[q - gap]){int tmp = dptr_colindex_C[q];int k = q - gap;while (k >= left && dptr_colindex_C[k] > tmp){dptr_colindex_C[k + gap] = dptr_colindex_C[k];k -= gap;}dptr_colindex_C[k + gap] = tmp;}}}}// 归并排序 待补充//  初始化行号数组的值for (int i = left; i < right; i++){dptr_rowindex_C[i] = rowindex;}}
}
// 核函数每个线程负责结果C中的每个位置 通过此位置对用的行号和列号 去遍历A和B中相应的元素乘积再相加 得到的结果存到当前位置
__global__ void cal(int *dptr_rowindex_C, int *dptr_colindex_C, double *dptr_value_C,const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B,const double *dptr_value_A, const double *dptr_value_B,int nonzero, double alpha)
{int idx = threadIdx.x + blockDim.x * blockIdx.x; // 对应在value_C数组的下标if (idx < nonzero){if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1]){dptr_value_C[idx] = 0.0;}else{int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素double sum = 0;					// 记录当前位置存入的最终结果double value_A = 0;double value_B = 0;int A_begin = dptr_offset_A[row];int A_end = dptr_offset_A[row + 1];for (int jj = A_begin; jj < A_end; jj++){								 // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置value_A = dptr_value_A[jj];	 // 当前A的值int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果int left = dptr_offset_B[j];int right = dptr_offset_B[j + 1] - 1;int mid = 0;while (left <= right){int mid = left + (right - left) / 2;if (dptr_colindex_B[mid] < col){left = mid + 1;}else if (dptr_colindex_B[mid] > col){right = mid - 1;}else if (dptr_colindex_B[mid] == col){value_B = dptr_value_B[mid];sum = sum + value_A * value_B;break;}}}dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数}}
}void call_device_spgemm(const int transA,const int transB,const dtype alpha,const size_t m,const size_t n,const size_t k,const size_t nnz_A,const csrIdxType *dptr_offset_A,const csrIdxType *dptr_colindex_A,const dtype *dptr_value_A,const size_t nnz_B,const csrIdxType *dptr_offset_B,const csrIdxType *dptr_colindex_B,const dtype *dptr_value_B,size_t *ptr_nnz_C,csrIdxType *dptr_offset_C,csrIdxType **pdptr_colindex_C,dtype **pdptr_value_C)
// device_valueC的值指向设备端中的存储位置首地址 传入进来的&device_valueC是指向device_valueC指针存储位置的指针
{dim3 dimBlock(256);dim3 dimGrid((m + dimBlock.x - 1) / dimBlock.x);getRowNnz<<<dimGrid, dimBlock>>>(dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_offset_C, m);// 主机端申请m+1大小的C偏移量数组将设备端的内容传回 得到dptr_offset_C[m]即nnzint *hptr_offset_C;HIP_CHECK(hipHostMalloc(&hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipHostRegisterDefault));HIP_CHECK(hipMemcpy(hptr_offset_C, dptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyDeviceToHost));// 求前缀和hptr_offset_C[0] = 0;for (int i = 1; i <= m; i++){hptr_offset_C[i] = hptr_offset_C[i] + hptr_offset_C[i - 1];}// 求完前缀和再传回设备HIP_CHECK(hipMemcpy(dptr_offset_C, hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyHostToDevice));// 得到结果非0元总数int nnz = hptr_offset_C[m];int nonzero = nnz;*ptr_nnz_C = nnz;// 释放主机端空间HIP_CHECK(hipHostFree(hptr_offset_C));// Malloc pdptr_colindex_C and pdptr_value_C    C有多少个非0元跟对应CSR格式中 列索引数组和值数组的大小相关HIP_CHECK(hipMalloc((void **)pdptr_colindex_C, nonzero * sizeof(csrIdxType)));HIP_CHECK(hipMalloc((void **)pdptr_value_C, nonzero * sizeof(dtype)));HIP_CHECK(hipMemset(*pdptr_value_C, 0.0, nonzero * sizeof(dtype)));// 核函数每个线程负责一行 对每行列索引数组中的相应区域进行排序// 并且将额外辅助标识行号的行数组初始化 得到每个元素对应在哪一行的位置int *dptr_rowindex_C;HIP_CHECK(hipMalloc((void **)&dptr_rowindex_C, nonzero * sizeof(int)));SortAndRow<<<dimGrid, dimBlock>>>(dptr_offset_A, dptr_offset_B,dptr_colindex_A, dptr_colindex_B, *pdptr_colindex_C, dptr_rowindex_C, dptr_offset_C, m);// 核函数 每个线程负责一个元素 计算最终结果dim3 dimBlocks(256);dim3 dimGrids((nonzero + dimBlocks.x - 1) / dimBlocks.x);cal<<<dimGrids, dimBlocks>>>(dptr_rowindex_C, *pdptr_colindex_C, *pdptr_value_C,dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_value_A, dptr_value_B, nonzero, alpha);// 释放额外的行号数组空间HIP_CHECK(hipFree(dptr_rowindex_C));
}

1、核函数 getRowNnz: 这个核函数的作用是计算结果矩阵 C 的每一行中非零元素的数量,并将这些数量存储在偏移量数组 dptr_offset_C 中。这个数组实际上用于构建结果矩阵 C 的压缩行存储格式(CSR格式)。每个线程处理一个行,计算出对应行的非零元素个数,然后将这些个数存储在偏移量数组中。

2、核函数 SortAndRow: 这个核函数的主要作用是构建结果矩阵 C 的压缩行存储格式。它为矩阵 C 的列索引数组 dptr_colindex_C 赋值,同时为行索引数组 dptr_rowindex_C 赋值。首先,它计算出每行在结果矩阵 C 中的存储位置范围(起始位置和结束位置),然后将遍历得到的矩阵 A 的非零元素所在的列索引存储在 dptr_colindex_C 数组中,并在 dptr_rowindex_C 数组中为每个元素记录对应的行号。这些操作将结果矩阵 C 转化为了压缩行存储格式。

通过上面这个方法可以过滤掉某一行为0或者某一列为0的情况

3、核函数 cal: 这个核函数的作用是执行稀疏矩阵乘法的计算,并将结果存储在结果矩阵 C 的值数组 dptr_value_C 中。每个线程处理结果矩阵中的一个位置,计算出对应位置的值。它在计算之前会检查相邻位置的行列索引是否相同,以避免重复计算。对于每个位置,它通过遍历矩阵 A 的非零元素和矩阵 B 的对应元素,按照稀疏矩阵乘法规则进行相乘累加操作,并将最终结果存储在 dptr_value_C 中。

if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1]){dptr_value_C[idx] = 0.0;}

如果是这样的话,那就代表有空行或者有空列,那就代表是0嘛,相当于提前判断了一下。

之后就开始正式的计算:

else{int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素double sum = 0;					// 记录当前位置存入的最终结果double value_A = 0;double value_B = 0;int A_begin = dptr_offset_A[row];int A_end = dptr_offset_A[row + 1];for (int jj = A_begin; jj < A_end; jj++){								 // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置value_A = dptr_value_A[jj];	 // 当前A的值int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果int left = dptr_offset_B[j];int right = dptr_offset_B[j + 1] - 1;int mid = 0;while (left <= right){int mid = left + (right - left) / 2;if (dptr_colindex_B[mid] < col){left = mid + 1;}else if (dptr_colindex_B[mid] > col){right = mid - 1;}else if (dptr_colindex_B[mid] == col){value_B = dptr_value_B[mid];sum = sum + value_A * value_B;break;}}}dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数}

上面这个程序也不难理解,就是找到对应行找到对应列进行相乘相加最后得到结果,但是得益于之前进行了判断,所以会减少很多的计算量。

综合来看,这三个核函数一起完成了稀疏矩阵乘法的计算过程。首先,getRowNnz 核函数确定了结果矩阵 C 的压缩行存储格式所需的偏移量信息。然后,SortAndRow 核函数构建了矩阵 C 的压缩行存储格式,并在其中执行了排序操作。最后,cal 核函数进行了实际的稀疏矩阵乘法计算,并将结果存储在矩阵 C 的值数组中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/27638.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【前端】html

HTML标签&#xff08;上&#xff09; 目标&#xff1a; -能够说出标签的书写注意规范 -能够写出HTML骨架标签 -能够写出超链接标签 -能够写出图片标签并说出alt和title的区别 -能够说出相对路径的三种形式 目录&#xff1a; HTML语法规范HTML基本结构标签开发工具HTML常用标…

PY32F003 FLASH

了解py32芯片的flash内容&#xff0c;对于py32进行api升级有更好的了解的操作 //uiOffset 0(4MHz), 1(8MHz), 2(16MHz), 3(22.12MHz), 4(24MHz) void SetFlashParameter(uint32_t uiOffset) {WRITE_REG(FLASH->KEYR, FLASH_KEY1);WRITE_REG(FLASH->KEYR, FLASH_KEY2); …

责任链模式(Chain of Responsibility)

责任链模式是一种行为设计模式&#xff0c;允许将请求沿着处理者链进行发送。收到请求后&#xff0c;每个处理者均可对请求进行处理&#xff0c;或将其传递给链上的下个处理者。职责链模式使多个对象都有机会处理请求&#xff0c;从而避免请求的发送者和接受者之间的耦合关系。…

qt中cmake自动处理ui文件的前提

说明&#xff1a;个人理解&#xff0c;未必正确 参考了下面的网址 http://cn.voidcc.com/question/p-wpcanvtj-tn.html http://cn.voidcc.com/question/p-wpcanvtj-tn.html cmake中将set(CMAKE_AUTOUIC ON)打开 set(CMAKE_AUTOUIC ON) # 自动处理ui文件, 自动处理ui文件是有…

构建未来移动应用:探索安卓、iOS和HarmonyOS的技术之旅

安卓、iOS和HarmonyOS的比较分析 在移动应用开发领域&#xff0c;安卓、iOS和HarmonyOS是三个常见的操作系统。本文将对它们进行比较分析&#xff0c;并展示一些相关的代码示例。 安卓&#xff08;Android&#xff09; 安卓是由Google开发的移动操作系统&#xff0c;基于Lin…

在外SSH远程连接Ubuntu系统

在外SSH远程连接Ubuntu系统【无公网IP】 文章目录 在外SSH远程连接Ubuntu系统【无公网IP】前言1. 在Ubuntu系统下安装cpolar软件2. 完成安装后打开cpolar客户端web—UI界面3. 创建隧道取得连接Ubuntu系统公网地址4. 打开Windows的命令界面并输入命令 前言 随着科技和经济的发展…

Synchronized同步锁的优化方法 待完工

Synchronized 和后来出的这个lock锁的区别 在并发编程中&#xff0c;多个线程访问同一个共享资源时&#xff0c;我们必须考虑如何维护数据的原子性。在 JDK1.5 之前&#xff0c;Java 是依靠 Synchronized 关键字实现锁功能来做到这点的。Synchronized 是 JVM 实现的一种内置锁…

[RTKLIB]模糊度固定相关问题(二)

文章目录 一、固定模糊度的前置工作1. 做好固定模糊度的准备2. 建立双差模糊度3. 问题与总结 版权声明&#xff1a;本文为原创文章&#xff0c;版权归 Winston Qu 所有&#xff0c;转载请注明出处。 在上一篇文章中&#xff0c;介绍了RTKLIB中manage_amb_LAMBDA()函数&#xff…

SSL介绍

1. SSL工作过程是什么&#xff1f; 当客户端向一个 https 网站发起请求时&#xff0c;服务器会将 SSL 证书发送给客户端进行校验&#xff0c;SSL 证书中包含一个公钥。校验成功后&#xff0c;客户端会生成一个随机串&#xff0c;并使用受访网站的 SSL 证书公钥进行加密&#xf…

论文阅读 RRNet: A Hybrid Detector for Object Detection in Drone-captured Images

文章目录 RRNet: A Hybrid Detector for Object Detection in Drone-captured ImagesAbstract1. Introduction2. Related work3. AdaResampling4. Re-Regression Net4.1. Coarse detector4.2. Re-Regression 5. Experiments5.1. Data augmentation5.2. Network details5.3. Tra…

NeRF室内重建对比:Nerfstudio vs. Luma AI vs. Instant-NGP

十年前&#xff0c;Matterport 改变了房地产业&#xff0c;让房地产买家可以进行数字旅游。 买家可以在房产内从一个点移动到另一个点并环顾四周。 与 2D 照片库相比&#xff0c;这是一个巨大的改进。 然而&#xff0c;买家仍然被房产内的一系列问题所困扰。 推荐&#xff1a;用…

rk3399移植linux kernel

rk3399移植linux kernel 0.前言一、移植ubuntu根文件系统二、移植linux1.支持NFS(可选)2.配置uevent helper3.支持etx4文件系统(默认已支持)4.配置DRM驱动5.有线网卡驱动6.无线网卡驱动 三、设备树四、内核镜像文件制作五、烧录六、总结 参考文章&#xff1a; 1.RK3399移植u-bo…

TypeScript 中【class类】与 【 接口 Interfaces】的联合搭配使用解读

导读&#xff1a; 前面章节&#xff0c;我们讲到过 接口&#xff08;Interface&#xff09;可以用于对「对象的形状&#xff08;Shape&#xff09;」进行描述。 本章节主要介绍接口的另一个用途&#xff0c;对类的一部分行为进行抽象。 类配合实现接口 实现&#xff08;impleme…

如何用正确的姿势监听Android屏幕旋转

作者&#xff1a;37手游移动客户端团队 背景 关于个人&#xff0c;前段时间由于业务太忙&#xff0c;所以一直没有来得及思考并且沉淀点东西&#xff1b;同时组内一个个都在业务上能有自己的思考和总结&#xff0c;在这样的氛围下&#xff0c;不由自主的驱使周末开始写点东西&…

QT生成Debug和Release发布版后,运行exe缺少dll问题

在QT Creator生成debug和release的exe执行文件后&#xff0c;运行时&#xff0c;报错缺少*.dll.解决办法1&#xff1a; 在系统环境变量中添加D:\Qt\Qt5.13.2\Tools\mingw730_64\bin后&#xff0c;即可运行。 当使用此方法时&#xff0c;将exe拷贝到其他电脑中运行时&#xff0c…

软件性能测试有哪些测试指标?性能测试报告对软件产品起到的作用

在软件开发过程中&#xff0c;性能测试是一个至关重要的环节&#xff0c;主要关注软件系统在不同负载条件下的表现&#xff0c;以评估其稳定性、可扩展性和响应能力。它可以帮助开发人员评估软件系统的质量和性能。 一、软件性能测试的测试指标 性能测试的测试指标直接影响着…

【代码解读】RRNet: A Hybrid Detector for Object Detection in Drone-captured Images

文章目录 1. train.py2. DistributedWrapper类2.1 init函数2.2 train函数2.3 dist_training_process函数 3. RRNetOperator类3.1 init函数3.1.1 make_dataloader函数 3.2 training_process函数3.2.1 criterion函数 4. RRNet类&#xff08;网络模型类&#xff09;4.1 init函数4.…

计算机视觉--距离变换算法的实战应用

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 计算机视觉CV是人工智能一个非常重要的领域。 在本次的距离变换任务中&#xff0c;我们将使用D4距离度量方法来对图像进行处理。通过这次实验&#xff0c;我们可以更好地理解距离度量在计算机视觉中的应用。希望大家对计算…

IPC之一:使用匿名管道进行父子进程间通信的例子

IPC 是 Linux 编程中一个重要的概念&#xff0c;IPC 有多种方式&#xff0c;本文主要介绍匿名管道(又称管道、半双工管道)&#xff0c;尽管很多人在编程中使用过管道&#xff0c;但一些特殊的用法还是鲜有文章涉及&#xff0c;本文给出了多个具体的实例&#xff0c;每个实例均附…

Maven的<relativePath/>标签

maven配置文件 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.4.3</version><relativePath/> <!-- lookup parent from repository --> </p…