引言
在深度学习和高性能计算领域,NVIDIA GPU 凭借其强大的并行计算能力占据着重要地位。而 FP8(8 位浮点格式)作为一种新兴的数据格式,在提供较低精度的同时,能显著减少内存占用和提高计算速度。本文将详细解读一个基于 NVIDIA GPU(假设为 NVIDIA 5090)的 FP8 CUDA 程序,该程序主要实现了矩阵乘法(GEMM)并进行性能测试。
代码整体概述
此 CUDA 程序借助 FP8 数据格式开展矩阵乘法运算,并且对其性能进行测试。代码主要包含以下几个关键部分:
- 核函数:实现矩阵乘法。
- 矩阵初始化函数:对矩阵进行随机初始化。
- 性能测试函数:测量矩阵乘法的执行时间。
- 命令行参数解析函数:解析命令行输入的矩阵维度和测试轮数。
- 主函数:调用上述函数,输出性能测试结果。
代码详细解析
1. 头文件包含
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <string.h> // 用于参数解析
这里引入了 CUDA 运行时库、FP8 相关库以及标准 C 库,为后续代码的运行提供必要的支持。
2. 数据类型定义
typedef __nv_fp8_e4m3 fp8_t;
将 __nv_fp8_e4m3 定义为 fp8_t,方便后续代码使用。__nv_fp8_e4m3 是 NVIDIA 定义的一种 FP8 数据格式,采用 4 位指数和 3 位尾数。具体数据格式可以参考如下文档
15.18. __nv_fp8_e4m3 — CUDA Math API Reference Manual 12.8 documentation
3. 核函数 fp8_gemm
__global__ void fp8_gemm(
const fp8_t* A,
const fp8_t* B,
fp8_t* C,
int M, // 矩阵A的行数 (M x K)
int N, // 矩阵B的列数 (K x N)
int K // 公共维度
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int i = 0; i < K; i++) {
float a = (float)A[row * K + i];
float b = (float)B[i * N + col];
sum += a * b;
}
C[row * N + col] = fp8_t(sum);
}
}
- 功能:实现矩阵乘法 C=A*B,其中 A 是 M *K的矩阵,B是 K *N 的矩阵,C 是 M*N 的矩阵。
- 线程索引计算:借助线程块索引(blockIdx)和线程索引(threadIdx)算出当前线程要处理的矩阵元素的行和列。
- 矩阵乘法计算:对于每个线程,计算 C 矩阵中对应元素的值,通过将A矩阵的对应行与 \(B\) 矩阵的对应列元素相乘并累加得到。
- 数据类型转换:在计算过程中,把 FP8 类型的数据转换为 float 类型进行计算,最后再将结果转换回 FP8 类型。
4. 矩阵初始化函数 init_matrix
void init_matrix(fp8_t* mat, int rows, int cols) {
srand(time(NULL));
for (int i = 0; i < rows * cols; i++) {
float val = ((float)rand() / RAND_MAX) * 10.0f; // 生成0-10的随机数(符合FP8范围)
mat[i] = fp8_t(val);
}
}
- 功能:对矩阵进行随机初始化,生成 0 到 10 之间的随机数,并将其转换为 FP8 类型。
- 随机数生成:使用 rand() 函数生成随机数,再通过 srand(time(NULL)) 保证每次运行程序时生成的随机数不同。
5. 性能测试函数 test_performance
float test_performance(int M, int N, int K, int num_runs) {
// 主机内存分配
fp8_t* h_A = (fp8_t*)malloc(M * K * sizeof(fp8_t));
fp8_t* h_B = (fp8_t*)malloc(K * N * sizeof(fp8_t));
fp8_t* h_C = (fp8_t*)malloc(M * N * sizeof(fp8_t));
init_matrix(h_A, M, K);
init_matrix(h_B, K, N);
// 设备内存分配
fp8_t* d_A, *d_B, *d_C;
cudaMalloc((void**)&d_A, M * K * sizeof(fp8_t));
cudaMalloc((void**)&d_B, K * N * sizeof(fp8_t));
cudaMalloc((void**)&d_C, M * N * sizeof(fp8_t));
// 数据传输到设备(仅第一次传输,后续复用数据)
cudaMemcpy(d_A, h_A, M * K * sizeof(fp8_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, K * N * sizeof(fp8_t), cudaMemcpyHostToDevice);
// 配置核函数参数
dim3 block_size(16, 16); // 固定线程块大小(可优化为动态)
dim3 grid_size(
(N + block_size.x - 1) / block_size.x,
(M + block_size.y - 1) / block_size.y
);
// 性能测量事件
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
float total_time = 0.0f;
for (int run = 0; run < num_runs; run++) {
cudaEventRecord(start, 0);
fp8_gemm<<<grid_size, block_size>>>(d_A, d_B, d_C, M, N, K);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
float ms;
cudaEventElapsedTime(&ms, start, stop);
total_time += ms;
}
// 释放资源
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
free(h_A);
free(h_B);
free(h_C);
return total_time / num_runs; // 返回平均时间
}
- 功能:对矩阵乘法的性能进行测试,返回多次运行的平均执行时间。
- 内存分配:在主机和设备上分别分配矩阵所需的内存。
- 数据传输:把主机上的矩阵数据传输到设备上。
- 核函数配置:设定线程块大小和网格大小,调用核函数进行矩阵乘法运算。
- 性能测量:利用 CUDA 事件记录核函数的执行时间,多次运行取平均值。
- 资源释放:释放主机和设备上分配的内存。
6. 命令行参数解析函数 parse_arguments
void parse_arguments(int argc, char** argv, int* M, int* N, int* K, int* num_runs) {
if (argc != 5) {
fprintf(stderr, "Usage: %s <M> <N> <K> <num_runs>\n", argv[0]);
fprintf(stderr, "Example: %s 1024 1024 1024 10\n", argv[0]);
exit(EXIT_FAILURE);
}
// 解析参数并检查有效性
*M = atoi(argv[1]);
*N = atoi(argv[2]);
*K = atoi(argv[3]);
*num_runs = atoi(argv[4]);
if (*M <= 0 || *N <= 0 || *K <= 0 || *num_runs <= 0) {
fprintf(stderr, "Error: All parameters must be positive integers\n");
exit(EXIT_FAILURE);
}
}
- 功能:解析命令行输入的矩阵维度(\(M\)、\(N\)、\(K\))和测试轮数(num_runs),并检查参数的有效性。
- 参数检查:若参数数量不对或者参数为非正整数,程序会输出错误信息并退出。
7. 主函数 main
int main(int argc, char** argv) {
int M, N, K, num_runs;
parse_arguments(argc, argv, &M, &N, &K, &num_runs); // 解析命令行参数
// 执行性能测试
float avg_time_ms = test_performance(M, N, K, num_runs);
// 计算理论峰值(GEMM运算量:2*M*N*K FLOPs)
double flops = 2.0 * M * N * K;
double tflops = flops / (avg_time_ms * 1e9); // 转换为GFLOPS
// 输出结果
printf("FP8 GEMM Performance Test\n");
printf("------------------------\n");
printf("Parameters:\n");
printf(" M (rows of A): %d\n", M);
printf(" N (cols of B): %d\n", N);
printf(" K (common dimension): %d\n", K);
printf(" Num Runs: %d\n", num_runs);
printf("\nResults:\n");
printf(" Average Execution Time: %.2f ms\n", avg_time_ms);
printf(" Performance: %.2f TFLOPS\n", tflops);
return 0;
}
- 功能:调用参数解析函数和性能测试函数,计算并输出矩阵乘法的平均执行时间和性能(以 TFLOPS 为单位)。
- 性能计算:矩阵乘法的理论运算量为 2*M*N*K FLOPs,通过平均执行时间计算出性能(TFLOPS)。
代码使用方法
编译代码:
nvcc -o fp8_gemm fp8_gemm.cu
运行代码:
./fp8_gemm 1024 1024 1024 10
其中,1024 1024 1024 分别为矩阵的维度M、N、K,10 为测试轮数。你可以根据需要调整这些参数。
测试样例
5090在64x7168