import torch import triton import triton.language as tl
@triton.jit defadd_kernel(X_ptr, Y_ptr, Z_ptr, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < N x = tl.load(X_ptr + offsets, mask=mask) y = tl.load(Y_ptr + offsets, mask=mask) z = x + y tl.store(Z_ptr + offsets, z, mask=mask)
if __name__ == "__main__": N = 1024 x = torch.randn(N, device='cuda') y = torch.randn(N, device='cuda') z = torch.empty_like(x) add_kernel[(N // 256,)](x, y, z, N, BLOCK_SIZE=256) print("前10个结果:", z[:10])