寒武纪显卡实现高维向量的softmax并行优化

关于寒武纪编程可以参考本人之前的文章添加链接描述,添加链接描述,添加链接描述

高维向量softmax的基础编程

高维向量的softmax实现更加复杂,回忆之前在英伟达平台上实现高维向量的softmax函数,比如说我们以形状为[1,2,3,4,5,6]的6维向量举例,变换维度假设axis=2,之前英伟达平台的实现,我们计算出变换维度的长度dimsize=3,其他维度的乘积othersize=1×2×4×5×6 = 240,步长stride= 1×6×5×4 = 120,使用othersize=240个线程块,其中每个线程块处理对应一份数据,计算出int tid =blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) × dimsize;全局索引为tid + threadIdx.x × stride,类似地,我们也按照这个思路来实现寒武纪显卡上的高维向量softmax:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 4;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;//__bang_reduce_sum每次从src取128字节数据相加,对应32个float元素,并且0-31的结果保存在索引0,32-63的结果保存在索引1__nram__ float src1[maxNum];//每次搬运maxNum数据到NRAM
__nram__ float destSum[maxNum];//后面数值求和
__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
__nram__ float srcMax[2];__mlu_entry__ void softmaxKernel(float* dst, float* source1, int othersize, int dimsize, int stride) {__nram__ float destOldMax;__nram__ float destNewMax;int liu = false;if(liu){for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){destOldMax = -INFINITY;destNewMax = -INFINITY;float sum_s = 1.0;int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;for(int i = 0; i < dimsize; i++){__memcpy(src1, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);if(destNewMax < src1[0]){destNewMax = src1[0];}if(i > 0){sum_s = sum_s * exp(destOldMax - destNewMax) + exp(src1[0] - destNewMax);}destOldMax = destNewMax;}float globalSumInv = 1.0/sum_s;;for(int i = 0; i < dimsize; i++){__memcpy(src1, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);src1[0] = exp(src1[0] - destNewMax) * globalSumInv;__memcpy(dst + tid + i * stride, src1, sizeof(float), NRAM2GDRAM);}}}else{for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){destOldMax = -INFINITY;destNewMax = -INFINITY;float sum_s = 1.0;int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;for(int i = 0; i < dimsize + 1; i++){if(i < dimsize){__memcpy_async(src1 + i%2, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);}if(i > 0){if(destNewMax < src1[(i - 1)%2]){destNewMax = src1[(i - 1)%2];}if(i > 1){sum_s = sum_s * exp(destOldMax - destNewMax) + exp(src1[(i - 1)%2] - destNewMax);}destOldMax = destNewMax;}__sync_all_ipu();}float globalSumInv = 1.0/sum_s;;for(int i = 0; i < dimsize + 2; i++){if(i < dimsize){__memcpy(src1 + i%3, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);}if(i > 0 && i < dimsize + 1){src1[(i - 1)%3] = exp(src1[(i - 1)%3] - destNewMax) * globalSumInv;}if(i > 1){__memcpy(dst + tid + (i - 2) * stride, src1 + (i - 2)%3, sizeof(float), NRAM2GDRAM);}__sync_all_ipu();}}}}int main(void)
{int num = 32 * 16 * 64 * 128;//shape = {32, 16, 64, 128},axis = 2int stride = 128;int dimsize = 64;int othersize = 32 * 16 * 128;/***int num = 24;//shape = {2,3,2,2}, axis = 1int stride = 4;int dimsize = 3;int othersize = 8;***/cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_dst = (float*)malloc(num * sizeof(float));float* host_src1 = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {host_src1[i] = i%4;//host_src1[i] = i;}float* mlu_dst;float* mlu_src1;CNRT_CHECK(cnrtMalloc((void**)&mlu_dst, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src1, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src1, host_src1, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernel<<<dim, ktype, queue>>>(mlu_dst, mlu_src1, othersize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_dst, mlu_dst, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_dst[i], host_src1[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_dst);cnrtFree(mlu_src1);free(host_dst);free(host_src1);return 0;
}

我们利用taskId来处理othersize,但是考虑到taskDim往往是2或者4的倍数,而othersize不一定满足这个条件,因此我们使用for循环来解决,参考for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim)
进入上述for循环以后,我们尝试来处理dimsize,由于寒武纪的函数基本上支持向量操作,无法针对具体某个元素来处理,为此我们仍然把dimsize这份数据按照maxNum长度分成多个小单元,如果不能整除后面特殊处理,特殊处理的方式和上面一维向量一模一样。在代码24行——25行,这里使用两层for循环来加载数据,高维数组导致每次处理的数据不连续,间隔stride,为此必须要不断遍历数组把结果集中到src1数组上处理,后续的处理类似,这里不做赘述。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 4;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;//__bang_reduce_sum每次从src取128字节数据相加,对应32个float元素,并且0-31的结果保存在索引0,32-63的结果保存在索引1__nram__ float src1[maxNum];//每次搬运maxNum数据到NRAM
__nram__ float destSum[maxNum];//后面数值求和
__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
__nram__ float srcMax[2];__mlu_entry__ void softmaxKernel(float* dst, float* source1, int othersize, int dimsize, int stride) {int remain = dimsize%maxNum;int repeat = (dimsize - remain)/maxNum;__nram__ float destOldMax;__nram__ float destNewMax;//下面利用taskId来处理其他维度for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){destOldMax = -INFINITY;destNewMax = -INFINITY;__bang_write_zero(destSum, maxNum);int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;for(int i = 0; i < repeat; i++){for(int j = 0; j < maxNum; j++){//从source1间隔stride读取数据__memcpy(src1 + j, source1 + tid + (i * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);}__bang_argmax(srcMax, src1, maxNum);if(destNewMax < srcMax[0]){destNewMax = srcMax[0];//更新最大值}__bang_sub_scalar(src1, src1, destNewMax, maxNum);//src1 = src1 - 最大值__bang_active_exp_less_0(src1, src1, maxNum);//src1 = exp(src1 - 最大值)if(i > 0){__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);}__bang_add(destSum, destSum, src1, maxNum);//destSum = destSum + exp(src1 - destNewMax)destOldMax = destNewMax;}//-------------------------------------if(remain){__bang_write_value(src1, maxNum, -INFINITY);//多余部分必须设置负无穷for(int j = 0; j < remain; j++){__memcpy(src1 + j, source1 + tid + (repeat * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);}__bang_argmax(srcMax, src1, maxNum);if(destNewMax < srcMax[0]){destNewMax = srcMax[0];}__bang_write_value(src1, maxNum, destNewMax);//必须重新初始化为destNewMaxfor(int j = 0; j < remain; j++){__memcpy(src1 + j, source1 + tid + (repeat * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);}__bang_sub_scalar(src1, src1, destNewMax, maxNum);//后面maxNum-remain部分为0__bang_active_exp_less_0(src1, src1, maxNum);//相当于多加了maxNum-remainif(repeat > 0){__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);}__bang_add(destSum, destSum, src1, maxNum);destOldMax = destNewMax;}//--------------------------------__bang_write_zero(destSumFinal, warpSize);int segNum = maxNum / warpSize;for(int strip = segNum/2; strip > 0; strip = strip / 2){for(int i = 0; i < strip ; i++){__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);} }__bang_reduce_sum(destSumFinal, destSum, warpSize);if(remain){destSumFinal[0] = destSumFinal[0] - (maxNum - remain);}//__bang_printf("--max:%.3e,sum:%.6e,:%d\n",destNewMax,destSumFinal[0], maxNum - remain);//------------------------------------至此全局最大值为destNewMax,全局数值和为destSumFinal[0]float globalSumInv = 1.0/destSumFinal[0];if(remain){__bang_mul_scalar(src1, src1, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果for(int j = 0; j < remain; j++){__memcpy(dst + tid + (repeat * maxNum + j) * stride, src1 + j, sizeof(float), NRAM2GDRAM);}}for(int i = 0; i < repeat; i++){for(int j = 0; j < maxNum; j++){__memcpy(src1 + j, source1 + tid + (i * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);}__bang_sub_scalar(src1, src1, destNewMax, maxNum); __bang_active_exp_less_0(src1, src1, maxNum);__bang_mul_scalar(src1, src1, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果for(int j = 0; j < maxNum; j++){__memcpy(dst + tid + (i * maxNum + j) * stride, src1 + j, sizeof(float), NRAM2GDRAM);}}}}int main(void)
{int num = 32 * 16 * 64 * 128;//shape = {32, 16, 64, 128},axis = 2int stride = 128;int dimsize = 64;int othersize = 32 * 16 * 128;/***int num = 24;//shape = {2,3,2,2}, axis = 1int stride = 4;int dimsize = 3;int othersize = 8;***/cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_dst = (float*)malloc(num * sizeof(float));float* host_src1 = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {host_src1[i] = i%4;//host_src1[i] = i;}float* mlu_dst;float* mlu_src1;CNRT_CHECK(cnrtMalloc((void**)&mlu_dst, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src1, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src1, host_src1, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernel<<<dim, ktype, queue>>>(mlu_dst, mlu_src1, othersize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_dst, mlu_dst, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_dst[i], host_src1[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_dst);cnrtFree(mlu_src1);free(host_dst);free(host_src1);return 0;
}

高维向量softmax的合并访存加速

上面提到的就是最简单最容易想到的编程手段了,上面的方案有一个问题,即数组元素的访问读取都是跳跃的,因此时间特别长,根本无法用于处理大规模数组,为了加速,下面我们尝试在原始方案上做优化。为了方便描述,我们以形状为[32,16,64,128]这样一个四维向量举例,其中softmax的操作维度假设axis=2,那么就可以计算出stride=128,dimsize=64,othersize=32×16×128。上面算法的特点是,利用不同taskId处理othersize得到对应的otherIdx,然后针对dimsize做循环,得到全局的index为otherIdx + i×stride,最终不断跳跃stride来获取数组对应元素,把结果集中到一个长度为maxNum的NRAM向量src里面,经过一系列变换以后通过for循环把src的元素写回目标向量dst中,这个过程最耗时的地方就在于数组的跳跃访问,为了解决这个问题,我们尝试一种合并访存的方式来读取数组,我们以4维向量举例子,其中假设向量的形状为[A,B,C,D],下面需要针对softmax的操作维度axis进行分类讨论,全局索引为i(BCD) + j(CD) + k(D) + s,具体想法如下:

axis=0

我们知道 j(CD) + k(D) + s对应的othersize刚好就是BCD,而stride正好也是BCD,为此我们可以这样读取数据,把向量分成A个单元,其中每个单元的长度为BCD,考虑for循环如下:for(i = 0; i < A; i++),循环体内每次读取source[i×(BCD):(i+1)×BCD]这部分数据,我们发现这样做可以得到A个长度为BCD的向量,而且每个向量对应元素的索引差别就是stride,因此我们完全可以把这A个向量存储起来,逐个元素比较最大值M,最终得到一个长度为BCD的向量tmpMax,其中tmpMax当中的每个元素正好就是不同(j,k,s)对应的最大值,类似的可以这样求出数值和以及把数据写回GDRAM。
下面这个bang_maxequal可以完成对应元素比较最大值,另外关于对应元素求和的函数直接使用bang_add即可。
在这里插入图片描述
在这种情况下,taskId用于处理othersize这部分,主要原因在于此时读取数据的时候,只有othersize这部分数据恰好是连续的。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素__mlu_entry__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0__nram__ float src[maxNum];//每次搬运maxNum数据到NRAM__nram__ float tmpSum[maxNum];__nram__ float tmpNewMax[maxNum];__nram__ float tmpOldMax[maxNum];int remain = othersize % taskDim;int stepEasy = (othersize - remain)/taskDim;int stepHard = stepEasy + 1;int step = (taskId < remain ? stepHard : stepEasy);//前部分taskId多处理一个元素int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);int remainNram = step%maxNum;int repeat = (step - remainNram)/maxNum;__bang_printf("taskId:%d, repeat:%d, step:%d, indStart:%d, remainNram:%d\n", taskId, repeat, step, indStart, remainNram);for(int j = 0; j < repeat; j++){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);for(int i = 0; i < dimsize; i++){__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM} __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);} }if(remainNram){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);__bang_write_zero(src, maxNum);for(int i = 0; i < dimsize; i++){__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM} /***for(int k = 0; k < remainNram; k++){__bang_printf("%d,max:%.2f,sum:%.2f, src:%.2f\n",k, tmpNewMax[k], tmpSum[k], src[k]);}***/__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);} }}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 0;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {host_src[i] = i%4;//host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_s<<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

axis = -1

此时softmax操作维度正好是最后一个,这个时候就更加简单了,把向量分成ABC个单元,每个单元长度为D,考虑这样一个for循环:for(i = 0; i < ABC; i++),每轮循环读取source[i×(D):(i+1)×D]这份数据,针对这部分数据做规约获得最大值M,经过这个循环以后就可以得到不同(i,j,k)对应的最大值,对应的也就是othersize这部分数据对应的最大值,类似的可以得到数值和以及把数据写回GDRAM。在这种情况下,数据在axis=-1这个轴连续,此时并行策略有两种:
第一种策略:用taskId处理othersize,具体做法可以是for(i=taskId; i < ABC; i += taskDim),然后每轮循环内部读取对应的长度为D的数据,但是此时D不一定是2的幂次方,而且NRAM上也不一定能一次放下长度为D的向量,所以这个时候在循环内部,还需要额外针对source[i×(D):(i+1)×D]多做一个循环,每次循环读取maxNum个元素,直到数据读取结束。
第二种并行策略:串行处理othersize,for(i = 0; i < ABC; i++),在循环内部针对source[i×(D):(i+1)×D]这份数据分配给不同的taskId,这种做法导致每个taskId分到的数据是source[i×(D):(i+1)×D]一部分,在我们之前代码里面就是step,并且step也不一定是2的幂次方,也不一定能够在NRAM上放下,而且我们需要的最大值是source[i×(D):(i+1)×D]这部分数据的最大值,如果把这部分数据切分到不同taskId,最后算完以后还得额外针对不同taskId做一个规约(和上面的一维向量一模一样)。
经过上面两种分析,我们倾向于采取第一种策略。另外如果使用for(i=taskId; i < ABC; i += taskDim),站在taskId的角度来看,每次循环读取数据都是跳跃的。如果我们提前设定好step,让不同taskId处理的索引在[taskId×step:(taskId+1)×step]这个区间,此时站在taskId的角度来说,每次循环读取的数据会相对连续(但是需要实验结果来验证)。不过为了方便起见,我们还是使用for(i=taskId; i < ABC; i += taskDim)这种循环模式来计算结果。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;__mlu_entry__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize) {// axis = -1__nram__ float src[maxNum];__nram__ float destSum[maxNum];//后面数值求和__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]__nram__ float srcMax[2];__nram__ float destOldMax;__nram__ float destNewMax;int remain = dimsize % maxNum;int repeat = (dimsize - remain)/maxNum;for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){int tid = otherIdx * dimsize;destOldMax = -INFINITY;destNewMax = -INFINITY;__bang_write_zero(destSum, maxNum);for(int i = 0; i < repeat; i++){__memcpy(src, source + tid + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_argmax(srcMax, src, maxNum);if(destNewMax < srcMax[0]){destNewMax = srcMax[0];//更新最大值}__bang_sub_scalar(src, src, destNewMax, maxNum);//src = src - 最大值__bang_active_exp_less_0(src, src, maxNum);//src = exp(src - 最大值)if(i > 0){__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);}__bang_add(destSum, destSum, src, maxNum);destOldMax = destNewMax;}//------------if(remain){__bang_write_value(src, maxNum, -INFINITY);//多余部分必须设置负无穷__memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);__bang_argmax(srcMax, src, maxNum);if(destNewMax < srcMax[0]){destNewMax = srcMax[0];}__bang_write_value(src, maxNum, destNewMax);//必须重新初始化为destNewMax__memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);__bang_sub_scalar(src, src, destNewMax, maxNum);//后面maxNum-remain部分为0__bang_active_exp_less_0(src, src, maxNum);//相当于多加了maxNum-remainif(repeat > 0){__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);}__bang_add(destSum, destSum, src, maxNum);destOldMax = destNewMax;}//--------------//--------------------------------__bang_write_zero(destSumFinal, warpSize);int segNum = maxNum / warpSize;for(int strip = segNum/2; strip > 0; strip = strip / 2){for(int i = 0; i < strip ; i++){__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);} }__bang_reduce_sum(destSumFinal, destSum, warpSize);if(remain){destSumFinal[0] = destSumFinal[0] - (maxNum - remain);}//-----------float globalSumInv = 1.0/destSumFinal[0];for(int i = 0; i < repeat; i++){__memcpy(src, source + tid + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_sub_scalar(src, src, destNewMax, maxNum); __bang_active_exp_less_0(src, src, maxNum);__bang_mul_scalar(src, src, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果__memcpy(destination + tid + i * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);}if(remain){__bang_write_value(src, maxNum, destNewMax);__memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);__bang_sub_scalar(src, src, destNewMax, maxNum);__bang_active_exp_less_0(src, src, maxNum);__bang_mul_scalar(src, src, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果__memcpy(destination + tid + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);}}}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 3;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {//host_src[i] = i%4;host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_e<<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

0 < axis < dimsize - 1

假设dim表示向所属空间的维度,此时最为复杂,结合上面axis=0和axis=-1的分析,这里我们这样考虑0 < axis < dim - 1,为了方便叙述,我们分别以axis=1和axis=2来解释数据读取的做法:
axis=1,对于[A,B,C,D]这样的向量来说,我们设置otherIdx=i(BCD)和循环for(j = 0; j < B; j++),其中每轮循环读取长度为CD的数据source[otherIdx + j×stride:otherIdx + j×stride + CD],此时我们发现对于固定的otherIdx来说,经过for循环以后会得到dimsize=B个长度为CD的向量,并且我们逐个元素比较最大值最终可以得到一个长度为CD的向量tmpMax,其中tmpMax保存的是对于固定otherIdx下对应于(k,s)的最大值,类似的可以得到数值和以及写回数据。
axis=2,我们设置otherIdx=i(BCD) + j(CD)和循环for(k = 0; k < C; k++),其中每轮循环读取长度为D的数据source[otherIdx + k×stride:otherIdx + k×stride + D],此时我们发现对于固定的otherIdx来说,经过for循环以后会得到dimsize=C个长度为D的向量,并且我们逐个元素比较最大值最终可以得到一个长度为D的向量tmpMax,其中tmpMax保存的是对于固定otherIdx下对应于(s)的最大值,类似的可以得到数值和以及写回数据。
我们可以得到规律,如果axis是中间维度,那么我们需要固定axis之前的otherIdx,然后设置对应的for循环,每轮循环读取axis之后的数据即可。我们设置两个参数frontsize和behindsize分别表示axis前面和后面的数据,比如说axis=1,frontsize=A,behindsize=CD,如果axis=2,那么frontsize=AB,behindsize=D。
这种时候我们需要考虑taskId到底用来处理frontsize还是behindsize,两种想法都可以,下面我们来分析一下两种不同的策略,我们以axis=2来举例说明:
第一种:taskId处理frontsize,即for(ind = taskId; ind < frontsize; ind += taskDim),由于axis=2,此时我们知道frontsize=AB,ind对应的二维索引(i,j)有对应关系ind=iB + j,但是我们需要对ind进一步做一个转换得到frontIdx = ind×CD,更加一般的情况是frontIdx = ind×dimsize×behindsize。进入这个循环以后继续for(k = 0; k < C; k++),此时开始一次读取behindsize个数据。
第二种:taskId处理behindsize,此时对于frontsize只能串行处理了,即for(ind = 0; ind < frontsize; ind += 1),由于axis=2,frontIdx = ind×CD,更加一般的情况是frontIdx = ind×dimsize×behindsize。进入这个循环以后继续for(k = 0; k < C; k++),此时由于taskId处理的是behindsize,那么不同taskId分配的数据量是step,开始一次读取step个数据。
粗糙的观察,我们倾向于选择第一种策略,另外我们注意到,其实behindsize就是stride,为此后面我们不区分两者。
策略1:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {// 0<axis<dim -1 __nram__ float src[maxNum];__nram__ float tmpSum[maxNum];__nram__ float tmpNewMax[maxNum];__nram__ float tmpOldMax[maxNum];int remain = stride % maxNum;int repeat = (stride - remain) / maxNum;for(int ind = taskId; ind < frontsize; ind += taskDim){int frontIdx = ind * dimsize * stride;for(int j = 0; j < repeat; j++){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);__bang_write_zero(src, maxNum);for(int i = 0; i < dimsize; i++){__memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM}__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);} }if(remain){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);__bang_write_value(src, maxNum, -INFINITY);for(int i = 0; i < dimsize; i++){__memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM}//-------------------__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);} //---------------------}}}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 2;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;int frontsize = 1;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s < axis){frontsize *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {//host_src[i] = i%4;host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_m<<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

策略2:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {// 0<axis<dim -1 __nram__ float src[maxNum];__nram__ float tmpSum[maxNum];__nram__ float tmpNewMax[maxNum];__nram__ float tmpOldMax[maxNum];int remain = stride % taskDim;int stepEasy = (stride - remain)/taskDim;int stepHard = stepEasy + 1;int step = (taskId < remain ? stepHard : stepEasy);//前部分taskId多处理一个元素int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);int remainNram = step % maxNum;int repeat = (step - remainNram) / maxNum;for(int ind = 0; ind < frontsize; ind ++){int frontIdx = ind * dimsize * stride;for(int j = 0; j < repeat; j++){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);__bang_write_zero(src, maxNum);for(int i = 0; i < dimsize; i++){__memcpy(src, source + frontIdx + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM}__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + frontIdx + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + frontIdx + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + frontIdx + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);} }if(remainNram){__bang_write_value(tmpNewMax, maxNum, -INFINITY);__bang_write_zero(tmpSum, maxNum);__bang_write_zero(src, maxNum);for(int i = 0; i < dimsize; i++){__memcpy(src, source + frontIdx + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM}//-------------------__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum//开始指数变换并且写回GDRAM__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用__memcpy(destination + (dimsize - 1) * stride + indStart + frontIdx + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);__bang_sub(src, src, tmpNewMax, maxNum);//x - M__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)__bang_mul(src, src, tmpSum, maxNum);__memcpy(destination + i * stride + indStart + frontIdx + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);} //---------------------}}}int main(void)
{int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};//int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 1;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;int frontsize = 1;;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s < axis){frontsize *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {16, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION4;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {//host_src[i] = i%4;host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_m<<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

这里我们不妨看一下不同规模情况下上面并行策略带来的优化效果,下面针对axis=1,2都是指策略1,因为策略2的效果太差不展示:
在这里插入图片描述

高维softmax的进一步优化

axis = -1

从上面的表格我们发现对于axis=-1,此时虽然数据读取连续,但是速度仍然非常慢,我们发现最主要原因在于src数组大量内存浪费。比如说我们上面表格的例子,最后一个维度长度是32,但是我们为src开辟的内存是maxNum×sizeof(float),在上面的做法中,我们一次只从GDRAM读取32个浮点数到NRAM,剩下的空间全部浪费了,所以速度特别慢,为了充分利用这部分内存,下面我们将给出另一种思路。
上面做法的本质其实是taskId处理othersize,然后一个src处理一个otherIdx,相当于说src只存放固定一个otherIdx,axis=-1对应的这部分数据。为了充分利用内存,这里我们希望一个src可以存储多个otherIdx对应的axis=-1的这份数据,我们不妨先假设maxNum正好整除shape[-1],并且shape[-1]也是2的幂次方,假设multiple=maxNum/shape[-1]=maxNum/dimsize,此时一个src存储了muitiple个otherIdx对应的数据,一共有othersize个长度为dimsize的向量,一个src就存储了multiple个这样的向量,而且我们一共使用taskDim个任务,因此一次就可以存储size=multiple×taskDim个长度为dimsize的向量,下面为了方便叙述,我们引入一些变量:
multiple=maxNum/shape[-1]=maxNum/dimsize:一个src可以存储多少个长度为dimsize的向量
size=multiple×taskDim:开辟taskDim个任务可以存储长度为dimsize的向量数目
remainS = othersize % size:如果不能整除,多余的余数需要特殊处理,分配给不同taskId,每个taskId额外获得step个
taskRepeat = (othersize - remainS) / size:经过taskReapt次循环可以加载的othersize对应的数据量
整体来看,每个taskId处理的数据量就是(taskRepeat * multiple + step) * dimsize,此时我们可以计算出不同taskId的偏移量,计算以后,下面我们站在taskId的角度来看计算过程:
首先进入一个循环(int s = 0; s < taskRepeat; s++),循环体内部在原有偏移量的情况下计算出不同s对应的偏移量为tid = s × multiple × dimsize,循环体内部每次从GDRAM中读取长度为multiple×dimsize的数据加载到src上,然后再开一个循环(int j = 0; j < multiple; j++),单独针对src处理,每次从src读取长度为dimsize的数据进行求和,指数变换,最终把结果写回GDRAM。
跳出上面的二重循环以后,下面针对额外获得的step这份数据进行处理,此时只需要一重循环(int s = 0; s < step; s++),循环体内每次直接从source读取长度为dimsize的数据,经过一系列计算以后写回GDRAM即可。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;
//dimS至少要等于dimsize,且是最近的2的幂次方,同时由于后面需要规约,为此dimS至少是32
//下面这个kernel只适合dimsize < maxNum的情况
template<int dimS>
__mlu_entry__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize) {// axis = -1int multiple = maxNum / dimsize;int size = taskDim * multiple;int remainS = othersize % size;int taskRepeat = (othersize - remainS) / size;int remainT = remainS % taskDim;int stepEasy = (remainS - remainT) / taskDim;int stepHard = stepEasy + 1;int step = (taskId < remainT ? stepHard : stepEasy);//每个taskId处理othersize分配的量就是taskRepeat * multiple + step//整体来看,每个taskId处理的数据量就是(taskRepeat * multiple + step) * dimsizeint startHard = taskId * (taskRepeat * multiple + stepHard);int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);int indStart = (taskId < remainT ? startHard: startEasy);source = source + indStart * dimsize;destination = destination + indStart * dimsize;//printf("taskRepeat:%d, indstart:%d, %d\n", taskRepeat, indStart, indStart * dimsize);__nram__ float src[maxNum];__nram__ float tmp[dimS];__nram__ float destSum[dimS];//后面数值求和__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]__nram__ float srcMax[2];int tid;for(int s = 0; s < taskRepeat; s++){tid = s * multiple * dimsize;__memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);for(int j = 0; j < multiple; j++){__bang_write_zero(destSum, dimS);__bang_write_zero(destSumFinal, warpSize);__bang_write_value(tmp, dimS, -INFINITY);__memcpy(tmp, src + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);__bang_argmax(srcMax, tmp, dimS);__bang_write_value(tmp, dimS, srcMax[0]);//必须重新初始化为srcMax[0]__memcpy(tmp, src + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);//必须要重新读取__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);__bang_active_exp_less_0(tmp, tmp, dimS);//这里我们认为负无穷-srcMax[0]非常小,所以后面dimS - dimsize部分认为是0__bang_add(destSum, destSum, tmp, dimS);int segNum = dimS / warpSize;//开始数值求和for(int strip = segNum/2; strip > 0; strip = strip / 2){for(int i = 0; i < strip ; i++){__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);} }__bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);//__bang_printf("max:%.2f, sum:%.2f\n", srcMax[0], destSumFinal[0]);float globalSumInv = 1.0/destSumFinal[0];__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);//__memcpy(src + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM);__memcpy(destination + tid + j * dimsize, tmp, dimsize * sizeof(float), NRAM2GDRAM);}//必须马上写回GDRAM,如果先写回src,然后src写回GDRAM,可能出现src写回GDRAM没有结束就修改src数据的情况//__memcpy(destination + tid, src, multiple * dimsize * sizeof(float), NRAM2GDRAM);}for(int s = 0; s < step; s++){tid = taskRepeat * multiple * dimsize + s * dimsize;__bang_write_zero(destSum, dimS);__bang_write_zero(destSumFinal, warpSize);__bang_write_value(tmp, dimS, -INFINITY);__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);__bang_argmax(srcMax, tmp, dimS);__bang_write_value(tmp, dimS, srcMax[0]);__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);__bang_active_exp_less_0(tmp, tmp, dimS);//后面dimS - dimsize部分是1__bang_add(destSum, destSum, tmp, dimS);int segNum = dimS / warpSize;//开始数值求和for(int strip = segNum/2; strip > 0; strip = strip / 2){for(int i = 0; i < strip ; i++){__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);}}__bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);//__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]);float globalSumInv = 1.0/destSumFinal[0];__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);__memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM);}//__bang_printf("max:%.2f, sum:%.2f\n", srcMax[0], destSumFinal[0]);}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 3;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {host_src[i] = i%4;//host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_e<32><<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

在这里插入图片描述

0 < axis < dimsize - 1

这种情况更加特殊,根据上面的分析,我们知道如果axis是中间维度,比如说[A,B,C,D]向量,axis=1,索引为i(BCD)+j(CD)+k(D)+s,此时我们把索引分成三个部分,i(BCD)称之为frontIdx,k(D)+s对应的部分是长度为CD的behindsize,而且我们知道behindsize=stride,以及中间对应的j(CD)。上面我们分析,对于固定的frontIdx来说,behindsize在内存中是连续的,我们可以使用for(j = 0: j < B: j++),循环体内每次读取[frontIdx + j×(CD):frontIdx + (j+1)×(CD)]数据,因此得到B个长度为CD的向量,然后这B个向量逐元素对比最大值得到一个长度为CD的向量tmpNewMax,此时tmpNewMax对应元素保存的就是固定frontIdx下不同(k,s)对应的最大值。
和上面axis=-1类似,这种情况如果behindsize远远小于maxNum,那么src也会有大量的内存浪费,因此我们也希望能让src尽可能多加载数据。
这里我们需要考虑一下maxNum和BCD的相对大小,在axis=1的情况下,如果BCD的大小和maxNum差不多,那么我们尽量希望src一次加载长度为BCD的向量,此时src保存的数据相当于是固定frontIdx情况下,对于所有(k,s)的数据,接下来我们针对src的数据做一个循环for(j=0;j<B;j++),循环体每次读取长度为CD的数据,不断更新最大值,最后写回GDRAM。这种做法更加适合axis相对靠前,CD小于maxNum,BCD小于maxNum但是BCD接近maxNum的情况,因为当axis相对靠前的时候,此时dimsize×stride会更有机会超过maxNum。如果说stride比maxNum小,但是dimsize×stride比maxNum大,此时我们就需要针对dimsize进行拆分,详细细节参考代码。
如果axis相对靠后,此时就算是dimsize×stride也远小于maxNum,那么就算一次读取长度为dimsize×stride的数据,src也会有大量内存浪费,此时我们就希望src能够读取多个以长度为dimsize×stride的数据,保证src内存尽可能填充满(最极端的例子,比如说上面的4维向量[A,B,C,D],axis=2,如果D远小于maxNum,CD远小于maxNum,BCD远小于maxNum,就连ABCD也远小于maxNum,此时就干脆让src一次把所有数据都加载进来)。这个时候就需要额外开辟一个长度为dimsize×stride的NRAM向量,每次从src中读取数据,不断计算循环(和原始做法一样,只不过原来从GDRAM读取长度为dimsize×stride的数据,现在变成了从NRAM的src中读取长度为dimsize×stride的数据)。

axis相对靠前

此时虽然stride<maxNum,但是dimsize×stride>=maxNum,那么我们一次让src加载multiple×stride个数据,其中multiple=maxNum/stride,代码如下:
下面这个代码需要特别注意的是,计算出frontIdx以后,千万不能写source = source + frontIdx,而是应该在数据读取的时候进行偏移,否则会导致内存踩踏(内存踩踏原因还在查找)

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素//strideS是大于等于stride的最小的二的幂次方
template<int strideS>
__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {// 0<axis<dim -1 __nram__ float src[maxNum];__nram__ float tmp[strideS];__nram__ float tmpOldMax[strideS];__nram__ float tmpNewMax[strideS];__nram__ float tmpSum[strideS];if(dimsize * stride >= maxNum){int multiple = maxNum / stride;int size = multiple * stride;//一个src最多可以放的数据量int remain = dimsize % multiple;//如果不能整除,这部分数据需要特殊处理int repeat = (dimsize - remain) / multiple;//为了加载整个dimsize需要的循环总数printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain);for(int ind = taskId; ind < frontsize; ind += taskDim){int frontIdx = ind * dimsize * stride;__bang_write_value(tmpNewMax, strideS, -INFINITY);//必须初始化为负无穷__bang_write_value(tmp, strideS, -INFINITY);//必须初始化为负无穷__bang_write_zero(tmpSum, strideS);//必须初始化为0for(int j = 0; j < repeat; j++){__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);for(int m = 0; m < multiple; m++){__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//虽然tmpNewMax后面strideS-stride部分是0,但是不用写回GDRAM,不影响结果__bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0__bang_active_exp_less_0(tmp, tmp, strideS);if(j != 0 || m != 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)//if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM}}//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]);if(remain){__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);for(int m = 0; m < remain; m++){__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);__bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0__bang_active_exp_less_0(tmp, tmp, strideS);if(repeat != 0 || m != 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM}}//此时tmpNewMax存储的是对应于固定frontIdx,behindsize对应数据的最大值,而tmpSum存储的就是对应数值和//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);__bang_active_recip(tmpSum, tmpSum, strideS);//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);if(remain){for(int m = 0; m < remain; m++){__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);__bang_sub(tmp, tmp, tmpNewMax, strideS);__bang_active_exp_less_0(tmp, tmp, strideS);__bang_mul(tmp, tmp, tmpSum, strideS);__memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);}}for(int j = 0 ; j < repeat; j++){__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);for(int m = 0; m < multiple; m++){__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);__bang_sub(tmp, tmp, tmpNewMax, strideS);__bang_active_exp_less_0(tmp, tmp, strideS);__bang_mul(tmp, tmp, tmpSum, strideS);__memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);}}}}}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 1;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;int frontsize = 1;;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s < axis){frontsize *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {//host_src[i] = i%4;host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_m<1024><<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

axis相对靠后

此时不仅stride<maxNum,dimsize×stride<maxNum,那么干脆定义behindsize = dimsize×stride,我们一次让src加载multiple×behindsize个数据,其中multiple=maxNum/behindsize,代码如下:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素//strideS是大于等于stride的最小的二的幂次方
template<int strideS>
__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {// 0<axis<dim -1 __nram__ float src[maxNum];__nram__ float tmp[strideS];__nram__ float tmpOldMax[strideS];__nram__ float tmpNewMax[strideS];__nram__ float tmpSum[strideS];if(dimsize * stride < maxNum){int behindsize = dimsize * stride;int multiple = maxNum / behindsize;//表示一个maxNum能够在frontsize中分担的量int size = multiple * behindsize;//一个taskId中一个src能够加载的数据量int remainF = frontsize % (taskDim * multiple);int remainT = remainF % taskDim;int stepEasy = (remainF - remainT) / taskDim;int stepHard = stepEasy + 1;int step = (taskId < remainT ? stepHard : stepEasy);int taskRepeat = (frontsize - remainF) / (taskDim * multiple);//此时对应于frontsize,每个taskId处理的数据量是taskRepeat * multiple + stepint startHard = taskId * (taskRepeat * multiple + stepHard);int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);int indStart = (taskId < remainT ? startHard: startEasy);source = source + indStart * behindsize;//indStart * behindsize表示不同taskId对应的偏移量destination = destination + indStart * behindsize;int tid;for(int s = 0; s < taskRepeat; s++){tid = s * multiple * behindsize;__memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);for(int m = 0; m < multiple; m++){__bang_write_zero(tmpSum, strideS);__bang_write_value(tmp, strideS, -INFINITY);__bang_write_value(tmpNewMax, strideS, -INFINITY);for(int i = 0; i < dimsize; i++){__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM}__bang_active_recip(tmpSum, tmpSum, strideS);__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)__bang_mul(tmp, tmp, tmpSum, strideS);//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);}}__memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM);}__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);if(step){tid = taskRepeat * multiple * behindsize; __memcpy(src, source + tid, step * behindsize * sizeof(float), GDRAM2NRAM);for(int m = 0; m < step; m++){__bang_write_zero(tmpSum, strideS);__bang_write_value(tmp, strideS, -INFINITY);__bang_write_value(tmpNewMax, strideS, -INFINITY);for(int i = 0; i < dimsize; i++){__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)if(i > 0){__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)}__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM}//__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);__bang_active_recip(tmpSum, tmpSum, strideS);__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);for(int i = 0; i < dimsize - 1; i++){__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)__bang_mul(tmp, tmp, tmpSum, strideS);//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);}}__memcpy(destination + tid, src, step * behindsize * sizeof(float), NRAM2GDRAM);}}}int main(void)
{//int shape[4] = {1024,128,32,32};//int shape[4] = {1024,64,32,32};int shape[4] = {1024,32,32,32};//int shape[4] = {2, 3, 2, 2};int axis = 2;int stride = 1;int dimsize = shape[axis];int num = 1;int othersize = 1;int frontsize = 1;;for(int s = 3; s >= 0; s--){num *= shape[s];if(s > axis){stride *= shape[s];}if(s < axis){frontsize *= shape[s];}if(s != axis){othersize *= shape[s];}}printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);cnrtQueue_t queue;CNRT_CHECK(cnrtSetDevice(0));CNRT_CHECK(cnrtQueueCreate(&queue));cnrtDim3_t dim = {4, 1, 1};int taskNum = dim.x * dim.y * dim.z;cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;cnrtNotifier_t start, end;CNRT_CHECK(cnrtNotifierCreate(&start));CNRT_CHECK(cnrtNotifierCreate(&end));float* host_destination = (float*)malloc(num * sizeof(float));float* host_src = (float*)malloc(num * sizeof(float));for (int i = 0; i < num; i++) {//host_src[i] = i%4;host_src[i] = i;}float* mlu_destination;float* mlu_src;CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));//----------------------------CNRT_CHECK(cnrtPlaceNotifier(start, queue));softmaxKernelAxis_m<1024><<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);CNRT_CHECK(cnrtPlaceNotifier(end, queue));cnrtQueueSync(queue);//---------------------------CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));for(int i = 0; i < 24; i++){printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);}float timeTotal;CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));printf("Total Time: %.3f ms\n", timeTotal / 1000.0);CNRT_CHECK(cnrtQueueDestroy(queue));cnrtFree(mlu_destination);cnrtFree(mlu_src);free(host_destination);free(host_src);return 0;
}

下面使用的taskDim都是4,任务类型都是Union1:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/670585.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity_ShaderGraph节点问题

Unity_ShaderGraph节点问题 Unity版本&#xff1a;Unity2023.1.19 为什么在Unity2023.1.19的Shader Graph中找不见PBR Master节点&#xff1f; 以下这个PBR Maste从何而来&#xff1f;

linux下 Make 和 Makefile构建你的项目

Make 和 Makefile构建你的项目 介绍 在软件开发中&#xff0c;构建项目是一个必不可少的步骤。make 是一个强大的自动化构建工具&#xff0c;而 Makefile 是 make 工具使用的配置文件&#xff0c;用于描述项目的构建规则和依赖关系。本篇博客将介绍 make 和 Makefile 的基本概…

【成品论文】2024美赛B题完整成品论文23页+3小问matlab代码+数据集汇总

2024 年美国大学生数学建模竞赛&#xff08;2024 美赛&#xff09;B 题&#xff1a; 2024 MCM 问题 B: 搜寻潜水艇 题目翻译&#xff1a; Maritime Cruises Mini-Submarines (MCMS)是一家总部位于希腊的公司&#xff0c;专门制造能够携 带人类到达海洋最深处的潜水艇。潜水艇是…

【Kubernetes】在k8s1.24及以上版本基于containerd容器运行时测试pod从harbor拉取镜像

基于containerd容器运行时测试pod从harbor拉取镜像 1、安装高版本containerd2、安装docker3、登录harbor上传镜像4、从harbor拉取镜像 1、安装高版本containerd 集群中各个节点都要操作 yum remove containerd.io -y yum install containerd.io-1.6.22* -y cd /etc/containe…

SpringBoot实战第三天

今天主要完成了&#xff1a; 新增棋子分类 棋子分类列表 获取棋子分类详情 更新棋子分类 更新棋子分类和添加棋子分类_分组校验 新增棋子 新增棋子参数校验 棋子分类列表查询(条件分页) 先给出分类实体类 Data public class Category {private Integer id;//主键IDNot…

[UI5 常用控件] 06.Splitter,ResponsiveSplitter

文章目录 前言1. Splitter1.1 属性 2. ResponsiveSplitter 前言 本章节记录常用控件Splitter,ResponsiveSplitter。主要功能是分割画面布局。 其路径分别是&#xff1a; sap.ui.layout.Splittersap.ui.layout.ResponsiveSplitter 1. Splitter 1.1 属性 orientation &#x…

DevOps落地笔记-17|度量指标:寻找真正的好指标?

前面几个课时端到端地介绍了软件开发全生命周期中涉及的最佳实践&#xff0c;经过上面几个步骤&#xff0c;企业在进行 DevOps 转型时技术方面的问题解决了&#xff0c;这个时候我们还缺些什么呢&#xff1f;事实上很多团队和组织在实施 DevOps 时都专注于技术&#xff0c;而忽…

【Linux网络编程三】Udp套接字编程(简易版服务器)

【Linux网络编程三】Udp套接字编程(简易版服务器&#xff09; 一.创建套接字二.绑定网络信息1.构建通信类型2.填充网络信息①网络字节序的port②string类型的ip地址 3.最终绑定 三.读收消息1.服务器端接收消息recvfrom2.服务器端发送消息sendto3.客户端端发送消息sendto4.客户端…

TCP 了解

参考&#xff1a;4.2 TCP 重传、滑动窗口、流量控制、拥塞控制 | 小林coding TCP报文 其中比较重要的字段有&#xff1a;&#xff08;1&#xff09;序号&#xff08;sequence number&#xff09;&#xff1a;Seq序号&#xff0c;占32位&#xff0c;用来标识从TCP源端向目的端发…

利用IP地址精准定位服务

在数字化时代&#xff0c;IP地址已成为连接我们与网络世界的纽带之一。通过IP地址&#xff0c;我们可以追踪用户的位置信息&#xff0c;实现精准定位服务。本文将探讨如何利用IP地址精准定位服务&#xff0c;为个人和企业带来便利和价值。 一、什么是IP地址精准定位服务&#…

【FPGA】高云FPGA之IP核的使用->PLL锁相环

FPGA开发流程 1、设计定义2、设计输入3、分析和综合4、功能仿真5、布局布线6、时序仿真7、IO分配以及配置文件&#xff08;bit流文件&#xff09;的生成8、配置&#xff08;烧录&#xff09;FPGA9、在线调试 1、设计定义 使用高云内置IP核实现多路不同时钟输出 输入时钟50M由晶…

IDEA创建SpringBoot+Mybatis-Plus项目

IDEA创建SpringBootMybatis-Plus项目 一、配置Maven apache-maven-3.6.3的下载与安装&#xff08;详细教程&#xff09; 二、创建SpringBoot项目 在菜单栏选择File->new->project->Spring Initializr&#xff0c;然后修改Server URL为start.aliyun.com&#xff0c…

【图像文本化】Base64编解码OpenCV4中 Mat 对象

学习《OpenCV应用开发&#xff1a;入门、进阶与工程化实践》一书 做真正的OpenCV开发者&#xff0c;从入门到入职&#xff0c;一步到位&#xff01; 前言 很多时候在开发中&#xff0c;需要保存图像为文本形式&#xff0c;以便于存储与传输。最常见的就是把图像文件编码为Ba…

C# CAD交互界面-自定义工具栏(二)

运行环境 vs2022 c# cad2016 调试成功 一、引用 acdbmgd.dllacmgd.dllaccoremgd.dllAutodesk.AutoCAD.Interop.Common.dllAutodesk.AutoCAD.Interop.dll using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.T…

spring boot学习第十篇:elastic search必须使用用户名密码授权后才能访问、在java代码中操作索引

前提条件&#xff1a;安装好了elastic search服务&#xff0c;参考&#xff1a;elastic search入门_ubuntu elasticsearch 密码-CSDN博客 1、配置elastic search必须使用用户名密码授权才能访问 1.1开启x-pack验证 修改config目录下面的elasticsearch.yml文件&#xff0c;添…

VM 虚拟机和容器技术之间有什么区别?

随着云计算技术的不断发展&#xff0c;虚拟机和容器技术作为两种常见的虚拟化技术&#xff0c;被广泛应用于云计算领域。虽然虚拟机和容器技术都是虚拟化技术&#xff0c;但它们之间存在一些重要的区别。本文将详细介绍虚拟机和容器技术的区别&#xff0c;以便读者更好地了解这…

亚信安慧AntDB推动技术创新与满足用户需求

随着互联网技术的迅猛发展&#xff0c;大数据时代的到来&#xff0c;数据库的需求不断增长。在这样的背景下&#xff0c;国产分布式数据库正逐渐崭露头角&#xff0c;AntDB作为其中的重要代表&#xff0c;也积极参与到了这场竞争中。作为国内的技术创新者&#xff0c;AntDB不仅…

【Mybatis】从0学习Mybatis(2)

前言 本篇文章是从0学习Mybatis的第一篇文章&#xff0c;由于篇幅太长CSDN会限流&#xff0c;因此我打算分开两期来写&#xff0c;这是第二期&#xff01;第一期在这儿&#xff1a;【Mybatis】从0学习Mybatis&#xff08;1&#xff09;-CSDN博客 1.什么是ResultMap结果映射&am…

从小白到入门webrtc音视频通话

0. 写在前面 先会骑车&#xff0c;再研究为什么这么骑&#xff0c;才是我认为学习技术的思路&#xff0c;底部付了demo例子&#xff0c;根据例子上面的介绍即可运行。 1. 音视频通话要用到的技术简介 websocket 介绍&#xff1a;1. 服务器可以向浏览器推送信息&#xff1b;2…

C#,河豚算法(Blowfish Algorithm)的加密、解密源代码

Bruce Schneier 1 河豚算法&#xff08;Blowfish Algorithm&#xff09; 河豚算法&#xff08;Blowfish Algorithm&#xff09;是1993年11月由Bruce Schneier设计的一个完全开源的算法。 Blowfish算法是一个分组长度为64位、密钥长度可变的对称分组密码算法。 Blowfish算法具…