父子线程值传递在项目中使用的场景还是非常多的,比如APM系统都有类似的需求;分布式系统中做链路追踪时,就会遇到线程A创建一个线程B时,无法追踪到线程B的执行过程;
这里列举一下常见的解决思路:
1. InheritableThreadLocal
InheritableThreadLocal能够实现父线程创建子线程时,将值由父线程传递到子线程;通过一个简单示例来感受一下:
@Test
public void testInheritableThreadLocal() {
final InheritableThreadLocal<Object> tl = new InheritableThreadLocal<>();
tl.set("test demo");
new Thread(() -> {
System.out.println(tl.get());
}).start();
}
1.1 InheritableThreadLocal原理
在了解InheritableThreadLocal原理之前,建议先对ThreadLocal的原理有所了解,如果不了解的同学可以参考ThreadLocal深入原理分析 ,先来看一下InheritableThreadLocal的源码:
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
protected T childValue(T parentValue) {
return parentValue;
}
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
嗯,源码就这么简单,它继承了ThreadLocal,重写了getMap和createMap方法,重写内容可以看出之前是t.threadLocals
现在改为t.heritableThreadLocals
;那真正的奥秘还是在Thread类中;
1.2 t.heritableThreadLocals做了什么?
在Thread类的源码中,有这样一段代码:
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
// 。。。。省略其它暂时无用代码;
// inheritThreadLocals=true(在上层方法调用时传递的true值)
// Thread parent = currentThread();
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
// 这段代码的意思就是将父线程的inheritableThreadLocals值复制到子线程的inheritableThreadLocals值中;
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
// 。。。。省略其它暂时无用代码;
}
该init方法是在new Thread
时调用的;可见在new Thread
的时候,就已经将父线程的inheritableThreadLocals值复制到了子线程的inheritableThreadLocals里面;
InheritableThreadLocal的缺陷:无法满足线程池的情况;因为线程池里面的线程是已经创建过的,反复的复用;
我们验证一下这个结论:
public void testExecutorServiceInheritableThreadLocal() throws InterruptedException {
final InheritableThreadLocal<Object> tl = new InheritableThreadLocal<>();
tl.set("test demo");
ExecutorService executorService = Executors.newFixedThreadPool(1);
// 管这次操作为thread1
executorService.execute(new Runnable() {
@Override
public void run() {
System.out.println("thread1 : "+tl.get());
tl.set("test demo 2");
}
});
Thread.sleep(10);
// 管这次操作为thread2
executorService.execute(new Runnable() {
@Override
public void run() {
System.out.println("thread2 : "+tl.get());
}
});
Thread.sleep(10);
System.out.println("main:"+tl.get());
}
结果:
thread1 : test demo
thread2 : test demo 2
main:test demo
由上面的测试可以发现,thread1改了tl的值,在thread2的时候输出并不是父线程main的 test demo;原因线程2在执行的时候直接从t.InheritableThreadLocals中获取值,而其它thread1与thread2是同一个线程,因为我们的线程池大小是1,所以thread2并不是获取的main线程的InheritableThreadLocals值;
2. 手动实现工具类完成多线程值传递功能
这个也是我在项目中的使用,具体的实现逻辑在代码注释中说明;
2.1 Runnable包装
比如,我有一个上下文类,里面可以包含一些基本信息:
public class ContextUtil {
public static final ThreadLocal<Map<String, Object>> CONTEXT = ThreadLocal.withInitial(() -> {
Map<String, Object> map = new HashMap<>();
return map;
});
public static Object get(String key) {
return CONTEXT.get().get(key);
}
public static Map<String, Object> getAll() {
return CONTEXT.get();
}
public static void put(String key, Object value) {
if (value == null) {
return;
}
CONTEXT.get().put(key, value);
}
public static void putAll(Map<String, Object> map) {
if (map == null) {
return;
}
map.forEach(ContextUtil::put);
}
}
线程的包装类:
public class RunnableWrapper {
public static Runnable wrap(Runnable runnable) {
Map<String, Object> parentContextMap = ContextUtil.getAll();
return get(runnable, parentContextMap);
}
private static Runnable get(Runnable runnable, Map<String, Object> parentContextMap) {
return () -> {
ContextUtil.putAll(parentContextMap);
runnable.run();
};
}
}
测试结果如下:
@Test
public void testThreadWrap() throws InterruptedException {
ContextUtil.put("traceId", UUID.randomUUID().toString());
Runnable runnable = new Runnable(){
@Override
public void run() {
Map<String, Object> all = ContextUtil.getAll();
System.out.println("runnable:"+all);
}
};
Runnable wrap = RunnableWrapper.wrap(runnable);
new Thread(wrap).start();
Thread.sleep(10);
System.out.println("main:"+ContextUtil.getAll());
}
runnable:{traceId=e6355bb5-6ad0-423d-bf43-06bca5c569f1}
main:{traceId=e6355bb5-6ad0-423d-bf43-06bca5c569f1}
2.2 ExecutorServiceWrapper
public class ExecutorServiceUtil {
/**
* 对ThreadPool进行封装,实现父子线程传递值的功能;
*
* @param executorService
* @return
*/
public static ExecutorService executorServiceWrap(ExecutorService executorService) {
return new ExecutorServiceWrapper(executorService);
}
public static class ExecutorServiceWrapper implements ExecutorService {
private final ExecutorService executorService;
public ExecutorServiceWrapper(ExecutorService executorService) {
this.executorService = executorService;
}
@Override
public void shutdown() {
executorService.shutdown();
}
@Override
public List<Runnable> shutdownNow() {
return executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
@Override
public boolean isTerminated() {
return executorService.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return executorService.awaitTermination(timeout, unit);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.submit(get(task, currentContextMap));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.submit(get(task, currentContextMap), result);
}
@Override
public Future<?> submit(Runnable task) {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.submit(get(task, currentContextMap));
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.
invokeAll(tasks.stream()
.map(callable -> get(callable, currentContextMap))
.collect(Collectors.toList()));
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.
invokeAll(tasks.stream()
.map(callable -> get(callable, currentContextMap))
.collect(Collectors.toList()), timeout, unit);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
throws InterruptedException, ExecutionException {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.
invokeAny(tasks.stream()
.map(callable -> get(callable, currentContextMap))
.collect(Collectors.toList()));
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
Map<String, Object> currentContextMap = ContextUtil.getAll();
return executorService.
invokeAny(tasks.stream()
.map(callable -> get(callable, currentContextMap))
.collect(Collectors.toList()), timeout, unit);
}
@Override
public void execute(Runnable command) {
Map<String, Object> currentContextMap = ContextUtil.getAll();
executorService.execute(get(command, currentContextMap));
}
/**
* 获取包装过的 runnable
*
* @param runnable 原始的 Runnable
* @param parentContextMap 父线程的 context 信息
* @return 包装后的 Runnable
*/
private Runnable get(Runnable runnable, Map<String, Object> parentContextMap) {
return () -> {
ContextUtil.putAll(parentContextMap);
runnable.run();
};
}
/**
* 获取包装过的 callable
*
* @param callable 原始的 callable
* @param parentContextMap 父线程的 context 信息
* @return 包装后的 callable
*/
private <T> Callable<T> get(Callable<T> callable, Map<String, Object> parentContextMap) {
return () -> {
ContextUtil.putAll(parentContextMap);
return callable.call();
};
}
}
}
测试上面的代码:
@Test
public void testExecutorServiceWrap() throws InterruptedException {
BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(100);
ThreadPoolExecutor executor = new ThreadPoolExecutor(2,
8, 3000,
TimeUnit.MILLISECONDS,blockingQueue , new ThreadPoolExecutor.DiscardPolicy());
ContextUtil.put("traceId",UUID.randomUUID().toString());
ExecutorService executorService = ExecutorServiceUtil.executorServiceWrap(executor);
for(int i=0;i<200;i++){
executorService.execute(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName()+" : "+ ContextUtil.getAll());
// 尝试干扰其它线程;
ContextUtil.put("traceId",UUID.randomUUID().toString());
}
});
}
Thread.sleep(2000);
}
运行这段代码可见,所有子线程的输出都是父线程的值;
3. transmittable-thread-local开源解决方案
ttl github:https://github.com/alibaba/transmittable-thread-local 这里面有详细的说明实现原理及使用方式,这里不就在多说了;