Fork/Join 以递归方式将可以并行的任务拆分成更小的任务,然后将每个子任务的结果合并起来生成整体结果。
这个过程其实就是分治算法的并行版本,图解如下:
我们要使用 ForkJoin 框架,必须先创建一个 ForkJoinTask。它提供在任务中执行 fork() 和 join() 操作的机制,通常情况下我们不需要直接继承 ForkJoinTask 类,而只需要继承它的子类,Fork/Join 框架提供了以下两个子类:
而ForkJoinTask 需要通过 ForkJoinPool 来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。
下面使用forkJoin来计算一个Integer List之和:
import com.google.common.collect.Lists; import com.sun.istack.internal.NotNull; import java.util.List; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; import java.util.stream.Collectors; import java.util.stream.IntStream; public class SumTask extends RecursiveTask<Long> { private static final int SPLIT_NUM = 10000; @NotNull private List<Integer> numberList; public SumTask(List<Integer> numberList) { this.numberList = numberList; } @Override protected Long compute() { // 不需要进行任务拆分 if (numberList.size() <= SPLIT_NUM) { return numberList.stream().mapToLong(Integer::intValue).sum(); } // 进行任务拆分 List<List<Integer>> splitNumberList = Lists.partition(numberList, numberList.size() / 2); List<SumTask> sumTasks = splitNumberList.stream().map(SumTask::new).collect(Collectors.toList()); // 执行子任务,继续拆分 invokeAll(sumTasks); // 合并结果 return sumTasks.stream().mapToLong(SumTask::join).sum(); } public static void main(String[] args) { ForkJoinPool forkJoinPool = ForkJoinPool.commonPool(); List<Integer> numberList = IntStream.rangeClosed(1, 100000).mapToObj(Integer::new).collect(Collectors.toList()); SumTask sumTask = new SumTask(numberList); Long sum = forkJoinPool.invoke(sumTask); System.out.println("1 ~ 100000 sum result is: " + sum); } }
上述代码的执行过程为,先将List 分为两个子List, 并发执行两个子List 的计算。然后再将子List 拆分为更小的List,依此往复,直至List无法再拆分时,计算其Sum,最后合并结果。
ForkJoinTask 与一般的任务的主要区别在于它需要实现 compute 方法,在这个方法里,首先需要判断任务是否足够小,如果足够小就直接执行任务。如果不足够小,就必须分割成两个子任务,每个子任务在调用 fork 方法时,又会进入 compute 方法,形成递归调用,直到任务子任务不可再分。使用 join 方法会等待子任务执行完并得到其结果。
我们没办法在主线程中捕获ForkJoinTask 执行过程中抛出的异常。所以ForkJoinTask 提供了方法来检测Task 执行情况, 并提供了获取异常的方法。
// 检查Task 执行情况 sumTask.isCancelled(); sumTask.isCompletedNormally(); sumTask.isCompletedAbnormally(); // 获取异常信息 sumTask.getException();
参考 https://segmentfault.com/a/1190000016781127