线程池工作窃取实例

本文主要来展示一下简版的work stealing线程池的实现。

Executors

Executors默认提供了几个工厂方法

/**
     * Creates a thread pool that maintains enough threads to support
     * the given parallelism level, and may use multiple queues to
     * reduce contention. The parallelism level corresponds to the
     * maximum number of threads actively engaged in, or available to
     * engage in, task processing. The actual number of threads may
     * grow and shrink dynamically. A work-stealing pool makes no
     * guarantees about the order in which submitted tasks are
     * executed.
     *
     * @param parallelism the targeted parallelism level
     * @return the newly created thread pool
     * @throws IllegalArgumentException if {@code parallelism <= 0}
     * @since 1.8
     */
    public static ExecutorService newWorkStealingPool(int parallelism) {
        return new ForkJoinPool
            (parallelism,
             ForkJoinPool.defaultForkJoinWorkerThreadFactory,
             null, true);
    }

    /**
     * Creates a work-stealing thread pool using all
     * {@link Runtime#availableProcessors available processors}
     * as its target parallelism level.
     * @return the newly created thread pool
     * @see #newWorkStealingPool(int)
     * @since 1.8
     */
    public static ExecutorService newWorkStealingPool() {
        return new ForkJoinPool
            (Runtime.getRuntime().availableProcessors(),
             ForkJoinPool.defaultForkJoinWorkerThreadFactory,
             null, true);
    }

思路

ForkJoinPool主要用到的是双端队列,不过这里我们粗糙的实现的话,也可以不用到deque。

public class WorkStealingChannel<T> {

    private static final Logger LOGGER = LoggerFactory.getLogger(WorkStealingChannel.class);

    BlockingDeque<T>[] managedQueues;

    AtomicLongMap<Integer> stat = AtomicLongMap.create();

    public WorkStealingChannel() {
        int nCPU = Runtime.getRuntime().availableProcessors();
        int queueCount = nCPU / 2 + 1;
        managedQueues = new LinkedBlockingDeque[queueCount];
        for(int i=0;i<queueCount;i++){
            managedQueues[i] = new LinkedBlockingDeque<T>();
        }
    }

    public void put(T item) throws InterruptedException {
        int targetIndex = Math.abs(item.hashCode() % managedQueues.length);
        BlockingQueue<T> targetQueue = managedQueues[targetIndex];
        targetQueue.put(item);
    }

    public T take() throws InterruptedException {
        int rdnIdx = ThreadLocalRandom.current().nextInt(managedQueues.length);
        int idx = rdnIdx;
        while (true){
            idx = idx % managedQueues.length;
            T item = null;
            if(idx == rdnIdx){
                item = managedQueues[idx].poll();
            }else{
                item = managedQueues[idx].pollLast();
            }
            if(item != null){
                LOGGER.info("take ele from queue {}",idx);
                stat.addAndGet(idx,1);
                return item;
            }
            idx++;
            if(idx == rdnIdx){
                break;
            }
        }

        //走完一轮没有,则随机取一个等待
        LOGGER.info("wait for queue:{}",rdnIdx);
        stat.addAndGet(rdnIdx,1);
        return managedQueues[rdnIdx].take();
    }

    public AtomicLongMap<Integer> getStat() {
        return stat;
    }
}

测试实例

public class WorkStealingDemo {

    static final WorkStealingChannel<String> channel = new WorkStealingChannel<>();

    static volatile boolean running = true;

    static class Producer extends Thread{
        @Override
        public void run() {
            while(running){
                try {
                    channel.put(UUID.randomUUID().toString());
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    static class Consumer extends Thread{
        @Override
        public void run() {
            while(running){
                try {
                    String value = channel.take();
                    System.out.println(value);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    public static void stop(){
        running = false;
        System.out.println(channel.getStat());
    }

    public static void main(String[] args) throws InterruptedException {
        int nCPU = Runtime.getRuntime().availableProcessors();
        int consumerCount = nCPU / 2 + 1;
        for (int i = 0; i < nCPU; i++) {
            new Producer().start();
        }

        for (int i = 0; i < consumerCount; i++) {
            new Consumer().start();
        }

        Thread.sleep(30*1000);
        stop();
    }
}

输出

{0=660972, 1=660613, 2=661537, 3=659846, 4=659918}

原文发布于微信公众号 - 码匠的流水账(geek_luandun)

原文发表时间:2017-09-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏码匠的流水账

聊聊eureka client的fetch-remote-regions-registry属性

本文主要研究一下eureka client的fetch-remote-regions-registry属性

1421
来自专栏编舟记

命令式到函数式编程

应用场景:当我们用到 if-elseif-else 的时候,可以考虑使用 Optional 语义。 举例说明:

782
来自专栏Hongten

python开发_calendar

如果你用过linux,你可能知道在linux下面的有一个强大的calendar功能,即日历

1272
来自专栏ml

HDUOJ----(1030)Delta-wave

Delta-wave Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K...

3447
来自专栏cmazxiaoma的架构师之路

你真的了解Spring MVC处理请求流程吗?

3174
来自专栏函数式编程语言及工具

Scalaz(44)- concurrency :scalaz Future,尚不完整的多线程类型

scala已经配备了自身的Future类。我们先举个例子来了解scala Future的具体操作: 1 import scala.concurrent._ ...

2039
来自专栏小樱的经验随笔

Code forces 719A Vitya in the Countryside

A. Vitya in the Countryside time limit per test:1 second memory limit per test:2...

3496
来自专栏码匠的流水账

聊聊sentinel的ModifyRulesCommandHandler

本文主要研究一下sentinel的ModifyRulesCommandHandler

1171
来自专栏desperate633

Java并发之ScheduledThreadPoolExecutor在Executor中延时执行任务在Executor中周期的执行任务

ScheduledExecutorService类顾名思义,就是可以延迟执行的Executor。如果,对于某些任务,我们并不想马上执行,而是想让任务过一段时间后...

951
来自专栏开发与安全

90% of python in 90 minutes

注:本文整理自 http://www.slideshare.net/MattHarrison4/learn-90 -----------------------...

2180

扫码关注云+社区

领取腾讯云代金券