前言

本节不具体讲解cython的原理和细节,提供一个最简单的例子,将一个python代码转化为一个cython代码,同时由于本人对cython刚入门,只会一个简单的操作,即在cython中声明变量的类型。实验证明,就这样简单添加变量类型,代码运行速度提升了将近4倍
cython对于代码中许多循环的情况很有帮助!

python代码

这里给的是CVPPP官方提供的evaluate代码(evaluate.py)
为了节省空间,这里删除了注释和一些无关紧要的判断语句

import numpy as np
def DiffFGLabels(inLabel,gtLabel):
    maxInLabel = np.int(np.max(inLabel))
    minInLabel = np.int(np.min(inLabel))
    maxGtLabel = np.int(np.max(gtLabel))
    minGtLabel = np.int(np.min(gtLabel))
    return  (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel) 

def BestDice(inLabel,gtLabel):
    score = 0
    maxInLabel = np.max(inLabel) 
    minInLabel = np.min(inLabel) 
    maxGtLabel = np.max(gtLabel) 
    minGtLabel = np.min(gtLabel) 
    if(maxInLabel==minInLabel):
        return score
    for i in range(minInLabel+1,maxInLabel+1):
        sMax = 0; 
        for j in range(minGtLabel+1,maxGtLabel+1): 
            s = Dice(inLabel, gtLabel, i, j) 
            if(sMax < s):
                sMax = s
        score = score + sMax; 
    score = score/(maxInLabel-minInLabel)
    return score

def FGBGDice(inLabel,gtLabel):
    minInLabel = np.min(inLabel) 
    minGtLabel = np.min(gtLabel) 
    one = np.ones(inLabel.shape)    
    inFgLabel = (inLabel != minInLabel*one)*one
    gtFgLabel = (gtLabel != minGtLabel*one)*one
    return Dice(inFgLabel,gtFgLabel,1,1)

def Dice(inLabel, gtLabel, i, j):
    one = np.ones(inLabel.shape)
    inMask = (inLabel==i*one) 
    gtMask = (gtLabel==j*one) 
    inSize = np.sum(inMask*one) 
    gtSize = np.sum(gtMask*one) 
    overlap= np.sum(inMask*gtMask*one) 
    if ((inSize + gtSize)>1e-8):
        out = 2*overlap/(inSize + gtSize) 
    else:
        out = 0
    return out

def AbsDiffFGLabels(inLabel,gtLabel):
    return np.abs( DiffFGLabels(inLabel,gtLabel) )

def SymmetricBestDice(inLabel,gtLabel):
    bd1 = BestDice(inLabel,gtLabel)
    bd2 = BestDice(gtLabel,inLabel)
    if bd1 < bd2:
        return bd1
    else:
        return bd2

Cython代码

创建一个evaluate.pyx文件(注意:后缀得是pyx!!!

from __future__ import division
from libcpp cimport bool as bool_t
import numpy as np
cimport numpy as np
cimport cython

ctypedef bint TYPE_BOOL
ctypedef unsigned long long TYPE_U_INT64
ctypedef unsigned int TYPE_U_INT32
ctypedef unsigned short TYPE_U_INT16
ctypedef unsigned char TYPE_U_INT8
ctypedef long long TYPE_INT64
ctypedef int TYPE_INT32
ctypedef short TYPE_INT16
ctypedef signed char TYPE_INT8
ctypedef float TYPE_FLOAT
ctypedef double TYPE_DOUBLE

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
def DiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int maxInLabel = np.int(np.max(inLabel)) 
    cdef int minInLabel = np.int(np.min(inLabel)) 
    cdef int maxGtLabel = np.int(np.max(gtLabel)) 
    cdef int minGtLabel = np.int(np.min(gtLabel)) 
    cdef double out = (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel)
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def BestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int i, j
    cdef double sMax = 0.0
    cdef double s = 0.0
    cdef double score = 0.0 
    cdef int maxInLabel = np.max(inLabel) 
    cdef int minInLabel = np.min(inLabel) 
    cdef int maxGtLabel = np.max(gtLabel) 
    cdef int minGtLabel = np.min(gtLabel) 
    if(maxInLabel == minInLabel): 
        return score
    for i in range(minInLabel+1, maxInLabel+1):
        sMax = 0;
        for j in range(minGtLabel+1, maxGtLabel+1):
            s = Dice(inLabel, gtLabel, i, j) 
            if(sMax < s):
                sMax = s
        score = score + sMax;
    score = score / (maxInLabel-minInLabel)
    return score

@cython.boundscheck(False)
@cython.wraparound(False)
def FGBGDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int minInLabel = np.min(inLabel) 
    cdef int minGtLabel = np.min(gtLabel) 
    cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
    cdef np.ndarray[TYPE_U_INT16, ndim=2] inFgLabel = (inLabel != minInLabel*one)*one
    cdef np.ndarray[TYPE_U_INT16, ndim=2] gtFgLabel = (gtLabel != minGtLabel*one)*one
    cdef double out = Dice(inFgLabel,gtFgLabel,1,1) 
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def Dice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel, int i, int j):
    cdef double out = 0.0
    cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
    cdef int inSize = np.sum((inLabel==i*one)*one) 
    cdef int gtSize = np.sum((gtLabel==j*one)*one) 
    cdef int overlap= np.sum((inLabel==i*one)*(gtLabel==j*one)*one) 
    if ((inSize + gtSize)>1e-8):
        out = 2*overlap/(inSize + gtSize) 
    else:
        out = 0
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def AbsDiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef double out = np.abs(DiffFGLabels(inLabel, gtLabel))
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def SymmetricBestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef double bd1 = BestDice(inLabel,gtLabel)
    cdef double bd2 = BestDice(gtLabel,inLabel)
    if bd1 < bd2:
        return bd1
    else:
        return bd2

编译

写好pyx文件后,需要再写一个setup.py文件,里面的内容也很简单(注意修改相应的pyx文件名!!!):

import distutils.core
import Cython.Build
import numpy as np
distutils.core.setup(
    ext_modules = Cython.Build.cythonize("evaluate.pyx"),
    include_dirs = [np.get_include()])

编译:

python setup.py build_ext --inplace

编译成功后,就可以正常的 import 里面的函数了

解释

通过对比两个代码,我们可以看出一些规律也可以总结出一些规律

  1. 在导入包的时候,有一句最重要的是:cimport numpy as np,表明使用的是cython接口的numpy。(当然也有import numpy as np,编译器会根据情况使用numpy还是c-numpy);还有一句是:from libcpp cimport bool as bool_t,这是为了使用C语言中的bool类型(这个例子里面没有用到bool类型,可以不用管)
  2. 为数据类型起一个新名字:ctypedef。这个不是必须,但这里为了可读性,我列举了一些numpy中常用的数据类型对应的C语言中的数据类型

numpy

C

np.uint8

unsigned char

np.uint16

unsigned short

np.uint32

unsigned int

np.uint64

unsigned long long

np.int8

signed char

np.int16

short

np.int32

int

np.int64

long long

np.float32

double

  1. 每个函数前面都有:@cython.boundscheck(False) 和 @cython.wraparound(False),这是为了加速而关闭边界检查,这样做就需要提前保证代码的准确性,建议在python下验证代码的准确性
  2. 每个函数的输入变量都定义了数据类型,比如这里全是:np.ndarray[TYPE_U_INT16, ndim=2],这表明输入的一个二维的uint16的numpy数组,如果输入类型不是这样,那就会报错
  3. cdef int/double:定义整型/双精度浮点型变量。在使用每个变量须先对它进行定义,如果没有编译器就会花时间来判断,就会耗时