CUDA 经典问题:矩阵乘法 | CUDA
本文最后更新于:2022年11月21日
概述
关于矩阵乘法,相关的文章比较多,这里推荐这个系列:
- 深入浅出GPU优化系列:GEMM优化(一):矩阵乘法优化的思路
- 深入浅出GPU优化系列:GEMM优化(二):不使用汇编的优化方法
- 深入浅出GPU优化系列:GEMM优化(三):使用汇编的优化方法
本文也是参考自上述三篇文章。
代码实现
Baseline
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(SIZE, SIZE);
// dim3 grid(n / SIZE, m / SIZE);
int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
int idx_y = blockIdx.y * blockDim.y + threadIdx.y;
float sum = 0.f;
if (idx_x < n && idx_y < m) {
for (int kk = 0; kk < k; ++kk) {
sum += A[idx_y * k + kk] * B[kk * n + idx_x];
}
C[idx_y * n + idx_x] = sum;
}
}
使用 shared memory
将每个 block 需要计算的数据先存放到 shared memory 中,减少对 global memory 的访存次数。
在这段代码中,每个 SIZE x SIZE 的 block 负责计算 C 中 SIZE x SIZE 的块,在外层循环的每次迭代中,首先将 A 和 B 中对应的块拷贝到共享内存,然后基于共享内存中的数据进行矩阵乘法,然后进行下一次迭代并将每次迭代的结果累加,最终得到 C 中对应的一块。在每次循环中,共享内存都会被更新。
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(SIZE, SIZE);
// dim3 grid(n / SIZE, m / SIZE);
__shared__ float s_a[SIZE][SIZE];
__shared__ float s_b[SIZE][SIZE];
int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
int idx_y = blockIdx.y * blockDim.y + threadIdx.y;
float sum = 0.0;
for (int bk = 0; bk < k; bk += SIZE) {
s_a[threadIdx.y][threadIdx.x] = A[idx_y * k + (bk + threadIdx.x)];
s_b[threadIdx.y][threadIdx.x] = B[(bk + threadIdx.y) * n + idx_x];
__syncthreads();
for (int i = 0; i < SIZE; ++i) {
sum += s_a[threadIdx.y][i] * s_b[i][threadIdx.x];
}
__syncthreads();
}
if (idx_x < n && idx_y < m) {
C[idx_y * n + idx_x] = sum;
}
}
每个线程处理多个数据
新的分块方式
在上面的计算中,每个线程负责一个输出矩阵元素的计算,接下来的实现将使每个线程处理多个数据,这样可以大大提升计算访存比。足够大的计算访存比能提升计算单元的利用率,并能起到隐藏访存延迟的作用。本节使用了新的矩阵分块方式,分块的大小不再和 block size 相同,并且对共享内存也进行分块。
对于 bm、bn、bk、rm、rn 这几个参数,这里取 bm=128、bn=128、bk=8、rm=8、rn=8,这几个参数的选取逻辑可以参考CUDA 矩阵乘法终极优化指南。当这几个参数选定之后先来直观地感受一下这几个参数意义:假定给了三个矩阵 A、B、C,其维度都是 2048x2048。要求 C=AxB。那么我们需要开启 (2048/128)x(2048/128)=256 个 block,每个 block 里面有 (128/8)x(128/8)=256 个线程,每个线程需要负责计算 C 矩阵中 8x8=64 个元素的结果,每个 block 负责 256×64=16384 个元素的结果。
总的来说,对于一个 block 而言,有 256 个大迭代,每个大迭代中又有 8 个小迭代。
使用数据预取(double buffer)
代码示例
可以使用 #pragma unroll
宏展开循环,注意只对循环起止和循环步长在编译期就确定的循环有效。
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);
__shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
__shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;
// 每个线程一次只搬运一个数据
// 在 s_a/s_b 中,当前线程需要搬运的第一个数据的横纵坐标
const int A_TILE_ROW = tid / BLOCK_SIZE_K;
const int A_TILE_COL = tid % BLOCK_SIZE_K;
const int B_TILE_ROW = tid / BLOCK_SIZE_N;
const int B_TILE_COL = tid % BLOCK_SIZE_N;
// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;
for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW;
const int col = bk + A_TILE_COL;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
s_a[i + A_TILE_ROW][A_TILE_COL] = row < m && col < k ? A[row * k + col] : 0;
} else {
s_a[i + A_TILE_ROW][A_TILE_COL] = A[row * k + col];
}
}
// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = bk + i + B_TILE_ROW;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
s_b[i + B_TILE_ROW][B_TILE_COL] = row < k && col < n ? B[row * n + col] : 0;
} else {
s_b[i + B_TILE_ROW][B_TILE_COL] = B[row * n + col];
}
}
__syncthreads();
// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系
// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk] * s_b[kk][THREAD_SIZE_X * threadIdx.x + tx];
}
}
}
__syncthreads();
}
// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
if (row < m && col < n) {
C[row * n + col] += r_c[ty][tx];
}
} else {
C[row * n + col] += r_c[ty][tx];
}
}
}
}
利用向量化指令
可以尝试引导编译器使用 LDG.128 和 STG.128 指令来加速数据 IO。
// 通过这种方式可以引导编译器使用 LDG.128 指令
#define FETCH_FLOAT4(p) (reinterpret_cast<float4*>(&(p))[0])
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);
__shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
__shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
float frag_a[THREAD_SIZE_Y];
float frag_b[THREAD_SIZE_X];
const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;
// 每个线程一次搬运四个数据
// 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;
// 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;
// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;
for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = bk + A_TILE_COL;
FETCH_FLOAT4(s_a[i + A_TILE_ROW_START][A_TILE_COL]) = FETCH_FLOAT4(A[row * k + col]);
}
// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = bk + i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(s_b[i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
}
__syncthreads();
// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系
// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
// load A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
frag_a[ty] = s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk];
}
// load B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[tx]) = FETCH_FLOAT4(s_b[kk][THREAD_SIZE_X * threadIdx.x + tx]);
}
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[ty] * frag_b[tx];
}
}
}
}
// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
}
}
}
使用 double buffer
使用 double buffer 可以进一步掩盖访存延迟。
__global__ void matrixMultiplyKernel(float * A, float * B, float * C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);
__shared__ float s_a[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float s_b[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];
// 为了存储 BLOCK_SIZE_M * BLOCK_SIZE_K 的数据块,每个线程需要额外开启 ldg_a_reg 个寄存器进行存储
float ldg_a_reg[BLOCK_SIZE_M * BLOCK_SIZE_K / THREAD_NUM_PER_BLOCK];
float ldg_b_reg[BLOCK_SIZE_K * BLOCK_SIZE_N / THREAD_NUM_PER_BLOCK];
const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;
// 每个线程一次搬运四个数据
// 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;
// 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;
// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;
// preload A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = A_TILE_COL;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
s_a[0][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
s_a[0][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
s_a[0][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
s_a[0][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
}
// preload B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(s_b[0][i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
}
__syncthreads();
// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[0][0][THREAD_SIZE_Y * threadIdx.y + ty]);
}
// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[0][0][THREAD_SIZE_X * threadIdx.x + tx]);
}
int write_stage_idx = 1;
int bk = 0;
do {
bk += BLOCK_SIZE_K;
if (bk < k) {
// preload A from global memory to register
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = bk + A_TILE_COL;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
}
// preload B from global memory to register
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
const int row = bk + i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[row * n + col]);
}
}
// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系
int load_stage_idx = write_stage_idx ^ 1;
// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K - 1; ++kk) {
// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[(kk + 1) % 2][ty]) = FETCH_FLOAT4(s_a[load_stage_idx][kk + 1][THREAD_SIZE_Y * threadIdx.y + ty]);
}
// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[(kk + 1) % 2][tx]) = FETCH_FLOAT4(s_b[load_stage_idx][kk + 1][THREAD_SIZE_X * threadIdx.x + tx]);
}
// calculate C (this tile)
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[kk % 2][ty] * frag_b[kk % 2][tx];
}
}
}
if (bk < k) {
// preload A from register to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
s_a[write_stage_idx][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
s_a[write_stage_idx][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
s_a[write_stage_idx][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
s_a[write_stage_idx][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
}
// preload B from register to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(s_b[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}
__syncthreads();
write_stage_idx ^= 1;
}
// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[load_stage_idx ^ 1][0][THREAD_SIZE_Y * threadIdx.y + ty]);
}
// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[load_stage_idx ^ 1][0][THREAD_SIZE_X * threadIdx.x + tx]);
}
// compute last tile matmul THREAD_SIZE_X * THREAD_SIZE_Y
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[1][ty] * frag_b[1][tx];
}
}
} while(bk < k);
// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
}
}
}
性能对比
设备信息:NVIDIA Tesla T4, CUDA 11.1
对于 m = n = k = 1024,性能数据如下:
- baseline:4.05 ms
- 使用 shared memory:2.40 ms
- 每个线程处理多个数据
- without unroll:0.82 ms
- with unroll:0.65 ms
- 利用向量化指令:0.57 ms
- 使用 double buffer:0.54 ms
更多矩阵尺寸性能对比:
完整代码在 zh0ngtian/cuda_learning。
评论系统采用 utterances ,加载有延迟,请稍等片刻。