CUDA 经典问题:矩阵乘法 | CUDA

本文最后更新于:2022年11月21日

概述

关于矩阵乘法,相关的文章比较多,这里推荐这个系列:

本文也是参考自上述三篇文章。

代码实现

Baseline

__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
  // dim3 block(SIZE, SIZE);
  // dim3 grid(n / SIZE, m / SIZE);

  int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
  int idx_y = blockIdx.y * blockDim.y + threadIdx.y;

  float sum = 0.f;
  if (idx_x < n && idx_y < m) {
    for (int kk = 0; kk < k; ++kk) {
      sum += A[idx_y * k + kk] * B[kk * n + idx_x];
    }
    C[idx_y * n + idx_x] = sum;
  }
}

使用 shared memory

将每个 block 需要计算的数据先存放到 shared memory 中,减少对 global memory 的访存次数。

在这段代码中,每个 SIZE x SIZE 的 block 负责计算 C 中 SIZE x SIZE 的块,在外层循环的每次迭代中,首先将 A 和 B 中对应的块拷贝到共享内存,然后基于共享内存中的数据进行矩阵乘法,然后进行下一次迭代并将每次迭代的结果累加,最终得到 C 中对应的一块。在每次循环中,共享内存都会被更新。

__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
  // dim3 block(SIZE, SIZE);
  // dim3 grid(n / SIZE, m / SIZE);

  __shared__ float s_a[SIZE][SIZE];
  __shared__ float s_b[SIZE][SIZE];

  int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
  int idx_y = blockIdx.y * blockDim.y + threadIdx.y;

  float sum = 0.0;
  for (int bk = 0; bk < k; bk += SIZE) {
    s_a[threadIdx.y][threadIdx.x] = A[idx_y * k + (bk + threadIdx.x)];
    s_b[threadIdx.y][threadIdx.x] = B[(bk + threadIdx.y) * n + idx_x];
    __syncthreads();

    for (int i = 0; i < SIZE; ++i) {
      sum += s_a[threadIdx.y][i] * s_b[i][threadIdx.x];
    }
    __syncthreads();
  }

  if (idx_x < n && idx_y < m) {
    C[idx_y * n + idx_x] = sum;
  }
}

每个线程处理多个数据

新的分块方式

在上面的计算中,每个线程负责一个输出矩阵元素的计算,接下来的实现将使每个线程处理多个数据,这样可以大大提升计算访存比。足够大的计算访存比能提升计算单元的利用率,并能起到隐藏访存延迟的作用。本节使用了新的矩阵分块方式,分块的大小不再和 block size 相同,并且对共享内存也进行分块。

对于 bm、bn、bk、rm、rn 这几个参数,这里取 bm=128、bn=128、bk=8、rm=8、rn=8,这几个参数的选取逻辑可以参考CUDA 矩阵乘法终极优化指南。当这几个参数选定之后先来直观地感受一下这几个参数意义:假定给了三个矩阵 A、B、C,其维度都是 2048x2048。要求 C=AxB。那么我们需要开启 (2048/128)x(2048/128)=256 个 block,每个 block 里面有 (128/8)x(128/8)=256 个线程,每个线程需要负责计算 C 矩阵中 8x8=64 个元素的结果,每个 block 负责 256×64=16384 个元素的结果。

总的来说,对于一个 block 而言,有 256 个大迭代,每个大迭代中又有 8 个小迭代。

使用数据预取(double buffer)

代码示例

可以使用 #pragma unroll 宏展开循环,注意只对循环起止和循环步长在编译期就确定的循环有效。

__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
  // dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
  // dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

  __shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
  __shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
  float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};

  const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

  // 每个线程一次只搬运一个数据

  // 在 s_a/s_b 中,当前线程需要搬运的第一个数据的横纵坐标
  const int A_TILE_ROW = tid / BLOCK_SIZE_K;
  const int A_TILE_COL = tid % BLOCK_SIZE_K;
  const int B_TILE_ROW = tid / BLOCK_SIZE_N;
  const int B_TILE_COL = tid % BLOCK_SIZE_N;

  // 在进行多次搬运时需要跨越的行
  const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
  const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;

  for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
    // load A from global memory to shared memory
    #pragma unroll
    for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
      const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW;
      const int col = bk + A_TILE_COL;
      if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
        s_a[i + A_TILE_ROW][A_TILE_COL] = row < m && col < k ? A[row * k + col] : 0;
      } else {
        s_a[i + A_TILE_ROW][A_TILE_COL] = A[row * k + col];
      }
    }

    // load B from global memory to shared memory
    #pragma unroll
    for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
      const int row = bk + i + B_TILE_ROW;
      const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
      if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
        s_b[i + B_TILE_ROW][B_TILE_COL] = row < k && col < n ? B[row * n + col] : 0;
      } else {
        s_b[i + B_TILE_ROW][B_TILE_COL] = B[row * n + col];
      }
    }

    __syncthreads();

    // 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

    // calculate C
    #pragma unroll
    for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
      #pragma unroll
      for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
        #pragma unroll
        for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
            r_c[ty][tx] += s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk] * s_b[kk][THREAD_SIZE_X * threadIdx.x + tx];
        }
      }
    }

    __syncthreads();
  }

  // store back to C
  #pragma unroll
  for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
    #pragma unroll
    for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
      const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
      const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
      if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
        if (row < m && col < n) {
          C[row * n + col] += r_c[ty][tx];
        }
      } else {
        C[row * n + col] += r_c[ty][tx];
      }
    }
  }
}

利用向量化指令

可以尝试引导编译器使用 LDG.128 和 STG.128 指令来加速数据 IO。

// 通过这种方式可以引导编译器使用 LDG.128 指令
#define FETCH_FLOAT4(p) (reinterpret_cast<float4*>(&(p))[0])

__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
  // dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
  // dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

  __shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
  __shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
  float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
  float frag_a[THREAD_SIZE_Y];
  float frag_b[THREAD_SIZE_X];

  const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

  // 每个线程一次搬运四个数据

  // 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
  const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
  const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

  // 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
  const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
  const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
  const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
  const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

  // 在进行多次搬运时需要跨越的行
  const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
  const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

  for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
    // load A from global memory to shared memory
    #pragma unroll
    for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
      const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
      const int col = bk + A_TILE_COL;
      FETCH_FLOAT4(s_a[i + A_TILE_ROW_START][A_TILE_COL]) = FETCH_FLOAT4(A[row * k + col]);
    }

    // load B from global memory to shared memory
    #pragma unroll
    for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
      const int row = bk + i + B_TILE_ROW_START;
      const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
      FETCH_FLOAT4(s_b[i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
    }

    __syncthreads();

    // 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

    // calculate C
    #pragma unroll
    for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
      // load A from shared memory to register
      #pragma unroll
      for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
        frag_a[ty] = s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk];
      }

      // load B from shared memory to register
      #pragma unroll
      for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
        FETCH_FLOAT4(frag_b[tx]) = FETCH_FLOAT4(s_b[kk][THREAD_SIZE_X * threadIdx.x + tx]);
      }

      #pragma unroll
      for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
        #pragma unroll
        for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
            r_c[ty][tx] += frag_a[ty] * frag_b[tx];
        }
      }
    }
  }

  // store back to C
  #pragma unroll
  for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
    #pragma unroll
    for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
      const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
      const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
      FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
    }
  }
}

使用 double buffer

使用 double buffer 可以进一步掩盖访存延迟。

__global__ void matrixMultiplyKernel(float * A, float * B, float * C, int m, int n, int k) {
  // dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
  // dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

  __shared__ float s_a[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
  __shared__ float s_b[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
  float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
  float frag_a[2][THREAD_SIZE_Y];
  float frag_b[2][THREAD_SIZE_X];

  // 为了存储 BLOCK_SIZE_M * BLOCK_SIZE_K 的数据块,每个线程需要额外开启 ldg_a_reg 个寄存器进行存储
  float ldg_a_reg[BLOCK_SIZE_M * BLOCK_SIZE_K / THREAD_NUM_PER_BLOCK];
  float ldg_b_reg[BLOCK_SIZE_K * BLOCK_SIZE_N / THREAD_NUM_PER_BLOCK];

  const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

  // 每个线程一次搬运四个数据

  // 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
  const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
  const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

  // 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
  const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
  const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
  const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
  const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

  // 在进行多次搬运时需要跨越的行
  const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
  const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

  // preload A from global memory to shared memory
  #pragma unroll
  for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
    int ldg_index = i / A_TILE_ROW_STRIDE * 4;
    const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
    const int col = A_TILE_COL;
    FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
    s_a[0][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
    s_a[0][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
    s_a[0][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
    s_a[0][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
  }

  // preload B from global memory to shared memory
  #pragma unroll
  for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
    const int row = i + B_TILE_ROW_START;
    const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
    FETCH_FLOAT4(s_b[0][i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
  }

  __syncthreads();

  // preload A from shared memory to register
  #pragma unroll
  for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
    FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[0][0][THREAD_SIZE_Y * threadIdx.y + ty]);
  }

  // preload B from shared memory to register
  #pragma unroll
  for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
    FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[0][0][THREAD_SIZE_X * threadIdx.x + tx]);
  }

  int write_stage_idx = 1;
  int bk = 0;
  do {
    bk += BLOCK_SIZE_K;

    if (bk < k) {
      // preload A from global memory to register
      #pragma unroll
      for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
        int ldg_index = i / A_TILE_ROW_STRIDE * 4;
        const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
        const int col = bk + A_TILE_COL;
        FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
      }

      // preload B from global memory to register
      #pragma unroll
      for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
        int ldg_index = i / B_TILE_ROW_STRIDE * 4;
        const int row = bk + i + B_TILE_ROW_START;
        const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
        FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[row * n + col]);
      }
    }

    // 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

    int load_stage_idx = write_stage_idx ^ 1;

    // calculate C
    #pragma unroll
    for (int kk = 0; kk < BLOCK_SIZE_K - 1; ++kk) {
      // preload A from shared memory to register
      #pragma unroll
      for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
        FETCH_FLOAT4(frag_a[(kk + 1) % 2][ty]) = FETCH_FLOAT4(s_a[load_stage_idx][kk + 1][THREAD_SIZE_Y * threadIdx.y + ty]);
      }

      // preload B from shared memory to register
      #pragma unroll
      for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
        FETCH_FLOAT4(frag_b[(kk + 1) % 2][tx]) = FETCH_FLOAT4(s_b[load_stage_idx][kk + 1][THREAD_SIZE_X * threadIdx.x + tx]);
      }

      // calculate C (this tile)
      #pragma unroll
      for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
        #pragma unroll
        for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
            r_c[ty][tx] += frag_a[kk % 2][ty] * frag_b[kk % 2][tx];
        }
      }
    }

    if (bk < k) {
      // preload A from register to shared memory
      #pragma unroll
      for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
        int ldg_index = i / A_TILE_ROW_STRIDE * 4;
        s_a[write_stage_idx][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
        s_a[write_stage_idx][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
        s_a[write_stage_idx][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
        s_a[write_stage_idx][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
      }

      // preload B from register to shared memory
      #pragma unroll
      for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
        int ldg_index = i / B_TILE_ROW_STRIDE * 4;
        FETCH_FLOAT4(s_b[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
      }

      __syncthreads();
      write_stage_idx ^= 1;
    }

    // preload A from shared memory to register
    #pragma unroll
    for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
      FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[load_stage_idx ^ 1][0][THREAD_SIZE_Y * threadIdx.y + ty]);
    }

    // preload B from shared memory to register
    #pragma unroll
    for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
      FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[load_stage_idx ^ 1][0][THREAD_SIZE_X * threadIdx.x + tx]);
    }

    // compute last tile matmul THREAD_SIZE_X * THREAD_SIZE_Y
    #pragma unroll
    for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
      #pragma unroll
      for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
          r_c[ty][tx] += frag_a[1][ty] * frag_b[1][tx];
      }
    }

  } while(bk < k);

  // store back to C
  #pragma unroll
  for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
    #pragma unroll
    for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
      const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
      const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
      FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
    }
  }
}

性能对比

设备信息:NVIDIA Tesla T4, CUDA 11.1

对于 m = n = k = 1024,性能数据如下:

  • baseline:4.05 ms
  • 使用 shared memory:2.40 ms
  • 每个线程处理多个数据
    • without unroll:0.82 ms
    • with unroll:0.65 ms
  • 利用向量化指令:0.57 ms
  • 使用 double buffer:0.54 ms

更多矩阵尺寸性能对比:

完整代码在 zh0ngtian/cuda_learning