CUTLASS 中的 47_ampere_gemm_universal_streamk 示例

前一篇文章介绍了 Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU 论文,下面对其代码实现进行分析。

cutlass 的 examples/47_ampere_gemm_universal_streamk 展示了 GEMM Stream-K 算法在 Ampere 架构上的使用。对比了普通 Gemm 以及 Split-K 算法和 Stream-K 的性能:

  • Device 层面,GemmUniversal 统一支持了 Gemm、Split-K 和 Stream-K 算法,主要实现在其基类 GemmUniversalBase 中;
  • Kernel 层面, GemmUniversal 为 Gemm 和 Split-K 的实现,GemmUniversalStreamk 为 Stream-K 的实现:
    • 二者由 DefaultGemmUniversal、DefaultGemm、Gemm 等共享很多组件和配置,即构建了 Gemm,但是仅使用其中组件;
    • 通用的 kernel 模板函数 Kernel2 调用 GemmUniversal::invoke 和 GemmUniversalStreamk::invoke 函数,主要实现为 GemmUniversal::run_with_swizzle 和 GemmUniversalStreamk::gemm 函数;
  • Threadblock 层面,同样分为 GemmIdentityThreadblockSwizzle 和 ThreadblockSwizzleStreamK 两个分支:
    • 采用4阶段的 MmaMultistage,其中 A 和 B 矩阵的迭代器为 PredicatedTileAccessIterator;
    • Epilogue 同时继承了 EpilogueBase 和 EpilogueBaseStreamK,DefaultEpilogueTensorOp 中指定了输出迭代器为 PredicatedTileIterator。

可以参照 CUTLASS GEMM Components 展示的层级来理解。
enter image description here

注意

  • 当前 cutlass 仓库中新旧代码并存,该示例调用的是 2.x API。
  • 论文中 Data-parallel 和 Fixed-split 均对应到kGemm模式,kGemmSplitKParallel模式为 GemmSplitKParallel。

ampere_gemm_universal_streamk.cu

检查 CUDA Toolkit 版本。

/// Program entrypoint
int main(int argc, const char **argv)
{// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.if (!(__CUDACC_VER_MAJOR__ >= 11)) {std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;// Returning zero so this test passes on older Toolkits. Its actions are no-op.return 0;}

cudaGetDevice 为 CUDA Runtime API,返回当前正在使用的设备。
cudaGetDeviceProperties 返回有关计算设备的信息 cudaDeviceProp 。

检查设备计算能力。这里要求 SM80以上。

  // Current device must must have compute capability at least 80cudaDeviceProp props;int current_device_id;CUDA_CHECK(cudaGetDevice(&current_device_id));CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));if (!((props.major * 10 + props.minor) >= 80)){std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."<< std::endl;// Returning zero so this test passes on older Toolkits. Its actions are no-op.return 0;}

创建一个 Options 结构体。
Options::parse 通过 CommandLine 结构体解析命令行参数。

  // Parse commandline optionsOptions options("ampere_streamk_gemm");options.parse(argc, argv);if (options.help) {options.print_usage(std::cout) << std::endl;return 0;}std::cout <<options.iterations << " timing iterations of " <<options.problem_size.m() << " x " <<options.problem_size.n() << " x " <<options.problem_size.k() << " matrix-matrix multiply" << std::endl;if (!options.valid()) {std::cerr << "Invalid problem." << std::endl;return -1;}

HostTensor::resize 改变逻辑张量的大小。
HostTensor::host_view 返回一个 TensorView 对象。
TensorFillRandomUniform 函数通过 std::rand 生成随机数。

  //// Initialize GEMM datasets//// Initialize tensors using CUTLASS helper functionsoptions.tensor_a.resize(options.problem_size.mk());       // <- Create matrix A with dimensions M x Koptions.tensor_b.resize(options.problem_size.kn());       // <- Create matrix B with dimensions K x Noptions.tensor_c.resize(options.problem_size.mn());       // <- Create matrix C with dimensions M x Noptions.tensor_d.resize(options.problem_size.mn());       // <- Create matrix D with dimensions M x N used to store output from CUTLASS kerneloptions.tensor_ref_d.resize(options.problem_size.mn());   // <- Create matrix D with dimensions M x N used to store output from reference kernel// Fill matrix A on host with uniform-random data [-2, 2]cutlass::reference::host::TensorFillRandomUniform(options.tensor_a.host_view(),1,ElementA(2),ElementA(-2),0);// Fill matrix B on host with uniform-random data [-2, 2]cutlass::reference::host::TensorFillRandomUniform(options.tensor_b.host_view(),1,ElementB(2),ElementB(-2),0);// Fill matrix C on host with uniform-random data [-2, 2]cutlass::reference::host::TensorFillRandomUniform(options.tensor_c.host_view(),1,ElementC(2),ElementC(-2),0);

HostTensor::sync_device 拷贝数据到设备端。
HostTensor::device_ref 返回一个 TensorRef 对象。
DeviceGemmReference 即 Gemm。调用参考 kernel 计算结果。
HostTensor::sync_host 拷贝数据到主机端。

  //// Compute reference output//// Copy data from host to GPUoptions.tensor_a.sync_device();options.tensor_b.sync_device();options.tensor_c.sync_device();// Zero-initialize reference output matrix Dcutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());options.tensor_ref_d.sync_device();// Create instantiation for device reference gemm kernelDeviceGemmReference gemm_reference;// Launch device reference gemm kernelgemm_reference(options.problem_size,ElementAccumulator(options.alpha),options.tensor_a.device_ref(),options.tensor_b.device_ref(),ElementAccumulator(options.beta),options.tensor_c.device_ref(),options.tensor_ref_d.device_ref());// Wait for kernels to finishCUDA_CHECK(cudaDeviceSynchronize());// Copy output data from reference kernel to host for comparisonoptions.tensor_ref_d.sync_host();

options.split_k_factor=1时比较 Basic-DP 和 StreamK。
调用 run 模板函数来运行参数实例化的 kernel。
DeviceGemmBasic 和 DeviceGemmStreamK 均为 GemmUniversal。只是前者使用 GemmIdentityThreadblockSwizzle 后者使用 ThreadblockSwizzleStreamK。
options.split_k_factor自增。

  //// Evaluate CUTLASS kernels//// Test default operationif (options.split_k_factor == 1){// Compare basic data-parallel version versus StreamK version using default load-balancing heuristicsResult basic_dp         = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);Result streamk_default  = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1options.avail_sms       = 1;        // Set loadbalancing width to 1 SM (no load balancing)Result streamk_dp       = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);options.avail_sms       = -1;       // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));options.split_k_factor++;     // Increment splitting factor for next evaluation}

options.split_k_factor大于1时,比较 Basic-SplitK 和 SplitK-StreamK。

  // Show that StreamK can emulate "Split-K" with a tile-splitting factorResult basic_splitk = run<DeviceGemmBasic>(std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),options);Result streamk_splitk = run<DeviceGemmStreamK>(std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),options);printf("  Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));return 0;
}

run

TensorFill 用标量元素填充张量。

/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options)
{// Display test descriptionstd::cout << std::endl << description << std::endl;// Zero-initialize test output matrix Dcutlass::reference::host::TensorFill(options.tensor_d.host_view());options.tensor_d.sync_device();

创建一个 GemmUniversal 对象。
args_from_options 分为 DeviceGemmBasic 和 DeviceGemmStreamK 两个版本。根据 Options 构造出 GemmUniversal::Arguments,即 GemmUniversalBase::Arguments,即 GemmUniversal::Arguments。
GemmUniversalBase::get_workspace_size 返回由这些参数表示的问题几何形状所需的工作区大小(以字节为单位)。
allocation 即 DeviceAllocation。构造函数调用 allocate 申请内存。
GemmUniversalBase::can_implement 判断能否 grid 是否超出以及形状是否满足对齐要求。
GemmUniversalBase::initialize 初始化参数。

  // Instantiate CUTLASS kernel depending on templatesDeviceGemmT device_gemm;// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmTauto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);// Using the arguments, query for extra workspace required for matrix multiplication computationsize_t workspace_size = DeviceGemmT::get_workspace_size(arguments);// Allocate workspace memorycutlass::device_memory::allocation<uint8_t> workspace(workspace_size);// Check the problem size is supported or notCUTLASS_CHECK(device_gemm.can_implement(arguments));// Initialize CUTLASS kernel with arguments and workspace pointerCUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));

进行功能测试。
调用不带入参的 GemmUniversalBase::operator() 函数。
TensorEquals 检查输出是否和参考值的每个元素都相等。能做到严格相等吗?

  // Correctness / Warmup iterationCUTLASS_CHECK(device_gemm());// Copy output data from CUTLASS and reference kernel to host for comparisonoptions.tensor_d.sync_host();// Check if output from CUTLASS kernel and reference kernel are equal or notResult result;result.passed = cutlass::reference::host::TensorEquals(options.tensor_d.host_view(),options.tensor_ref_d.host_view());std::cout << "  Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;

性能测试。
GpuTimer 通过 cudaEvent 计时。
gflops 为实际计算吞吐量。

  // Run profiling loopif (options.iterations > 0){GpuTimer timer;timer.start();for (int iter = 0; iter < options.iterations; ++iter) {CUTLASS_CHECK(device_gemm());}timer.stop();// Compute average runtime and GFLOPs.float elapsed_ms = timer.elapsed_millis();result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);std::cout << "  Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;std::cout << "  GFLOPs: " << result.gflops << std::endl;}if (!result.passed) {exit(-1);}return result;
}

GemmUniversal

GemmUniversal
GemmUniversalBase

GemmUniversal 是一个有状态的、可重用的 GEMM 句柄。一旦为给定的 GEMM 计算(问题几何形状和数据引用)初始化后,它就可以在具有相同几何形状的不同 GEMM 问题之间重复使用。(一旦初始化,有关问题几何形状和指向工作区内存的引用的详细信息将无法更新。)通用 GEMM 支持串行归约、并行归约、批量跨步和批量数组变体。

主要实现都在 GemmUniversalBase 中。

DefaultGemmUniversal::GemmKernel 即 GemmUniversal 或 GemmUniversalStreamk。
DefaultGemmConfiguration::EpilogueOutputOp,即 LinearCombination。

/*! GemmUniversal is a stateful, reusable GEMM handle.  Once initialized for a given GEMM computation(problem geometry and data references), it can be reused across different GEMM problems having thegeometry.  (Once initialized, details regarding problem geometry and references to workspace memorycannot be updated.)The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and batched array variants.
*/
template </// Element type for A matrix operandtypename ElementA_,/// Layout type for A matrix operandtypename LayoutA_,/// Element type for B matrix operandtypename ElementB_,/// Layout type for B matrix operandtypename LayoutB_,/// Element type for C and D matrix operandstypename ElementC_,/// Layout type for C and D matrix operandstypename LayoutC_,/// Element type for internal accumulationtypename ElementAccumulator_ = ElementC_,/// Operator class tagtypename OperatorClass_ = arch::OpClassSimt,/// Tag indicating architecture to tune for.  This is the minimum SM that/// supports the intended feature. The device kernel can be built/// targeting any SM larger than this number.typename ArchTag_ = arch::Sm70,/// Threadblock-level tile size (concept: GemmShape)typename ThreadblockShape_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::ThreadblockShape,/// Warp-level tile size (concept: GemmShape)typename WarpShape_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::WarpShape,/// Instruction-level tile size (concept: GemmShape)typename InstructionShape_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::InstructionShape,/// Epilogue output operatortypename EpilogueOutputOp_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::EpilogueOutputOp,/// Threadblock-level swizzling operatortypename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,/// Number of stages used in the pipelined mainloopint Stages =DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,ElementC_, ElementAccumulator_>::kStages,/// Access granularity of A matrix in units of elementsint AlignmentA =DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,ElementC_, ElementAccumulator_>::kAlignmentA,/// Access granularity of B matrix in units of elementsint AlignmentB =DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,ElementC_, ElementAccumulator_>::kAlignmentB,/// Operation performed by GEMMtypename Operator_ = typename DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,ElementAccumulator_>::Operator,/// Complex elementwise transformation on A operandComplexTransform TransformA = ComplexTransform::kNone,/// Complex elementwise transformation on B operandComplexTransform TransformB = ComplexTransform::kNone,/// Gather operand A by using an index arraybool GatherA = false,/// Gather operand B by using an index arraybool GatherB = false,/// Scatter result D by using an index arraybool ScatterD = false,/// Permute result Dtypename PermuteDLayout_ = layout::NoPermute,/// Permute operand Atypename PermuteALayout_ = layout::NoPermute,/// Permute operand Btypename PermuteBLayout_ = layout::NoPermute
>
class GemmUniversal : public GemmUniversalBase<typename kernel::DefaultGemmUniversal<ElementA_,LayoutA_,TransformA,AlignmentA,ElementB_,LayoutB_,TransformB,AlignmentB,ElementC_,LayoutC_,ElementAccumulator_,OperatorClass_,ArchTag_,ThreadblockShape_,WarpShape_,InstructionShape_,EpilogueOutputOp_,ThreadblockSwizzle_,Stages,Operator_,SharedMemoryClearOption::kNone,GatherA,GatherB,ScatterD,PermuteDLayout_,PermuteALayout_,PermuteBLayout_>::GemmKernel> {public:using ElementAccumulator = ElementAccumulator_;using OperatorClass = OperatorClass_;using ArchTag = ArchTag_;using ThreadblockShape = ThreadblockShape_;using WarpShape = WarpShape_;using InstructionShape = InstructionShape_;using EpilogueOutputOp = EpilogueOutputOp_;using ThreadblockSwizzle = ThreadblockSwizzle_;using Operator = Operator_;using PermuteDLayout = PermuteDLayout_;using PermuteALayout = PermuteALayout_;using PermuteBLayout = PermuteBLayout_;static int const kStages = Stages;static int const kAlignmentA = AlignmentA;static int const kAlignmentB = AlignmentB;static int const kAlignmentC = EpilogueOutputOp::kCount;static ComplexTransform const kTransformA = TransformA;static ComplexTransform const kTransformB = TransformB;

GemmUniversal::GemmKernel 为 GemmUniversalBase::GemmKernel,即 DefaultGemmUniversal::GemmKernel。后者根据传入的模板参数ThreadblockSwizzle来确定。

  using Base = GemmUniversalBase<typename kernel::DefaultGemmUniversal<ElementA_,LayoutA_,TransformA,AlignmentA,ElementB_,LayoutB_,TransformB,AlignmentB,ElementC_,LayoutC_,ElementAccumulator_,OperatorClass_,ArchTag_,ThreadblockShape_,WarpShape_,InstructionShape_,EpilogueOutputOp_,ThreadblockSwizzle_,Stages,Operator_,SharedMemoryClearOption::kNone,GatherA,GatherB,ScatterD,PermuteDLayout_,PermuteALayout_,PermuteBLayout_>::GemmKernel>;using Arguments = typename Base::Arguments;using GemmKernel = typename Base::GemmKernel;
};

GemmUniversalBase

使用 GemmUniversal 或者 GemmUniversalStreamk 中的信息。

template <typename GemmKernel_>
class GemmUniversalBase {
public:using GemmKernel = GemmKernel_;/// Boolean indicating whether the CudaHostAdapter is enabledstatic bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;using ThreadblockShape = typename GemmKernel::Mma::Shape;using ElementA = typename GemmKernel::ElementA;using LayoutA = typename GemmKernel::LayoutA;using TensorRefA = TensorRef<ElementA const, LayoutA>;static ComplexTransform const kTransformA = GemmKernel::kTransformA;using ElementB = typename GemmKernel::ElementB;using LayoutB = typename GemmKernel::LayoutB;using TensorRefB = TensorRef<ElementB const, LayoutB>;static ComplexTransform const kTransformB = GemmKernel::kTransformB;using ElementC = typename GemmKernel::ElementC;using LayoutC = typename GemmKernel::LayoutC;using TensorRefC = TensorRef<ElementC const, LayoutC>;using TensorRefD = TensorRef<ElementC, LayoutC>;/// Numerical accumulation element typeusing ElementAccumulator = typename GemmKernel::Mma::ElementC;using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;using Operator = typename GemmKernel::Operator;

Arguments使用传入GemmKernel结构体的类型。即 GemmUniversal::Arguments 或者GemmUniversalStreamk::Arguments。
device_ordinal_的初始值为-1。

  /// Argument structureusing Arguments = typename GemmKernel::Arguments;/// Index of the GEMM Kernel within the CudaHostAdapterstatic int32_t const kGemmKernelIndex = 0;/// Kernel dynamic shared memory allocation requirement/// Update the kernel function's shared memory configuration for the current devicestatic constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage);protected://// Device properties (uniform across all instances of the current thread)//// Device ordinalCUTLASS_THREAD_LOCAL static int device_ordinal_;/// Device SM countCUTLASS_THREAD_LOCAL static int device_sms_;/// Kernel SM occupancy (in thread blocks)CUTLASS_THREAD_LOCAL static int sm_occupancy_;

GemmUniversalBase::init_device_props

初始化device_sms_sm_occupancy_,并设置动态 Shared Memory。

如果有必要,初始化线程当前设备的静态线程本地成员。
CUTLASS_TRACE_HOST 在 debug 模式下,打印文件名和行号。

protected:/// Initialize static thread-local members for the thread's current device,/// if necessary.static Status init_device_props(){CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()");

cudaGetDevice 返回当前正在使用的设备。
如果当前设备已经初始化了,则直接返回。

    cudaError_t cudart_result;// Get current device ordinalint current_ordinal;cudart_result = cudaGetDevice(&current_ordinal);if (cudart_result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaGetDevice() returned error " << cudaGetErrorString(cudart_result));return Status::kErrorInternal;}// Done if matches the current static memberif (current_ordinal == device_ordinal_) {// Already initializedreturn Status::kSuccess;}

cudaDeviceGetAttribute 返回有关设备的信息。

    // Update SM count membercudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal);if (cudart_result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result));return Status::kErrorInternal;}

cudaFuncSetAttribute 设置给定函数的属性。
如果 SharedMemory 大于48KB,则设置函数的动态分配的共享内存的最大容量。

    // If requires more than 48KB: configure for extended, dynamic shared memoryif constexpr (kSharedStorageSize >= (48 << 10)){cudart_result = cudaFuncSetAttribute(Kernel2<GemmKernel>,cudaFuncAttributeMaxDynamicSharedMemorySize,kSharedStorageSize);if (cudart_result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));return Status::kErrorInternal;}}

cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags 是 CUDA Runtime API,返回每个 SM 运行该 kernel 函数时的最大活跃线程块数。

    // Update SM occupancy membercudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(&sm_occupancy_,Kernel2<GemmKernel>,GemmKernel::kThreadCount,kSharedStorageSize,cudaOccupancyDisableCachingOverride);if (cudart_result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));return Status::kErrorInternal;}// Update device ordinal member on successdevice_ordinal_ = current_ordinal;CUTLASS_TRACE_HOST("  ""device_ordinal: (" << device_ordinal_ << "), ""device_sms: (" << device_sms_ << "), ""sm_occupancy: (" << sm_occupancy_ << ") ""smem_size: (" << kSharedStorageSize << ") ""GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");return Status::kSuccess;}

因 Kernel 不同,可能是 GemmUniversal::Params 或者 GemmUniversalStreamk::Params。


protected://// Instance data members///// Kernel parameterstypename GemmKernel::Params params_;

GemmUniversalBase::init_params

GemmUniversalBase::init_params
GemmUniversalBase::init_device_props

初始化params_

  /// Initialize params memberStatus init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr){int32_t device_sms = 0;int32_t sm_occupancy = 0;

kEnableCudaHostAdapter 的值为宏CUTLASS_ENABLE_CUDA_HOST_ADAPTER,未启用。
CudaHostAdapter 类也没有实现。

    if constexpr (kEnableCudaHostAdapter) {CUTLASS_ASSERT(cuda_adapter);//// Occupancy query using CudaHostAdapter::query_occupancy().//if (cuda_adapter) {Status status = cuda_adapter->query_occupancy(&device_sms,&sm_occupancy,kGemmKernelIndex,GemmKernel::kThreadCount,kSharedStorageSize);CUTLASS_ASSERT(status == Status::kSuccess);if (status != Status::kSuccess) {return status;}}else {return Status::kErrorInternal;}}

因此,调用 GemmUniversalBase::init_device_props 函数得到 SM 数量和 SM 内的最大线程块数。

    else {CUTLASS_ASSERT(cuda_adapter == nullptr);// Initialize static device properties, if necessaryStatus result = init_device_props();if (result != Status::kSuccess) {return result;}//// Use thread-local static members for occupancy query initialized by call to// `init_device_props()`//device_sms   = device_sms_;sm_occupancy = sm_occupancy_;}

得到一个 GemmUniversal::Params 或者 GemmUniversalStreamk::Params 对象。

    // Initialize params memberparams_ = typename GemmKernel::Params(args, device_sms, sm_occupancy);return Status::kSuccess;}

GemmUniversalBase::can_implement

GemmUniversalBase::can_implement
GemmUniversal::can_implement
GemmUniversalStreamk::can_implement

调用 kernel 的 GemmUniversal::can_implement 或 GemmUniversalStreamk::can_implement 进一步检查。

public://---------------------------------------------------------------------------------------------// Stateless API//---------------------------------------------------------------------------------------------/// Determines whether the GEMM can execute the given problem.static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");dim3 grid = get_grid_shape(args, cuda_adapter);if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&grid.z <= std::numeric_limits<uint16_t>::max())){return Status::kErrorInvalidProblem;}return GemmKernel::can_implement(args);}

GemmUniversalBase::get_workspace_size

返回由这些参数表示的问题几何形状所需的工作区大小(以字节为单位)。

GemmUniversalBase::get_workspace_size
GemmUniversalBase::init_params
UniversalParamsBase::get_workspace_size
GemmUniversalStreamk::Params::get_workspace_size
  /// Returns the workspace size (in bytes) needed for the problem/// geometry expressed by these argumentsstatic size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");

首先创建一个 GemmUniversalBase 对象。
然后调用 GemmUniversalBase::init_params 初始化参数。

    // Initialize parameters from argsGemmUniversalBase base;if (base.init_params(args, cuda_adapter) != Status::kSuccess) {return 0;}

调用 UniversalParamsBase::get_workspace_size 或者 GemmUniversalStreamk::Params::get_workspace_size 函数得到 kernel 需要的全局内存工作空间大小。

    // Get size from parameterssize_t workspace_bytes = base.params_.get_workspace_size();CUTLASS_TRACE_HOST("  workspace_bytes: " << workspace_bytes);return workspace_bytes;}

GemmUniversalBase::get_grid_shape

  /// Returns the grid extents in thread blocks to launchstatic dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");

首先创建一个 GemmUniversalBase 对象。
然后调用 GemmUniversalBase::init_params 初始化参数。

    // Initialize parameters from argsGemmUniversalBase base;if (base.init_params(args, cuda_adapter) != Status::kSuccess) {return dim3(0,0,0);}

调用 UniversalParamsBase::get_grid_dims 或者 GemmUniversalStreamk::Params::get_grid_dims 函数得到网格的维度。

    // Get dims from parametersdim3 grid_dims = base.params_.get_grid_dims();CUTLASS_TRACE_HOST("  tiled_shape: " << base.params_.get_tiled_shape()  << "\n"<< "  grid_dims: {" << grid_dims << "}");return grid_dims;}

GemmUniversalBase::maximum_active_blocks

与 GemmUniversalBase::init_params 中的操作类似。

  /// Returns the maximum number of active thread blocks per multiprocessorstatic int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");int32_t device_sms   = 0;int32_t sm_occupancy = 0;if constexpr (kEnableCudaHostAdapter) {CUTLASS_ASSERT(cuda_adapter);if (cuda_adapter) {Status status = cuda_adapter->query_occupancy(&device_sms,&sm_occupancy,kGemmKernelIndex,GemmKernel::kThreadCount,kSharedStorageSize);CUTLASS_ASSERT(status == Status::kSuccess);if (status != Status::kSuccess) {return -1;}}else {return -1;}}else {CUTLASS_ASSERT(cuda_adapter == nullptr);// Initialize static device properties, if necessaryif (init_device_props() != Status::kSuccess) {return -1;}sm_occupancy = sm_occupancy_;}CUTLASS_TRACE_HOST("  max_active_blocks: " << sm_occupancy_);return sm_occupancy;}

GemmUniversalBase::initialize

GemmUniversalBase::initialize
GemmUniversalBase::init_params
UniversalParamsBase::init_workspace
GemmUniversalStreamk::Params::init_workspace
  // Stateful API//---------------------------------------------------------------------------------------------/// Initializes GEMM state from arguments and workspace memoryStatus initialize(Arguments const &args,void *workspace = nullptr,cudaStream_t stream = nullptr,CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "<< workspace << ", stream: " << (stream ? "non-null" : "null"));

调用 GemmUniversalBase::init_params 函数得到 GemmUniversal::Params 或者 GemmUniversalStreamk::Params。

    // Initialize parameters from argsStatus result = init_params(args, cuda_adapter);if (result != Status::kSuccess) {return result;}

调用 UniversalParamsBase::init_workspace 函数或者 GemmUniversalStreamk::Params::init_workspace 函数对工作空间清零。

    // Assign and prepare workspace memoryif (args.mode == GemmUniversalMode::kGemm) {return params_.init_workspace(workspace, stream);}return Status::kSuccess;}

GemmUniversalBase::update

GemmUniversalBase::update
GemmUniversal::Params::update
GemmUniversalStreamk::Params::update

调用 GemmUniversal::Params::update 或者 GemmUniversalStreamk::Params::update 函数更新参数。

  /// Lightweight update given a subset of arguments.Status update(Arguments const &args){CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");params_.update(args);return Status::kSuccess;}

GemmUniversalBase::run

GemmUniversalBase::run
UniversalParamsBase::invoke
GemmUniversalStreamk::invoke

CUTLASS_TRACE_HOST 宏在 debug 模式下使用。
GemmUniversal::kThreadCount 和 GemmUniversalStreamk::kThreadCount 均通过 WarpCount得到。后者为 MmaBase::WarpCount,通过应用程序传入的 ThreadblockShape 和 WarpShape 确定。
调用 UniversalParamsBase::get_grid_dims 或 GemmUniversalStreamk::Params::get_grid_dims 函数得到网格维度。

  /// Runs the kernel using initialized state.Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr){CUTLASS_TRACE_HOST("GemmUniversalBase::run()");// Configure grid and block dimensionsdim3 block(GemmKernel::kThreadCount, 1, 1);dim3 grid = params_.get_grid_dims();

CUTLASS_ASSERT 为断言。
Kernel2 调用 GemmUniversal::invoke 或者 GemmUniversalStreamk::invoke 函数。
kernel 函数的参数为 GemmUniversal::Params 或者 GemmUniversalStreamk::Params 类。

    // Launch kernelCUTLASS_TRACE_HOST("  ""grid: (" << grid << "), ""block: (" << block << "), ""SMEM: (" << kSharedStorageSize << ")");if constexpr (kEnableCudaHostAdapter) {CUTLASS_ASSERT(cuda_adapter);if (cuda_adapter) {void* kernel_params[] = {&params_};return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0);}else {return Status::kErrorInternal;}}else {CUTLASS_ASSERT(cuda_adapter == nullptr);Kernel2<GemmKernel><<<grid, block, kSharedStorageSize, stream>>>(params_);// Query for errorscudaError_t result = cudaGetLastError();if (result != cudaSuccess) {CUTLASS_TRACE_HOST("  grid launch failed with error " << cudaGetErrorString(result));return Status::kErrorInternal;}}return Status::kSuccess;}

GemmUniversalBase::operator()

重载运算符调用 GemmUniversalBase::run 函数。

/// Runs the kernel using initialized state.Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr){return run(stream, cuda_adapter);}

GemmUniversalBase::operator()

接受输入参数的版本先 GemmUniversalBase::initialize 再 GemmUniversalBase::run。

  /// Runs the kernel using initialized state.Status operator()(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr,CudaHostAdapter *cuda_adapter = nullptr){Status status = initialize(args, workspace, stream, cuda_adapter);if (status == Status::kSuccess) {status = run(stream, cuda_adapter);}return status;}
};

UniversalParamsBase

/// Parameters structure
template <typename ThreadblockSwizzle,typename ThreadblockShape,typename ElementA,typename ElementB,typename ElementC,typename LayoutA,typename LayoutB>
struct UniversalParamsBase
{//// Data members//GemmCoord problem_size{};GemmCoord grid_tiled_shape{};int swizzle_log_tile{0};GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;int batch_count {0};int gemm_k_size {0};int64_t batch_stride_D {0};int *semaphore = nullptr;//// Host dispatch API///// Default constructorUniversalParamsBase() = default;

UniversalParamsBase::UniversalParamsBase

UniversalParamsBase::UniversalParamsBase
UniversalParamsBase::init_grid_tiled_shape

构造函数调用 UniversalParamsBase::init_grid_tiled_shape 计算切块后的网格形状。

  /// ConstructorUniversalParamsBase(UniversalArgumentsBase const &args, /// GEMM application argumentsint device_sms,                     /// Number of SMs on the deviceint sm_occupancy)                   /// Kernel SM occupancy (in thread blocks):problem_size(args.problem_size),mode(args.mode),batch_count(args.batch_count),batch_stride_D(args.batch_stride_D),semaphore(nullptr){init_grid_tiled_shape();}

UniversalParamsBase::get_workspace_size

GemmSplitKParallel 需要problem.m() * problem.n() * k_slice 的工作空间。

  /// Returns the workspace size (in bytes) needed for this problem geometrysize_t get_workspace_size() const{size_t workspace_bytes = 0;if (mode == GemmUniversalMode::kGemmSplitKParallel){// Split-K parallel always requires a temporary workspaceworkspace_bytes =sizeof(ElementC) *size_t(batch_stride_D) *size_t(grid_tiled_shape.k());}

串行的话空间对应输出分块数量,因为每个输出分块需要一个同步信号量进行归约。

    else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1){// Serial split-K only requires a temporary workspace if the number of partitions along the// GEMM K dimension is greater than one.workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());}return workspace_bytes;}

UniversalParamsBase::init_workspace

调用 UniversalParamsBase::get_workspace_size 获取大小。
cudaMemsetAsync 将同步信号量清零。
分配并初始化指定的工作区缓冲区。 假设分配给工作区的内存至少与 get_workspace_size() 相同大。

  /// Assign and initialize the specified workspace buffer.  Assumes/// the memory allocated to workspace is at least as large as get_workspace_size().Status init_workspace(void *workspace,cudaStream_t stream = nullptr){semaphore = static_cast<int *>(workspace);// Zero-initialize entire workspaceif (semaphore){size_t workspace_bytes = get_workspace_size();CUTLASS_TRACE_HOST("  Initialize " << workspace_bytes << " workspace bytes");cudaError_t result = cudaMemsetAsync(semaphore,0,workspace_bytes,stream);if (result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaMemsetAsync() returned error " << cudaGetErrorString(result));return Status::kErrorInternal;}}return Status::kSuccess;}

UniversalParamsBase::get_tiled_shape

  /// Returns the GEMM volume in thread block tilesGemmCoord get_tiled_shape() const{return grid_tiled_shape;}

UniversalParamsBase::get_grid_blocks

UniversalParamsBase::get_grid_blocks
UniversalParamsBase::get_grid_dims

返回要启动的线程块总数。
UniversalParamsBase::get_grid_dims 函数返回网格的维度。

  /// Returns the total number of thread blocks to launchint get_grid_blocks() const{dim3 grid_dims = get_grid_dims();return grid_dims.x * grid_dims.y * grid_dims.z;}

UniversalParamsBase::get_grid_dims

UniversalParamsBase::get_grid_dims
GemmIdentityThreadblockSwizzle::get_grid_shape

GemmIdentityThreadblockSwizzle::get_grid_shape 函数根据传入的grid_tiled_shape以逻辑图块为单位计算 CUDA 网格尺寸。

  /// Returns the grid extents in thread blocks to launchdim3 get_grid_dims() const{return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape);}

UniversalParamsBase::init_grid_tiled_shape

UniversalParamsBase::init_grid_tiled_shape
GemmIdentityThreadblockSwizzle::get_tiled_shape
GemmIdentityThreadblockSwizzle::get_log_tile

调用 GemmIdentityThreadblockSwizzle::get_tiled_shape 函数以逻辑图块为单位返回问题的形状。
GemmIdentityThreadblockSwizzle::get_log_tile 函数计算最佳光栅化宽度。

private:CUTLASS_HOST_DEVICEvoid init_grid_tiled_shape() {// Get GEMM volume in thread block tilesgrid_tiled_shape = ThreadblockSwizzle::get_tiled_shape(problem_size,{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},batch_count);swizzle_log_tile = ThreadblockSwizzle::get_log_tile(grid_tiled_shape);// Determine extent of K-dimension assigned to each blockgemm_k_size = problem_size.k();

如果是 Gemm 模式或者 GemmSplitKParallel 模式,调整grid_tiled_shape.k()的值。
is_continous_k_aligned 判断 k 维是否对齐。
const_max 返回两个整型的最大值。
CACHELINE_BYTES是128,写法上支持更大值。
ceil_div 向上对齐的除法。
gemm_k_size为 GEMM 运算时的 k 维大小。根据问题大小得到 k 维上的分块数量。

    if (mode == GemmUniversalMode::kGemm || mode == GemmUniversalMode::kGemmSplitKParallel){static const uint32_t CACHELINE_BYTES = 128;static const size_t element_bytes_a = sizeof(ElementA);static const size_t element_bytes_b = sizeof(ElementB);static const size_t cacheline_elements_a = CACHELINE_BYTES / element_bytes_a;static const size_t cacheline_elements_b = CACHELINE_BYTES / element_bytes_b;const bool cacheline_alignment_needed =util::is_continous_k_aligned<LayoutA, LayoutB>(problem_size, cacheline_elements_a, cacheline_elements_b);int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value),cacheline_alignment_needed ? const_max(cacheline_elements_a, cacheline_elements_b) : 1);gemm_k_size = round_up(ceil_div(problem_size.k(), batch_count), kAlignK);if (gemm_k_size) {grid_tiled_shape.k() = ceil_div(problem_size.k(), gemm_k_size);}}}
};

DefaultGemmUniversal

/
//
// Real-valued GEMM kernels
//template </// Element type for A matrix operandtypename ElementA,/// Layout type for A matrix operandtypename LayoutA,/// Access granularity of A matrix in units of elementsint kAlignmentA,/// Element type for B matrix operandtypename ElementB,/// Layout type for B matrix operandtypename LayoutB,/// Access granularity of B matrix in units of elementsint kAlignmentB,/// Element type for C and D matrix operandstypename ElementC,/// Layout type for C and D matrix operandstypename LayoutC,/// Element type for internal accumulationtypename ElementAccumulator,/// Operator class tagtypename OperatorClass,/// Tag indicating architecture to tune fortypename ArchTag,/// Threadblock-level tile size (concept: GemmShape)typename ThreadblockShape,/// Warp-level tile size (concept: GemmShape)typename WarpShape,/// Warp-level tile size (concept: GemmShape)typename InstructionShape,/// Epilogue output operatortypename EpilogueOutputOp,/// Threadblock-level swizzling operatortypename ThreadblockSwizzle,/// Number of stages used in the pipelined mainloopint Stages,/// Operation performed by GEMMtypename Operator,/// Use zfill or predicate for out-of-bound cp.asyncSharedMemoryClearOption SharedMemoryClear,/// Gather operand A by using an index arraybool GatherA,/// Gather operand B by using an index arraybool GatherB,/// Scatter result D by using an index arraybool ScatterD,/// Permute result Dtypename PermuteDLayout,/// Permute operand Atypename PermuteALayout,/// Permute operand Btypename PermuteBLayout
>
struct DefaultGemmUniversal<ElementA,LayoutA,ComplexTransform::kNone,   // transform AkAlignmentA,ElementB,LayoutB,ComplexTransform::kNone,   // transform BkAlignmentB,ElementC,LayoutC,ElementAccumulator,OperatorClass,ArchTag,ThreadblockShape,WarpShape,InstructionShape,EpilogueOutputOp,ThreadblockSwizzle,Stages,Operator,SharedMemoryClear,GatherA,GatherB,ScatterD,PermuteDLayout,PermuteALayout,PermuteBLayout,typename platform::enable_if< ! cutlass::is_complex<ElementAccumulator>::value>::type
> {

DefaultGemmKernel为 DefaultGemm::GemmKernel,即 Gemm。
DefaultGemmKernel::Mma 为 Gemm::Mma,即
DefaultGemm::Mma,即 DefaultMma::ThreadblockMma,即 MmaMultistage。因为应用程序指定了 NumStages 等于4。
DefaultGemmKernel::Epilogue 为 Gemm::Epilogue,即 DefaultGemm::Epilogue,即 DefaultGemm::RegularEpilogue,即 DefaultEpilogueTensorOp::Epilogue,即 Epilogue。

  using DefaultGemmKernel = typename kernel::DefaultGemm<ElementA,LayoutA,kAlignmentA,ElementB,LayoutB,kAlignmentB,ElementC,LayoutC,ElementAccumulator,OperatorClass,ArchTag,ThreadblockShape,WarpShape,InstructionShape,EpilogueOutputOp,ThreadblockSwizzle,Stages,true,Operator,SharedMemoryClear,GatherA,GatherB,ScatterD,PermuteDLayout,PermuteALayout,PermuteBLayout>::GemmKernel;

SelectBase继承 GemmUniversal 或者 GemmUniversalStreamk。
根据传入的ThreadblockSwizzle是 GemmIdentityThreadblockSwizzle 还是 ThreadblockSwizzleStreamK 推断出来。

  /// Universal kernel without StreamkFeature member typetemplate <class SwizzleT, class Enable = void>class SelectBase :public kernel::GemmUniversal<typename DefaultGemmKernel::Mma,typename DefaultGemmKernel::Epilogue,SwizzleT>{};/// Universal kernel with StreamkFeature member typetemplate <class SwizzleT>class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :public kernel::GemmUniversalStreamk<typename DefaultGemmKernel::Mma,typename DefaultGemmKernel::Epilogue,SwizzleT>{};/// Select kernel by ThreadblockSwizzle's support for StreamkFeatureusing GemmKernel = SelectBase<ThreadblockSwizzle>;
};

GemmUniversal

Mma::Policy为DefaultMmaCore::MmaPolicy,即 DefaultMmaTensorOp::Policy,即 MmaTensorOpPolicy。
Mma::Operator为 DefaultMmaCore::MmaTensorOp ,即 DefaultMmaTensorOp::type,MmaTensorOp。

template <typename Mma_,                  ///! Threadblock-scoped matrix multiply-accumulatetypename Epilogue_,             ///! Epiloguetypename ThreadblockSwizzle_    ///! Threadblock swizzling function
>
class GemmUniversal<Mma_,Epilogue_,ThreadblockSwizzle_,void,// 3.x kernels use the first template argument to define the ProblemShape// We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x APIcute::enable_if_t<not (cute::is_tuple<Mma_>::value || IsCutlass3ArrayKernel<Mma_>::value)>
> {
public:using Mma = Mma_;using Epilogue = Epilogue_;using EpilogueOutputOp = typename Epilogue::OutputOp;using ThreadblockSwizzle = ThreadblockSwizzle_;using ElementA = typename Mma::IteratorA::Element;using LayoutA = typename Mma::IteratorA::Layout;using ElementB = typename Mma::IteratorB::Element;using LayoutB = typename Mma::IteratorB::Layout;using ElementC = typename Epilogue::OutputTileIterator::Element;using LayoutC = typename Epilogue::OutputTileIterator::Layout;static ComplexTransform const kTransformA = Mma::kTransformA;static ComplexTransform const kTransformB = Mma::kTransformB;using Operator = typename Mma::Operator;using OperatorClass = typename Mma::Operator::OperatorClass;using ThreadblockShape = typename Mma::Shape;using WarpShape = typename Mma::Operator::Shape;using InstructionShape = typename Mma::Policy::Operator::InstructionShape;using ArchTag = typename Mma::ArchTag;static int const kStages = Mma::kStages;static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;/// Warp count (concept: GemmShape)using WarpCount = typename Mma::WarpCount;static int const kThreadCount = 32 * WarpCount::kCount;/// Split-K preserves splits that are 128b alignedstatic int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);

GemmUniversal::Arguments

GemmUniversal::Arguments
UniversalArgumentsBase

主要实现在基类 UniversalArgumentsBase 中。

  //// Structures///// Argument structurestruct Arguments : UniversalArgumentsBase{//// Data members//typename EpilogueOutputOp::Params epilogue;void const * ptr_A;void const * ptr_B;void const * ptr_C;void * ptr_D;int64_t batch_stride_A;int64_t batch_stride_B;int64_t batch_stride_C;typename LayoutA::Stride stride_a;typename LayoutB::Stride stride_b;typename LayoutC::Stride stride_c;typename LayoutC::Stride stride_d;typename LayoutA::Stride::LongIndex lda;typename LayoutB::Stride::LongIndex ldb;typename LayoutC::Stride::LongIndex ldc;typename LayoutC::Stride::LongIndex ldd;int const * ptr_gather_A_indices;int const * ptr_gather_B_indices;int const * ptr_scatter_D_indices;
GemmUniversal::Arguments::Arguments
    //// Methods//Arguments():ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),ptr_gather_A_indices(nullptr),ptr_gather_B_indices(nullptr),ptr_scatter_D_indices(nullptr){}
GemmUniversal::Arguments::Arguments
    /// constructs an arguments structureArguments(GemmUniversalMode mode,GemmCoord problem_size,int batch_count,typename EpilogueOutputOp::Params epilogue,void const * ptr_A,void const * ptr_B,void const * ptr_C,void * ptr_D,int64_t batch_stride_A,int64_t batch_stride_B,int64_t batch_stride_C,int64_t batch_stride_D,typename LayoutA::Stride stride_a,typename LayoutB::Stride stride_b,typename LayoutC::Stride stride_c,typename LayoutC::Stride stride_d,int const *ptr_gather_A_indices = nullptr,int const *ptr_gather_B_indices = nullptr,int const *ptr_scatter_D_indices = nullptr):UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),ptr_scatter_D_indices(ptr_scatter_D_indices){lda = 0;ldb = 0;ldc = 0;ldd = 0;CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);}
GemmUniversal::Arguments::Arguments
    /// constructs an arguments structureArguments(GemmUniversalMode mode,GemmCoord problem_size,int batch_count,typename EpilogueOutputOp::Params epilogue,void const * ptr_A,void const * ptr_B,void const * ptr_C,void * ptr_D,int64_t batch_stride_A,int64_t batch_stride_B,int64_t batch_stride_C,int64_t batch_stride_D,typename LayoutA::Stride::LongIndex lda,typename LayoutB::Stride::LongIndex ldb,typename LayoutC::Stride::LongIndex ldc,typename LayoutC::Stride::LongIndex ldd,int const *ptr_gather_A_indices = nullptr,int const *ptr_gather_B_indices = nullptr,int const *ptr_scatter_D_indices = nullptr):UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),epilogue(epilogue),ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),ptr_scatter_D_indices(ptr_scatter_D_indices){stride_a = make_Coord(lda);stride_b = make_Coord(ldb);stride_c = make_Coord(ldc);stride_d = make_Coord(ldd);CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);}
GemmUniversal::Arguments::transposed_problem
    /// Returns arguments for the transposed problemArguments transposed_problem() const{Arguments args(*this);std::swap(args.problem_size.m(), args.problem_size.n());std::swap(args.ptr_A, args.ptr_B);std::swap(args.lda, args.ldb);std::swap(args.stride_a, args.stride_b);std::swap(args.batch_stride_A, args.batch_stride_B);std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);return args;}};

GemmUniversal::Params

GemmUniversal::Params
UniversalParamsBase

主要实现同样在基类 UniversalParamsBase 中。

  //// Structure for precomputing values in host memory and passing to kernels///// Parameters structurestruct Params : UniversalParamsBase<ThreadblockSwizzle,ThreadblockShape,ElementA,ElementB,ElementC,LayoutA,LayoutB>{using ParamsBase = UniversalParamsBase<ThreadblockSwizzle,ThreadblockShape,ElementA,ElementB,ElementC,LayoutA,LayoutB>;//// Data members//typename Mma::IteratorA::Params params_A;typename Mma::IteratorB::Params params_B;typename Epilogue::OutputTileIterator::Params params_C;typename Epilogue::OutputTileIterator::Params params_D;typename EpilogueOutputOp::Params output_op;void * ptr_A;void * ptr_B;void * ptr_C;void * ptr_D;int64_t batch_stride_A;int64_t batch_stride_B;int64_t batch_stride_C;int * ptr_gather_A_indices;int * ptr_gather_B_indices;int * ptr_scatter_D_indices;//// Host dispatch API///// Default constructorParams() = default;
GemmUniversal::Params::Params
    /// ConstructorParams(Arguments const &args,  /// GEMM application argumentsint device_sms,         /// Number of SMs on the deviceint sm_occupancy)       /// Kernel SM occupancy (in thread blocks):ParamsBase(args, device_sms, sm_occupancy),params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),output_op(args.epilogue),ptr_A(const_cast<void *>(args.ptr_A)),ptr_B(const_cast<void *>(args.ptr_B)),ptr_C(const_cast<void *>(args.ptr_C)),ptr_D(args.ptr_D),batch_stride_A(args.batch_stride_A),batch_stride_B(args.batch_stride_B),batch_stride_C(args.batch_stride_C),ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)){}
GemmUniversal::Params::update

更新数据指针和 batch 步长。

    /// Lightweight update given a subset of arguments.void update(Arguments const &args){CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");// Update input/output pointersptr_A = const_cast<void *>(args.ptr_A);ptr_B = const_cast<void *>(args.ptr_B);ptr_C = const_cast<void *>(args.ptr_C);ptr_D = args.ptr_D;batch_stride_A = args.batch_stride_A;batch_stride_B = args.batch_stride_B;batch_stride_C = args.batch_stride_C;this->batch_stride_D = args.batch_stride_D;ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);output_op = args.epilogue;}};

主循环和收尾阶段使用相同的 Shared Memory。

  /// Shared memory storage structureunion SharedStorage {typename Mma::SharedStorage main_loop;typename Epilogue::SharedStorage epilogue;};

GemmUniversal::can_implement

检查问题的尺寸是否满足3个矩阵 layout 的对齐要求。

public://// Host dispatch API///// Determines whether kernel satisfies alignmentstatic Status can_implement(cutlass::gemm::GemmCoord const & problem_size){CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");static int const kAlignmentA = (cute::is_same<LayoutA,layout::ColumnMajorInterleaved<32>>::value)? 32: (cute::is_same<LayoutA,layout::ColumnMajorInterleaved<64>>::value)? 64: Mma::IteratorA::AccessType::kElements;static int const kAlignmentB = (cute::is_same<LayoutB,layout::RowMajorInterleaved<32>>::value)? 32: (cute::is_same<LayoutB,layout::RowMajorInterleaved<64>>::value)? 64: Mma::IteratorB::AccessType::kElements;static int const kAlignmentC = (cute::is_same<LayoutC,layout::ColumnMajorInterleaved<32>>::value)? 32: (cute::is_same<LayoutC,layout::ColumnMajorInterleaved<64>>::value)? 64: Epilogue::OutputTileIterator::kElementsPerAccess;
    bool isAMisaligned = false;bool isBMisaligned = false;bool isCMisaligned = false;if (cute::is_same<LayoutA, layout::RowMajor>::value) {isAMisaligned = problem_size.k() % kAlignmentA;} else if (cute::is_same<LayoutA, layout::ColumnMajor>::value) {isAMisaligned = problem_size.m() % kAlignmentA;} else if (cute::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value|| cute::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {isAMisaligned = problem_size.k() % kAlignmentA;}if (cute::is_same<LayoutB, layout::RowMajor>::value) {isBMisaligned = problem_size.n() % kAlignmentB;} else if (cute::is_same<LayoutB, layout::ColumnMajor>::value) {isBMisaligned = problem_size.k() % kAlignmentB;} else if (cute::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value|| cute::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {isBMisaligned = problem_size.k() % kAlignmentB;}if (cute::is_same<LayoutC, layout::RowMajor>::value) {isCMisaligned = problem_size.n() % kAlignmentC;} else if (cute::is_same<LayoutC, layout::ColumnMajor>::value) {isCMisaligned = problem_size.m() % kAlignmentC;} else if (cute::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value|| cute::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {isCMisaligned = problem_size.n() % kAlignmentC;}if (isAMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for A operand");return Status::kErrorMisalignedOperand;}if (isBMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for B operand");return Status::kErrorMisalignedOperand;}if (isCMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for C operand");return Status::kErrorMisalignedOperand;}CUTLASS_TRACE_HOST("  returning kSuccess");return Status::kSuccess;}

GemmUniversal::can_implement

  static Status can_implement(Arguments const &args) {return can_implement(args.problem_size);}

GemmUniversal::invoke

类静态方法实现工厂调用。 GemmUniversal::operator() 为实现。

public://// Device-only API//// Factory invocationCUTLASS_DEVICEstatic void invoke(Params const &params,SharedStorage &shared_storage){GemmUniversal op;op(params, shared_storage);}

GemmUniversal::operator()

GemmUniversal::operator
GemmUniversal::run_with_swizzle

调用 GemmUniversal::run_with_swizzle 函数。

/// Executes one GEMMCUTLASS_DEVICEvoid operator()(Params const &params, SharedStorage &shared_storage) {ThreadblockSwizzle threadblock_swizzle;run_with_swizzle(params, shared_storage, threadblock_swizzle);}

GemmUniversal::run_with_swizzle

GemmUniversal::run_with_swizzle
GemmIdentityThreadblockSwizzle::get_tile_offset
MmaMultistage
Epilogue

Gemm 模式的实现。

调用 GemmIdentityThreadblockSwizzle::get_tile_offset 获得交错重排后的 CTA 坐标。
如果超出区间则直接返回。

  /// Executes one GEMM with an externally-provided swizzling functionCUTLASS_DEVICEvoid run_with_swizzle(Params const &params, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {cutlass::gemm::GemmCoord threadblock_tile_offset =threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);// Early exit if CTA is out of rangeif (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {return;}

offset_k为当前 CTA 在 k 维上的偏移。
problem_size_k为当前 CTA 处理的问题 k 维大小。
Gemm 和 GemmSplitKParallel 模式下多个 CTA 处理 k 维。
Batched 和 Array 模式需要调整ptr_Aptr_B当前矩阵的位置。
为什么需要同步线程呢?

    int offset_k = 0;int problem_size_k = params.problem_size.k();ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A); ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);//// Fetch pointers based on mode.//if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; }offset_k = threadblock_tile_offset.k() * params.gemm_k_size;}else if (params.mode == GemmUniversalMode::kBatched) {ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;}else if (params.mode == GemmUniversalMode::kArray) {ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];}__syncthreads();

计算 CTA 在 A 和 B 矩阵上的逻辑坐标。
MmaMultistage::IteratorA 为 DefaultMma::IteratorA,即 PredicatedTileAccessIterator。IteratorB类型与之相同。

    // Compute initial location in logical coordinatescutlass::MatrixCoord tb_offset_A{threadblock_tile_offset.m() * Mma::Shape::kM,offset_k,};cutlass::MatrixCoord tb_offset_B{offset_k,threadblock_tile_offset.n() * Mma::Shape::kN};// Compute position within threadblockint thread_idx = threadIdx.x;// Construct iterators to A and B operandstypename Mma::IteratorA iterator_A(params.params_A,ptr_A,{params.problem_size.m(), problem_size_k},thread_idx,tb_offset_A,params.ptr_gather_A_indices);typename Mma::IteratorB iterator_B(params.params_B,ptr_B,{problem_size_k, params.problem_size.n()},thread_idx,tb_offset_B,params.ptr_gather_B_indices);

canonical_warp_idx_sync 得到线程束的索引。

    // Broadcast the warp_id computed by lane 0 to ensure dependent code// is compiled as warp-uniform.int warp_idx = canonical_warp_idx_sync();int lane_idx = threadIdx.x % 32;

gemm_k_iterations为 CTA 在 k 维上的循环次数。
MmaMultistage 执行主体循环。

    //// Main loop//// Construct thread-scoped matrix multiplyMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);typename Mma::FragmentC accumulators;accumulators.clear();// Compute threadblock-scoped matrix multiply-addint gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;// Compute threadblock-scoped matrix multiply-addmma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);

收尾
EpilogueOutputOp 即 Epilogue::OutputOp,即 EpilogueOp,即 LinearCombination。

    //// Epilogue//EpilogueOutputOp output_op(params.output_op);//// Masked tile iterators constructed from members//threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);//assume identity swizzleMatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,threadblock_tile_offset.n() * Mma::Shape::kN);int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C); ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);

创建 CTA 间的同步信号量 Semaphore。
如果是kSplitKSerial, Semaphore::fetch 函数最初获取同步锁但不阻塞。
LinearCombination::set_k_partition 根据归约时 k 的索引设置beta_值。除了第一个 CTA 外均为1。
已知 GEMM 的公式为:
D = α A B + β C D = \alpha AB + \beta C D=αAB+βC
这样第一个 CTA 根据情况处理 C 矩阵,其他 CTA 均从 Global Memory 加载 D 矩阵,累加部分和。
Epilogue::OutputTileIterator 为 DefaultEpilogueTensorOp::OutputTileIterator,即 DefaultEpilogueTensorOp::PackedOutputTileIterator,即 PredicatedTileIterator。

    //// Fetch pointers based on mode.//// Construct the semaphore.Semaphore semaphore(params.semaphore + block_idx, thread_idx);if (params.mode == GemmUniversalMode::kGemm) {// If performing a reduction via split-K, fetch the initial synchronizationif (params.grid_tiled_shape.k() > 1) {// Fetch the synchronization lock initially but do not block.semaphore.fetch();// Indicate which position in a serial reduction the output operator is currently updatingoutput_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());}}else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;}else if (params.mode == GemmUniversalMode::kBatched) {ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;}else if (params.mode == GemmUniversalMode::kArray) {ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];}// Tile iterator loading from source tensor.typename Epilogue::OutputTileIterator iterator_C(params.params_C,ptr_C,params.problem_size.mn(),thread_idx,threadblock_offset,params.ptr_scatter_D_indices);// Tile iterator writing to destination tensor.typename Epilogue::OutputTileIterator iterator_D(params.params_D,ptr_D,params.problem_size.mn(),thread_idx,threadblock_offset,params.ptr_scatter_D_indices);

创建一个 Epilogue 对象。
如果不是第一个 CTA,则需要切换源矩阵,从前一个线程块计算的结果开始继续计算。
Semaphore::wait 等待到 k 个。

    Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);// Wait on the semaphore - this latency may have been covered by iterator constructionif (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {// For subsequent threadblocks, the source matrix is held in the 'D' tensor.if (threadblock_tile_offset.k()) {iterator_C = iterator_D;}semaphore.wait(threadblock_tile_offset.k());}// Execute the epilogue operator to update the destination tensor.epilogue(output_op, iterator_D, accumulators, iterator_C); 

Semaphore::release 释放信号量。

    //// Release the semaphore//if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {int lock = 0;if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {// The final threadblock resets the semaphore for subsequent grids.lock = 0;}else {// Otherwise, the semaphore is incrementedlock = threadblock_tile_offset.k() + 1;}semaphore.release(lock);}}
};

GemmUniversalStreamk

template <typename Mma_,                  ///! Threadblock-scoped matrix multiply-accumulatetypename Epilogue_,             ///! Epiloguetypename ThreadblockSwizzle_    ///! Threadblock mapping function
>
struct GemmUniversalStreamk {
public://// Types and constants//using Mma = Mma_;using Epilogue = Epilogue_;using EpilogueOutputOp = typename Epilogue::OutputOp;using ThreadblockSwizzle = ThreadblockSwizzle_;using ElementA = typename Mma::IteratorA::Element;using LayoutA = typename Mma::IteratorA::Layout;using ElementB = typename Mma::IteratorB::Element;using LayoutB = typename Mma::IteratorB::Layout;using ElementC = typename Epilogue::OutputTileIterator::Element;using LayoutC = typename Epilogue::OutputTileIterator::Layout;/// The per-thread tile of raw accumulatorsusing AccumulatorTile = typename Mma::FragmentC;static ComplexTransform const kTransformA = Mma::kTransformA;static ComplexTransform const kTransformB = Mma::kTransformB;using Operator = typename Mma::Operator;using OperatorClass = typename Mma::Operator::OperatorClass;using ThreadblockShape = typename Mma::Shape;using WarpShape = typename Mma::Operator::Shape;using InstructionShape = typename Mma::Policy::Operator::InstructionShape;using ArchTag = typename Mma::ArchTag;static int const kStages = Mma::kStages;static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;/// Warp count (concept: GemmShape)using WarpCount = typename Mma::WarpCount;static int const kThreadCount = 32 * WarpCount::kCount;

__NV_STD_MAX 是在常量表达式中使用的宏函数,因为 C++等价的功能需要编译器支持constexpr。这些宏函数以__NV_STD_*为前缀。
kWorkspaceBytesPerBlock取 Mma 和 Epilogue 两者中的最大值。

  /// Workspace bytes per thread blockstatic size_t const kWorkspaceBytesPerBlock =__NV_STD_MAX(kThreadCount * sizeof(AccumulatorTile),Epilogue::kWorkspaceBytesPerBlock);/// Block-striped reduction utilityusing BlockStripedReduceT = BlockStripedReduce<kThreadCount, AccumulatorTile>;

GemmUniversalStreamk::Arguments

  //// Structures///// Argument structurestruct Arguments {//// Data members//GemmUniversalMode mode = GemmUniversalMode::kGemm;GemmCoord problem_size {};int batch_count {1};        // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factortypename EpilogueOutputOp::Params epilogue{};void const * ptr_A = nullptr;void const * ptr_B = nullptr;void const * ptr_C = nullptr;void * ptr_D = nullptr;int64_t batch_stride_A{0};int64_t batch_stride_B{0};int64_t batch_stride_C{0};int64_t batch_stride_D{0};typename LayoutA::Stride stride_a{0};typename LayoutB::Stride stride_b{0};typename LayoutC::Stride stride_c{0};typename LayoutC::Stride stride_d{0};typename LayoutA::Stride::LongIndex lda{0};typename LayoutB::Stride::LongIndex ldb{0};typename LayoutC::Stride::LongIndex ldc{0};typename LayoutC::Stride::LongIndex ldd{0};int avail_sms{-1};          /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)//// Methods///// Default ConstructorArguments() = default;
GemmUniversalStreamk::Arguments::Arguments

RowMajor::Stride,即 Coord 的版本。

    /// ConstructorArguments(GemmUniversalMode mode,GemmCoord problem_size,int batch_split,                              /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)typename EpilogueOutputOp::Params epilogue,void const * ptr_A,void const * ptr_B,void const * ptr_C,void * ptr_D,int64_t batch_stride_A,int64_t batch_stride_B,int64_t batch_stride_C,int64_t batch_stride_D,typename LayoutA::Stride stride_a,typename LayoutB::Stride stride_b,typename LayoutC::Stride stride_c,typename LayoutC::Stride stride_d,int avail_sms = -1                            /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)):mode(mode),problem_size(problem_size),batch_count(batch_split),epilogue(epilogue),ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms){CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);}
GemmUniversalStreamk::Arguments::Arguments

RowMajor::LongIndex 的版本。

    /// ConstructorArguments(GemmUniversalMode mode,GemmCoord problem_size,int batch_split,                              /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)typename EpilogueOutputOp::Params epilogue,void const * ptr_A,void const * ptr_B,void const * ptr_C,void * ptr_D,int64_t batch_stride_A,int64_t batch_stride_B,int64_t batch_stride_C,int64_t batch_stride_D,typename LayoutA::Stride::LongIndex lda,typename LayoutB::Stride::LongIndex ldb,typename LayoutC::Stride::LongIndex ldc,typename LayoutC::Stride::LongIndex ldd,int avail_sms = -1                            /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)):mode(mode),problem_size(problem_size),batch_count(batch_split),epilogue(epilogue),ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms){stride_a = make_Coord(lda);stride_b = make_Coord(ldb);stride_c = make_Coord(ldc);stride_d = make_Coord(ldd);CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);}
GemmUniversalStreamk::Arguments::transposed_problem

交换 A 和 B 矩阵。

    /// Returns arguments for the transposed problemArguments transposed_problem() const{Arguments args(*this);std::swap(args.problem_size.m(), args.problem_size.n());std::swap(args.ptr_A, args.ptr_B);std::swap(args.lda, args.ldb);std::swap(args.stride_a, args.stride_b);std::swap(args.batch_stride_A, args.batch_stride_B);return args;}};

GemmUniversalStreamk::Params

  /// Parameters structurestruct Params{public://// Data members//void * ptr_A = nullptr;void * ptr_B = nullptr;typename Mma::IteratorA::Params params_A{};typename Mma::IteratorB::Params params_B{};int64_t batch_stride_A{0};int64_t batch_stride_B{0};GemmUniversalMode mode = GemmUniversalMode::kGemm;ThreadblockSwizzle block_mapping{};void *barrier_workspace = nullptr;void *partials_workspace = nullptr;typename EpilogueOutputOp::Params output_op{};void * ptr_D = nullptr;void * ptr_C = nullptr;typename Epilogue::OutputTileIterator::Params params_D{};typename Epilogue::OutputTileIterator::Params params_C{};int64_t batch_stride_D{0};int64_t batch_stride_C{0};
GemmUniversalStreamk::Params::cacheline_align_up

内部定义静态变量CACHELINE_SIZE
将给定的内存分配大小对齐到最近的缓存行边界,减少缓存冲突。

  protected://// Host-only dispatch-utilities///// Pad the given allocation size up to the nearest cache linestatic size_t cacheline_align_up(size_t size){static const int CACHELINE_SIZE = 128;return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE;}
GemmUniversalStreamk::Params::get_barrier_workspace_size

计算执行屏障操作时所需的工作区大小。
ThreadblockSwizzleStreamK::sk_regions 返回 sk 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 每个区域中的SK CTA 的数量。
对于原子归约,每个 SK CTA 需要一个同步标志;
对于并行归约,每个归约 CTA 需要其自己的同步标志。

    /// Get the workspace size needed for barriersize_t get_barrier_workspace_size() const{// For atomic reduction, each SK-block needs a synchronization flag.  For parallel reduction,// each reduction block needs its own synchronization flag.int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks);return cacheline_align_up(sizeof(typename Barrier::T) * num_flags);}
GemmUniversalStreamk::Params::get_partials_workspace_size

ThreadblockSwizzleStreamK::sk_regions 返回 sk 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 每个区域中的SK CTA 的数量。
kWorkspaceBytesPerBlock 为每个 CTA 累加结果需要的空间。

    /// Get the workspace size needed for intermediate partial sumssize_t get_partials_workspace_size() const{int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);}
  public://// Host dispatch API///// Default constructorParams() = default;
GemmUniversalStreamk::Params::Params
    /// ConstructorParams(Arguments const &args,  /// GEMM application argumentsint device_sms,         /// Number of SMs on the deviceint sm_occupancy)       /// Kernel SM occupancy (in thread blocks):params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),output_op(args.epilogue),mode(args.mode),ptr_A(const_cast<void *>(args.ptr_A)),ptr_B(const_cast<void *>(args.ptr_B)),ptr_C(const_cast<void *>(args.ptr_C)),ptr_D(args.ptr_D),batch_stride_A(args.batch_stride_A),batch_stride_B(args.batch_stride_B),batch_stride_C(args.batch_stride_C),batch_stride_D(args.batch_stride_D),barrier_workspace(nullptr),partials_workspace(nullptr){// Number of SMs to make available for StreamK decompositionint avail_sms = (args.avail_sms == -1) ?device_sms :fast_min(args.avail_sms, device_sms);

创建一个ThreadblockSwizzleStreamK 对象。

      // Initialize the block mapping structureblock_mapping = ThreadblockSwizzle(args.mode,args.problem_size,{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},args.batch_count,sm_occupancy,device_sms,avail_sms,sizeof(ElementA),sizeof(ElementB),sizeof(ElementC),Epilogue::kAccumulatorFragments);}
GemmUniversalStreamk::Params::get_workspace_size
GemmUniversalStreamk::Params::get_workspace_size
GemmUniversalStreamk::Params::get_barrier_workspace_size
GemmUniversalStreamk::Params::get_partials_workspace_size

调用 GemmUniversalStreamk::Params::get_barrier_workspace_size 和 GemmUniversalStreamk::Params::get_partials_workspace_size, 返回工作区大小。

    /// Returns the workspace size (in bytes) needed for these parameterssize_t get_workspace_size() const{returnget_barrier_workspace_size() +get_partials_workspace_size();}
GemmUniversalStreamk::Params::init_workspace
    /// Assign and initialize the specified workspace buffer.  Assumes/// the memory allocated to workspace is at least as large as get_workspace_size().Status init_workspace(void *workspace,cudaStream_t stream = nullptr){uint8_t *ptr = static_cast<uint8_t*>(workspace);

调用 GemmUniversalStreamk::Params::get_partials_workspace_size 函数获取大小。

      // Establish partials workspacepartials_workspace = nullptr;size_t partials_workspace_bytes = get_partials_workspace_size();if (partials_workspace_bytes > 0){if (!workspace) {return Status::kErrorWorkspaceNull;}partials_workspace = ptr;ptr += partials_workspace_bytes;}

workspace=partials_workspace+barrier_workspace
GemmUniversalStreamk::Params::get_barrier_workspace_size

      // Establish barrier workspacebarrier_workspace = nullptr;size_t barrier_workspace_bytes = get_barrier_workspace_size();if (barrier_workspace_bytes > 0){if (!workspace) {return Status::kErrorWorkspaceNull;}barrier_workspace = ptr;ptr += barrier_workspace_bytes;}

重复定义barrier_workspace_bytes
barrier_workspace清零。

      // Zero-initialize barrier workspaceif (barrier_workspace){size_t barrier_workspace_bytes = get_barrier_workspace_size();CUTLASS_TRACE_HOST("  Initialize " << barrier_workspace_bytes << " barrier bytes");cudaError_t result = cudaMemsetAsync(barrier_workspace,0,barrier_workspace_bytes,stream);if (result != cudaSuccess) {CUTLASS_TRACE_HOST("  cudaMemsetAsync() returned error " << cudaGetErrorString(result));return Status::kErrorInternal;}}return Status::kSuccess;}
GemmUniversalStreamk::Params::get_tiled_shape
GemmUniversal::Params::get_tiled_shape
ThreadblockSwizzleStreamK::get_tiled_shape

调用 ThreadblockSwizzleStreamK::tiled_shape 返回三维图块数量。

    /// Returns the GEMM volume in thread block tilescutlass::gemm::GemmCoord get_tiled_shape() const{return block_mapping.tiled_shape();}
GemmUniversalStreamk::Params::get_grid_blocks
GemmUniversal::Params::get_grid_blocks
GemmUniversal::Params::get_grid_dims

GemmUniversalStreamk::Params::get_grid_dims 得到网格维度。

    /// Returns the total number of thread blocks to launchint get_grid_blocks() const{dim3 grid_dims = get_grid_dims();return grid_dims.x * grid_dims.y * grid_dims.z;}
GemmUniversalStreamk::Params::get_grid_dims
GemmUniversal::Params::get_grid_dims
ThreadblockSwizzleStreamK::get_grid_dims

调用 ThreadblockSwizzleStreamK::get_grid_dims 函数。

    /// Returns the grid extents in thread blocks to launchdim3 get_grid_dims() const{return block_mapping.get_grid_dims();}
GemmUniversalStreamk::Params::update

更新指针、步长信息以及收尾操作。

    /// Lightweight update given a subset of arguments.void update(Arguments const &args){CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");// Update input/output pointersptr_A = const_cast<void *>(args.ptr_A);ptr_B = const_cast<void *>(args.ptr_B);ptr_C = const_cast<void *>(args.ptr_C);ptr_D = args.ptr_D;batch_stride_A = args.batch_stride_A;batch_stride_B = args.batch_stride_B;batch_stride_C = args.batch_stride_C;batch_stride_D = args.batch_stride_D;output_op = args.epilogue;}};

GemmUniversalStreamk::TileWorkDesc

结构体中包含图块索引、坐标、全局 MAC 起始索引、k 轴 MAC 起止索引、k 轴剩余 MAC 迭代数。

  /// Tile work descriptorstruct TileWorkDesc{/// The linear tile indexint tile_idx;/// The location of this tile (in threadblock-tile coordinates) in the output matrixcutlass::gemm::GemmCoord tiled_coord;// The first global-scoped MAC-iteration this threadblock will perform for this tileint iter_begin;// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tileint k_begin;// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tileint k_end;/// The number of remaining MAC-iterations this threadblock will perform for this tileint k_iters_remaining;
GemmUniversalStreamk::TileWorkDesc::tile_started

判断当前 CTA 是否执行图块的第一个 MAC 迭代。

    // Whether this block will perform the first iteration of this tileCUTLASS_DEVICEbool tile_started(){return (k_begin == 0);}
GemmUniversalStreamk::TileWorkDesc::tile_finished

判断当前 CTA 是否执行图块的第后一个 MAC 迭代。

    // Whether this block will perform the last iteration of this tileCUTLASS_DEVICEbool tile_finished(Params const &params){return (k_end == params.block_mapping.problem_size.k());}};
  /// Shared memory storage structureunion SharedStorage{typename Mma::SharedStorage main_loop;typename Epilogue::SharedStorage epilogue;};protected://// Data members///// GEMM problem parametersParams params;/// Shared storage referenceSharedStorage &shared_storage;/// ID within the threadblockint thread_idx;/// ID of warpint warp_idx;/// ID of each thread within a warpint lane_idx;/// Threadblock scoped epilogueEpilogue epilogue;

GemmUniversalStreamk::can_implement

检查问题的尺寸是否满足3个矩阵 layout 的对齐要求。

public://// Host-only dispatch API///// Determines whether the GEMM problem size satisfies this kernel's/// alignment requirementsstatic Status can_implement(cutlass::gemm::GemmCoord const & problem_size){CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");static int const kAlignmentA = (platform::is_same<LayoutA,layout::ColumnMajorInterleaved<32>>::value)? 32: (platform::is_same<LayoutA,layout::ColumnMajorInterleaved<64>>::value)? 64: Mma::IteratorA::AccessType::kElements;static int const kAlignmentB = (platform::is_same<LayoutB,layout::RowMajorInterleaved<32>>::value)? 32: (platform::is_same<LayoutB,layout::RowMajorInterleaved<64>>::value)? 64: Mma::IteratorB::AccessType::kElements;static int const kAlignmentC = (platform::is_same<LayoutC,layout::ColumnMajorInterleaved<32>>::value)? 32: (platform::is_same<LayoutC,layout::ColumnMajorInterleaved<64>>::value)? 64: Epilogue::OutputTileIterator::kElementsPerAccess;bool isAMisaligned = false;bool isBMisaligned = false;bool isCMisaligned = false;if (platform::is_same<LayoutA, layout::RowMajor>::value) {isAMisaligned = problem_size.k() % kAlignmentA;} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {isAMisaligned = problem_size.m() % kAlignmentA;} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {isAMisaligned = problem_size.k() % kAlignmentA;}if (platform::is_same<LayoutB, layout::RowMajor>::value) {isBMisaligned = problem_size.n() % kAlignmentB;} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {isBMisaligned = problem_size.k() % kAlignmentB;} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {isBMisaligned = problem_size.k() % kAlignmentB;}if (platform::is_same<LayoutC, layout::RowMajor>::value) {isCMisaligned = problem_size.n() % kAlignmentC;} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {isCMisaligned = problem_size.m() % kAlignmentC;} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {isCMisaligned = problem_size.n() % kAlignmentC;}if (isAMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for A operand");return Status::kErrorMisalignedOperand;}if (isBMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for B operand");return Status::kErrorMisalignedOperand;}if (isCMisaligned) {CUTLASS_TRACE_HOST("  returning kErrorMisalignedOperand for C operand");return Status::kErrorMisalignedOperand;}CUTLASS_TRACE_HOST("  returning kSuccess");return Status::kSuccess;}

GemmUniversalStreamk::can_implement

  /// Determines whether the GEMM problem satisfies this kernel's/// alignment requirementsstatic Status can_implement(Arguments const &args) {return can_implement(args.problem_size);}

GemmUniversalStreamk::init_iterator_A

根据 GemmUniversalStreamk::TileWorkDesc 中的指针和形状信息初始化矩阵 A 的迭代器 PredicatedTileAccessIterator。

protected://// Device-only utility methods///// Iterator for fetching tile fragments from ACUTLASS_DEVICEtypename Mma::IteratorA init_iterator_A(TileWorkDesc &tile_work,GemmUniversalMode mode){// The input A matrixElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);

如果是 Batched 模式,根据 k 值调整ptr_A到对应的行;如果是 Array 模式,params.ptr_A是一个指针数组。

    // Update input pointers based on batched/array modeif (mode == GemmUniversalMode::kBatched) {ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A;}if (mode == GemmUniversalMode::kArray) {ptr_A = static_cast<ElementA * const *>(params.ptr_A)[tile_work.tiled_coord.k()];}

MmaMultistage::IteratorA 为 DefaultMma::IteratorA,即 PredicatedTileAccessIterator。

    int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM;int m_end = params.block_mapping.problem_size.m();return typename Mma::IteratorA(params.params_A,ptr_A,{ m_end, tile_work.k_end },threadIdx.x,{ m_begin, tile_work.k_begin });}

GemmUniversalStreamk::init_iterator_B

  /// Iterator for fetching tile fragments from BCUTLASS_DEVICEtypename Mma::IteratorB init_iterator_B(TileWorkDesc &tile_work,GemmUniversalMode mode){// The input B matrixElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);// Update input pointers based on batched/array modeif (mode == GemmUniversalMode::kBatched) {ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B;}if (mode == GemmUniversalMode::kArray) {ptr_B = static_cast<ElementB * const *>(params.ptr_B)[tile_work.tiled_coord.k()];}int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN;int n_end = params.block_mapping.problem_size.n();return typename Mma::IteratorB(params.params_B,ptr_B,{ tile_work.k_end, n_end },threadIdx.x,{ tile_work.k_begin, n_begin });}

GemmUniversalStreamk::init_dp_tile_work

初始化 DP 图块的工作描述符 GemmUniversalStreamk::TileWorkDesc。一个 CTA 处理一个图块。
k_iters_remaining 表示线程块在当前图块中还要执行的 MAC 迭代次数。

  CUTLASS_DEVICEvoid init_dp_tile_work(TileWorkDesc &tile_work,int tile_idx){// The linear tile indextile_work.tile_idx = tile_idx;// The first global-scoped MAC-iteration this threadblock will perform for this tiletile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();// The number of MAC-iterations this threadblock will perform for this tiletile_work.k_iters_remaining = params.block_mapping.iters_per_tile();

处理图块的整个 k 轴。

    // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tiletile_work.k_begin = 0;// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tiletile_work.k_end = params.block_mapping.problem_size.k();

ThreadblockSwizzleStreamK::get_tile_offset 计算出当前线程块在网格中的二维平铺坐标。

    // The location of this tile (in threadblock-tile coordinates) in the output matrixtile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx);}

GemmUniversalStreamk::init_sk_tile_work

初始化 SK 图块的工作描述符 GemmUniversalStreamk::TileWorkDesc。

  CUTLASS_DEVICEvoid init_sk_tile_work(TileWorkDesc &tile_work,int tile_idx,int block_iter_begin,int block_iter_end){// The linear tile indextile_work.tile_idx = tile_idx;// The first global-scoped MAC-iteration for this tileint tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile();

一个图块可能由多个 CTA 处理,因此 CTA 处理的第一个图块的起始迭代索引可能不是图块的起始索引。block_iter_begin为 CTA 处理的迭代起始位置,tile_iter_begin为图块的迭代起始位置。tile_work.iter_begin为当前 CTA 负责处理的迭代起始位置。
k_iter_begin为 CTA 需要处理当前图块的本地起始索引。
tile_work.k_iters_remaining为 CTA 需要处理的剩余的迭代数。

    // The first global-scoped MAC-iteration this threadblock will perform for this tiletile_work.iter_begin = max(block_iter_begin, tile_iter_begin);// The first tile-scoped MAC-iteration this threadblock will perform for this tileint k_iter_begin = tile_work.iter_begin - tile_iter_begin;// The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tileint k_iter_end = block_iter_end - tile_iter_begin;// The number of MAC-iterations this threadblock will perform for this tiletile_work.k_iters_remaining = k_iter_end - k_iter_begin;

tile_work.k_begintile_work.k_end为图块任务的 k 轴起止索引。

    // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tiletile_work.k_begin = k_iter_begin * Mma::Shape::kK;// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tiletile_work.k_end = min(params.block_mapping.problem_size.k(),            // extent of k domain(k_iter_end * Mma::Shape::kK));                   // extent of the threadblock's global iteration assignment

ThreadblockSwizzleStreamK::get_tile_offset 函数返回图块在输出矩阵中的位置。

    // The location of this tile (in threadblock-tile coordinates) in the output matrixtile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx);}

GemmUniversalStreamk::share_accumulators

协作 CTA 将部分和汇总到accum_tile_workspace中。

  /// Share accumulators with peersCUTLASS_DEVICEvoid share_accumulators(AccumulatorTile const &accumulator_tile,int block_idx,int first_block_idx){AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);int accum_tile_offset = first_block_idx * kThreadCount;

如果是第一个 CTA,

  • 调用 BlockStriped::store 将accumulator_tile中的部分和保存到accum_tile_workspace

否则,

  • 等待其他 CTA:
    • 原子策略:调用 GenericBarrier::wait_lt 等待信号大于0,即第一个 CTA 完成保存;
    • 非原子策略:GenericBarrier::wait_eq 等待前面的 CTA 都完成;
  • 调用 BlockStripedReduce::reduce 将自己的accumulator_tile累加到accum_tile_workspace
    if (block_idx == first_block_idx){// First peer initializes the workspace partialsBlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);}else{// Subsequent peers atomically accumulate into the workspace partialsif (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic){// Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to themBarrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1);}else{// Turnstile reduction order: wait until the previous peer has writtenint wait_count = block_idx - first_block_idx;Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count);}// Perform reduction in workspaceBlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);}

GenericBarrier::arrive_inc 使用线程0增加某个标志位的到达计数(arrival count)。

    // Signal our arrivalBarrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx);}

GemmUniversalStreamk::acquire_accumulators

GenericBarrier::wait_eq_reset 使用0号线程等待前面的num_carry_in个 CTA 完成暂存。

  /// Acquire accumulators from peersCUTLASS_DEVICEvoid acquire_accumulators(AccumulatorTile &accumulator_tile,int block_idx,int first_block_idx){AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);// Wait for arrivalint num_carry_in = block_idx - first_block_idx;Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in);

BlockStripedReduce::load_add 将params.partials_workspace中的部分和累加到accum_tile_offset

    // Load and add peer-partials accumulator tile to local accumulator tileint accum_tile_offset = first_block_idx * kThreadCount;BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx);}

GemmUniversalStreamk::do_epilogue

更新指针以指向正确的矩阵位置。

  /// Perform epilogue computations and outputCUTLASS_DEVICEvoid do_epilogue(TileWorkDesc &tile_work,AccumulatorTile &accumulator_tile){ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);// Update pointers for batched/array mode(s)if (params.mode == GemmUniversalMode::kBatched) {ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C;ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;}if (params.mode == GemmUniversalMode::kArray) {ptr_C = static_cast<ElementC * const *>(params.ptr_C)[tile_work.tiled_coord.k()];ptr_D = static_cast<ElementC * const *>(params.ptr_D)[tile_work.tiled_coord.k()];}

确定图块在矩阵中的位置。
创建 C 和 D 矩阵的 PredicatedTileIterator 迭代器。

    // Location of this tile in item-coordsMatrixCoord threadblock_item_begin(tile_work.tiled_coord.m() * Mma::Shape::kM,tile_work.tiled_coord.n() * Mma::Shape::kN);// Tile iterator loading from source tensor.typename Epilogue::OutputTileIterator iterator_C(params.params_C,ptr_C,params.block_mapping.problem_size.mn(),thread_idx,threadblock_item_begin);// Tile iterator writing to destination tensor.typename Epilogue::OutputTileIterator iterator_D(params.params_D,ptr_D,params.block_mapping.problem_size.mn(),thread_idx,threadblock_item_begin);

通过 Epilogue 收尾。

    // Execute the epilogue operator to update the destination tensor.epilogue(EpilogueOutputOp(params.output_op),iterator_D,accumulator_tile,iterator_C);}

GemmUniversalStreamk::separate_reduction

GemmUniversalStreamk::separate_reduction
Epilogue::reduce

根据 reduce_idx确定要图块索引 reduce_tile_idx和片段索引reduce_fragment_idx

  CUTLASS_DEVICEvoid separate_reduction(int reduce_idx){int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx;// Reduce by sk-tile (every tile contributed to by one or more blocks)reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;

ThreadblockSwizzleStreamK::iters_per_tile 为每个图块上的迭代数。
计算当前归约操作的第一个和最后一个迭代位置。
ThreadblockSwizzleStreamK::get_sk_block_idx 计算出该迭代对应的第一个 SK CTA 索引。
peer_idx_beginpeer_idx_last为处理这个图块的第一个和最后一个 SK CTA 的索引,用于后续的同步和归约操作。

    int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile();int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1;peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first);peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last);

Barrier 即 GenericBarrier。
GenericBarrier::wait_eq_reset 使用0号线程等待num_peers个 SK CTA 完成暂存。

    // Wait for peers to completeint peer_idx_end = peer_idx_last + 1;int num_peers = peer_idx_end - peer_idx_begin;Barrier::wait_eq_reset(params.barrier_workspace,thread_idx,(reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx,num_peers);

ThreadblockSwizzleStreamK::get_tile_offset 根据reduce_tile_idx 计算出当前线程块在网格中的二维平铺坐标。

    /// The location of this tile (in threadblock-tile coordinates) in the output matrixGemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx);// Location of this tile in item-coordsMatrixCoord threadblock_item_begin(tiled_coord.m() * Mma::Shape::kM,tiled_coord.n() * Mma::Shape::kN);ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);// Tile iterator loading from source tensor.typename Epilogue::OutputTileIterator iterator_C(params.params_C,ptr_C,params.block_mapping.problem_size.mn(),thread_idx,threadblock_item_begin);// Tile iterator writing to destination tensor.typename Epilogue::OutputTileIterator iterator_D(params.params_D,ptr_D,params.block_mapping.problem_size.mn(),thread_idx,threadblock_item_begin);

Epilogue::reduce 将来自多个 peer block 的累加器片段归约到片上,同时应用收尾计算并将最终结果写入输出矩阵。
与 EpilogueBaseStreamK::share 相对应。

    // Execute the epilogue operator to update the destination tensor.epilogue.reduce(peer_idx_begin,peer_idx_end,reduce_fragment_idx,params.partials_workspace,EpilogueOutputOp(params.output_op),iterator_D,iterator_C);}

GemmUniversalStreamk::process_tile

flowchart
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>endst->op->cond
cond(yes)->e
cond(no)->op

调用 GemmUniversalStreamk::init_iterator_A 和 GemmUniversalStreamk::init_iterator_B 初始化输入迭代器。
AccumulatorTile 即 MmaMultistage::FragmentC,即 Mma::FragmentC,即 Array。
创建一个 MmaMultistage 对象。
MmaMultistage::operator() 对当前 tile 执行一系列乘加(MAC)操作,累加结果存储在 accumulator_tile 中。

  CUTLASS_DEVICEvoid process_tile(TileWorkDesc tile_work,int block_idx,int dp_start_block_idx,int block_iter_begin){// Initialize input iteratorstypename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);// Initialize accumulatorsAccumulatorTile accumulator_tile;accumulator_tile.clear();// Initialize MMA abstractionMma mma(shared_storage.main_loop,thread_idx,warp_idx,lane_idx);// Perform this tile's range of multiply-accumulate (MAC) iterationsmma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);

如果归约策略是原子的或者没有归约块,或者当前为 DP CTA:

  • ThreadblockSwizzleStreamK::get_first_block_idx 获取第一个处理tile_work.tile_idx图块的 DP CTA;
  • GemmUniversalStreamk::TileWorkDesc::tile_finished 判断是否为图块的末尾 CTA:
    • 如果不是最后一个 CTA,GemmUniversalStreamk::share_accumulators 将部分和累加到partials_workspace
    • DP CTA 或者最后一个 SK CTA,
      • 调用 GemmUniversalStreamk::acquire_accumulators 将partials_workspace累加到accumulator_tile
      • GemmUniversalStreamk::do_epilogue 调用 Epilogue 执行收尾操作。
    if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||(params.block_mapping.reduction_blocks == 0) ||(block_idx >= dp_start_block_idx)){//// Cooperative SK peer reduction or DP block//int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx);if (!tile_work.tile_finished(params)) {// Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspaceshare_accumulators(accumulator_tile, block_idx, first_block_idx);}else{// DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tileif (!tile_work.tile_started()){// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocksacquire_accumulators(accumulator_tile, block_idx, first_block_idx);}do_epilogue(tile_work, accumulator_tile);}}

否则,SK CTA,

  • 通过 EpilogueBaseStreamK::share 将accumulator_tile分成多个片段存储到全局共享工作区params.partials_workspace中;
  • GenericBarrier::arrive_range_inc 递增 EpilogueBaseStreamK::kAccumulatorFragments 个标志位。
    else{//// Separate peer reduction//// Share accumulator partial sums with peer threadblock(s) through scratch workspaceepilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started());// Signal arrivalBarrier::arrive_range_inc(params.barrier_workspace,thread_idx,tile_work.tile_idx * Epilogue::kAccumulatorFragments,Epilogue::kAccumulatorFragments);}}

GemmUniversalStreamk::gemm

Reduce
DP or SK
GemmUniversalStreamk::gemm
GemmUniversalStreamk::separate_reduction
GemmUniversalStreamk::process_tile

ThreadblockSwizzleStreamK::get_block_idx 返回线性 CTA 索引。
ThreadblockSwizzleStreamK::sk_regions 返回 SK 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 返回每个区域包含的 SK CTA 的个数。
sk_padding_start_block_idx是 SK CTA 的结束位置。
3种 CTA 的顺序为 SK、DP、Reduce。
grid_padding_start_block_idx是 Reduce CTA 的结束位置。

  /// Executes one GEMMCUTLASS_DEVICEvoid gemm(){// Initialize block's iteration rangeint tile_idx = 0;int block_iter_begin = 0;int block_iters_remaining = 0;int block_idx = params.block_mapping.get_block_idx();int sk_padding_start_block_idx =  params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;

创建一个 GemmUniversalStreamk::TileWorkDesc 结构体,然后不同类型的 CTA 对其初始化。

    // Initialize tile work descriptorTileWorkDesc tile_work;bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);bool sk_block = (block_idx < sk_padding_start_block_idx);bool reduce_block = (block_idx >= reduce_start_block_idx) &&(block_idx < grid_padding_start_block_idx) &&(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);

如果是 Data-parallel 的分块,

  • dp_block_idx为当前 CTA 在 DP 块中的索引;
  • tile_idx为需要处理的图块索引,tile_allottment为图块数量。如果不是第一个波次,
    • 则每个 DP CTA 仅分配一个图块;
    • tile_idx 会增加一个第一个波次的偏移量。
  • block_iters_remaining为需要处理的迭代次数;
  • GemmUniversalStreamk::init_dp_tile_work 初始化 DP 图块的工作描述符;
  • 检查 DP CTA 的图块是否与 SK 图块重叠(仅可能发生在 cohort 光栅化中),或者是否超出了矩阵的边界。
    if (dp_block){// This is a DP blockint dp_block_idx = block_idx - dp_start_block_idx;int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles;// Blocks in first DP wave get configured number of tilestile_idx = first_dp_tile + dp_block_idx;int tile_allottment = params.block_mapping.dp_first_wave_tiles;// Blocks in subsequent DP waves get 1 tileif (dp_block_idx >= params.block_mapping.avail_sms) {tile_allottment = 1;tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;}block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;init_dp_tile_work(tile_work, tile_idx);// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)if ((tile_idx < params.block_mapping.sk_tiles) ||(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())){return;}}

如果是 stream-k 分块,

  • ThreadblockSwizzleStreamK::get_iter_extents 确定当前 SK CTA 在工作分配中的迭代范围;
  • ThreadblockSwizzleStreamK::get_sk_tile_idx 根据迭代索引推断出与之对应的图块索引;
    GemmUniversalStreamk::init_sk_tile_work 初始化 SK 图块的工作描述符。
    else if (sk_block){// This is a SK blockint block_iter_end;params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);block_iters_remaining = block_iter_end - block_iter_begin;tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);}

如果是归约分块,

  • reduce_block_idx为当前 CTA 在 Reduction 块中的相对索引;
  • 调用 GemmUniversalStreamk::separate_reduction 将多个线程块的计算结果(例如来自 Split-K 块的部分结果)汇总成一个最终的结果;
  • 直接返回,不进入循环。
    else{if (reduce_block){// This is a reduction threadblockint reduce_block_idx = block_idx - reduce_start_block_idx;separate_reduction(reduce_block_idx);}return;}

调用 GemmUniversalStreamk::process_tile 函数根据 tile_work 中的信息(图块的坐标、迭代范围等)进行相应的计算。每次处理tile_work.k_iters_remaining次迭代。
block_iters_remaining为0时退出。

    // Iteration-processing loop bodyCUTLASS_PRAGMA_NO_UNROLLwhile (true){// Perform this block's share of work for this tileprocess_tile(tile_work,block_idx,dp_start_block_idx,block_iter_begin);block_iters_remaining -= tile_work.k_iters_remaining;if (block_iters_remaining == 0){break;}// Continue to next tile__syncthreads();

处理下一个图块,

  • 如果是 DP CTA,调整为下一个波中的图块,调用 GemmUniversalStreamk::init_dp_tile_work 初始化;
  • 如果是 SK CTA,以倒序方式处理图块,调用 GemmUniversalStreamk::init_sk_tile_work 函数。
      if (block_idx >= dp_start_block_idx){// DP block consume their tiles at stridetile_idx += params.block_mapping.avail_sms;init_dp_tile_work(tile_work, tile_idx);}else{// SK blocks consume their tiles in backwards ordertile_idx--;init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);}}}

GemmUniversalStreamk::invoke

GemmUniversalStreamk::invoke
GemmUniversalStreamk

静态函数创建一个 GemmUniversalStreamk 对象,然后调用 GemmUniversalStreamk::operator()。

public://// Device-only API//// Factory invocationCUTLASS_DEVICEstatic void invoke(Params const &params,SharedStorage &shared_storage){GemmUniversalStreamk op(params, shared_storage);op();}

GemmUniversalStreamk::GemmUniversalStreamk

  // ConstructorCUTLASS_DEVICEGemmUniversalStreamk(Params const &params,SharedStorage &shared_storage):params(params),shared_storage(shared_storage),thread_idx(threadIdx.x),warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)),   // broadcast the warp_id computed by lane 0 to ensure dependent codelane_idx(threadIdx.x % 32),epilogue(shared_storage.epilogue,thread_idx,warp_idx,lane_idx){}

GemmUniversalStreamk::operator()

GemmUniversalStreamk::operator
GemmUniversalStreamk::gemm

调用 GemmUniversalStreamk::gemm 函数。

  /// Executes one GEMMCUTLASS_DEVICEvoid operator()(){// Generic SK code pathgemm();}
};

ThreadblockSwizzleStreamK

/// Threadblock mapping control for GEMMs
struct ThreadblockSwizzleStreamK {/// Advertise StreamkFeatureusing StreamkFeature = void;/// Kernel traitstemplate <typename GemmKernel>struct KernelTraits {};

3种归约方法,这里使用 Mixed 模式。

  /// Reduction strategyenum ReductionStrategy{kNone,      // Data-parallel strategy (no seams, fixup, etc.)kAtomic,    // Non-deterministic reduction of SK-block partials using atomic aggregation in L2kMixed,     // Deterministic reduction of SK-block partials employing either://   (a) A separate wave of reduction thread blocks" (for scenarios with lots of//       SK-blocks per SK-tile)//   (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few//       SK-blocks per SK-tile)};static ReductionStrategy const kReductionStrategy = kMixed;

kDpEfficiencyThreshold没有用到。
光栅化队列为8x4,每个队列包含32个 CTA。
kFixupStartupIterEquivkFixupPeerIterEquiv没有用到。

  //// Heuristics///// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel)static float constexpr kDpEfficiencyThreshold = 0.92f;/// Minimum number of MAC-iterations per streamk blockstatic int const kMinItersPerSkBlock = 2;/// Height in CTAs of a grid rasterization cohortstatic int const kCohortCtasM = 8;/// Width in CTAs of a grid rasterization cohortstatic int const kCohortCtasN = 4;/// Number of CTAs per cohortstatic int const kCtasPerCohort = kCohortCtasN * kCohortCtasM;/// Cost-equivalent number of SM-iterations for fixup I/Ostatic int const kFixupStartupIterEquiv = 10;static int const kFixupPeerIterEquiv = 3;
  //// Member state///// The 3D value-extents of the GEMM computation volume (m,n,k)GemmCoord problem_size;/// Div/mod acceleratorsFastDivmod div_mod_tiled_shape_m;FastDivmod div_mod_tiled_shape_n;FastDivmod div_mod_tiled_cohort_shape_n;FastDivmod div_mod_iters_per_tile;/// Whether to perform cohort CTA rasterizationbool cohort_raster;// Whether to pad and remap block indicesbool remap_block_indices;/// CTA occupancy per SMint sm_occupancy;/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)int avail_sms;int dp_blocks;                            /// Number of data-parallel thread blocks in the gridint dp_first_wave_tiles;                  /// Number of output tiles each CTA in the first DP wave will produce/// Number of reduction blocks in the gridint reduction_blocks;int sk_waves;int sk_tiles;int sk_big_blocks_per_region;int sk_iters_per_region;/// Div/mod acceleratorsFastDivmod div_mod_sk_iters_per_normal_block;FastDivmod div_mod_sk_iters_per_big_block;FastDivmod div_mod_sk_iters_per_region;FastDivmod div_mod_sk_regions;                      //!! used in block mapFastDivmod div_mod_sk_blocks_per_region;            //!! used in block map/// The batch countint batch_count;//// Host+device interface///// ConstructorThreadblockSwizzleStreamK() = default;

ThreadblockSwizzleStreamK::tiled_shape

返回一个 GemmCoord 对象。
batch_count放到 k 维上。

  /// Returns the GEMM volume in thread block tilesCUTLASS_HOST_DEVICEGemmCoord tiled_shape() const{return GemmCoord(static_cast<int>(div_mod_tiled_shape_m),static_cast<int>(div_mod_tiled_shape_n),batch_count);}

ThreadblockSwizzleStreamK::iters_per_tile

FastDivmod::int 可以取出原始的除数。

  /// Number of iterations per output tileCUTLASS_HOST_DEVICEint iters_per_tile() const{return static_cast<int>(div_mod_iters_per_tile);}

ThreadblockSwizzleStreamK::sk_iters_per_normal_block

  /// Number of iterations for normal SK-blocksCUTLASS_HOST_DEVICEint sk_iters_per_normal_block() const{return static_cast<int>(div_mod_sk_iters_per_normal_block);}

ThreadblockSwizzleStreamK::sk_regions

  /// Number of SK regionsCUTLASS_HOST_DEVICEint sk_regions() const{return static_cast<int>(div_mod_sk_regions);}

ThreadblockSwizzleStreamK::sk_blocks_per_region

  /// Number of SK blocks per region (splitting factor)CUTLASS_HOST_DEVICEint sk_blocks_per_region() const{return static_cast<int>(div_mod_sk_blocks_per_region);}

ThreadblockSwizzleStreamK::Print

tiles = dp_tiles + sk_tiles

  //// Host-side interface///// Debug printvoid Print(){
#ifndef __CUDA_ARCH__auto tiles = tiled_shape().mn().product();std::cout <<"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" <<", tiles: " << tiles <<", dp_tiles: " << tiles - sk_tiles <<", sk_tiles: " << sk_tiles <<", iters_per_tile: " << iters_per_tile() <<", reduction_blocks: " << reduction_blocks <<", dp_blocks: " << dp_blocks <<", dp_waves: " << dp_blocks / avail_sms <<", dp_first_wave_tiles: " << dp_first_wave_tiles <<", sk_blocks_per_region: " << sk_blocks_per_region() <<", sk_regions: " << sk_regions() <<", sk_waves: " << sk_waves <<", sk_iters_per_normal_block: " << sk_iters_per_normal_block() <<", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<", remap_block_indices: " << remap_block_indices <<", cohort_raster: " << cohort_raster <<", sm_occupancy: " << sm_occupancy <<", avail_sms: " << avail_sms <<", num_blocks: " << get_num_blocks() <<"\n\n";
#endif}

ThreadblockSwizzleStreamK::get_sk_blocks

初始化 savings_iters 为最小整数值,sk_blocks 为 0。如果 sk_tiles 为 0,则直接返回。

  // Compute sk_blocks to dispatch for a given number of sk_tilesstatic void get_sk_blocks(int &sk_blocks,     /// [out]int &savings_iters, /// [out]int sk_tiles,int iters_per_tile,int avail_sms,int max_sk_occupancy,bool allow_partial_wave){savings_iters = INT_MIN;sk_blocks = 0;if (sk_tiles == 0) {return;}

sk_iters为 SK 图块的总迭代次数。
dp_equiv_iters为等效 DP 迭代次数。由于向上取整,所以大于sk_iters

    int sk_iters = sk_tiles * iters_per_tile;int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms;int dp_equiv_iters = iters_per_tile * dp_equiv_waves;

kMinItersPerSkBlock 为每个 SK 块内的最小 MAC 循环次数。
如果允许部分波次分配,则最小 SK 线程块数为 avail_smssk_tiles + 1 之间的最小值。否则,最小线程块数等于 avail_sms
最大的 SK 块数 max_sk_blocksavail_sms * max_sk_occupancysk_iters / kMinItersPerSkBlock的限制。

    int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms;int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock);

t i m e C T A ( g ) ← a + b ( F i x u p P e e r s ( g ) > 1 ) + c ( I t e r s P e r C t a ( g ) ) + d ( F i x u p P e e r s ( g ) − 1 ) \begin{aligned} time_{CTA}(g) \leftarrow & {a} + {b} (FixupPeers(g) > 1) \\ & + {c} (ItersPerCta(g)) +{d} (FixupPeers(g) - 1) \end{aligned} timeCTA(g)a+b(FixupPeers(g)>1)+c(ItersPerCta(g))+d(FixupPeers(g)1)
其中:
I t e r s P e r C t a ( g ) ← ⌈ ⌈ m BLK_M ⌉ × ⌈ n BLK_N ⌉ × ⌈ k BLK_K ⌉ g ⌉ F i x u p P e e r s ( g ) ← ⌈ ⌈ k BLK_K ⌉ I t e r a t i o n s P e r C t a ( g ) ⌉ \begin{aligned} ItersPerCta(g) \leftarrow & \left\lceil \frac{ \lceil \frac{m}{\text{BLK\_M}} \rceil \times \lceil \frac{n}{\text{BLK\_N}} \rceil \times \lceil \frac{k}{\text{BLK\_K}} \rceil} {g}\right\rceil \\ FixupPeers(g) \leftarrow & \left\lceil \frac{\left\lceil\frac{k}{\text{BLK\_K}} \right\rceil} {IterationsPerCta(g)} \right\rceil \end{aligned} ItersPerCta(g)FixupPeers(g)gBLK_Mm×BLK_Nn×BLK_Kk IterationsPerCta(g)BLK_Kk
遍历[min_sk_blocks, max_sk_blocks]区间的所有分块数。

  • 根据trial_sk_blocks计算 SK 波数 sk_waves和每个 CTA 处理的最大 SK 迭代次数max_sk_iters_per_block,即 I t e r s P e r C t a ( g ) ItersPerCta(g) ItersPerCta(g)
  • sk_iter_equiv为等效的 SK 迭代次数。
  • num_peers为处理同一图块的 CTA 数量,即 F i x u p P e e r s ( g ) FixupPeers(g) FixupPeers(g)
  • base_cost为 CTA 的固定成本 a a a
  • iter_cost为迭代成本 c c c
  • peer_cost为 CTA 协作成本 b b b d d d
  • 如果 trial_savings_iters 大于或等于当前的 savings_iters,则更新 savings_iterssk_blocks
    for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks){int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms;int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks;int sk_iter_equiv = max_sk_iters_per_block * sk_waves;int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1;        // add one for alignment skewfloat iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv);if (trial_sk_blocks % sk_tiles == 0){// alignednum_peers = (trial_sk_blocks / sk_tiles);iter_cost = 0.0f;}float peer_cost = 2.0f * float(num_peers);float base_cost = 2.0f * float(sk_waves);int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost);int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv;if (trial_savings_iters >= savings_iters) {savings_iters = trial_savings_iters;sk_blocks = trial_sk_blocks;}}}

ThreadblockSwizzleStreamK::get_blocks

ThreadblockSwizzleStreamK::get_blocks
ThreadblockSwizzleStreamK::get_sk_blocks

计算全波图块数full_wave_tiles和部分波图块数partial_wave_tiles
如果输出块数能够被 SM 整除,则只使用 DP 分块。

  /// Determine the populations of DP and SK blocks to invoke for the given number of output tilesstatic void get_blocks(int &dp_tiles,      /// [out]int &sk_blocks,     /// [out]int output_tiles,int iters_per_tile,int avail_sms,int sm_occupancy){int full_waves = output_tiles / avail_sms;int full_wave_tiles = full_waves * avail_sms;int partial_wave_tiles = output_tiles - full_wave_tiles;int score = -1;dp_tiles = output_tiles;sk_blocks = 0;if (partial_wave_tiles == 0){// Perfect quantizationreturn;}

如果full_waves小于 SM 的最大活跃 CTA 数,通过形成 SK 波来达到满 GPU 占用率。
max_sk_occupancy计算出可以用于 SK 的最大波数。
调用 ThreadblockSwizzleStreamK::get_sk_blocks 函数计算sk_blocksscore

    if (full_waves < sm_occupancy){// We're less than full GPU occupancy// Form the SK wave from the partial wave to get us up to full GPU occupancyint max_sk_occupancy = sm_occupancy - full_waves;dp_tiles = full_wave_tiles;get_sk_blocks(sk_blocks,score,partial_wave_tiles,iters_per_tile,avail_sms,max_sk_occupancy,true);                 // we can run with less than a full wave of SK-blocksif (score < 0) {// not profitablesk_blocks = 0;dp_tiles = output_tiles;}return;}

如果当前完整波的数量是 SM 占用数的整数倍减一。这意味着如果添加一个部分波,将使占用率达到满负载。代码将尝试通过将剩余的partial_wave_tiles 分配为 SK 块,以确保 GPU 被完全占用。

    // We're at (or greater) than GPU occupancyif ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1)){// If occupancy is more than one CTA per SM, form the SK wave from the partial// wave to get us to full GPU occupancyint max_sk_occupancy = 1;dp_tiles = full_wave_tiles;get_sk_blocks(sk_blocks,score,partial_wave_tiles,iters_per_tile,avail_sms,max_sk_occupancy,true);                 // we can run with less than a full wave of SK-blocksif (score >= 0) {return;}}

在 GPU 占用不足时,通过结合最后一个完整波和部分波来形成 SK 波(Stream-K 波)。其目的是调整 SK 块以优化 GPU 资源的利用。
减少 DP 分块的数量。

    // Form the SK wave by combining the last full wave and the partial wave// We're less than full GPU occupancydp_tiles = full_wave_tiles - avail_sms;int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy);get_sk_blocks(sk_blocks,score,partial_wave_tiles + avail_sms,iters_per_tile,avail_sms,max_sk_occupancy,false);                 // we cannot run with less than a full wave of SK-blocksif (score < 0) {// not profitablesk_blocks = 0;dp_tiles = output_tiles;}}

ThreadblockSwizzleStreamK::ThreadblockSwizzleStreamK

ThreadblockSwizzleStreamK::ThreadblockSwizzleStreamK
ThreadblockSwizzleStreamK::get_blocks

iters_per_tile为每个图块的迭代次数。

  /// Constructor: *Gemm* problem size (m, n, k)ThreadblockSwizzleStreamK(GemmUniversalMode const mode_,GemmCoord const problem_size_,GemmCoord const tile_size_,int const batch_split_,                        /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)int const sm_occupancy_,int const device_sms_,int const avail_sms_,                          /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)size_t const element_A_bytes_,size_t const element_B_bytes_,size_t const element_C_bytes_,int const epilogue_acc_fragments_):problem_size(problem_size_),batch_count((mode_ == GemmUniversalMode::kBatched || mode_ == GemmUniversalMode::kArray) ? batch_split_ : 1),reduction_blocks(0),dp_blocks(0),dp_first_wave_tiles(1),     // Default: one tile per DP-block in the first wave of DP blockssk_tiles(0),sk_big_blocks_per_region(0),sk_iters_per_region(0),sk_waves(0),sm_occupancy(sm_occupancy_),remap_block_indices(false),avail_sms(fast_max(1, avail_sms_)),cohort_raster(false){int gpu_occupancy = device_sms_ * sm_occupancy;int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k();int sk_iters_per_normal_block = 0;int sk_regions = 1;              // Default: a single region of iteration space (across all SK tiles)int sk_blocks_per_region = 0;

tiled_shape为图块在3维上的形状。
flops_per_bytedp_efficiency没有用到。

    GemmCoord tiled_shape((problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),batch_count);size_t problem_bytes =(element_C_bytes_ * problem_size.m() * problem_size.n()) +(element_A_bytes_ * problem_size.m() * problem_size.k()) +(element_B_bytes_ * problem_size.k() * problem_size.n());size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2;[[maybe_unused]] float flops_per_byte = float(problem_flops) / float(problem_bytes);int output_tiles = tiled_shape.m() * tiled_shape.n();int waves = (output_tiles + avail_sms - 1) / avail_sms;[[maybe_unused]] float dp_efficiency = float(output_tiles) / float(waves * avail_sms);

首先初始化为仅使用 DP 图块。

    //// Determine dispatch composition of DP-tiles and SK-blocks//// Start with a DP-only configurationint dp_tiles = output_tiles;    // Number of data-parallel tilesint sk_blocks = 0;              // Number of thread blocks to produce the remaining SK tiles

仅 Gemm 模式支持 SK 加载平衡。

  • 如果split_factor大于1,则split_factor个 SK 处理一个图块,不使用 DP CTA;
  • 否则如果设置了kReductionStrategyavail_sms大于1,调用 ThreadblockSwizzleStreamK::get_blocks 启发式计算dp_tilessk_blocks
    // Only kGemm mode allows for SK load balancingif (mode_ == GemmUniversalMode::kGemm){int split_factor = batch_split_;if (split_factor > 1){// Split-K overridedp_tiles = 0;sk_blocks = output_tiles * split_factor;}else if ((kReductionStrategy != kNone) &&   // Load-balancing strategy statically enabled(avail_sms > 1))                         // Plurality of SMs to load balance across{// Use heuristicsget_blocks(dp_tiles,      /// [out]sk_blocks,     /// [out]output_tiles,iters_per_tile,avail_sms,sm_occupancy);}}

计算 SK CTA 的信息。
sk_iters_per_normal_block为正常 SK CTA 的迭代数。
sk_regions表示 SK CTA 组处理的子分区数量。
如果sk_tiles能够被 SK CTA 均分时,sk_regions为 SK 图块的数量。
得到sk_blocks_per_regionsk_big_blocks_per_regionsk_iters_per_region三个变量。

    sk_tiles = output_tiles - dp_tiles;// Compute SK block iteration detailsif (sk_blocks > 0){sk_waves = (sk_blocks + avail_sms - 1) / avail_sms;int sk_iters = sk_tiles * iters_per_tile;sk_blocks = fast_min(sk_blocks, sk_iters);sk_iters_per_normal_block = sk_iters / sk_blocks;int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks);int sk_big_blocks = extra_sk_iters;if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0)){// Split-K decompositionsk_regions = sk_tiles;}sk_blocks_per_region = sk_blocks / sk_regions;sk_big_blocks_per_region = sk_big_blocks / sk_regions;sk_iters_per_region = sk_iters / sk_regions;

使用单独的归约波的条件:

  • 使用非原子归约策略;
  • SK 波的数量不足以完全占用 GPU;
  • 有超过三个 CTA 共同处理一个 SK 图块。

epilogue_acc_fragments_为 Epilogue::kAccumulatorFragments,即 LinearCombination::kCount,即 AlignmentC,等于8。

      // Use a separate reduction wave when all of:// - Non-atomic reduction stratgy// - The number of SK waves won't fully occupy the GPU (Otherwise we don't have//   a strong-scaling case for more parallel reduction)// - More than three peers working on an SK tile.  (This occurs when the ratio of//   SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks,//   e.g.:[partial-block | block | block | partial-block] ).  With three or//   less peers, the two non-finishing SK-blocks are not expexted to contend.if ((kReductionStrategy == kMixed) &&(sk_waves < sm_occupancy) &&(sk_blocks > 2 * sk_tiles)){// Launch a reduction block for every accumulator fragment in each SK-tilereduction_blocks = sk_tiles * epilogue_acc_fragments_;}

重新映射块索引的条件:

  • 可以占用多个 SM;
  • 所有可用的 SM 都在使用;
  • 活动 CTA 数量大于可用 SM 数量的两倍。
      // When we have a multi-occupancy kernel and at least two waves of active blocks (where// at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2)// remap the block indices so that we can reliably spread the SK blocks evenly across the// device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx().remap_block_indices = ((sm_occupancy > 1) &&(device_sms_ == avail_sms) &&(get_num_active_blocks() > avail_sms * 2));// Initialize fast div/mod members related to SKdiv_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region);div_mod_sk_regions = FastDivmod(sk_regions);div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region);}

计算 DP CTA 的信息。
在2维平面将图块分为群组,从而提升 L2缓存的重用率,群组形状为tiled_cohort_shape。类似于 swizzling 方法。
cohort_blocks为群组中 CTA 的数量。
cohort_efficiency为群组内 CTA 的有效率。因对齐,可能会小于1。

    //// Compute DP blocks//dp_blocks = dp_tiles;cutlass::gemm::GemmCoord tiled_cohort_shape((tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,tiled_shape.k());int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);

计算最后一个 SK CTA 所在的群组网格坐标(cohort_grid_m, cohort_grid_n)。检查其是否超出了tiled_shape

    // Check if the SK tiles would be in cohorts that are in-boundsbool sk_in_range = true;if (sk_tiles > 0){int last_sk_tile = sk_tiles - 1;int cohort_tile_idx = last_sk_tile / kCtasPerCohort;int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n();int cohort_grid_n = (cohort_grid_m > 0) ?tiled_cohort_shape.n() - 1 :cohort_tile_idx % tiled_cohort_shape.n();if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) ||(((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n())){sk_in_range = false;}}

如果 SK CTA 没有超出,DP CTA 的数量是 GPU 占用的两倍,cohort_efficiency大于0.85,则

  • 启用群组光栅化,更新dp_blockscohort_blocks

否则,更新 DP 第一波的半持久性,以确保完整的网格波集(仅适用于存在 SK 组件且未进行阻塞队列光栅化时):

  • dp_tile_waves为 DP 图块需要的波数,full_dp_tile_waves为其中的完整波数量。
  • dp_first_wave_tiles为第一个 DP 波中每个 CTA 产生的输出图块数。
  • waveset_excess 表示 SK 波和 DP 图块波的总和与 SM 占用的余数。
  • 如果 dp_first_wave_tiles + waveset_excess 小于或等于 full_dp_tile_waves,则增加第一个 DP 波中处理的图块数,去掉 SM 占用的余数。
    // Decide if we're going to be doing cohort rasterif (sk_in_range &&(dp_blocks >= gpu_occupancy * 2) &&(cohort_efficiency > 0.85f)){cohort_raster = true;dp_blocks = cohort_blocks;}else if (sk_waves > 0){// Update semi-persistence of first DP wave to ensure full grid wavesets// (Only applies when there's an SK component and we're not doing blocked cohort rasterization)int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms;int full_dp_tile_waves = dp_tiles / avail_sms;int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy;if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves){dp_first_wave_tiles += waveset_excess;dp_blocks -= (waveset_excess * avail_sms);}}

将图块划分构造为 FastDivmod,加速设备端的计算。

    // Setup fast-div/mod for device-side usagediv_mod_tiled_shape_m = FastDivmod(tiled_shape.m());div_mod_tiled_shape_n = FastDivmod(tiled_shape.n());div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());div_mod_iters_per_tile = FastDivmod(iters_per_tile);}

ThreadblockSwizzleStreamK::get_num_active_blocks

计算有效 CTA 的数量:SK CTA + DP CTA + Reduction CTA。

  /// Number of blocks performing useful workint get_num_active_blocks() const{return (sk_waves * avail_sms) + dp_blocks + reduction_blocks;}

ThreadblockSwizzleStreamK::get_num_blocks

ThreadblockSwizzleStreamK::get_num_blocks
ThreadblockSwizzleStreamK::get_num_active_blocks

获取当前 GEMM 操作中使用的 CTA 数量。
先调用 ThreadblockSwizzleStreamK::get_num_active_blocks 获得实际需要的 CTA 总数。
如果需要重新映射块索引,则至少要为 SM 分配4个波。

  /// Obtains number of threadblocks per GEMMint get_num_blocks() const{int active_blocks = get_num_active_blocks();if (remap_block_indices){// Add padding blocks if we are performing remapping in order to dispatch a grid of at least four wavesreturn fast_max(active_blocks, avail_sms * 4);}return active_blocks;}

ThreadblockSwizzleStreamK::get_grid_dims

ThreadblockSwizzleStreamK::get_grid_dims
ThreadblockSwizzleStreamK::get_num_blocks

调用 ThreadblockSwizzleStreamK::get_num_blocks 得到 CTA 总数。

  /// Obtains grid extents in CTAsdim3 get_grid_dims() const{return dim3(get_num_blocks(), 1, batch_count);}

ThreadblockSwizzleStreamK::device_num_blocks

  //// Device-side interface///// Obtains number of threadblocks per GEMMCUTLASS_DEVICEint device_num_blocks() const{return gridDim.x;}

ThreadblockSwizzleStreamK::get_sk_tile_idx

  /// Obtains tile index for the given sk iterationCUTLASS_DEVICEint get_sk_tile_idx(int iter) const{int tile_idx = div_mod_iters_per_tile.div(iter);return tile_idx;}

ThreadblockSwizzleStreamK::get_batch_idx

  /// Obtains the batch indexCUTLASS_DEVICEint get_batch_idx() const{return RematerializeBlockIdxZ();}

ThreadblockSwizzleStreamK::get_tile_offset

根据给定的 tile_idx 计算出当前线程块在网格中的二维平铺坐标 (m, n),并将其封装到 GemmCoord 对象中。
首先使用行主序的方式计算线程块的二维坐标。

  /// Obtains the calling threadblock's tiled coordinates for the given tile indexCUTLASS_DEVICEGemmCoord get_tile_offset(int tile_idx) const{int m, n;// row-major rasterdiv_mod_tiled_shape_n(m, n, tile_idx);

如果矩阵的行数m小于列数n,则切换到列主序光栅化。当矩阵是宽矩阵时,列的优先遍历可能会提高访存性能和资源利用率。

    if (tiled_shape().m() < tiled_shape().n()){// column-major rasterdiv_mod_tiled_shape_m(n, m, tile_idx);}

当启用 cohort_raster 时,线程块按群组进行光栅化排列。

  • 计算群组的线性索引cohort_tile_idx,转换为群组网格中的二维坐标 (cohort_grid_m, cohort_grid_n)
  • 计算 CTA 在群组内的线性索引block_idx_cohort,进一步分解为组内二维坐标 (block_cohort_m, block_cohort_n)
  • 根据群组网格坐标和群组内部 CTA 索引合成mn

这种光栅化方式可以使线程块在 GPU 上的分布更加均匀,可能有助于负载均衡和减少资源竞争。

    if (cohort_raster){// tiled cohort rasterint cohort_tile_idx = tile_idx / kCtasPerCohort;int cohort_grid_m, cohort_grid_n;div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);int block_idx_cohort = tile_idx % kCtasPerCohort;int block_cohort_m = block_idx_cohort / kCohortCtasN;int block_cohort_n = block_idx_cohort % kCohortCtasN;m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;}return GemmCoord(m, n, get_batch_idx());}

ThreadblockSwizzleStreamK::get_tile_offset_row_major

使用行主序(row-major)光栅化的方式来计算每个线程块在网格中的位置,并返回一个 GemmCoord 对象。
ThreadblockSwizzleStreamK::get_batch_idx 返回网格的 z 轴索引。

  /// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rasterization)CUTLASS_DEVICEGemmCoord get_tile_offset_row_major(int tile_idx) const{// row-major rasterint m, n;div_mod_tiled_shape_n(m, n, tile_idx);return GemmCoord(m, n, get_batch_idx());}

ThreadblockSwizzleStreamK::get_block_idx

获取当前 CTA 的线性索引。
首先获取原始块索引。

  /// Obtains calling threadblock's linear threadblock indexCUTLASS_DEVICEint get_block_idx() const{int block_idx = RematerializeBlockIdxX();

如果启用了 remap_block_indices,并且当前 CTA 在前两波,则重新映射块索引:remapped_block_idx将相邻的两个线程块(block)分配到不同的波次(wave),以优化计算资源的使用。

    // Remap the block indices for the first two waves of thread blocks if// we have multi-occupancy and the grid constitutes four or more wavesif (remap_block_indices && (block_idx < avail_sms * 2)){int dest_sm = block_idx / 2;int dest_wave = block_idx % 2;int remapped_block_idx = dest_sm + (dest_wave * avail_sms);block_idx = remapped_block_idx;}

如果当前 CTA 位于 SK 区域,则进一步重新映射:通过调整 block_in_regionregiondiv_mod_sk_regions 函数传入的顺序交换二者的值。重新映射减少区域内等待时间,提高计算效率。
假设区域数为3,每个区域的块数为4。按照上述映射规则,计算每个 block_idx 对应的 regionblock_in_region

block_idxregionblock_in_region重映射后的索引
0000
1104
2208
3011
4115
5219
6022
7126
82210
9033
10137
112311

不确定该操作对访存模式和缓存重用的影响。

    // Remap block indices to interleave SK regions to limit intra-region waitingif (block_idx < sk_regions() * sk_blocks_per_region()){int block_in_region;int region;div_mod_sk_regions(block_in_region, region, block_idx);block_idx = (region * sk_blocks_per_region()) + block_in_region;}return block_idx;}

ThreadblockSwizzleStreamK::get_sk_block_idx

根据给定的迭代索引iter计算出该迭代对应的第一个 SK CTA 索引。
首先计算iter属于哪个区域,以及在区域中的偏移量iter_in_region

  /// Obtains calling linear threadblock index of the first block to work on the given tileCUTLASS_DEVICEint get_sk_block_idx(int iter) const{int region_idx;int iter_in_region;div_mod_sk_iters_per_region(region_idx, iter_in_region, iter);

ThreadblockSwizzleStreamK::sk_iters_per_normal_block 为普通 SK CTA 的迭代次数。

计算区域内所有大 CTA 的迭代次数总和big_block_iters,以及所有普通 CTA 迭代数总和normal_block_iters。big_block 比 normal_block 的迭代数多一个。big_block 在前,normal_block 在后。

    int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region;   // number of iterations in the region's big blocksint normal_block_iters = iter_in_region - big_block_iters;                                                 // number of iterations in the region's normal blocks

假设该 CTA 为 big_block,计算其索引big_block_idx_in_region
假设其属于 normal_block,计算其索引normal_block_idx_in_region
真正的索引为block_idx_in_region

    int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region);int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters);int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?big_block_idx_in_region :normal_block_idx_in_region;

ThreadblockSwizzleStreamK::sk_blocks_per_region 为每个区域中 SK CTA 的数量。
最终计算出处理当前迭代的线程块在整个线程块网格中的全局索引owning_block_idx

    int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region;return owning_block_idx;}

ThreadblockSwizzleStreamK::get_iter_extents

计算线程块应处理的开始迭代和结束迭代索引。
首先确定sk_block_idx的所属区域,以及在区域中的索引block_idx_in_region
假定没有 big_block,计算全局起始迭代索引block_iter_begin

  /// Obtains iteration extends for the given SK block indexCUTLASS_DEVICEvoid get_iter_extents(int sk_block_idx,int &block_iter_begin,int &block_iter_end) const{int region_idx;int block_idx_in_region;div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block());

调整block_iter_begin的值:

  • 如果当前 CTA 是一个“big_block”,则增加块的迭代数量。
  • 如果是 normal_block,跳过前面所有的 big_block。
    // Adjust extents for the first "num_big_blocks" blocks that get one extra iterationint block_iters = sk_iters_per_normal_block();if (block_idx_in_region < sk_big_blocks_per_region) {// This is a +1 iteration blockblock_iter_begin += block_idx_in_region;block_iters++;} else {// This is a regular blockblock_iter_begin += sk_big_blocks_per_region;}

计算迭代的末尾。

    block_iter_end = block_iter_begin + block_iters;}

ThreadblockSwizzleStreamK::get_first_block_idx

ThreadblockSwizzleStreamK::get_first_block_idx
ThreadblockSwizzleStreamK::iters_per_tile
ThreadblockSwizzleStreamK::get_sk_block_idx

获取处理tile_idx图块时,第一个开始工作的 CTA 索引。
如果是 DP 图块,直接返回 CTA 索引即可;

  /// Obtains calling linear threadblock index of the first block to work on the given tileCUTLASS_DEVICEint get_first_block_idx(int tile_idx, int block_idx) const{if (tile_idx >= sk_tiles) {// DP tilereturn block_idx;}

否则调用 ThreadblockSwizzleStreamK::iters_per_tile,计算出全局的 MAC 迭代数iter,再通过 ThreadblockSwizzleStreamK::get_sk_block_idx 函数获取第一个处理给定迭代的 CTA 索引。

    int iter = tile_idx * iters_per_tile();return get_sk_block_idx(iter);}};

参考资料:

  • Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU
  • 聊聊_苹果 AMX_ 矩阵运算单元
  • 探索_AMX_: 解锁_Apple_ Silicon隐藏性能
  • [QST] How should I set batch_stride of gemm_universal? #702
  • [QST] StreamK ReductionStrategy: “Atomic” or “Mixed” #1488
  • StreamK in 47_ampere_gemm_universal_streamk
  • High-Performance Software Rasterization on GPUs
  • variable cache line width ?
  • Dissecting GPU Memory Hierarchy through Microbenchmarking
  • TESLA V100 GPU
  • Comparing LLC-memory Traffic between CPU and GPU Architectures
  • OPTIMIZING CUDA APPLICATIONS FOR NVIDIA A100 GPU
  • SC18 MatMul CublasLt CUTLASS
  • [9.7.12.6. Parallel Synchronization and Communication Instructions: red
  • New instruction for inter-CTA barrier in future GPUs? #1502
  • 9.7.8.8. Data Movement and Conversion Instructions: ld
  • Warps and occupancy - GTC
  • [BUG] gemm universal streamk core dump when cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags return sm_occupancy_ is 0 #1223
  • cudaGetDevice function what does it do not so clear?
  • 算子性能优化 方法介绍
  • CS 701 Software Pipelining
  • 使用 CUDA 扭曲级别基本体
  • libcu++
  • C++雾中风景16:std::make_index_sequence, 来试一试新的黑魔法吧
  • 谈谈 C++ 中的内存顺序 (Memory Order)
  • 8.8. Release and Acquire Patterns
  • 7.6. Synchronization Functions
  • [QST] ThreadblockSwizzleStreamK cost modeling questions #1489
  • How register liveness reducing is achieved #1515
  • Branching with Predication
  • CUTLASS GEMM API
  • cutlass GEMM 流水线——single-stage、pipelined、multi-stage

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/52150.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JNPF 5.0升级钜惠,感恩回馈永远在路上

尊敬的JNPF用户们&#xff1a; 经过引迈团队数月的辛勤努力和不断的技术创新&#xff0c;JNPF快速开发平台迎来全新升级——5.0版本&#xff01;此次5.0版本的迭代革新&#xff0c;不仅代表着我们技术实力的进一步提升&#xff0c;是我们对用户需求的深度理解和积极回应。为了…

基于C# winform部署图像动漫化AnimeGANv2部署onnx模型

【界面截图】 【效果演示】 【部分实现代码】 using System; using System.Diagnostics; using System.Windows.Forms; using OpenCvSharp;namespace FIRC {public partial class Form1 : Form{Mat src null;public Form1(){InitializeComponent();}private void button1_Cli…

html+css+js网页设计 天猫首页

htmlcssjs网页设计 天猫首页 网页作品代码简单&#xff0c;可使用任意HTML编辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 1&#xff0c;访问…

git本地仓库同步到远程仓库

整个过程分为如下几步&#xff1a; 1、本地仓库的创建 2、远程仓库的创建 3、远程仓库添加key 4、同步本地仓库到远程仓库 1、本地仓库的创建&#xff1a; 使用如下代码创建本地仓库&#xff1a; echo "# test" >> README.md git init git add README.md …

用户增长:策略与实践,驱动SaaS企业持续繁荣

在当今这个数字化时代&#xff0c;用户增长已成为所有行业&#xff0c;尤其是SaaS&#xff08;Software as a Service&#xff0c;软件即服务&#xff09;企业生存与发展的核心驱动力。用户增长不仅关乎市场份额的扩大&#xff0c;更是企业价值实现和持续盈利的基石。那么&…

【计算机网络】网络版本计算器

此前我们关于TCP协议一直写的都是直接recv或者read&#xff0c;有了字节流的概念后&#xff0c;我们知道这样直接读可能会出错&#xff0c;所以我们如何进行分割完整报文&#xff1f;这就需要报头来解决了&#xff01; 但当前我们先不谈这个话题&#xff0c;先从头开始。 将会…

【秋招笔试】8.18大疆秋招(第一套)-后端岗

🍭 大家好这里是 春秋招笔试突围,一起备战大厂笔试 💻 ACM金牌团队🏅️ | 多次AK大厂笔试 | 编程一对一辅导 ✨ 本系列打算持续跟新 春秋招笔试题 👏 感谢大家的订阅➕ 和 喜欢💗 和 手里的小花花🌸 ✨ 笔试合集传送们 -> 🧷春秋招笔试合集 🍒 本专栏已收…

Springboot发邮件功能如何实现?详细步骤?

Springboot发邮件配置指南&#xff1f;如何集成Spring Mail发邮件&#xff1f; 无论是用户注册、密码重置还是重要通知的发送&#xff0c;邮件都是不可或缺的沟通方式。Springboot作为一个流行的Java开发框架&#xff0c;提供了简洁易用的方式来实现邮件功能。AokSend将详细探…

音频转换器有哪些?一键转换,轻松享受

暑假里&#xff0c;你是否也沉浸在激情四溢的演唱会中&#xff0c;用手机记录下了那些难忘的现场音频&#xff1f; 但回到家中&#xff0c;想要将这些珍贵的现场记忆从手机迁移到电脑上永久保存时&#xff0c;却遇到了格式不兼容的难题。 别担心&#xff0c;今天我们就要解决…

基于Python的机器学习系列(8):Newton Raphson逻辑回归

在本篇博文中&#xff0c;我们将探讨一种比传统梯度下降更高效的优化方法——Newton Raphson方法&#xff0c;并学习如何在逻辑回归中应用它。Newton Raphson方法通过利用二阶导数的曲率信息&#xff0c;快速地找到使代价函数最小化的参数。尽管这种方法在处理较小规模的数据集…

前端项目重新打包部署后如何通知用户更新

前端项目重新打包部署后如何通知用户更新 前端项目重新打包部署后如何通知用户更新常用的webSocket解决方案纯前端方案路由拦截多线程main.ts中 创建多线程多线程逻辑处理 前端项目重新打包部署后如何通知用户更新 前端项目重新打包部署后&#xff0c;由于用户没及时更新页面&…

什么是OpenTiny?

OpenTiny 是一套企业级的 Web 前端开发解决方案&#xff0c;提供跨端、跨框架的 UI 组件库和低代码引擎&#xff0c;帮助开发者高效构建 Web 应用 。企业运用开发中&#xff0c;可以利用 OpenTiny 的以下核心组件和优势&#xff1a; TinyVue 组件库&#xff1a;一个丰富的组件库…

python初级爬虫实战:我是怎么用python下载音乐的

今天分享的内容是如何使用python下载歌曲和歌词信息&#xff0c;文章涉及内容主要为了帮助大家学习python技能&#xff0c;请大家合规合理使用。 如果你正在学习Python爬虫&#xff0c;但是找不到方向的话可以试试我这一份学习方法和籽料呀&#xff01;点击 领取&#xff08;不…

汽车IVI中控OS Linux driver开发实操(二十四):I2C设备驱动的编写

概述: 在Linux驱动中I2C系统中主要包含以下几个成员: I2C adapter(即I2C适配器,用来控制各种I2C从设备,其驱动需要完成对适配器的完整描述,最主要的工作是需要完成i2c_algorithm结构体。这个结构体包含了此I2C控制器的数据传输具体实现,以及对外上报此设备所支持的功…

0.91寸OLED迷你音频频谱

一、简介 音频频谱在最小0.91寸OLED 屏幕上显示&#xff0c;小巧玲珑 二、应用场景 本模块为音频频谱显示模块&#xff0c;用来获取声音频谱并展示频谱&#xff0c;跟随音乐声音律动 三、产品概述 基于主控芯片设计的将声音采集分析频谱&#xff0c;显示到0.91寸OLED的功能…

我们如何将数据输入到神经网络中?

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 下面我拿识别美女的例子来给大家介绍如何将美女的图片数据输入到神经网络中。 此例中&#xff0c;待输入的数据是一张图像。为了存储图像…

Java中String类的经典问题、错误认知以及归纳总结

在学习过程中对String类的理解反复刷新&#xff0c;以此文记之&#xff0c;做归纳总结&#xff0c;也适合新手避坑。 以实用性考虑&#xff0c;环境为Java 8 以及 之后版本。 String类相比其它类特殊的地方在于有一个字符串常量池(StringTable)&#xff0c;里面存着字面量的引…

Hackademic.RTB1靶场实战【超详细】

靶机下载链接&#xff1a;https://download.vulnhub.com/hackademic/Hackademic.RTB1.zip 一、主机探测和端口扫描 nmap 192.168.121.0/24 ip:192.168.121.196 端口:22、80 二、访问80端口 发现target可点击 点击后跳转&#xff0c;页面提示目标是读取到 key.txt 文件 fin…

Enhancing Octree-Based Context Models for Point Cloud Geometry Compression 论文笔记

1. 论文基本信息 发布于&#xff1a; IEEE SPL 2024 2. 创新点 分析了基于 one-hot 编码的交叉熵损失函数为什么不能准确衡量标签与预测概率分布之间的差异。介绍了 ACNP 模块&#xff0c;该模块通过预测占用的子节点数量来增强上下文模型的表现。实验证明了ACNP模块在基于八…

【Java】 力扣 最大子数组和

目录 题目链接题目描述思路代码 题目链接 53.最大子数组和 题目描述 思路 动态规划解析&#xff1a; 状态定义&#xff1a; 设动态规划列表 dp &#xff0c;dp[i] 代表以元素 nums[i] 为结尾的连续子数组最大和。 为何定义最大和 dp[i] 中必须包含元素 nums[i] &#xff1a;…