随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。
今天,我们就来看一道面试题:
如何充分利用多核CPU,计算很大数组中所有整数的和?
OK,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。
/**
* 计算1亿个整数的和
*/
publicclassForkJoinPoolTest01{
publicstaticvoid main(String[] args) throwsExecutionException, InterruptedException{
// 构造数据
int length = 100000000;
long[] arr = newlong[length];
for(int i = 0; i < length; i++) {
arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
}
// 单线程
singleThreadSum(arr);
// ThreadPoolExecutor线程池
multiThreadSum(arr);
// ForkJoinPool线程池
forkJoinSum(arr);
}
privatestaticvoid singleThreadSum(long[] arr) {
long start = System.currentTimeMillis();
long sum = 0;
for(int i = 0; i < arr.length; i++) {
// 模拟耗时
sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
}
System.out.println("sum: "+ sum);
System.out.println("single thread elapse: "+ (System.currentTimeMillis() - start));
}
privatestaticvoid multiThreadSum(long[] arr) throwsExecutionException, InterruptedException{
long start = System.currentTimeMillis();
int count = 8;
ExecutorService threadPool = Executors.newFixedThreadPool(count);
List<Future<Long>> list = newArrayList<>();
for(int i = 0; i < count; i++) {
int num = i;
// 分段提交任务
Future<Long> future = threadPool.submit(() -> {
long sum = 0;
for(int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
try{
// 模拟耗时
sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);
} catch(Exception e) {
e.printStackTrace();
}
}
return sum;
});
list.add(future);
}
// 每个段结果相加
long sum = 0;
for(Future<Long> future : list) {
sum += future.get();
}
System.out.println("sum: "+ sum);
System.out.println("multi thread elapse: "+ (System.currentTimeMillis() - start));
}
privatestaticvoid forkJoinSum(long[] arr) throwsExecutionException, InterruptedException{
long start = System.currentTimeMillis();
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
// 提交任务
ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(newSumTask(arr, 0, arr.length));
// 获取结果
Long sum = forkJoinTask.get();
forkJoinPool.shutdown();
System.out.println("sum: "+ sum);
System.out.println("fork join elapse: "+ (System.currentTimeMillis() - start));
}
privatestaticclassSumTaskextendsRecursiveTask<Long> {
privatelong[] arr;
privateint from;
privateint to;
publicSumTask(long[] arr, int from, int to) {
this.arr = arr;
this.from = from;
this.to = to;
}
@Override
protectedLong compute() {
// 小于1000的时候直接相加,可灵活调整
if(to - from <= 1000) {
long sum = 0;
for(int i = from; i < to; i++) {
// 模拟耗时
sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
}
return sum;
}
// 分成两段任务
int middle = (from + to) / 2;
SumTask left = newSumTask(arr, from, middle);
SumTask right = newSumTask(arr, middle, to);
// 提交左边的任务
left.fork();
// 右边的任务直接利用当前线程计算,节约开销
Long rightResult = right.compute();
// 等待左边计算完毕
Long leftResult = left.join();
// 返回结果
return leftResult + rightResult;
}
}
}
彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。
所以,为了演示ForkJoinPool的牛逼之处,我把每个数都 /3*3/3*3/3*3/3*3/3*3
了一顿操作,用来模拟计算耗时。
来看结果:
sum: 107352457433800662
single thread elapse: 789
sum: 107352457433800662
multi thread elapse: 228
sum: 107352457433800662
fork join elapse: 189
可以看到,ForkJoinPool相对普通线程池还是有很大提升的。
问题:普通线程池能否实现ForkJoinPool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?
你可以试试看(-᷅_-᷄)
OK,下面我们正式进入ForkJoinPool的解析。
在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。
ForkJoinPool是 java 7 中新增的线程池类,它的继承体系如下:
ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,所以它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask以外。
这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。
可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。
ForkJoinPool内部使用的是“工作窃取”算法实现的。
(1)每个工作线程都有自己的工作队列WorkQueue;
(2)这是一个双端队列,它是线程私有的;
(3)ForkJoinTask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以LIFO的顺序来处理工作队列中的任务;
(4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;
(5)从工作队列的尾部窃取任务,以减少竞争;
(6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;
(7)当只剩下最后一个任务时,还是会存在竞争,是通过CAS来实现的;
(1)最适合的是计算密集型任务;
(2)在需要阻塞工作线程时,可以使用ManagedBlocker;
(3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();
(1)ForkJoinPool特别适合于“分而治之”算法的实现;
(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;
(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;
(4)ForkjoinPool内部基于“工作窃取”算法实现;
(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;
(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;
(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;
ManagedBlocker怎么使用?
答:ManagedBlocker相当于明确告诉ForkJoinPool框架要阻塞了,ForkJoinPool就会启另一个线程来运行任务,以最大化地利用CPU。
请看下面的例子,自己琢磨哈^^。
/**
* 斐波那契数列
* 一个数是它前面两个数之和
* 1,1,2,3,5,8,13,21
*/
publicclassFibonacci{
publicstaticvoid main(String[] args) {
long time = System.currentTimeMillis();
Fibonacci fib = newFibonacci();
int result = fib.f(1_000).bitCount();
time = System.currentTimeMillis() - time;
System.out.println("result = "+ result);
System.out.println("test1_000() time = "+ time);
}
publicBigInteger f(int n) {
Map<Integer, BigInteger> cache = newConcurrentHashMap<>();
cache.put(0, BigInteger.ZERO);
cache.put(1, BigInteger.ONE);
return f(n, cache);
}
privatefinalBigInteger RESERVED = BigInteger.valueOf(-1000);
publicBigInteger f(int n, Map<Integer, BigInteger> cache) {
BigInteger result = cache.putIfAbsent(n, RESERVED);
if(result == null) {
int half = (n + 1) / 2;
RecursiveTask<BigInteger> f0_task = newRecursiveTask<BigInteger>() {
@Override
protectedBigInteger compute() {
return f(half - 1, cache);
}
};
f0_task.fork();
BigInteger f1 = f(half, cache);
BigInteger f0 = f0_task.join();
long time = n > 10_000 ? System.currentTimeMillis() : 0;
try{
if(n % 2== 1) {
result = f0.multiply(f0).add(f1.multiply(f1));
} else{
result = f0.shiftLeft(1).add(f1).multiply(f1);
}
synchronized(RESERVED) {
cache.put(n, result);
RESERVED.notifyAll();
}
} finally{
time = n > 10_000 ? System.currentTimeMillis() - time : 0;
if(time > 50)
System.out.printf("f(%d) took %d%n", n, time);
}
} elseif(result == RESERVED) {
try{
ReservedFibonacciBlocker blocker = newReservedFibonacciBlocker(n, cache);
ForkJoinPool.managedBlock(blocker);
result = blocker.result;
} catch(InterruptedException e) {
thrownewCancellationException("interrupted");
}
}
return result;
// return f(n - 1).add(f(n - 2));
}
privateclassReservedFibonacciBlockerimplementsForkJoinPool.ManagedBlocker{
privateBigInteger result;
privatefinalint n;
privatefinalMap<Integer, BigInteger> cache;
publicReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {
this.n = n;
this.cache = cache;
}
@Override
publicboolean block() throwsInterruptedException{
synchronized(RESERVED) {
while(!isReleasable()) {
RESERVED.wait();
}
}
returntrue;
}
@Override
publicboolean isReleasable() {
return(result = cache.get(n)) != RESERVED;
}
}
}