前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >手把手教你写一个延时队列

手把手教你写一个延时队列

原创
作者头像
派大星在吗
发布2021-12-18 11:10:53
3860
发布2021-12-18 11:10:53
举报
文章被收录于专栏:我的技术专刊我的技术专刊

一、单机

1. while+sleep组合

定义一个线程,然后 while 循环

代码语言:txt
复制
public static void main(String[] args) {
代码语言:txt
复制
    final long timeInterval = 5000;
代码语言:txt
复制
    new Thread(new Runnable() {
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
            while (true) {
代码语言:txt
复制
                System.out.println(Thread.currentThread().getName() + "每隔5秒执行一次");
代码语言:txt
复制
                try {
代码语言:txt
复制
                    Thread.sleep(timeInterval);
代码语言:txt
复制
                } catch (InterruptedException e) {
代码语言:txt
复制
                    e.printStackTrace();
代码语言:txt
复制
                }
代码语言:txt
复制
            }
代码语言:txt
复制
        }
代码语言:txt
复制
    }).start();
代码语言:txt
复制
}

这种实现方式下多个定时任务需要开启多个线程,而且线程在做无意义sleep,消耗资源,性能低下。

2. 最小堆实现

2.1 Timer

实现代码,调度两个任务

代码语言:txt
复制
public static void main(String[] args) {
代码语言:txt
复制
    Timer timer = new Timer();
代码语言:txt
复制
    //每隔1秒调用一次
代码语言:txt
复制
    timer.schedule(new TimerTask() {
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
            System.out.println("test1");
代码语言:txt
复制
        }
代码语言:txt
复制
    }, 1000, 1000);
代码语言:txt
复制
    //每隔3秒调用一次
代码语言:txt
复制
    timer.schedule(new TimerTask() {
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
            System.out.println("test2");
代码语言:txt
复制
        }
代码语言:txt
复制
    }, 3000, 3000);
代码语言:txt
复制
}

schedule实现源码

代码语言:txt
复制
    public void schedule(TimerTask task, long delay, long period) {
代码语言:txt
复制
        if (delay < 0)
代码语言:txt
复制
            throw new IllegalArgumentException("Negative delay.");
代码语言:txt
复制
        if (period <= 0)
代码语言:txt
复制
            throw new IllegalArgumentException("Non-positive period.");
代码语言:txt
复制
        sched(task, System.currentTimeMillis()+delay, -period);
代码语言:txt
复制
    }

shed里面将任务add到最小堆,然后fixUp进行调整

TimerThread其实就是一个任务调度线程,首先从TaskQueue里面获取排在最前面的任务,然后判断它是否到达任务执行时间点,如果已到达,就会立刻执行任务

代码语言:txt
复制
class TimerThread extends Thread {
代码语言:txt
复制
    boolean newTasksMayBeScheduled = true;
代码语言:txt
复制
    private TaskQueue queue;
代码语言:txt
复制
    TimerThread(TaskQueue queue) {
代码语言:txt
复制
        this.queue = queue;
代码语言:txt
复制
    }
代码语言:txt
复制
    public void run() {
代码语言:txt
复制
        try {
代码语言:txt
复制
            mainLoop();
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            // Someone killed this Thread, behave as if Timer cancelled
代码语言:txt
复制
            synchronized(queue) {
代码语言:txt
复制
                newTasksMayBeScheduled = false;
代码语言:txt
复制
                queue.clear();  // Eliminate obsolete references
代码语言:txt
复制
            }
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * The main timer loop.  (See class comment.)
     */
    private void mainLoop() {
        while (true) {
            try {
                TimerTask task;
                boolean taskFired;
                synchronized(queue) {
                    // Wait for queue to become non-empty
                    while (queue.isEmpty() && newTasksMayBeScheduled)
                        queue.wait();
                    if (queue.isEmpty())
                        break; // Queue is empty and will forever remain; die
代码语言:txt
复制
                    // Queue nonempty; look at first evt and do the right thing
代码语言:txt
复制
                    long currentTime, executionTime;
代码语言:txt
复制
                    task = queue.getMin();
代码语言:txt
复制
                    synchronized(task.lock) {
代码语言:txt
复制
                        if (task.state == TimerTask.CANCELLED) {
代码语言:txt
复制
                            queue.removeMin();
代码语言:txt
复制
                            continue;  // No action required, poll queue again
代码语言:txt
复制
                        }
代码语言:txt
复制
                        currentTime = System.currentTimeMillis();
代码语言:txt
复制
                        executionTime = task.nextExecutionTime;
代码语言:txt
复制
                        if (taskFired = (executionTime<=currentTime)) {
代码语言:txt
复制
                            if (task.period == 0) { // Non-repeating, remove
代码语言:txt
复制
                                queue.removeMin();
代码语言:txt
复制
                                task.state = TimerTask.EXECUTED;
代码语言:txt
复制
                            } else { // Repeating task, reschedule
代码语言:txt
复制
                                queue.rescheduleMin(
代码语言:txt
复制
                                  task.period<0 ? currentTime   - task.period
代码语言:txt
复制
                                                : executionTime + task.period);
代码语言:txt
复制
                            }
代码语言:txt
复制
                        }
代码语言:txt
复制
                    }
代码语言:txt
复制
                    if (!taskFired) // Task hasn't yet fired; wait
代码语言:txt
复制
                        queue.wait(executionTime - currentTime);
代码语言:txt
复制
                }
代码语言:txt
复制
                if (taskFired)  // Task fired; run it, holding no locks
代码语言:txt
复制
                    task.run();
代码语言:txt
复制
            } catch(InterruptedException e) {
代码语言:txt
复制
            }
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
}

总结这个利用最小堆实现的方案,相比 while + sleep

方案,多了一个线程来管理所有的任务,优点就是减少了线程之间的性能开销,提升了执行效率;但是同样也带来的了一些缺点,整体的新加任务写入效率变成了

O(log(n))。

同时,细心的发现,这个方案还有以下几个缺点:

串行阻塞:调度线程只有一个,长任务会阻塞短任务的执行,例如,A任务跑了一分钟,B任务至少需要等1分钟才能跑

容错能力差:没有异常处理能力,一旦一个任务执行故障,后续任务都无法执行

2.2 ScheduledThreadPoolExecutor

鉴于 Timer 的上述缺陷,从 Java 5 开始,推出了基于线程池设计的 ScheduledThreadPoolExecutor 。

image

其设计思想是,每一个被调度的任务都会由线程池来管理执行,因此任务是并发执行的,相互之间不会受到干扰。需要注意的是,只有当任务的执行时间到来时,ScheduledThreadPoolExecutor

才会真正启动一个线程,其余时间 ScheduledThreadPoolExecutor 都是在轮询任务的状态。

简单的使用示例:

代码语言:txt
复制
        ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(3);
代码语言:txt
复制
        //启动1秒之后,每隔1秒执行一次
代码语言:txt
复制
        executor.scheduleAtFixedRate(()-> System.out.println("test3"),1,1, TimeUnit.SECONDS);
代码语言:txt
复制
        //启动1秒之后,每隔3秒执行一次
代码语言:txt
复制
        executor.scheduleAtFixedRate((() -> System.out.println("test4")),1,3, TimeUnit.SECONDS);

同样的,我们首先打开源码,看看里面到底做了啥

  • 进入scheduleAtFixedRate()方法

首先是校验基本参数,然后将任务作为封装到ScheduledFutureTask线程中,ScheduledFutureTask继承自RunnableScheduledFuture,并作为参数调用delayedExecute()方法进行预处理

代码语言:txt
复制
public ScheduledFuture<?> scheduleAtFixedRate(Runnable command,
代码语言:txt
复制
                                              long initialDelay,
代码语言:txt
复制
                                              long period,
代码语言:txt
复制
                                              TimeUnit unit) {
代码语言:txt
复制
    if (command == null || unit == null)
代码语言:txt
复制
        throw new NullPointerException();
代码语言:txt
复制
    if (period <= 0)
代码语言:txt
复制
        throw new IllegalArgumentException();
代码语言:txt
复制
    ScheduledFutureTask<Void> sft =
代码语言:txt
复制
        new ScheduledFutureTask<Void>(command,
代码语言:txt
复制
                                      null,
代码语言:txt
复制
                                      triggerTime(initialDelay, unit),
代码语言:txt
复制
                                      unit.toNanos(period));
代码语言:txt
复制
    RunnableScheduledFuture<Void> t = decorateTask(command, sft);
代码语言:txt
复制
    sft.outerTask = t;
代码语言:txt
复制
    delayedExecute(t);
代码语言:txt
复制
    return t;
代码语言:txt
复制
}
  • 继续看delayedExecute()方法

可以很清晰的看到,当线程池没有关闭的时候,会通过super.getQueue().add(task)操作,将任务加入到队列,同时调用ensurePrestart()方法做预处理

代码语言:txt
复制
private void delayedExecute(RunnableScheduledFuture<?> task) {
代码语言:txt
复制
    if (isShutdown())
代码语言:txt
复制
        reject(task);
代码语言:txt
复制
    else {
代码语言:txt
复制
        super.getQueue().add(task);
代码语言:txt
复制
        if (isShutdown() &&
代码语言:txt
复制
            !canRunInCurrentRunState(task.isPeriodic()) &&
代码语言:txt
复制
            remove(task))
代码语言:txt
复制
            task.cancel(false);
代码语言:txt
复制
        else
代码语言:txt
复制
   //预处理
代码语言:txt
复制
            ensurePrestart();
代码语言:txt
复制
    }
代码语言:txt
复制
}

其中super.getQueue()得到的是一个自定义的new DelayedWorkQueue()阻塞队列,数据存储方面也是一个最小堆结构的队列,这一点在初始化new ScheduledThreadPoolExecutor()的时候,可以看出!

代码语言:txt
复制
public ScheduledThreadPoolExecutor(int corePoolSize) {
代码语言:txt
复制
    super(corePoolSize, Integer.MAX_VALUE, 0, NANOSECONDS,
代码语言:txt
复制
          new DelayedWorkQueue());
代码语言:txt
复制
}

打开源码可以看到,DelayedWorkQueue其实是ScheduledThreadPoolExecutor中的一个静态内部类,在添加的时候,会将任务加入到RunnableScheduledFuture数组中。然后调用线程池的ensurePrestart方法将任务添加到线程池。调用链:addWorker->t.run->new Worker.run-> runWorker->Runnable r = timed ?undefined workQueue.poll(keepAliveTime, TimeUnit.NANOSECONDS) :undefined workQueue.take();->task.run->RunnableScheduledFuture.run

代码语言:txt
复制
static class DelayedWorkQueue extends AbstractQueue<Runnable>
代码语言:txt
复制
        implements BlockingQueue<Runnable> {
代码语言:txt
复制
    private static final int INITIAL_CAPACITY = 16;
代码语言:txt
复制
    private RunnableScheduledFuture<?>[] queue =
代码语言:txt
复制
        new RunnableScheduledFuture<?>[INITIAL_CAPACITY];
代码语言:txt
复制
    private final ReentrantLock lock = new ReentrantLock();
代码语言:txt
复制
    private int size = 0;   
代码语言:txt
复制
    //....
代码语言:txt
复制
    public boolean add(Runnable e) {
代码语言:txt
复制
        return offer(e);
代码语言:txt
复制
    }
代码语言:txt
复制
    public boolean offer(Runnable x) {
代码语言:txt
复制
        if (x == null)
代码语言:txt
复制
            throw new NullPointerException();
代码语言:txt
复制
        RunnableScheduledFuture<?> e = (RunnableScheduledFuture<?>)x;
代码语言:txt
复制
        final ReentrantLock lock = this.lock;
代码语言:txt
复制
        lock.lock();
代码语言:txt
复制
        try {
代码语言:txt
复制
            int i = size;
代码语言:txt
复制
            if (i >= queue.length)
代码语言:txt
复制
                grow();
代码语言:txt
复制
            size = i + 1;
代码语言:txt
复制
            if (i == 0) {
代码语言:txt
复制
                queue[0] = e;
代码语言:txt
复制
                setIndex(e, 0);
代码语言:txt
复制
            } else {
代码语言:txt
复制
                siftUp(i, e);
代码语言:txt
复制
            }
代码语言:txt
复制
            if (queue[0] == e) {
代码语言:txt
复制
                leader = null;
代码语言:txt
复制
                available.signal();
代码语言:txt
复制
            }
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
        return true;
代码语言:txt
复制
    }
代码语言:txt
复制
    public RunnableScheduledFuture<?> take() throws InterruptedException {
代码语言:txt
复制
        final ReentrantLock lock = this.lock;
代码语言:txt
复制
        lock.lockInterruptibly();
代码语言:txt
复制
        try {
代码语言:txt
复制
            for (;;) {
代码语言:txt
复制
                RunnableScheduledFuture<?> first = queue[0];
代码语言:txt
复制
                if (first == null)
代码语言:txt
复制
                    available.await();
代码语言:txt
复制
                else {
代码语言:txt
复制
                    long delay = first.getDelay(NANOSECONDS);
代码语言:txt
复制
                    if (delay <= 0)
代码语言:txt
复制
                        return finishPoll(first);
代码语言:txt
复制
                    first = null; // don't retain ref while waiting
代码语言:txt
复制
                    if (leader != null)
代码语言:txt
复制
                        available.await();
代码语言:txt
复制
                    else {
代码语言:txt
复制
                        Thread thisThread = Thread.currentThread();
代码语言:txt
复制
                        leader = thisThread;
代码语言:txt
复制
                        try {
代码语言:txt
复制
                            available.awaitNanos(delay);
代码语言:txt
复制
                        } finally {
代码语言:txt
复制
                            if (leader == thisThread)
代码语言:txt
复制
                                leader = null;
代码语言:txt
复制
                        }
代码语言:txt
复制
                    }
代码语言:txt
复制
                }
代码语言:txt
复制
            }
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            if (leader == null && queue[0] != null)
代码语言:txt
复制
                available.signal();
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
}
  • 回到我们最开始说到的ScheduledFutureTask任务线程类,最终执行任务的其实就是它

ScheduledFutureTask任务线程,才是真正执行任务的线程类,只是绕了一圈,做了很多包装,run()方法就是真正执行定时任务的方法。

代码语言:txt
复制
private class ScheduledFutureTask<V>
代码语言:txt
复制
            extends FutureTask<V> implements RunnableScheduledFuture<V> {
代码语言:txt
复制
    /** Sequence number to break ties FIFO */
代码语言:txt
复制
    private final long sequenceNumber;
代码语言:txt
复制
    /** The time the task is enabled to execute in nanoTime units */
代码语言:txt
复制
    private long time;
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Period in nanoseconds for repeating tasks.  A positive
     * value indicates fixed-rate execution.  A negative value
     * indicates fixed-delay execution.  A value of 0 indicates a
     * non-repeating task.
     */
    private final long period;
代码语言:txt
复制
    /** The actual task to be re-enqueued by reExecutePeriodic */
代码语言:txt
复制
    RunnableScheduledFuture<V> outerTask = this;
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Overrides FutureTask version so as to reset/requeue if periodic.
     */
    public void run() {
        boolean periodic = isPeriodic();
        if (!canRunInCurrentRunState(periodic))
            cancel(false);
        else if (!periodic)//非周期性定时任务
            ScheduledFutureTask.super.run();
        else if (ScheduledFutureTask.super.runAndReset()) {//周期性定时任务,需要重置
            setNextRunTime();
            reExecutePeriodic(outerTask);
        }
    }
代码语言:txt
复制
 //...
代码语言:txt
复制
}

3.3、小结

ScheduledExecutorService 相比 Timer 定时器,完美的解决上面说到的 Timer 存在的两个缺点!

在单体应用里面,使用 ScheduledExecutorService 可以解决大部分需要使用定时任务的业务需求!

但是这是否意味着它是最佳的解决方案呢?

我们发现线程池中 ScheduledExecutorService 的排序容器跟 Timer

一样,都是采用最小堆的存储结构,新任务加入排序效率是O(log(n)),执行取任务是O(1)。

这里的写入排序效率其实是有空间可提升的,有可能优化到O(1)的时间复杂度,也就是我们下面要介绍的 时间轮实现

2.3 DelayQueue

DelayQueue是一个无界延时队列,内部有一个优先队列,可以重写compare接口,按照我们想要的方式进行排序。

实现Demo

代码语言:txt
复制
    public static void main(String[] args) throws Exception {
代码语言:txt
复制
        DelayQueue<Order> orders = new DelayQueue<>();
代码语言:txt
复制
        Order order1 = new Order(1000, "1x");
代码语言:txt
复制
        Order order2 = new Order(2000, "2x");
代码语言:txt
复制
        Order order3 = new Order(3000, "3x");
代码语言:txt
复制
        Order order4 = new Order(4000, "4x");
代码语言:txt
复制
        orders.add(order1);
代码语言:txt
复制
        orders.add(order2);
代码语言:txt
复制
        orders.add(order3);
代码语言:txt
复制
        orders.add(order4);
代码语言:txt
复制
        for (; ; ) {
代码语言:txt
复制
            //没有到期会阻塞
代码语言:txt
复制
            Order take = orders.take();
代码语言:txt
复制
            System.out.println(take);
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
}
代码语言:txt
复制
class Order implements Delayed {
代码语言:txt
复制
    @Override
代码语言:txt
复制
    public String toString() {
代码语言:txt
复制
        return "DelayedElement{" + "delay=" + delayTime +
代码语言:txt
复制
                ", expire=" + expire +
代码语言:txt
复制
                ", data='" + data + '\'' +
代码语言:txt
复制
                '}';
代码语言:txt
复制
    }
代码语言:txt
复制
    Order(long delay, String data) {
代码语言:txt
复制
        delayTime = delay;
代码语言:txt
复制
        this.data = data;
代码语言:txt
复制
        expire = System.currentTimeMillis() + delay;
代码语言:txt
复制
    }
代码语言:txt
复制
    private final long delayTime; //延迟时间
代码语言:txt
复制
    private final long expire;  //到期时间
代码语言:txt
复制
    private String data;   //数据
代码语言:txt
复制
    /**
代码语言:txt
复制
     * 剩余时间=到期时间-当前时间
     */
    @Override
    public long getDelay(TimeUnit unit) {
        return unit.convert(this.expire - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * 优先队列里面优先级规则
     */
    @Override
    public int compareTo(Delayed o) {
        return (int) (this.getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS));
    }

从源码可以看出,DelayQueue的offer和take方法调用的是优先队列的offer和take。并且使用了ReetrtantLock保证线程安全

代码语言:txt
复制
    public boolean offer(E e) {
代码语言:txt
复制
        final ReentrantLock lock = this.lock;
代码语言:txt
复制
        lock.lock();
代码语言:txt
复制
        try {
代码语言:txt
复制
            q.offer(e);
代码语言:txt
复制
            if (q.peek() == e) {
代码语言:txt
复制
                leader = null;
代码语言:txt
复制
                available.signal();
代码语言:txt
复制
            }
代码语言:txt
复制
            return true;
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
public E take() throws InterruptedException {
代码语言:txt
复制
        final ReentrantLock lock = this.lock;
代码语言:txt
复制
        lock.lockInterruptibly();
代码语言:txt
复制
        try {
代码语言:txt
复制
            for (;;) {
代码语言:txt
复制
                E first = q.peek();
代码语言:txt
复制
                if (first == null)
代码语言:txt
复制
                    available.await();
代码语言:txt
复制
                else {
代码语言:txt
复制
                    long delay = first.getDelay(NANOSECONDS);
代码语言:txt
复制
                    if (delay <= 0)
代码语言:txt
复制
                        return q.poll();
代码语言:txt
复制
                    first = null; // don't retain ref while waiting
代码语言:txt
复制
                    if (leader != null)
代码语言:txt
复制
                        available.await();
代码语言:txt
复制
                    else {
代码语言:txt
复制
                        Thread thisThread = Thread.currentThread();
代码语言:txt
复制
                        leader = thisThread;
代码语言:txt
复制
                        try {
代码语言:txt
复制
                            available.awaitNanos(delay);
代码语言:txt
复制
                        } finally {
代码语言:txt
复制
                            if (leader == thisThread)
代码语言:txt
复制
                                leader = null;
代码语言:txt
复制
                        }
代码语言:txt
复制
                    }
代码语言:txt
复制
                }
代码语言:txt
复制
            }
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            if (leader == null && q.peek() != null)
代码语言:txt
复制
                available.signal();
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
    }

https://my.oschina.net/u/2474629/blog/1919127

3. 时间轮实现

代码实现:支持秒级别的循环队列,从下标最小的任务集合开始,提交到线程池执行。然后休眠1s,指针移动到下一个下标处。

所谓时间轮(RingBuffer)实现,从数据结构上看,简单的说就是循环队列,从名称上看可能感觉很抽象。

它其实就是一个环形的数组,如图所示,假设我们创建了一个长度为 8 的时间轮。

image

插入、取值流程:

  • 1.当我们需要新建一个 1s 延时任务的时候,则只需要将它放到下标为 1 的那个槽中,2、3、...、7也同样如此。
  • 2.而如果是新建一个 10s 的延时任务,则需要将它放到下标为 2 的槽中,但同时需要记录它所对应的圈数,也就是 1 圈,不然就和 2 秒的延时消息重复了
  • 3.当创建一个 21s 的延时任务时,它所在的位置就在下标为 5 的槽中,同样的需要为他加上圈数为 2,依次类推...

因此,总结起来有两个核心的变量:

  • 数组下标:表示某个任务延迟时间,从数据操作上对执行时间点进行取余
  • 圈数:表示需要循环圈数

通过这张图可以更直观的理解!

image

当我们需要取出延时任务时,只需要每秒往下移动这个指针,然后取出该位置的所有任务即可,取任务的时间消耗为O(1)。

当我们需要插入任务,也只需要计算出对应的下表和圈数,即可将任务插入到对应的数组位置中,插入任务的时间消耗为O(1)。

如果时间轮的槽比较少,会导致某一个槽上的任务非常多,那么效率也比较低,这就和 HashMap 的 hash

冲突是一样的,因此在设计槽的时候不能太大也不能太小。

代码语言:txt
复制
package com.hui.hui;
代码语言:txt
复制
import java.util.Collection;
代码语言:txt
复制
import java.util.HashSet;
代码语言:txt
复制
import java.util.Map;
代码语言:txt
复制
import java.util.Set;
代码语言:txt
复制
import java.util.concurrent.ConcurrentHashMap;
代码语言:txt
复制
import java.util.concurrent.ExecutorService;
代码语言:txt
复制
import java.util.concurrent.Executors;
代码语言:txt
复制
import java.util.concurrent.TimeUnit;
代码语言:txt
复制
import java.util.concurrent.atomic.AtomicBoolean;
代码语言:txt
复制
import java.util.concurrent.atomic.AtomicInteger;
代码语言:txt
复制
import java.util.concurrent.locks.Condition;
代码语言:txt
复制
import java.util.concurrent.locks.Lock;
代码语言:txt
复制
import java.util.concurrent.locks.ReentrantLock;
代码语言:txt
复制
public class RingBuffer {
代码语言:txt
复制
    private static final int STATIC_RING_SIZE = 64;
代码语言:txt
复制
    private Object[] ringBuffer;
代码语言:txt
复制
    private int bufferSize;
代码语言:txt
复制
    /**
代码语言:txt
复制
     * business thread pool
     */
    private ExecutorService executorService;
代码语言:txt
复制
    private volatile int size = 0;
代码语言:txt
复制
    /***
代码语言:txt
复制
     * task stop sign
     */
    private volatile boolean stop = false;
代码语言:txt
复制
    /**
代码语言:txt
复制
     * task start sign
     */
    private volatile AtomicBoolean start = new AtomicBoolean(false);
代码语言:txt
复制
    /**
代码语言:txt
复制
     * total tick times
     */
    private AtomicInteger tick = new AtomicInteger();
代码语言:txt
复制
    private Lock lock = new ReentrantLock();
代码语言:txt
复制
    private Condition condition = lock.newCondition();
代码语言:txt
复制
    private AtomicInteger taskId = new AtomicInteger();
代码语言:txt
复制
    private Map<Integer, Task> taskMap = new ConcurrentHashMap<>(16);
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Create a new delay task ring buffer by default size
     *
     * @param executorService the business thread pool
     */
    public RingBuffer(ExecutorService executorService) {
        this.executorService = executorService;
        this.bufferSize = STATIC_RING_SIZE;
        this.ringBuffer = new Object[bufferSize];
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Create a new delay task ring buffer by custom buffer size
     *
     * @param executorService the business thread pool
     * @param bufferSize      custom buffer size
     */
    public RingBuffer(ExecutorService executorService, int bufferSize) {
        this(executorService);
代码语言:txt
复制
        if (!powerOf2(bufferSize)) {
代码语言:txt
复制
            throw new RuntimeException("bufferSize=[" + bufferSize + "] must be a power of 2");
代码语言:txt
复制
        }
代码语言:txt
复制
        this.bufferSize = bufferSize;
代码语言:txt
复制
        this.ringBuffer = new Object[bufferSize];
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Add a task into the ring buffer(thread safe)
     *
     * @param task business task extends {@link Task}
     */
    public int addTask(Task task) {
        int key = task.getKey();
        int id;
代码语言:txt
复制
        try {
代码语言:txt
复制
            lock.lock();
代码语言:txt
复制
            int index = mod(key, bufferSize);
代码语言:txt
复制
            task.setIndex(index);
代码语言:txt
复制
            Set<Task> tasks = get(index);
代码语言:txt
复制
            int cycleNum = cycleNum(key, bufferSize);
代码语言:txt
复制
            if (tasks != null) {
代码语言:txt
复制
                task.setCycleNum(cycleNum);
代码语言:txt
复制
                tasks.add(task);
代码语言:txt
复制
            } else {
代码语言:txt
复制
                task.setIndex(index);
代码语言:txt
复制
                task.setCycleNum(cycleNum);
代码语言:txt
复制
                Set<Task> sets = new HashSet<>();
代码语言:txt
复制
                sets.add(task);
代码语言:txt
复制
                put(key, sets);
代码语言:txt
复制
            }
代码语言:txt
复制
            id = taskId.incrementAndGet();
代码语言:txt
复制
            task.setTaskId(id);
代码语言:txt
复制
            taskMap.put(id, task);
代码语言:txt
复制
            size++;
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
        start();
代码语言:txt
复制
        return id;
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Cancel task by taskId
     *
     * @param id unique id through {@link #addTask(Task)}
     * @return
     */
    public boolean cancel(int id) {
代码语言:txt
复制
        boolean flag = false;
代码语言:txt
复制
        Set<Task> tempTask = new HashSet<>();
代码语言:txt
复制
        try {
代码语言:txt
复制
            lock.lock();
代码语言:txt
复制
            Task task = taskMap.get(id);
代码语言:txt
复制
            if (task == null) {
代码语言:txt
复制
                return false;
代码语言:txt
复制
            }
代码语言:txt
复制
            Set<Task> tasks = get(task.getIndex());
代码语言:txt
复制
            for (Task tk : tasks) {
代码语言:txt
复制
                if (tk.getKey() == task.getKey() && tk.getCycleNum() == task.getCycleNum()) {
代码语言:txt
复制
                    size--;
代码语言:txt
复制
                    flag = true;
代码语言:txt
复制
                    taskMap.remove(id);
代码语言:txt
复制
                } else {
代码语言:txt
复制
                    tempTask.add(tk);
代码语言:txt
复制
                }
代码语言:txt
复制
            }
代码语言:txt
复制
            //update origin data
代码语言:txt
复制
            ringBuffer[task.getIndex()] = tempTask;
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
        return flag;
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Thread safe
     *
     * @return the size of ring buffer
     */
    public int taskSize() {
        return size;
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Same with method {@link #taskSize}
     *
     * @return
     */
    public int taskMapSize() {
        return taskMap.size();
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Start background thread to consumer wheel timer, it will always run until you call method {@link #stop}
     */
    public void start() {
        if (!start.get()) {
            System.out.println("Delay task is starting");
            if (start.compareAndSet(start.get(), true)) {
                Thread job = new Thread(new TriggerJob());
                job.setName("consumer RingBuffer thread");
                job.start();
                start.set(true);
            }
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Stop consumer ring buffer thread
     *
     * @param force True will force close consumer thread and discard all pending tasks
     *              otherwise the consumer thread waits for all tasks to completes before closing.
     */
    public void stop(boolean force) {
        if (force) {
            stop = true;
            executorService.shutdownNow();
        } else {
            System.out.println("Delay task is stopping");
            if (taskSize() > 0) {
                try {
                    lock.lock();
                    condition.await();
                    stop = true;
                } catch (InterruptedException e) {
                    System.out.println("InterruptedException" + e);
                } finally {
                    lock.unlock();
                }
            }
            executorService.shutdown();
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    private Set<Task> get(int index) {
代码语言:txt
复制
        return (Set<Task>) ringBuffer[index];
代码语言:txt
复制
    }
代码语言:txt
复制
    private void put(int key, Set<Task> tasks) {
代码语言:txt
复制
        int index = mod(key, bufferSize);
代码语言:txt
复制
        ringBuffer[index] = tasks;
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * Remove and get task list.
     *
     * @param key
     * @return task list
     */
    private Set<Task> remove(int key) {
        Set<Task> tempTask = new HashSet<>();
        Set<Task> result = new HashSet<>();
代码语言:txt
复制
        Set<Task> tasks = (Set<Task>) ringBuffer[key];
代码语言:txt
复制
        if (tasks == null) {
代码语言:txt
复制
            return result;
代码语言:txt
复制
        }
代码语言:txt
复制
        for (Task task : tasks) {
代码语言:txt
复制
            if (task.getCycleNum() == 0) {
代码语言:txt
复制
                result.add(task);
代码语言:txt
复制
                size2Notify();
代码语言:txt
复制
            } else {
代码语言:txt
复制
                // decrement 1 cycle number and update origin data
代码语言:txt
复制
                task.setCycleNum(task.getCycleNum() - 1);
代码语言:txt
复制
                tempTask.add(task);
代码语言:txt
复制
            }
代码语言:txt
复制
            // remove task, and free the memory.
代码语言:txt
复制
            taskMap.remove(task.getTaskId());
代码语言:txt
复制
        }
代码语言:txt
复制
        //update origin data
代码语言:txt
复制
        ringBuffer[key] = tempTask;
代码语言:txt
复制
        return result;
代码语言:txt
复制
    }
代码语言:txt
复制
    private void size2Notify() {
代码语言:txt
复制
        try {
代码语言:txt
复制
            lock.lock();
代码语言:txt
复制
            size--;
代码语言:txt
复制
            if (size == 0) {
代码语言:txt
复制
                condition.signal();
代码语言:txt
复制
            }
代码语言:txt
复制
        } finally {
代码语言:txt
复制
            lock.unlock();
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    private boolean powerOf2(int target) {
代码语言:txt
复制
        if (target < 0) {
代码语言:txt
复制
            return false;
代码语言:txt
复制
        }
代码语言:txt
复制
        int value = target & (target - 1);
代码语言:txt
复制
        if (value != 0) {
代码语言:txt
复制
            return false;
代码语言:txt
复制
        }
代码语言:txt
复制
        return true;
代码语言:txt
复制
    }
代码语言:txt
复制
    private int mod(int target, int mod) {
代码语言:txt
复制
        // equals target % mod
代码语言:txt
复制
        target = target + tick.get();
代码语言:txt
复制
        return target & (mod - 1);
代码语言:txt
复制
    }
代码语言:txt
复制
    private int cycleNum(int target, int mod) {
代码语言:txt
复制
        //equals target/mod
代码语言:txt
复制
        return target >> Integer.bitCount(mod - 1);
代码语言:txt
复制
    }
代码语言:txt
复制
    /**
代码语言:txt
复制
     * An abstract class used to implement business.
     */
    public abstract static class Task extends Thread {
代码语言:txt
复制
        private int index;
代码语言:txt
复制
        private int cycleNum;
代码语言:txt
复制
        private int key;
代码语言:txt
复制
        /**
代码语言:txt
复制
         * The unique ID of the task
         */
        private int taskId;
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
        }
代码语言:txt
复制
        public int getKey() {
代码语言:txt
复制
            return key;
代码语言:txt
复制
        }
代码语言:txt
复制
        /**
代码语言:txt
复制
         * @param key Delay time(seconds)
         */
        public void setKey(int key) {
            this.key = key;
        }
代码语言:txt
复制
        public int getCycleNum() {
代码语言:txt
复制
            return cycleNum;
代码语言:txt
复制
        }
代码语言:txt
复制
        private void setCycleNum(int cycleNum) {
代码语言:txt
复制
            this.cycleNum = cycleNum;
代码语言:txt
复制
        }
代码语言:txt
复制
        public int getIndex() {
代码语言:txt
复制
            return index;
代码语言:txt
复制
        }
代码语言:txt
复制
        private void setIndex(int index) {
代码语言:txt
复制
            this.index = index;
代码语言:txt
复制
        }
代码语言:txt
复制
        public int getTaskId() {
代码语言:txt
复制
            return taskId;
代码语言:txt
复制
        }
代码语言:txt
复制
        public void setTaskId(int taskId) {
代码语言:txt
复制
            this.taskId = taskId;
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    private class TriggerJob implements Runnable {
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
            int index = 0;
代码语言:txt
复制
            while (!stop) {
代码语言:txt
复制
                try {
代码语言:txt
复制
                    Set<Task> tasks = remove(index);
代码语言:txt
复制
                    for (Task task : tasks) {
代码语言:txt
复制
                        executorService.submit(task);
代码语言:txt
复制
                    }
代码语言:txt
复制
                    if (++index > bufferSize - 1) {
代码语言:txt
复制
                        index = 0;
代码语言:txt
复制
                    }
代码语言:txt
复制
                    //Total tick number of records
代码语言:txt
复制
                    tick.incrementAndGet();
代码语言:txt
复制
                    TimeUnit.SECONDS.sleep(1);
代码语言:txt
复制
                } catch (Exception e) {
代码语言:txt
复制
                    System.out.println("Exception" + e);
代码语言:txt
复制
                }
代码语言:txt
复制
            }
代码语言:txt
复制
            System.out.println("Delay task has stopped");
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    public static void main(String[] args) {
代码语言:txt
复制
        RingBuffer ringBufferWheel = new RingBuffer(Executors.newFixedThreadPool(2));
代码语言:txt
复制
        for (int i = 0; i < 3; i++) {
代码语言:txt
复制
            RingBuffer.Task job = new Job();
代码语言:txt
复制
            job.setKey(i);
代码语言:txt
复制
            ringBufferWheel.addTask(job);
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
    public static class Job extends RingBuffer.Task {
代码语言:txt
复制
        @Override
代码语言:txt
复制
        public void run() {
代码语言:txt
复制
            System.out.println("test5"+getIndex());
代码语言:txt
复制
        }
代码语言:txt
复制
    }
代码语言:txt
复制
}

二、分布式

之前说的单机实现,一旦服务器重启,那么延时任务会丢失,而分布式的方案则不会丢失任务。

Redis ZSet实现

  1. 底层实现:Redis的底层实现是当key大小小于某个阈值,并且键值对个数小于某个阈值(都可配置),使用ZipList实现,否则使用SkipList和Hash实现,SkipList中按照score排序,hash存储成员到分数的映射。
  2. ZSet API
  • 添加,如果值存在添加,将会重新排序。zaddundefined127.0.0.1:6379>zadd myZSet 1 zlh ---添加分数为1,值为zlh的zset集合
  • 查看zset集合的成员个数。zcardundefined127.0.0.1:6379>zcard myZSet
  • 查看Zset指定范围的成员,withscores为输出结果带分数。zrangeundefined127.0.0.1:6379>zrange mZySet 0 -1 ----0为开始,-1为结束,输出顺序结果为: zlh tom jim
  • 获取zset成员的下标位置,如果值不存在返回null。zrankundefined127.0.0.1:6379>zrank mZySet Jim ---Jim的在zset集合中的下标为2
  • 获取zset集合指定分数之间存在的成员个数。zcountundefined127.0.0.1:6379>zcount mySet 1 3 ---输出分数>=1 and 分数 <=3的成员个数为3
  1. 实现思路:
  • 添加任务时,将当前时间+延时时间作为SkipList的分词,job的key作为成员标识加入ZSet
  • 搬运线程开启定时任务,将在当前时间戳之前的任务添加到队列中
  • 开启消费线程,无限循环,超时从队列获取Job,将任务放到线程池中消费
  • 添加任务,消费线程,搬运线程,都需要获取Redis分布式锁

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、单机
    • 1. while+sleep组合
      • 2. 最小堆实现
      • 3.3、小结
        • 3. 时间轮实现
          • 二、分布式
            • Redis ZSet实现
        相关产品与服务
        云数据库 Redis
        腾讯云数据库 Redis(TencentDB for Redis)是腾讯云打造的兼容 Redis 协议的缓存和存储服务。丰富的数据结构能帮助您完成不同类型的业务场景开发。支持主从热备,提供自动容灾切换、数据备份、故障迁移、实例监控、在线扩容、数据回档等全套的数据库服务。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档