1、问题背景
之前在这篇《Python RPC 远程调用脚本之 RPyC 实践》中实现过一个小 Demo,通过 RPyC 可以实现一个简单的分布式程序,但是,有过开发经验的同学应该一眼就能看出这个程序有个致命缺陷:假如用户执行了一个非常耗时或者耗资源的程序,那客户端将永远无法获取结果甚至导致服务端直接宕掉,因此我们需要对命令的执行时长做出限制,引入 Timeout 机制增强程序健壮性和用户体验。
2、so easy:装饰器!
如果你恰好看过我之前的这篇《深入浅出 Python 装饰器:16 步轻松搞定 Python 装饰器》,那应该很自然的想到,Python 装饰器最适合这种业务场景了:对函数进行额外功能性包装,又不侵入主体业务逻辑。
Timeout 装饰器的代码如下:
# coding=utf-8
# 测试utf-8编码
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import signal, functools
class TimeoutError(Exception): pass
def timeout(seconds, error_message="Timeout Error: the cmd 30s have not finished."):
def decorated(func):
result = ""
def _handle_timeout(signum, frame):
global result
result = error_message
raise TimeoutError(error_message)
def wrapper(*args, **kwargs):
global result
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return result
return functools.wraps(func)(wrapper)
return decorated
@timeout(5) # 限定下面的slowfunc函数如果在5s内不返回就强制抛TimeoutError Exception结束
def slowfunc(sleep_time):
a = 1
import time
time.sleep(sleep_time)
return a
# slowfunc(3) #sleep 3秒,正常返回 没有异常
print slowfunc(11) # 被终止
测试用例也正常,但是把这个装饰器用在文初提到的 RPC 代码中时,抛了异常:
Traceback (most recent call last):
File "exec_cmd.py", line 79, in <module>
exec_cmd(cmd_str)
File "exec_cmd.py", line 53, in exec_cmd
results = pool.map(rpc_client, host_port_list)
File "/opt/soft/python-2.7.10/lib/python2.7/multiprocessing/pool.py", line 251, in map
return self.map_async(func, iterable, chunksize).get()
File "/opt/soft/python-2.7.10/lib/python2.7/multiprocessing/pool.py", line 567, in get
raise self._value
ValueError: signal only works in main thread
========= Remote Traceback (1) =========
Traceback (most recent call last):
File "/opt/soft/python-2.7.10/lib/python2.7/site-packages/rpyc/core/protocol.py", line 305, in _dispatch_request
res = self._HANDLERS[handler](self, *args)
File "/opt/soft/python-2.7.10/lib/python2.7/site-packages/rpyc/core/protocol.py", line 535, in _handle_call
return self._local_objects[oid](*args, **dict(kwargs))
File "flumeFileMonitor_RPC_Server.py", line 39, in wrapper
signal.signal(signal.SIGALRM, _handle_timeout)
ValueError: signal only works in main thread
为了更简单说明问题,我们把测试代码再简化下:
# coding=utf-8
#测试utf-8编码
from time import sleep, time
import sys, threading
reload(sys)
sys.setdefaultencoding('utf-8')
from multiprocessing.dummy import Pool as ThreadPool
@timeout(1)
def processNum(num):
num_add = num + 1
# results.append(str(threading.current_thread()) + ": " + str(num) + " → " + str(num_add))
sleep(2)
return str(threading.current_thread()) + ": " + str(num) + " → " + str(num_add)
def main():
ts = time()
pool = ThreadPool(4)
results = pool.map(processNum, range(4))
pool.close()
pool.join()
for _ in results:
print _
print("cost time is: {:.2f}s".format(time() - ts))
if __name__ == "__main__":
main()
可以看到报错是因为 signal 只能用在主线程中,不能用在多线程环境下的子线程中,而且 signal 只能用在 *nix 环境下,不能跨平台,看到这里,似乎这个问题又不那么容易解决了,看来咱们得另辟蹊径。
3、另辟蹊径:线程控制超时
大体逻辑如下:咱们启动新子线程执行指定的方法,主线程等待子线程的运行结果,若在指定时间内子线程还未执行完毕,则判断为超时,抛出超时异常,并杀掉子线程;否则未超时,返回子线程所执行的方法的返回值。但是python默认模块里是没有方法可以杀掉线程的,怎么办呢?发现有人已经实现了该KThread类,它继承了threading.Thread,并添加了kill方法,让我们能杀掉子线程。
先上代码,然后我会简述下 KThread类的设计思路:
from time import sleep, time
import sys, threading
class KThread(threading.Thread):
"""A subclass of threading.Thread, with a kill()
method.
Come from:
Kill a thread in Python:
http://mail.python.org/pipermail/python-list/2004-May/260937.html
"""
def __init__(self, *args, **kwargs):
threading.Thread.__init__(self, *args, **kwargs)
self.killed = False
def start(self):
"""Start the thread."""
self.__run_backup = self.run
self.run = self.__run # Force the Thread to install our trace.
threading.Thread.start(self)
def __run(self):
"""Hacked run function, which installs the
trace."""
sys.settrace(self.globaltrace)
self.__run_backup()
self.run = self.__run_backup
def globaltrace(self, frame, why, arg):
if why == 'call':
return self.localtrace
else:
return None
def localtrace(self, frame, why, arg):
if self.killed:
if why == 'line':
raise SystemExit()
return self.localtrace
def kill(self):
self.killed = True
class Timeout(Exception):
"""function run timeout"""
def timeout(seconds):
"""超时装饰器,指定超时时间
若被装饰的方法在指定的时间内未返回,则抛出Timeout异常"""
def timeout_decorator(func):
"""真正的装饰器"""
def _new_func(oldfunc, result, oldfunc_args, oldfunc_kwargs):
result.append(oldfunc(*oldfunc_args, **oldfunc_kwargs))
def _(*args, **kwargs):
result = []
new_kwargs = { # create new args for _new_func, because we want to get the func return val to result list
'oldfunc': func,
'result': result,
'oldfunc_args': args,
'oldfunc_kwargs': kwargs
}
thd = KThread(target=_new_func, args=(), kwargs=new_kwargs)
thd.start()
thd.join(seconds)
alive = thd.isAlive()
thd.kill() # kill the child thread
if alive:
# raise Timeout(u'function run too long, timeout %d seconds.' % seconds)
try:
raise Timeout(u'function run too long, timeout %d seconds.' % seconds)
finally:
return u'function run too long, timeout %d seconds.' % seconds
else:
return result[0]
_.__name__ = func.__name__
_.__doc__ = func.__doc__
return _
return timeout_decorator
然后根据上面的代码测试结果如下:
@timeout(1)
def processNum(num):
num_add = num + 1
# results.append(str(threading.current_thread()) + ": " + str(num) + " → " + str(num_add))
sleep(2)
return str(threading.current_thread()) + ": " + str(num) + " → " + str(num_add)
//
function run too long, timeout 1 seconds.
function run too long, timeout 1 seconds.
function run too long, timeout 1 seconds.
function run too long, timeout 1 seconds.
cost time is: 1.17s
看了代码咱们再来聊聊上述 KThread 的设计思路:
关键点在那个threading.settrace(self.globaltrace),它是用来设置跟踪调试threading。
看下threading.settrace文档。需要在线程调用run前设置好,threading.settrace只起一个中转作用,它会在线程运行前将self.globaltrace传给sys.settrace。
threading.settrace(func)
Set a trace function for all threads started from the threading module. The func will be passed to sys.settrace() for each thread, before its run() method is called.
New in version 2.3.
再看下sys.settrace的文档,英文文档说明有点长,参照上面代码看起来应没什么问题。
分析下上面的代码:
def start(self):
threading.settrace(self.globaltrace) #线程运行前设置跟踪过程self.globaltrace
threading.Thread.start(self)#运行线程
def globaltrace(self,frame,why,arg):
if why=='call': #将会调用一个子过程
return self.localtrace #返回调用子过程的跟踪过程self.localtrace,并使用子过程跟踪过程self.localtrace跟踪子过程运行
else:
return None
def localtrace(self,frame,why,arg):
if self._willKill and why=='line': #self._willKill自己设置的中断标识,why为跟踪的事件,其中line为执行一行或多行python代码
raise SystemExit() #当中断标识为True及将会执行下一行python代码时,使用SystemExit()中断线程
return self.localtrace
这就是中断线程的整个过程。只是在线程每执行一行代码将都检查一下中断标识,如果需要中断则返回,否则继续执行。
4、缺陷
- 整体的执行效率会慢一点。因为每次执行一句python语句,都会有一个判断的过程。
- 因为其本质是使用将函数使用重载的线程来控制,一旦被添加装饰器的函数内部使用了线程或者子进程等复杂的结构,而这些线程和子进程其实是无法获得超时控制的,所以可能导致外层的超时控制无效。
5、函数超时在多线程场景下 2 个常见误区
sleep、wait、join 不能直接用来实现或替代超时功能
尤其是 join(timeout) 方法里的 timeout 很容易让初学者误解,以为调用了 join(n) 就是 n 秒后线程超时结束
咱们先看下文档:
When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). As join() always returns None, you must call isAlive() after join() to decide whether a timeout happened – if the thread is still alive, the join() call timed out.
可以看到其实 timeout 只是将主线程阻塞,它只告诉join等待子线程运行多久,如果超时后,主线程和子线程还是各自向下继续运行,因此你必须调用 isAlive() 来决定是否超时发生——如果子线程还活着, 表示本次 join() 调用超时了。
举个例子吧:
假设有 10 个线程,每个线程业务逻辑是 sleep 3s,现在需要总体控制在 2s 内执行完毕,很多初学者可能写出这样的代码:
for i in range(10):
t = ThreadTest(i)
thread_arr.append(t)
for i in range(10):
thread_arr[i].start()
for i in range(10):
thread_arr[i].join(2)
其实最后你会发现,这段代码会耗时 20s,因为每个 join(2) 都是顺序执行的,而且没有真正的超时结束功能。
还是上一份完整的代码供大家测试学习使用吧:
# coding=utf-8
# 测试utf-8编码
from time import sleep, time
import sys, threading
from Queue import Queue
from threading import Thread
reload(sys)
sys.setdefaultencoding('utf-8')
def processNum(num):
num_add = num + 1
sleep(3)
print str(threading.current_thread()) + ": " + str(num) + " → " + str(num_add)
class ProcessWorker(Thread):
def __init__(self, queue):
Thread.__init__(self)
self.queue = queue
def run(self):
while True:
num = self.queue.get()
processNum(num)
self.queue.task_done()
thread_arr = []
def main():
ts = time()
queue = Queue()
for x in range(10):
worker = ProcessWorker(queue)
worker.daemon = True
worker.start()
thread_arr.append(worker)
for num in range(10):
queue.put(num)
# queue.join()
for _ in thread_arr:
_.join(2)
print("cost time is: {:.2f}s".format(time() - ts))
if __name__ == "__main__":
main()
好了,今天就先聊到这儿吧,多线程是个永恒的话题,路漫漫其修远兮~