Triton分析之一:compiler
JIT实现 链接到标题
在使用Triton编写kernel的时候,首先需要@triton.jit修饰我们的kernel,这个修饰符本质上是把函数变成了一个JITFunction。
@overload
def jit(fn: T) -> JITFunction[T]:
...
class JITFunction(KernelInterface[T]):
...
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
Python中用obj[key]索引的时候,其实是调用class的__getitem__方法。启动kernel的时候用kernel[grid](...)语法传入grid,通过调用__getitem__方法,返回一个lambda函数,在lambda函数调用run方法来启动。
JITFunction类几个主要的方法如下:
run,执行kernel时候调用的方法;warmup,调用run方法,用于提前编译kernel;cache_key,这是一个只读方法,用于获取缓存kernel的key;add_pre_run_hook&_call_hook,自定义一些钩子;
主要的方法是run,这里大概做了几件事情:
def run(...):
# 获取kernel执行的device和stream
# 执行一些hook函数
# 计算cache key
# 如果kernel没有被编译缓存,编译kernel,这里会比较耗时;
if kernel is None:
src = self.ASTSource(self, signature, constexprs, attrs)
kernel = self.compile(src, target=target, options=options.__dict__)
kernel_cache[key] = kernel
self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
# 如果run不是被warmup调用的,计算grid并启动kernel计算,否则把之前编译好的kernel返回给warmup;
可以看到在compile的时候,Triton源码先是被解析成AST,然后调用compile方法去编译。