Python数值计算:使用插值函数提高特殊函数的计算速度
使用插值函数提高特殊函数的计算速度
在最近的数值模拟中,有一类函数被上万次地调用,而库函数中的计算速率很慢。所以尝试做了优化,最终将此热点函数提升了大概11倍的运算速度、并保持了float64的数值精度,在此做个记录。
源起
涉及到的函数叫第一类贝塞尔函数, ,python的第三方库scipy中有这个函数可以调用,叫做scipy.special.jv(n, x),这玩意大概长这样:
这种特殊函数的求值一般是需要利用其递推性质、用一系列的递归、迭代才能得到。相比于这种简单函数,特殊函数求值往往需要更多的CPU计算操作. 我们可以做一个简单的测试。先创建两个函数,一个是我们感兴趣的
即
JV
, 另一个则是最简单的函数
simple_func
:
import numpy as np
import time
from scipy.special import jv as JV # the target function.
simple_func= lambda x: 3.2*x**2 + 0.6
我们创建一个维的数组x
,
SZ = 2048
REPEAT = 10000
n = 5 # order of Bessel function.
x = np.linspace(-30., 30., SZ)
y = np.zeros_like(x)
然后重复10000
次来计算函数值,并比较耗费的时间:
# y = x^2:
start_time = time.time()
for _ in range(REPEAT):
y = simple_func(x)
end_time = time.time()
print("simple TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
n, SZ, REPEAT, end_time-start_time))
# y = J_n(x), n = 5:
start_time = time.time()
for _ in range(REPEAT):
y = JV(n, x)
end_time = time.time()
print("Jn(x) TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
n, SZ, REPEAT, end_time-start_time))
输出是这样的:
simple TimeCost[sec]_n5_SZ2048_REP10000: 0.033106
Jn(x) TimeCost[sec]_n5_SZ2048_REP10000: 17.599081
为了排除python可能的优化, 我们将REPEAT
加倍后得到大概两倍的运算时间,没问题:
simple TimeCost[sec]_n5_SZ2048_REP20000: 0.064189
Jn(x) TimeCost[sec]_n5_SZ2048_REP20000: 34.858215
WoW! 可以看到,计算贝塞尔函数的用时远远高于,大概是530倍!可以想象,计算一个贝塞尔函数值时背后的循环、递归等等操作的消耗量有多大。所以我们需要珍惜每一次调用
JV
得到的数据,而不是大量地重新计算!
优化
为了充分利用JV
函数的结果,我们选择在程序一开始,先计算大量的数据,并存储起来。当遇到求取某个x
处的函数值时,用已经计算好的、临近x
的函数值表示,而不是再去计算一遍。 也就是使用插值函数取代原来的复杂函数。
Scipy的interpolate库提供了一维的插值函数interp1d
,我们使用它来对上述两个函数制作插值版本:
from scipy.interpolate import interp1d
xmax = 50.
x_interp = np.linspace(-xmax, xmax, 4096*2)
Simple_interp = interp1d(x_interp, simple_func(x_interp),
kind="linear",
bounds_error=False,
fill_value=np.nan)
JV_interp = interp1d(x_interp, JV(n, x_interp),
kind="linear", # linear interpolation for now.
bounds_error=False,
fill_value=np.nan)
现在,原来的simple_func
和JV(n, x)
分别有了插值版本Simple_interp
和JV_interp
。这里插值的区间只在[-xmax, xmax]
,超出的话函数值就是np.nan
;插值类型暂时使用了linear
。那我们看看这两个函数的效率如何呢?
start_time = time.time()
for _ in range(REPEAT):
y = Simple_interp(x)
end_time = time.time()
print("simple_interp TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
n, SZ, REPEAT, end_time-start_time))
start_time = time.time()
for _ in range(REPEAT):
y = JV_interp(x)
end_time = time.time()
print("Jn(x)_interp TimeCost[sec]_n%d_SZ%d_REP%d: %.6f"%(
n, SZ, REPEAT, end_time-start_time))
在REPEAT
为10000时输出是:
simple_interp TimeCost[sec]_n5_SZ2048_REP10000: 0.386011
Jn(x)_interp TimeCost[sec]_n5_SZ2048_REP10000: 0.384005
贝塞尔函数耗时从sec减少到了
sec, 速度提升了45倍!这是非常可观的优化。
而简单函数的耗时从sec增加到了相同的量级
sec。我们得出的结论是:
- 对于复杂函数,如果在程序中被多次使用、且每一次的调用都是计算密集(Dense)的, 那么可以通过预计算,使用插值函数替代旧的函数,提高效率;
- 对于简单函数,如果上述操作后,插值函数调用成本高于原来函数直接使用的成本,那么不必这么操作。
数值精度
插值函数的精度可以保证吗?这是我们数值计算必须关系的问题。在合理地取样本点和插值阶次kind
后,答案是肯定的:float64
的精度是可以达到的()。
我们定义一个绝对误差量: ,这里
分别是两种方法得到的结果。尝试不同阶次的插值,并考虑相对误差的大小,
interp1d
支持的插值阶次是kind=1, 2, 3, 5, 7, 9, ...
(这里我们使用了插值和非插值版本的函数作差来表示相对误差,实际也可以使用另外的第三方库函数、比如mpmath来相互比较,后者相当于引入裁判、会更合理,不过最终结论相同):
x_interp = np.linspace(-xmax, xmax, 4096*4)
import matplotlib.pyplot as plt
kinds = [1, 2, 3, 5, 7, 9, 11]
err = lambda f1, f2: 2.*np.abs(f1 - f2)
log10_err = lambda f1, f2: np.log10(err(f1, f2))
interp_funcs = []
for kindi in kinds:
interp_funcs.append(
interp1d(x_interp, JV(n, x_interp),
kind=kindi,
bounds_error=False,
fill_value=np.nan)
)
f1 = JV(n, x)
# compare:
fig = plt.figure()
ax = fig.add_subplot(121)
ax.set_xlabel("x")
ax.set_ylabel("$J_{n}(x)$")
ax.plot(x, f1)
ax2 = fig.add_subplot(122)
ax2.set_xlabel("x")
ax2.set_ylabel("log10(err)")
for idx, func in enumerate(interp_funcs):
f2 = func(x)
logerr = log10_err(f1, f2)
ax2.plot(x, logerr, label="interp kind=" + str(kinds[idx]))
ax2.legend()
plt.show()
得到的图像是:
左图是函数图像, 右边则是绝对误差的大小。可以看到,当kind>=5
的时候,绝对误差已经是以下了,完全满足实际的需求。这时的相对误差实际上也在
以下。
那么这种数值误差控制完好的情况下效率提升多少呢?
simple TimeCost[sec]_n5_SZ2048_REP10000: 0.034089
Jn(x) TimeCost[sec]_n5_SZ2048_REP10000: 17.582057
simple_interp TimeCost[sec]_n5_SZ2048_REP10000: 1.467844
Jn(x)_interp TimeCost[sec]_n5_SZ2048_REP10000: 1.454807
由sec到
sec, 大概
倍。虽然由于插值密度和阶次的提高,这个增强不如之前的45倍那么夸张,但对于密集计算程序来说,也是非常可观的收获。
总结
针对大量使用的热点函数,如果函数本身的求值是计算密集的,那么可以使用插值函数取代,并注意数值精度。
欢迎提出新的想法~
原创,如果分享或转载,必须附本文链接
测试环境:python3.8.10-64bit, numpy=1.24.1, scipy=1.10.0, matplotlib=3.6.3, AMD R7 68** 64bit, single thread.
参考
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.jv.html
- https://netlib.org/amos/zbesj.f