1、概述

Fork/Join Pool采用优良的设计、代码实现和硬件原子操作机制等多种思路保证其执行性能。其中包括(但不限于):计算资源共享、高性能队列、避免伪共享、工作窃取机制等。本文(以及后续文章)试图和读者一起分析JDK1.8中Fork/Join Pool的源代码实现,去理解Fork/Join Pool是怎样工作的。当然这里要说明一下,起初本人在决定阅读Fork/Join归并计算相关类的源代码时(ForkJoinPool、WorkQueue、ForkJoinTask、RecursiveTask、ForkJoinWorkerThread等),并不觉得这部分代码比起LinkedList这样的类来说有多少难度, 但其中大量使用位运算和位运算技巧,有大量Unsafe原子操作。博主能力有限,确实不能在短时间内将所有代码一一详细解读,所以也希望各位读者能帮助笔者一同完善。

Set java 多线程循环 java多线程fork_Set java 多线程循环

2.  原理

基本思想

Set java 多线程循环 java多线程fork_多线程_02

  • ForkJoinPool 的每个工作线程都维护着一个工作队列WorkQueue),这是一个双端队列(Deque),里面存放的对象是任务ForkJoinTask)。
  • 每个工作线程在运行中产生新的任务(通常是因为调用了 fork())时,会放入工作队列的队尾,并且工作线程在处理自己的工作队列时,使用的是 LIFO 方式,也就是说每次从队尾取出任务来执行。
  • 每个工作线程在处理自己的工作队列同时,会尝试窃取一个任务(或是来自于刚刚提交到 pool 的任务,或是来自于其他工作线程的工作队列),窃取的任务位于其他线程的工作队列的队首,也就是说工作线程在窃取其他工作线程的任务时,使用的是 FIFO 方式。
  • 在遇到 join() 时,如果需要 join 的任务尚未完成,则会先处理其他任务,并等待其完成。
  • 在既没有自己的任务,也没有可以窃取的任务时,进入休眠。

fork

fork() 做的工作只有一件事,既是把任务推入当前工作线程的工作队列里。可以参看以下的源代码:

public final ForkJoinTask<V> fork() {
    Thread t;
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
        ((ForkJoinWorkerThread)t).workQueue.push(this);
    else
        ForkJoinPool.common.externalPush(this);
    return this;
}

join

join() 的工作则复杂得多,也是 join() 可以使得线程免于被阻塞的原因——不像同名的 Thread.join()

  1. 检查调用 join() 的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。
  2. 查看任务的完成状态,如果已经完成,直接返回结果。
  3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
  4. 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式),执行,以期帮助它早日完成欲 join 的任务。
  5. 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务。
  6. 递归地执行第5步。

将上述流程画成序列图的话就是这个样子:

Set java 多线程循环 java多线程fork_工作线程_03

以上就是 fork() 和 join() 的原理,这可以解释 ForkJoinPool 在递归过程中的执行逻辑,但还有一个问题

最初的任务是 push 到哪个线程的工作队列里的?

这就涉及到 submit() 函数的实现方法了

submit

其实除了前面介绍过的每个工作线程自己拥有的工作队列以外,ForkJoinPool 自身也拥有工作队列,这些工作队列的作用是用来接收由外部线程(非 ForkJoinThread 线程)提交过来的任务,而这些工作队列被称为 submitting queue 。

submit() 和 fork() 其实没有本质区别,只是提交对象变成了 submitting queue 而已(还有一些同步,初始化的操作)。submitting queue 和其他 work queue 一样,是工作线程”窃取“的对象,因此当其中的任务被一个工作线程成功窃取时,就意味着提交的任务真正开始进入执行阶段。

3. 示例

package ThreadTest.demo;

import lombok.Data;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.*;

/**
 * Created by lizq on 2020/3/1.
 */
public class ForkJoinTest {

    public static void main(String[] args) throws ExecutionException, InterruptedException {

        long[] arrs = RandomArr.createLongArr(30, 0, 100);
        Arrays.stream(arrs).forEach(i -> {
            System.out.print(i + " ");
        });
        long rlt = Arrays.stream(arrs).sum();
        System.out.println("sum " + rlt);

        // ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        ForkJoinPool forkJoinPool = new ForkJoinPool(5);
        // ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arrs, 0, arrs.length - 1));
        //  rlt = forkJoinTask.get();
        // System.out.println("sum " + rlt);
        ForkJoinTask forkJoinTask = forkJoinPool.submit(new orderTask(arrs, 0, arrs.length - 1));
        forkJoinTask.get();
        Arrays.stream(arrs).forEach(i -> {
            System.out.print(i + " ");
        });

    }
}

@Data
class SumTask extends RecursiveTask<Long> {

    private long[] arr;
    private int from;
    private int to;

    public SumTask(long[] arr, int from, int to) {
        this.arr = arr;
        this.from = from;
        this.to = to;

    }

    @Override
    protected Long compute() {

        System.out.println("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        Long l = 0l;

        if (to - from < 3) {
            for (int i = from; i <= to; i++) {
                l += this.arr[i];
            }
        } else {
            int m = (from + to) / 2;
            SumTask leftSum = new SumTask(this.arr, from, m);
            SumTask rightSum = new SumTask(this.arr, m + 1, to);
            leftSum.fork();
            rightSum.fork();
            l = leftSum.join() + rightSum.join();
        }
        System.out.println("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        return l;
    }


}

enum Type {
    DESC, ASC;
}


class orderTask extends RecursiveAction {

    private long[] arr;
    private int from;
    private int to;
    private Type type = Type.ASC;


    public orderTask(long[] arr, int from, int to) {
        this.arr = arr;
        this.from = from;
        this.to = to;

    }

    @Override
    protected void compute() {

        //   System.out.print("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        if (to - from < 2) {
            if (arr[from] > arr[to]) {
                long tmp = this.arr[from];
                this.arr[from] = this.arr[to];
                this.arr[to] = tmp;
            }
        } else {
            int m = (from + to) / 2;
            orderTask leftSum = new orderTask(this.arr, from, m);
            orderTask rightSum = new orderTask(this.arr, m + 1, to);
            leftSum.fork();
            rightSum.fork();
            leftSum.join();
            rightSum.join();
            // 组合排序
            for (int l = from, r = m + 1; l < to && r <= to; l++) {
                if (arr[l] > arr[r]) {
                    long tmp = arr[r];
                    for (int i = r; i > l; ) {
                        arr[i] = arr[--i];
                    }
                    arr[l] = tmp;
                    r++;
                }
            }
        }
        //  System.out.print("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);

    }
}

class RandomArr {
    public static long[] createLongArr(int length, int low, int high) {
        if (length < 1) {
            throw new RuntimeException("length < 1");
        }
        return new Random().longs(low, high).limit(length).toArray();
    }
}