Python数值计算:使用插值函数提高特殊函数的计算速度

使用插值函数提高特殊函数的计算速度

在最近的数值模拟中,有一类函数被上万次地调用,而库函数中的计算速率很慢。所以尝试做了优化,最终将此热点函数提升了大概11倍的运算速度、并保持了float64的数值精度,在此做个记录。

源起

涉及到的函数叫第一类贝塞尔函数, python插值 python插值求速度_python,python的第三方库scipy中有这个函数可以调用,叫做scipy.special.jv(n, x),这玩意大概长这样:

python插值 python插值求速度_python_02


这种特殊函数的求值一般是需要利用其递推性质、用一系列的递归、迭代才能得到。相比于python插值 python插值求速度_numpy_03这种简单函数,特殊函数求值往往需要更多的CPU计算操作. 我们可以做一个简单的测试。先创建两个函数,一个是我们感兴趣的python插值 python插值求速度_python_04JV, 另一个则是最简单的python插值 python插值求速度_python插值_05函数simple_func

import numpy as np
import time

from scipy.special import jv as JV # the target function. 

simple_func= lambda x: 3.2*x**2 + 0.6

我们创建一个维的数组x

SZ = 2048 
REPEAT = 10000 
n = 5 # order of Bessel function. 
x = np.linspace(-30., 30., SZ) 
y = np.zeros_like(x)

然后重复10000次来计算函数值,并比较耗费的时间:

# y = x^2: 
start_time = time.time() 
for _ in range(REPEAT): 
    y = simple_func(x) 
end_time = time.time() 
print("simple TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
    n, SZ, REPEAT, end_time-start_time))

# y = J_n(x), n = 5:
start_time = time.time() 
for _ in range(REPEAT): 
    y = JV(n, x)
end_time = time.time() 
print("Jn(x) TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
    n, SZ, REPEAT, end_time-start_time))

输出是这样的:

simple TimeCost[sec]_n5_SZ2048_REP10000: 0.033106
 Jn(x) TimeCost[sec]_n5_SZ2048_REP10000: 17.599081

为了排除python可能的优化, 我们将REPEAT加倍后得到大概两倍的运算时间,没问题:

simple TimeCost[sec]_n5_SZ2048_REP20000: 0.064189
 Jn(x) TimeCost[sec]_n5_SZ2048_REP20000: 34.858215

WoW! 可以看到,计算贝塞尔函数的用时远远高于python插值 python插值求速度_数据分析_06,大概是530倍!可以想象,计算一个贝塞尔函数值时背后的循环、递归等等操作的消耗量有多大。所以我们需要珍惜每一次调用JV得到的数据,而不是大量地重新计算!

优化

为了充分利用JV函数的结果,我们选择在程序一开始,先计算大量的数据,并存储起来。当遇到求取某个x处的函数值时,用已经计算好的、临近x的函数值表示,而不是再去计算一遍。 也就是使用插值函数取代原来的复杂函数。

Scipy的interpolate库提供了一维的插值函数interp1d,我们使用它来对上述两个函数制作插值版本:

from scipy.interpolate import interp1d 
xmax = 50. 
x_interp = np.linspace(-xmax, xmax, 4096*2) 

Simple_interp = interp1d(x_interp, simple_func(x_interp), 
                kind="linear", 
                bounds_error=False, 
                fill_value=np.nan) 

JV_interp = interp1d(x_interp, JV(n, x_interp), 
                kind="linear", # linear interpolation for now. 
                bounds_error=False, 
                fill_value=np.nan)

现在,原来的simple_funcJV(n, x)分别有了插值版本
Simple_interpJV_interp。这里插值的区间只在[-xmax, xmax],超出的话函数值就是np.nan;插值类型暂时使用了linear。那我们看看这两个函数的效率如何呢?

start_time = time.time() 
for _ in range(REPEAT): 
    y = Simple_interp(x) 
end_time = time.time() 
print("simple_interp TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
    n, SZ, REPEAT, end_time-start_time))

start_time = time.time() 
for _ in range(REPEAT): 
    y = JV_interp(x)
end_time = time.time() 
print("Jn(x)_interp TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
    n, SZ, REPEAT, end_time-start_time))

REPEAT为10000时输出是:

simple_interp TimeCost[sec]_n5_SZ2048_REP10000: 0.386011
 Jn(x)_interp TimeCost[sec]_n5_SZ2048_REP10000: 0.384005

贝塞尔函数耗时从python插值 python插值求速度_python插值_07sec减少到了python插值 python插值求速度_numpy_08sec, 速度提升了45倍!这是非常可观的优化。

而简单函数的耗时从python插值 python插值求速度_numpy_09sec增加到了相同的量级python插值 python插值求速度_算法_10sec。我们得出的结论是:

  1. 对于复杂函数,如果在程序中被多次使用、且每一次的调用都是计算密集(Dense)的, 那么可以通过预计算,使用插值函数替代旧的函数,提高效率;
  2. 对于简单函数,如果上述操作后,插值函数调用成本高于原来函数直接使用的成本,那么不必这么操作。

数值精度

插值函数的精度可以保证吗?这是我们数值计算必须关系的问题。在合理地取样本点和插值阶次kind后,答案是肯定的:float64的精度是可以达到的(python插值 python插值求速度_数据分析_11)。
我们定义一个绝对误差量: python插值 python插值求速度_python_12,这里python插值 python插值求速度_算法_13 python插值 python插值求速度_算法_14分别是两种方法得到的结果。尝试不同阶次的插值,并考虑相对误差的大小,interp1d支持的插值阶次是kind=1, 2, 3, 5, 7, 9, ... (这里我们使用了插值和非插值版本的函数作差来表示相对误差,实际也可以使用另外的第三方库函数、比如mpmath来相互比较,后者相当于引入裁判、会更合理,不过最终结论相同):

x_interp = np.linspace(-xmax, xmax, 4096*4) 
import matplotlib.pyplot as plt 

kinds = [1, 2, 3, 5, 7, 9, 11] 
err = lambda f1, f2: 2.*np.abs(f1 - f2) 
log10_err = lambda f1, f2: np.log10(err(f1, f2)) 

interp_funcs = []
for kindi in kinds:
    interp_funcs.append(
        interp1d(x_interp, JV(n, x_interp), 
                    kind=kindi, 
                    bounds_error=False, 
                    fill_value=np.nan) 
    )
f1 = JV(n, x) 

# compare:
fig = plt.figure() 
ax = fig.add_subplot(121) 
ax.set_xlabel("x") 
ax.set_ylabel("$J_{n}(x)$")
ax.plot(x, f1) 

ax2 = fig.add_subplot(122) 
ax2.set_xlabel("x")
ax2.set_ylabel("log10(err)")
for idx, func in enumerate(interp_funcs): 
    f2 = func(x) 
    logerr = log10_err(f1, f2) 
    ax2.plot(x, logerr, label="interp kind=" + str(kinds[idx]))  
ax2.legend()
plt.show()

得到的图像是:

python插值 python插值求速度_python_15


左图是函数图像, 右边则是绝对误差的大小。可以看到,当kind>=5的时候,绝对误差已经是python插值 python插值求速度_数据分析_11以下了,完全满足实际的需求。这时的相对误差实际上也在python插值 python插值求速度_算法_17以下。

那么这种数值误差控制完好的情况下效率提升多少呢?

simple TimeCost[sec]_n5_SZ2048_REP10000: 0.034089
 Jn(x) TimeCost[sec]_n5_SZ2048_REP10000: 17.582057
 simple_interp TimeCost[sec]_n5_SZ2048_REP10000: 1.467844
 Jn(x)_interp TimeCost[sec]_n5_SZ2048_REP10000: 1.454807

python插值 python插值求速度_python_18sec到python插值 python插值求速度_算法_19sec, 大概python插值 python插值求速度_numpy_20倍。虽然由于插值密度和阶次的提高,这个增强不如之前的45倍那么夸张,但对于密集计算程序来说,也是非常可观的收获。

总结

针对大量使用的热点函数,如果函数本身的求值是计算密集的,那么可以使用插值函数取代,并注意数值精度。

欢迎提出新的想法~
原创,如果分享或转载,必须附本文链接

测试环境:python3.8.10-64bit, numpy=1.24.1, scipy=1.10.0, matplotlib=3.6.3, AMD R7 68** 64bit, single thread.

参考

  1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.jv.html
  2. https://netlib.org/amos/zbesj.f