Triton 简介

Triton 是一个专为深度学习优化而设计的 GPU 编程语言,能够帮助我们在 Python 中编写类似 CUDA 的高性能内核,但语法更简洁、更适合科研和工程开发。

优点包括:

  • 用 Python 写 GPU 算子,无需手写 CUDA;
  • 自动完成线程分配、寄存器映射;
  • 与 PyTorch 无缝集成,支持 torch.Tensor

环境准备

我们建议使用 Conda 环境管理:

1
2
3
conda create -n triton101 python=3.8 -y
conda activate triton101
pip install triton

确保你的设备具备 NVIDIA GPU 并安装正确的 CUDA 驱动。


第一个 Triton 内核

创建一个名为 add.py 的文件,写入如下内容:

1
2
3
4
5
6
7
8
9
10
11
12
import triton
import triton.language as tl

@triton.jit
def add_kernel(X_ptr, Y_ptr, Z_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(X_ptr + offsets, mask=mask)
y = tl.load(Y_ptr + offsets, mask=mask)
z = x + y
tl.store(Z_ptr + offsets, z, mask=mask)

这个内核的功能是将两个张量逐元素相加,结果存入第三个张量中。


调用 Triton 内核

写一个主函数来调用上面定义的 Triton 内核:

1
2
3
4
5
6
7
8
9
10
import torch

N = 1024
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
z = torch.empty_like(x)

add_kernel[(N // 256,)](x, y, z, N, BLOCK_SIZE=256)

print(z[:10])

你将看到 x + y 的结果张量中前 10 项的值。


概念解释

  • @triton.jit:将 Python 函数即时编译为 GPU 内核;
  • tl.load / tl.store:读取和写入显存数据;
  • tl.arange:生成并行线程索引;
  • BLOCK_SIZE:每个 block 中的线程数,是性能调优的重要参数。

完整示例

创建一个名为 add.py 的文件,写入如下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# add.py

import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(X_ptr, Y_ptr, Z_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(X_ptr + offsets, mask=mask)
y = tl.load(Y_ptr + offsets, mask=mask)
z = x + y
tl.store(Z_ptr + offsets, z, mask=mask)

if __name__ == "__main__":
N = 1024
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
z = torch.empty_like(x)
add_kernel[(N // 256,)](x, y, z, N, BLOCK_SIZE=256)
print("前10个结果:", z[:10])

保存为 add.py,运行:

1
python add.py

总结

本文介绍了 Triton 的基本安装方式,并带你构建了一个最简单的加法内核:

  • 学会使用 @triton.jit 编写 GPU 算子;
  • 学会 block 级调度的基本结构;
  • 学会使用 Triton 操作显存张量。