Taichi是一款高性能空间稀疏数据结构的计算引擎。其涉及到的计算密集型任务全部由C++写成,而前端则选择了易于上手且灵活性强的Python。乍一看重点应该是C++,然而一个好的前端设计同样很重要,因为它是用户认识Taichi的第一关。这里的前端不单指Python本身,也是Taichi在Python的基础上开发出的自己的一套使用规则。

故事的起源是从这个Issue开始的:https://github.com/taichi-dev/taichi/issues/548

Unify ti.kernel and ti.classkernel

先来说一下这两个decorator分别在干什么。

一般来说,Taichi用户需要用@ti.kernel来修饰一个用于计算的Python function。举个例子:

import taichi as ti

x = ti.var(ti.i32, shape=(42,))

@ti.kernel
def compute():
  for i in x:
    x[i] += 1

Taichi同时还支持OOP。但为此,Taichi需要两个decorator:@ti.data_oriented@ti.classkernel,使用方法如下:

import taichi as ti

# 下文会单独讲解@ti.data_oriented
@ti.data_oriented
class X(object):
  def __init__(self):
    self.x = ti.var(ti.i32, shape=(42,))
  
  @ti.classkernel
  def compute(self):
    for i in self.x:
      self.x[i] += 1

可以看到,为了正确使用Taichi,用户需要记住@ti.classkernel@ti.kernel各自的使用场景。这在一定程度上增加了用户的心智负担,因此这个前端的设计仍有改进的空间。

改进目标很明确:只留下@ti.kernel就好。

如何做这个改进?思路也比较清晰,判断一下被修饰的函数是否是一个class method即可:

  • 理想情况:inspect.ismethod(func)
  • 实际情况:Python class内定义的函数和普通的函数没什么不同,都只是plain function。因此在decorator执行期间,我们无法知道被修饰的函数是否属于某个class。

思路1

一般来说,一个Python decorator大概长这样:

def decorator(func):
  @functools.wraps(func)
  def wrapped(*args, **kwargs):
    func(*args, **kwargs)
  return wrapped

一种正常人的思路是,把这个决定放到wrapped的执行期。由于在wrapped中我们有了args,我们可以尝试查看args[0]的一些元数据来确定func是否属于某个class。

先把我最开始想到的方法写出来:

# 判断func是否是个某个class的函数
def is_func_inside_class(x, func):
  try:
    # __wrapped__是因为func已经被@functools.wraps修饰过
    return type(x).__dict__[func.__name__].__wrapped__ == func
  except:
    return False

def decorator(func):
  @functools.wrap(func)
  def wrapped(*args, **kwargs):
    is_classkernel = False
    try:
      is_classkernel = is_func_inside_class(args[0], func)
    except:
      pass
    # ...

  return wrapped

为了理解is_func_inside_class在干什么,我们需要理解Python class中的function究竟是如何被绑定到一个instance上的,即——self是从哪里来的?

还是通过例子来解释,例子来自 https://stackoverflow.com/a/18342905/12003165

>>> X.compute
<function X.compute at 0x7fc2a0016d40>

>>> x = X()
>>> x.compute
<bound method X.compute of <__main__.X object at 0x7fd5f0031fd0>>

可以看到,通过class X本身和通过instance x来获取compute,返回的结果是不一样的。前者仍然是一个function,而后者变成了bound method。这里发生的事情涉及到了Python descriptor的概念。

长话短说,在执行x.compute时候,Python内部发生了这么一个过程:

try:
  #1
  return x.__dict__['compute']
except KeyError:
  value = type(x).__dict__['compute']
  try:
    #2
    return value.__get__(x)
  except AttributeError:
    #3
    return value

  1. X.compute作为class X中的function,并没有被存在instance x__dict__中,因此#1抛出KeyError
  2. X.compute被存在了class X__dict__中,因此value就指向X.compute本身。同时function定义了__get__,因此我们从#2返回。

那么,把X.compute绑定到x上就是发生在function.__get__了。这是其概念上的实现:

class function(object):
  # Built-in function class
  # ...
  def __get__(self, instance):
    return Boundmethod(self, instance)
  # ...

class BoundMethod(object):
  def __init__(self, func, instance):
    self.__func__ = func
    self.__self__ = instance

  def __call__(self, *args, **kwargs):
    return self.__func__(self.__self__, *args, **kwargs)

可以看见,BoundMethod不过是同时存了X.compute的function pointer(无状态的plain function)以及instance x。被调用时,它会将x绑定到X.compute的第一个参数上。

最后,由于执行期间任意一步都可能会挂掉(args是空的、__dict__中找不到等),这个判别式被放到了一个try block中。一旦挂了立刻返回False

到这里,我们似乎是顺利解决了这个问题?

Autodiff

Taichi另一个特性是反向自动微分。所有被@ti.kernel修饰过的函数都自动带有一个grad 的callable,调用它将计算这个kernel的导数。

@ti.kernel
def compute():
  # ...

compute()
# 自动生成的导数kernel
compute.grad()

这个grad同样也是在被@ti.kernel修饰期间加上的。

但是这就造成了一个问题。由于上面这个方案需要在wrapped执行期间才能判定function是否属于class,而grad针对class function或plain function是由完全不同的两种方案实现的。这就导致了一个限制:想要使用compute.grad(),用户必须至少运行一次compute本身,使得wrapped得以执行。

这个人为限制是由于实现方法本身并非最优导致的,有没有更给力的方法呢?

思路2

前面说完了正常人的思路。而下面这个思路,我第一次是跪着看完的。详情见 https://stackoverflow.com/questions/8793233/python-can-a-decorator-determine-if-a-function-is-being-defined-inside-a-class

PO主的问题和我们的如出一辙:python decorator可否判断所修饰的function是否在一个class中?

高赞回答的思路是...检查定义class时的stackframe!

具体来说,对于下面这个例子:

def decor(func):
  import inspect
  frames = inspect.stack()
  # ...


class X(object):
  @decor
  def compute(self):
    # ...

@decor本身作用到compute是在Python解释class X的定义期间执行的。也就是说,当执行到frames = inspect.stack()这一步时,我们还在定义class X的过程中。此时的frames大概长这样

FrameInfo(..., code_context=['  frames = inspect.stack()n'], index=0)
|- FrameInfo(..., code_context=['  frames = @decn'], index=1)
   |- FrameInfo(..., code_context=['  frames = class X(object):n'], index=2)

可以看到,在这种情况下,index=2的stackframe正是class X本身,因此我们通过检查frames[2].code_context[0].startswith('class ')就可以完成这个判断[注1]。我们不再需要把这个判断推迟到wrapped执行时进行。

grad是如何被添加的

讲到这里,我们最开始想要解决的问题已经结束了。然而Taichi本身实现的grad也非常巧妙,值得说道一番。

对于plain function kernel,grad并没有什么特殊的,向返回的wrapped对象上添加grad kernel即可。

def kernel(func):
  is_classkernel = check_inside_class_by_stackframe()
  primal = Kernel(func, is_grad=False, ...)
  adjoint = Kernel(func, is_grad=True, ...)

  if is_classkernel:
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
      # TODO: 如何实现???
  else:
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
      primal(*args, **kwargs)
    wrapped.grad = adjoint
  # ...
  return wrapped

但对于OOP kernel,这个问题变得有意思了很多。

先来设想一下我们如何调用OOP kernel grad,非常简单:

x.compute.grad()

然而这里有个问题,我们需要把grad绑定到x上,而xgrad之间隔着compute

如果我们把.的作用域划分的更清楚一些,如下图所示:

(x.compute).grad()
 |---------|     |
 |---------------|

可以看到,如果令x.compute返回某个包含了x的proxy object(类比前面提到的BoundMethod),那么这个proxy在调用grad()时候可以自动把x作为第一个参数传给grad()

Taichi实现OOP grad的原理正是如此。在accessx的某个attribute时,如果能利用某种方法截获这个attribute,并且做一些检查,判断这个attribute是不是一个kernel。如果是,我们就把它变成一个proxy。否则的话我们退化到Python本身对attribute的搜索规则。

到这里,答案已经呼之欲出了。Python的__getattribute__恰好可以满足我们的需求。进一步的,想要实现这个方案,我们需要@ti.kernel@ti.data_oriented这两个decorator配合工作。前者会在返回的object上添加几个私有的标记,而后者则override了所修饰的class本身的__getattribute__,来读取这些标记。

def kernel(func):
  is_classkernel = check_inside_class_by_stackframe()
  primal = ...
  adjoint = ...
  # ...
  wrapped._is_wrapped_kernel = True
  wrapped._classkernel = is_classkernel
  wrapped._primal = primal
  wrapped._adjoint = adjoint
  return wrapped

def data_oriented(cls):
  def getattr(self, item):
    x = super(cls, self).__getattribute__(item)
    if hasattr(x, '_is_wrapped_kernel'):
      wrapped = x.__func__
      if wrapped._classkernel:
        return BoundedDifferentiableMethod(self, wrapped)
    return x
  
  cls.__getattribute__ = getattr
  return cls

之前提到的proxy就是这个BoundedDifferentiableMethod。其原理也和将无状态的function变为bound method的方法类似,以下为其实现:

class BoundedDifferentiableMethod:
  def __init__(self, kernel_owner, wrapped_kernel_func):
    self._kernel_owner = kernel_owner
    self._primal = wrapped_kernel_func._primal
    self._adjoint = wrapped_kernel_func._adjoint

  def __call__(self, *args, **kwargs):
    return self._primal(self._kernel_owner, *args, **kwargs)

  def grad(self, *args, **kwargs):
    return self._adjoint(self._kernel_owner, *args, **kwargs)

我们还剩下最后一个小细节。之前在实现kernel这个decorator时,我们并没有给出在is_classkernel == True的情况下wrapped的实现。

Take a guess of its implementation first :)

事实上,它的实现毫无影响。因为在这个情况下,使用这个kernel时会被BoundedDifferentiableMethod接管,因此wrapped的实现并不会调用。

为了确保这个invariant,Taichi在这里只不过是抛出异常而已:

def kernel(func):
  is_classkernel = check_inside_class_by_stackframe()
  # ...
  if is_classkernel:
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
      raise KernelDefError(...)
  # ...
  return wrapped

备注

  1. 事实上,这个说法并不完全准确。经过测试,stackframe在Python3.8上和更低的版本上的code_context是有一些区别的,需要分别处理。以我个人的品味来看,这并非一个很优雅的解决方案。但是写软件本身就是妥协的过程:没有其他方案的情况下,能用的就是最好的。