父子线程值传递在项目中使用的场景还是非常多的,比如APM系统都有类似的需求;分布式系统中做链路追踪时,就会遇到线程A创建一个线程B时,无法追踪到线程B的执行过程;

这里列举一下常见的解决思路:

1. InheritableThreadLocal

InheritableThreadLocal能够实现父线程创建子线程时,将值由父线程传递到子线程;通过一个简单示例来感受一下:

@Test
public void testInheritableThreadLocal() {
    final InheritableThreadLocal<Object> tl = new InheritableThreadLocal<>();
    tl.set("test demo");
    new Thread(() -> {
        System.out.println(tl.get());
    }).start();
}

1.1 InheritableThreadLocal原理
在了解InheritableThreadLocal原理之前,建议先对ThreadLocal的原理有所了解,如果不了解的同学可以参考ThreadLocal深入原理分析 ,先来看一下InheritableThreadLocal的源码:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    protected T childValue(T parentValue) {
        return parentValue;
    }
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

嗯,源码就这么简单,它继承了ThreadLocal,重写了getMap和createMap方法,重写内容可以看出之前是t.threadLocals 现在改为t.heritableThreadLocals;那真正的奥秘还是在Thread类中;

1.2 t.heritableThreadLocals做了什么?
在Thread类的源码中,有这样一段代码:

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
     // 。。。。省略其它暂时无用代码;
   	 // inheritThreadLocals=true(在上层方法调用时传递的true值)
   	 // Thread parent = currentThread();
     if (inheritThreadLocals && parent.inheritableThreadLocals != null)
     	// 这段代码的意思就是将父线程的inheritableThreadLocals值复制到子线程的inheritableThreadLocals值中;
         this.inheritableThreadLocals =
             ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
     // 。。。。省略其它暂时无用代码;
}

该init方法是在new Thread时调用的;可见在new Thread的时候,就已经将父线程的inheritableThreadLocals值复制到了子线程的inheritableThreadLocals里面;

InheritableThreadLocal的缺陷:无法满足线程池的情况;因为线程池里面的线程是已经创建过的,反复的复用;
我们验证一下这个结论:

public void testExecutorServiceInheritableThreadLocal() throws InterruptedException {
    final InheritableThreadLocal<Object> tl = new InheritableThreadLocal<>();
    tl.set("test demo");
	
    ExecutorService executorService = Executors.newFixedThreadPool(1);
    // 管这次操作为thread1
    executorService.execute(new Runnable() {
        @Override
        public void run() {
            System.out.println("thread1 : "+tl.get());
            tl.set("test demo 2");
        }
    });

    Thread.sleep(10);
	// 管这次操作为thread2
    executorService.execute(new Runnable() {
        @Override
        public void run() {
            System.out.println("thread2 : "+tl.get());
        }
    });

    Thread.sleep(10);

    System.out.println("main:"+tl.get());
}
结果:
thread1 : test demo
thread2 : test demo 2
main:test demo

由上面的测试可以发现,thread1改了tl的值,在thread2的时候输出并不是父线程main的 test demo;原因线程2在执行的时候直接从t.InheritableThreadLocals中获取值,而其它thread1与thread2是同一个线程,因为我们的线程池大小是1,所以thread2并不是获取的main线程的InheritableThreadLocals值;

2. 手动实现工具类完成多线程值传递功能

这个也是我在项目中的使用,具体的实现逻辑在代码注释中说明;
2.1 Runnable包装

比如,我有一个上下文类,里面可以包含一些基本信息:

public class ContextUtil {

    public static final ThreadLocal<Map<String, Object>> CONTEXT = ThreadLocal.withInitial(() -> {
        Map<String, Object> map = new HashMap<>();
        return map;
    });

    public static Object get(String key) {
        return CONTEXT.get().get(key);
    }

    public static Map<String, Object> getAll() {
        return CONTEXT.get();
    }

    public static void put(String key, Object value) {
        if (value == null) {
            return;
        }
        CONTEXT.get().put(key, value);
    }

    public static void putAll(Map<String, Object> map) {
        if (map == null) {
            return;
        }
        map.forEach(ContextUtil::put);
    }
}

线程的包装类:

public class RunnableWrapper {

    public static Runnable wrap(Runnable runnable) {
        Map<String, Object> parentContextMap = ContextUtil.getAll();
        return get(runnable, parentContextMap);
    }

    private static Runnable get(Runnable runnable, Map<String, Object> parentContextMap) {
        return () -> {
            ContextUtil.putAll(parentContextMap);
            runnable.run();
        };
    }
}

测试结果如下:

@Test
public void testThreadWrap() throws InterruptedException {
    ContextUtil.put("traceId", UUID.randomUUID().toString());
    Runnable runnable = new Runnable(){
        @Override
        public void run() {
            Map<String, Object> all = ContextUtil.getAll();
            System.out.println("runnable:"+all);
        }
    };
    Runnable wrap = RunnableWrapper.wrap(runnable);
    new Thread(wrap).start();
    Thread.sleep(10);
    System.out.println("main:"+ContextUtil.getAll());
}

runnable:{traceId=e6355bb5-6ad0-423d-bf43-06bca5c569f1}
main:{traceId=e6355bb5-6ad0-423d-bf43-06bca5c569f1}

2.2 ExecutorServiceWrapper

public class ExecutorServiceUtil {

    /**
     * 对ThreadPool进行封装,实现父子线程传递值的功能;
     *
     * @param executorService
     * @return
     */
    public static ExecutorService executorServiceWrap(ExecutorService executorService) {
        return new ExecutorServiceWrapper(executorService);
    }

    public static class ExecutorServiceWrapper implements ExecutorService {

        private final ExecutorService executorService;

        public ExecutorServiceWrapper(ExecutorService executorService) {
            this.executorService = executorService;
        }

        @Override
        public void shutdown() {
            executorService.shutdown();
        }

        @Override
        public List<Runnable> shutdownNow() {
            return executorService.shutdownNow();
        }

        @Override
        public boolean isShutdown() {
            return executorService.isShutdown();
        }

        @Override
        public boolean isTerminated() {
            return executorService.isTerminated();
        }

        @Override
        public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
            return executorService.awaitTermination(timeout, unit);
        }

        @Override
        public <T> Future<T> submit(Callable<T> task) {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.submit(get(task, currentContextMap));
        }

        @Override
        public <T> Future<T> submit(Runnable task, T result) {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.submit(get(task, currentContextMap), result);
        }

        @Override
        public Future<?> submit(Runnable task) {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.submit(get(task, currentContextMap));
        }

        @Override
        public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.
                    invokeAll(tasks.stream()
                            .map(callable -> get(callable, currentContextMap))
                            .collect(Collectors.toList()));
        }

        @Override
        public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
                throws InterruptedException {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.
                    invokeAll(tasks.stream()
                            .map(callable -> get(callable, currentContextMap))
                            .collect(Collectors.toList()), timeout, unit);
        }

        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks)
                throws InterruptedException, ExecutionException {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.
                    invokeAny(tasks.stream()
                            .map(callable -> get(callable, currentContextMap))
                            .collect(Collectors.toList()));
        }

        @Override
        public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException, TimeoutException {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            return executorService.
                    invokeAny(tasks.stream()
                            .map(callable -> get(callable, currentContextMap))
                            .collect(Collectors.toList()), timeout, unit);
        }

        @Override
        public void execute(Runnable command) {
            Map<String, Object> currentContextMap = ContextUtil.getAll();
            executorService.execute(get(command, currentContextMap));
        }

        /**
         * 获取包装过的 runnable
         *
         * @param runnable         原始的 Runnable
         * @param parentContextMap 父线程的 context 信息
         * @return 包装后的 Runnable
         */
        private Runnable get(Runnable runnable, Map<String, Object> parentContextMap) {
            return () -> {
                ContextUtil.putAll(parentContextMap);
                runnable.run();
            };
        }

        /**
         * 获取包装过的 callable
         *
         * @param callable         原始的 callable
         * @param parentContextMap 父线程的 context 信息
         * @return 包装后的 callable
         */
        private <T> Callable<T> get(Callable<T> callable, Map<String, Object> parentContextMap) {
            return () -> {
                ContextUtil.putAll(parentContextMap);
                return callable.call();
            };
        }
    }
}

测试上面的代码:

@Test
public void testExecutorServiceWrap() throws InterruptedException {
    BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(100);

    ThreadPoolExecutor executor = new ThreadPoolExecutor(2,
            8, 3000,
            TimeUnit.MILLISECONDS,blockingQueue , new ThreadPoolExecutor.DiscardPolicy());

    ContextUtil.put("traceId",UUID.randomUUID().toString());

    ExecutorService executorService = ExecutorServiceUtil.executorServiceWrap(executor);
    for(int i=0;i<200;i++){
        executorService.execute(new Runnable() {
            @Override
            public void run() {
                System.out.println(Thread.currentThread().getName()+"  :   "+ ContextUtil.getAll());
                // 尝试干扰其它线程;
                ContextUtil.put("traceId",UUID.randomUUID().toString());
            }
        });
    }
    Thread.sleep(2000);
}

运行这段代码可见,所有子线程的输出都是父线程的值;

3. transmittable-thread-local开源解决方案

ttl github:https://github.com/alibaba/transmittable-thread-local 这里面有详细的说明实现原理及使用方式,这里不就在多说了;