// 设备端并行求偏移数组 不判断当前列号是否出现过也添加进去
// 列号排序
// 然后计算
#include <hip/hip_runtime.h>// 核函数每个线程负责一行 计算当前行中有多少个元素 并先存入相应的偏移量数组中行号的+1位置 (不判断列号是否重复的版本)
__global__ void getRowNnz(const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B, int *dptr_offset_C, int m)
{int rowindex = threadIdx.x + blockDim.x * blockIdx.x;if (rowindex < m){int row_nnz = 0; // row_nnz记录当前第i行一共有多少个元素 初始设为0int A_begin = dptr_offset_A[rowindex];int A_end = dptr_offset_A[rowindex + 1];for (int jj = A_begin; jj < A_end; jj++){ // jj为当前第rowindex行的非0元素在value数组与col数组中的起始位置int j = dptr_colindex_A[jj]; // j为当前A的第rowindex行中非0元素所处于的列号 然后找B中第j行的非0元素int B_begin = dptr_offset_B[j];int B_end = dptr_offset_B[j + 1];row_nnz += B_end - B_begin;}dptr_offset_C[rowindex + 1] = row_nnz; // 得到每行有多少个元素先存入相应的偏移量数组中行号的+1位置}
}// 核函数每个线程负责一行 标识当前元素的行号 并且对每行列索引数组中的相应区域进行排序
__global__ void SortAndRow(const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B,int *dptr_colindex_C, int *dptr_rowindex_C, int *dptr_offset_C, int m)
{int rowindex = threadIdx.x + blockDim.x * blockIdx.x;if (rowindex < m){// 找到当前行元素再col和val数组的下标范围int left = dptr_offset_C[rowindex];int right = dptr_offset_C[rowindex + 1];// 将遍历得到的列数组的列号存储到相应位置// 先设置初始插入的位置// 所有元素插入完成后再排序int pos = left; // 插入位置初始为left位置int i = left;int A_begin = dptr_offset_A[rowindex];int A_end = dptr_offset_A[rowindex + 1];for (int jj = A_begin; jj < A_end; jj++){ // jj为当前第i行的非0元素在value数组与col数组中的起始位置int j = dptr_colindex_A[jj]; // j为当前A的第i行中非0元素所处于的列号 然后找B中第j行的非0元素int B_begin = dptr_offset_B[j];int B_end = dptr_offset_B[j + 1];for (int kk = B_begin; kk < B_end; kk++){ // kk为当前B中第j行中的非0元素在value数组与col数组中的起始位置int k = dptr_colindex_B[kk]; // k为当前B的第j行中非0元素所处的列号 即最终结果元素所处的列号dptr_colindex_C[pos] = k;pos++;}}// 排序算法// // 插入排序// for (int i = left + 1; i < right; i++)// {// int key = dptr_colindex_C[i];// int j = i - 1;// while (j >= left && dptr_colindex_C[j] > key)// {// dptr_colindex_C[j + 1] = dptr_colindex_C[j];// j--;// }// dptr_colindex_C[j + 1] = key;// }// 希尔排序int n = right - left;int p, q, gap;for (gap = n / 2; gap > 0; gap /= 2){for (p = 0; p < gap; p++){for (q = p + gap + left; q < n + left; q += gap){if (dptr_colindex_C[q] < dptr_colindex_C[q - gap]){int tmp = dptr_colindex_C[q];int k = q - gap;while (k >= left && dptr_colindex_C[k] > tmp){dptr_colindex_C[k + gap] = dptr_colindex_C[k];k -= gap;}dptr_colindex_C[k + gap] = tmp;}}}}// 归并排序 待补充// 初始化行号数组的值for (int i = left; i < right; i++){dptr_rowindex_C[i] = rowindex;}}
}
// 核函数每个线程负责结果C中的每个位置 通过此位置对用的行号和列号 去遍历A和B中相应的元素乘积再相加 得到的结果存到当前位置
__global__ void cal(int *dptr_rowindex_C, int *dptr_colindex_C, double *dptr_value_C,const int *dptr_offset_A, const int *dptr_offset_B,const int *dptr_colindex_A, const int *dptr_colindex_B,const double *dptr_value_A, const double *dptr_value_B,int nonzero, double alpha)
{int idx = threadIdx.x + blockDim.x * blockIdx.x; // 对应在value_C数组的下标if (idx < nonzero){if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1]){dptr_value_C[idx] = 0.0;}else{int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素double sum = 0; // 记录当前位置存入的最终结果double value_A = 0;double value_B = 0;int A_begin = dptr_offset_A[row];int A_end = dptr_offset_A[row + 1];for (int jj = A_begin; jj < A_end; jj++){ // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置value_A = dptr_value_A[jj]; // 当前A的值int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果int left = dptr_offset_B[j];int right = dptr_offset_B[j + 1] - 1;int mid = 0;while (left <= right){int mid = left + (right - left) / 2;if (dptr_colindex_B[mid] < col){left = mid + 1;}else if (dptr_colindex_B[mid] > col){right = mid - 1;}else if (dptr_colindex_B[mid] == col){value_B = dptr_value_B[mid];sum = sum + value_A * value_B;break;}}}dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数}}
}void call_device_spgemm(const int transA,const int transB,const dtype alpha,const size_t m,const size_t n,const size_t k,const size_t nnz_A,const csrIdxType *dptr_offset_A,const csrIdxType *dptr_colindex_A,const dtype *dptr_value_A,const size_t nnz_B,const csrIdxType *dptr_offset_B,const csrIdxType *dptr_colindex_B,const dtype *dptr_value_B,size_t *ptr_nnz_C,csrIdxType *dptr_offset_C,csrIdxType **pdptr_colindex_C,dtype **pdptr_value_C)
// device_valueC的值指向设备端中的存储位置首地址 传入进来的&device_valueC是指向device_valueC指针存储位置的指针
{dim3 dimBlock(256);dim3 dimGrid((m + dimBlock.x - 1) / dimBlock.x);getRowNnz<<<dimGrid, dimBlock>>>(dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_offset_C, m);// 主机端申请m+1大小的C偏移量数组将设备端的内容传回 得到dptr_offset_C[m]即nnzint *hptr_offset_C;HIP_CHECK(hipHostMalloc(&hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipHostRegisterDefault));HIP_CHECK(hipMemcpy(hptr_offset_C, dptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyDeviceToHost));// 求前缀和hptr_offset_C[0] = 0;for (int i = 1; i <= m; i++){hptr_offset_C[i] = hptr_offset_C[i] + hptr_offset_C[i - 1];}// 求完前缀和再传回设备HIP_CHECK(hipMemcpy(dptr_offset_C, hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyHostToDevice));// 得到结果非0元总数int nnz = hptr_offset_C[m];int nonzero = nnz;*ptr_nnz_C = nnz;// 释放主机端空间HIP_CHECK(hipHostFree(hptr_offset_C));// Malloc pdptr_colindex_C and pdptr_value_C C有多少个非0元跟对应CSR格式中 列索引数组和值数组的大小相关HIP_CHECK(hipMalloc((void **)pdptr_colindex_C, nonzero * sizeof(csrIdxType)));HIP_CHECK(hipMalloc((void **)pdptr_value_C, nonzero * sizeof(dtype)));HIP_CHECK(hipMemset(*pdptr_value_C, 0.0, nonzero * sizeof(dtype)));// 核函数每个线程负责一行 对每行列索引数组中的相应区域进行排序// 并且将额外辅助标识行号的行数组初始化 得到每个元素对应在哪一行的位置int *dptr_rowindex_C;HIP_CHECK(hipMalloc((void **)&dptr_rowindex_C, nonzero * sizeof(int)));SortAndRow<<<dimGrid, dimBlock>>>(dptr_offset_A, dptr_offset_B,dptr_colindex_A, dptr_colindex_B, *pdptr_colindex_C, dptr_rowindex_C, dptr_offset_C, m);// 核函数 每个线程负责一个元素 计算最终结果dim3 dimBlocks(256);dim3 dimGrids((nonzero + dimBlocks.x - 1) / dimBlocks.x);cal<<<dimGrids, dimBlocks>>>(dptr_rowindex_C, *pdptr_colindex_C, *pdptr_value_C,dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_value_A, dptr_value_B, nonzero, alpha);// 释放额外的行号数组空间HIP_CHECK(hipFree(dptr_rowindex_C));
}
1、核函数 getRowNnz
: 这个核函数的作用是计算结果矩阵 C 的每一行中非零元素的数量,并将这些数量存储在偏移量数组 dptr_offset_C
中。这个数组实际上用于构建结果矩阵 C 的压缩行存储格式(CSR格式)。每个线程处理一个行,计算出对应行的非零元素个数,然后将这些个数存储在偏移量数组中。
2、核函数 SortAndRow
: 这个核函数的主要作用是构建结果矩阵 C 的压缩行存储格式。它为矩阵 C 的列索引数组 dptr_colindex_C
赋值,同时为行索引数组 dptr_rowindex_C
赋值。首先,它计算出每行在结果矩阵 C 中的存储位置范围(起始位置和结束位置),然后将遍历得到的矩阵 A 的非零元素所在的列索引存储在 dptr_colindex_C
数组中,并在 dptr_rowindex_C
数组中为每个元素记录对应的行号。这些操作将结果矩阵 C 转化为了压缩行存储格式。
通过上面这个方法可以过滤掉某一行为0或者某一列为0的情况
3、核函数 cal
: 这个核函数的作用是执行稀疏矩阵乘法的计算,并将结果存储在结果矩阵 C 的值数组 dptr_value_C
中。每个线程处理结果矩阵中的一个位置,计算出对应位置的值。它在计算之前会检查相邻位置的行列索引是否相同,以避免重复计算。对于每个位置,它通过遍历矩阵 A 的非零元素和矩阵 B 的对应元素,按照稀疏矩阵乘法规则进行相乘累加操作,并将最终结果存储在 dptr_value_C
中。
if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1]){dptr_value_C[idx] = 0.0;}
如果是这样的话,那就代表有空行或者有空列,那就代表是0嘛,相当于提前判断了一下。
之后就开始正式的计算:
else{int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素double sum = 0; // 记录当前位置存入的最终结果double value_A = 0;double value_B = 0;int A_begin = dptr_offset_A[row];int A_end = dptr_offset_A[row + 1];for (int jj = A_begin; jj < A_end; jj++){ // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置value_A = dptr_value_A[jj]; // 当前A的值int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果int left = dptr_offset_B[j];int right = dptr_offset_B[j + 1] - 1;int mid = 0;while (left <= right){int mid = left + (right - left) / 2;if (dptr_colindex_B[mid] < col){left = mid + 1;}else if (dptr_colindex_B[mid] > col){right = mid - 1;}else if (dptr_colindex_B[mid] == col){value_B = dptr_value_B[mid];sum = sum + value_A * value_B;break;}}}dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数}
上面这个程序也不难理解,就是找到对应行找到对应列进行相乘相加最后得到结果,但是得益于之前进行了判断,所以会减少很多的计算量。
综合来看,这三个核函数一起完成了稀疏矩阵乘法的计算过程。首先,getRowNnz
核函数确定了结果矩阵 C 的压缩行存储格式所需的偏移量信息。然后,SortAndRow
核函数构建了矩阵 C 的压缩行存储格式,并在其中执行了排序操作。最后,cal
核函数进行了实际的稀疏矩阵乘法计算,并将结果存储在矩阵 C 的值数组中。