高精度矢量内积计算方法 (单精度浮点, 超长矢量)
对于单精度浮点类型的超长矢量(超过1亿元素)内积计算,累加误差确实是一个重要问题。以下是几种减少误差的方法:
1. Kahan求和算法
这是最常用的补偿求和算法,可以有效减少累加误差:
float kahan_dot_product(const float* a, const float* b, size_t n) {float sum = 0.0f;float c = 0.0f; // 补偿变量for (size_t i = 0; i < n; i++) {float product = a[i] * b[i];float y = product - c;float t = sum + y;c = (t - sum) - y;sum = t;}return sum;
}
2. 成对求和 (Pairwise Summation)
将求和过程组织为二叉树形式,减少大数吃小数的机会:
float pairwise_dot(const float* a, const float* b, size_t start, size_t end) {if (end - start <= 128) { // 基础块大小可调整float sum = 0.0f;for (size_t i = start; i < end; i++) {sum += a[i] * b[i];}return sum;}size_t mid = start + (end - start) / 2;float left = pairwise_dot(a, b, start, mid);float right = pairwise_dot(a, b, mid, end);return left + right;
}
3. 多精度累加器
使用双精度或更高精度累加器:
float high_precision_dot(const float* a, const float* b, size_t n) {double sum = 0.0;for (size_t i = 0; i < n; i++) {sum += (double)a[i] * (double)b[i];}return (float)sum;
}
4. 分块并行计算
对于超长矢量,可以结合并行计算和误差控制:
float parallel_dot(const float* a, const float* b, size_t n, int num_threads) {std::vector<double> partial_sums(num_threads, 0.0);#pragma omp parallel for num_threads(num_threads)for (int t = 0; t < num_threads; t++) {size_t start = t * n / num_threads;size_t end = (t + 1) * n / num_threads;double local_sum = 0.0;for (size_t i = start; i < end; i++) {local_sum += (double)a[i] * (double)b[i];}partial_sums[t] = local_sum;}double total = 0.0;for (double s : partial_sums) {total += s;}return (float)total;
}
实际应用建议
- 对于GPU计算:使用Kahan求和或成对求和,因为双精度运算在GPU上可能性能较差
- 对于CPU计算:优先考虑双精度累加器方法,简单且有效
- 混合精度:可以考虑使用融合乘加(FMA)指令结合补偿算法
- 内存访问优化:确保数据连续访问,减少缓存未命中
误差分析
对于1亿个单精度数的内积:
- 普通累加:相对误差可能达到1e-4量级
- Kahan求和:误差可降低到1e-8量级
- 双精度累加:误差可降低到1e-16量级
选择哪种方法取决于您的精度要求和计算环境限制。
测试代码
#include <cmath>
#include <iostream>
#include <iomanip>
#include <omp.h>
#include <cstdlib>
#include <immintrin.h>using namespace std;//编译: g++ -DAVX2 -DAVX512F -std=c++17 -Ofast -march=native -fopenmp test_dot.cpp/********************************************************<x,y>内积计算用double双精度做累加类型,保证数值稳定
**********************************************************/template<typename F,int P=0>
F dot(int n, const F *x, const F *y)
{if constexpr (P==0){double s_time=omp_get_wtime();//累加用单精度F s=0;for(int i=0; i<n; i++){s+=x[i]*y[i];}double e_time=omp_get_wtime();cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;return s;}else if constexpr (P==1){double s_time=omp_get_wtime();//累加用双精度,乘法用单精度double s=0;for(int i=0; i<n; i++){s+=x[i]*y[i];}double e_time=omp_get_wtime();cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;return s;}else if constexpr (P==2){double s_time=omp_get_wtime();//累加用双精度,乘法用双精度double s=0;for(int i=0; i<n; i++){double a=x[i];double b=y[i];s+=a*b;}double e_time=omp_get_wtime();cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;return s;}
#ifdef AVX2 else if constexpr(P==3){static_assert(is_same_v<F,float>);/********************************************************************OpenMP多线程,AVX2计算<x,x>,<x,y>累加和乘法都用双精度 *********************************************************************/double s_time=omp_get_wtime();__m256d xx_sum,xy_sum;xx_sum=xy_sum=_mm256_setzero_pd();#pragma omp parallel{__m256d sum_xx=_mm256_setzero_pd();__m256d sum_xy=_mm256_setzero_pd();#pragma omp for nowaitfor(int i=0; i<n; i+=8){__m256 x8=_mm256_loadu_ps(x+i);__m256 y8=_mm256_loadu_ps(y+i);__m128 lo,hi;__m256d t1,t2,t3,t4;lo=_mm256_extractf128_ps(x8,0);hi=_mm256_extractf128_ps(x8,1);t1=_mm256_cvtps_pd(lo);t2=_mm256_cvtps_pd(hi);lo=_mm256_extractf128_ps(y8,0);hi=_mm256_extractf128_ps(y8,1);t3=_mm256_cvtps_pd(lo);t4=_mm256_cvtps_pd(hi);
#if 0 sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t1,t1));sum_xx=_mm256_add_pd(sum_xx,_mm256_mul_pd(t2,t2));sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t1,t3));sum_xy=_mm256_add_pd(sum_xy,_mm256_mul_pd(t2,t4));
#else/**********************************FMA***********************************/sum_xx=_mm256_fmadd_pd(t1,t1,sum_xx);sum_xx=_mm256_fmadd_pd(t2,t2,sum_xx);sum_xy=_mm256_fmadd_pd(t1,t3,sum_xy);sum_xy=_mm256_fmadd_pd(t2,t4,sum_xy);
#endif}#pragma omp critical{xx_sum=_mm256_add_pd(xx_sum,sum_xx);xy_sum=_mm256_add_pd(xy_sum,sum_xy);}}double tmp[4];_mm256_storeu_pd(tmp,xy_sum);double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3];_mm256_storeu_pd(tmp,xx_sum);double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3];for(int i=n&~7; i<n; i++){double a=x[i];double b=y[i];xx+=a*a;xy+=a*b;}double e_time=omp_get_wtime();cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;cout<<"xx="<<xx<<endl;return xy;}//P==3
#endif
#ifdef AVX512F else if constexpr(P==4){static_assert(is_same_v<F,float>);/********************************************************************OpenMP多线程,AVX512F计算<x,x>,<x,y>累加和乘法都用双精度 *********************************************************************/double s_time=omp_get_wtime();__m512d xx_sum=_mm512_setzero_pd();__m512d xy_sum=_mm512_setzero_pd();#pragma omp parallel{__m512d sum_xx=_mm512_setzero_pd();__m512d sum_xy=_mm512_setzero_pd();#pragma omp for nowaitfor(int i=0; i<n; i+=16){__m512 x16=_mm512_loadu_ps(x+i);__m512 y16=_mm512_loadu_ps(y+i);__m256 lo,hi;__m512d t1,t2,t3,t4;lo=_mm512_extractf32x8_ps(x16,0);hi=_mm512_extractf32x8_ps(x16,1);t1=_mm512_cvtps_pd(lo);t2=_mm512_cvtps_pd(hi);lo=_mm512_extractf32x8_ps(y16,0);hi=_mm512_extractf32x8_ps(y16,1);t3=_mm512_cvtps_pd(lo);t4=_mm512_cvtps_pd(hi);
#if 0 sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t1,t1));sum_xx=_mm512_add_pd(sum_xx,_mm512_mul_pd(t2,t2));sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t1,t3));sum_xy=_mm512_add_pd(sum_xy,_mm512_mul_pd(t2,t4));
#else/***********************************FMA************************************/sum_xx=_mm512_fmadd_pd(t1,t1,sum_xx);sum_xx=_mm512_fmadd_pd(t2,t2,sum_xx);sum_xy=_mm512_fmadd_pd(t1,t3,sum_xy);sum_xy=_mm512_fmadd_pd(t2,t4,sum_xy);
#endif}#pragma omp critical{xx_sum=_mm512_add_pd(xx_sum,sum_xx);xy_sum=_mm512_add_pd(xy_sum,sum_xy);}}double tmp[8];_mm512_storeu_pd(tmp,xy_sum);double xy=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];_mm512_storeu_pd(tmp,xx_sum);double xx=tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];for(int i=n&~15; i<n; i++){double a=x[i];double b=y[i];xx+=a*a;xy+=a*b;}double e_time=omp_get_wtime();cout<<"P="<<P<<": time used "<<e_time-s_time<<endl;cout<<"xx="<<xx<<endl;return xy;}
#endif return(0);}const int N=50000000;void test()
{using FLOAT=float;FLOAT *x=new FLOAT[N];FLOAT *y=new FLOAT[N];for(int i=0; i<N; i++){FLOAT t=0.001*sqrtf(FLOAT(i));FLOAT s=sqrt(sqrt(FLOAT(i)));x[i]=(rand()<RAND_MAX/2)?t:-t;y[i]=(rand()<RAND_MAX/2)?s:-s;}cout<<setprecision(15)<<endl;cout<<(double)dot<FLOAT,0>(N,x,y)<<endl;cout<<(double)dot<FLOAT,1>(N,x,y)<<endl;cout<<(double)dot<FLOAT,2>(N,x,y)<<endl;#ifdef AVX2 cout<<(double)dot<FLOAT,3>(N,x,y)<<endl;
#endif #ifdef AVX512F cout<<(double)dot<FLOAT,4>(N,x,y)<<endl;
#endif }int main(int argc, char **argv)
{test();return(0);
}
资料
Intel /intrinsics-guide