Taichi是一款高性能空间稀疏数据结构的计算引擎。其涉及到的计算密集型任务全部由C++写成,而前端则选择了易于上手且灵活性强的Python。乍一看重点应该是C++,然而一个好的前端设计同样很重要,因为它是用户认识Taichi的第一关。这里的前端不单指Python本身,也是Taichi在Python的基础上开发出的自己的一套使用规则。
故事的起源是从这个Issue开始的:https://github.com/taichi-dev/taichi/issues/548
Unify
ti.kernel
andti.classkernel
先来说一下这两个decorator分别在干什么。
一般来说,Taichi用户需要用@ti.kernel
来修饰一个用于计算的Python function。举个例子:
import taichi as ti
x = ti.var(ti.i32, shape=(42,))
@ti.kernel
def compute():
for i in x:
x[i] += 1
Taichi同时还支持OOP。但为此,Taichi需要两个decorator:@ti.data_oriented
和@ti.classkernel
,使用方法如下:
import taichi as ti
# 下文会单独讲解@ti.data_oriented
@ti.data_oriented
class X(object):
def __init__(self):
self.x = ti.var(ti.i32, shape=(42,))
@ti.classkernel
def compute(self):
for i in self.x:
self.x[i] += 1
可以看到,为了正确使用Taichi,用户需要记住@ti.classkernel
和@ti.kernel
各自的使用场景。这在一定程度上增加了用户的心智负担,因此这个前端的设计仍有改进的空间。
改进目标很明确:只留下@ti.kernel
就好。
如何做这个改进?思路也比较清晰,判断一下被修饰的函数是否是一个class method即可:
- 理想情况:
inspect.ismethod(func)
- 实际情况:Python class内定义的函数和普通的函数没什么不同,都只是plain function。因此在decorator执行期间,我们无法知道被修饰的函数是否属于某个class。
思路1
一般来说,一个Python decorator大概长这样:
def decorator(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
func(*args, **kwargs)
return wrapped
一种正常人的思路是,把这个决定放到wrapped
的执行期。由于在wrapped
中我们有了args
,我们可以尝试查看args[0]
的一些元数据来确定func
是否属于某个class。
先把我最开始想到的方法写出来:
# 判断func是否是个某个class的函数
def is_func_inside_class(x, func):
try:
# __wrapped__是因为func已经被@functools.wraps修饰过
return type(x).__dict__[func.__name__].__wrapped__ == func
except:
return False
def decorator(func):
@functools.wrap(func)
def wrapped(*args, **kwargs):
is_classkernel = False
try:
is_classkernel = is_func_inside_class(args[0], func)
except:
pass
# ...
return wrapped
为了理解is_func_inside_class
在干什么,我们需要理解Python class中的function究竟是如何被绑定到一个instance上的,即——self
是从哪里来的?
还是通过例子来解释,例子来自 https://stackoverflow.com/a/18342905/12003165
>>> X.compute
<function X.compute at 0x7fc2a0016d40>
>>> x = X()
>>> x.compute
<bound method X.compute of <__main__.X object at 0x7fd5f0031fd0>>
可以看到,通过class X
本身和通过instance x
来获取compute
,返回的结果是不一样的。前者仍然是一个function
,而后者变成了bound method
。这里发生的事情涉及到了Python descriptor的概念。
长话短说,在执行x.compute
时候,Python内部发生了这么一个过程:
try:
#1
return x.__dict__['compute']
except KeyError:
value = type(x).__dict__['compute']
try:
#2
return value.__get__(x)
except AttributeError:
#3
return value
-
X.compute
作为classX
中的function,并没有被存在instancex
的__dict__
中,因此#1
抛出KeyError
-
X.compute
被存在了classX
的__dict__
中,因此value
就指向X.compute
本身。同时function
定义了__get__
,因此我们从#2
返回。
那么,把X.compute
绑定到x
上就是发生在function.__get__
了。这是其概念上的实现:
class function(object):
# Built-in function class
# ...
def __get__(self, instance):
return Boundmethod(self, instance)
# ...
class BoundMethod(object):
def __init__(self, func, instance):
self.__func__ = func
self.__self__ = instance
def __call__(self, *args, **kwargs):
return self.__func__(self.__self__, *args, **kwargs)
可以看见,BoundMethod
不过是同时存了X.compute
的function pointer(无状态的plain function)以及instance x
。被调用时,它会将x
绑定到X.compute
的第一个参数上。
最后,由于执行期间任意一步都可能会挂掉(args
是空的、__dict__
中找不到等),这个判别式被放到了一个try
block中。一旦挂了立刻返回False
。
到这里,我们似乎是顺利解决了这个问题?
Autodiff
Taichi另一个特性是反向自动微分。所有被@ti.kernel
修饰过的函数都自动带有一个grad
的callable,调用它将计算这个kernel的导数。
@ti.kernel
def compute():
# ...
compute()
# 自动生成的导数kernel
compute.grad()
这个grad
同样也是在被@ti.kernel
修饰期间加上的。
但是这就造成了一个问题。由于上面这个方案需要在wrapped
执行期间才能判定function是否属于class,而grad
针对class function或plain function是由完全不同的两种方案实现的。这就导致了一个限制:想要使用compute.grad()
,用户必须至少运行一次compute
本身,使得wrapped
得以执行。
这个人为限制是由于实现方法本身并非最优导致的,有没有更给力的方法呢?
思路2
前面说完了正常人的思路。而下面这个思路,我第一次是跪着看完的。详情见 https://stackoverflow.com/questions/8793233/python-can-a-decorator-determine-if-a-function-is-being-defined-inside-a-class
PO主的问题和我们的如出一辙:python decorator可否判断所修饰的function是否在一个class中?
高赞回答的思路是...检查定义class时的stackframe!
具体来说,对于下面这个例子:
def decor(func):
import inspect
frames = inspect.stack()
# ...
class X(object):
@decor
def compute(self):
# ...
@decor
本身作用到compute
是在Python解释class X
的定义期间执行的。也就是说,当执行到frames = inspect.stack()
这一步时,我们还在定义class X
的过程中。此时的frames
大概长这样
FrameInfo(..., code_context=[' frames = inspect.stack()n'], index=0)
|- FrameInfo(..., code_context=[' frames = @decn'], index=1)
|- FrameInfo(..., code_context=[' frames = class X(object):n'], index=2)
可以看到,在这种情况下,index=2
的stackframe正是class X
本身,因此我们通过检查frames[2].code_context[0].startswith('class ')
就可以完成这个判断[注1]。我们不再需要把这个判断推迟到wrapped
执行时进行。
grad
是如何被添加的
讲到这里,我们最开始想要解决的问题已经结束了。然而Taichi本身实现的grad
也非常巧妙,值得说道一番。
对于plain function kernel,grad
并没有什么特殊的,向返回的wrapped
对象上添加grad
kernel即可。
def kernel(func):
is_classkernel = check_inside_class_by_stackframe()
primal = Kernel(func, is_grad=False, ...)
adjoint = Kernel(func, is_grad=True, ...)
if is_classkernel:
@functools.wraps(func)
def wrapped(*args, **kwargs):
# TODO: 如何实现???
else:
@functools.wraps(func)
def wrapped(*args, **kwargs):
primal(*args, **kwargs)
wrapped.grad = adjoint
# ...
return wrapped
但对于OOP kernel,这个问题变得有意思了很多。
先来设想一下我们如何调用OOP kernel grad,非常简单:
x.compute.grad()
然而这里有个问题,我们需要把grad
绑定到x
上,而x
和grad
之间隔着compute
。
如果我们把.
的作用域划分的更清楚一些,如下图所示:
(x.compute).grad()
|---------| |
|---------------|
可以看到,如果令x.compute
返回某个包含了x
的proxy object(类比前面提到的BoundMethod
),那么这个proxy在调用grad()
时候可以自动把x
作为第一个参数传给grad()
。
Taichi实现OOP grad的原理正是如此。在accessx
的某个attribute时,如果能利用某种方法截获这个attribute,并且做一些检查,判断这个attribute是不是一个kernel。如果是,我们就把它变成一个proxy。否则的话我们退化到Python本身对attribute的搜索规则。
到这里,答案已经呼之欲出了。Python的__getattribute__
恰好可以满足我们的需求。进一步的,想要实现这个方案,我们需要@ti.kernel
和@ti.data_oriented
这两个decorator配合工作。前者会在返回的object上添加几个私有的标记,而后者则override了所修饰的class本身的__getattribute__
,来读取这些标记。
def kernel(func):
is_classkernel = check_inside_class_by_stackframe()
primal = ...
adjoint = ...
# ...
wrapped._is_wrapped_kernel = True
wrapped._classkernel = is_classkernel
wrapped._primal = primal
wrapped._adjoint = adjoint
return wrapped
def data_oriented(cls):
def getattr(self, item):
x = super(cls, self).__getattribute__(item)
if hasattr(x, '_is_wrapped_kernel'):
wrapped = x.__func__
if wrapped._classkernel:
return BoundedDifferentiableMethod(self, wrapped)
return x
cls.__getattribute__ = getattr
return cls
之前提到的proxy就是这个BoundedDifferentiableMethod
。其原理也和将无状态的function
变为bound method
的方法类似,以下为其实现:
class BoundedDifferentiableMethod:
def __init__(self, kernel_owner, wrapped_kernel_func):
self._kernel_owner = kernel_owner
self._primal = wrapped_kernel_func._primal
self._adjoint = wrapped_kernel_func._adjoint
def __call__(self, *args, **kwargs):
return self._primal(self._kernel_owner, *args, **kwargs)
def grad(self, *args, **kwargs):
return self._adjoint(self._kernel_owner, *args, **kwargs)
我们还剩下最后一个小细节。之前在实现kernel
这个decorator时,我们并没有给出在is_classkernel == True
的情况下wrapped
的实现。
Take a guess of its implementation first :)
事实上,它的实现毫无影响。因为在这个情况下,使用这个kernel时会被BoundedDifferentiableMethod
接管,因此wrapped
的实现并不会调用。
为了确保这个invariant,Taichi在这里只不过是抛出异常而已:
def kernel(func):
is_classkernel = check_inside_class_by_stackframe()
# ...
if is_classkernel:
@functools.wraps(func)
def wrapped(*args, **kwargs):
raise KernelDefError(...)
# ...
return wrapped
备注
- 事实上,这个说法并不完全准确。经过测试,stackframe在Python3.8上和更低的版本上的
code_context
是有一些区别的,需要分别处理。以我个人的品味来看,这并非一个很优雅的解决方案。但是写软件本身就是妥协的过程:没有其他方案的情况下,能用的就是最好的。