温馨提示:文章篇幅过长,请耐心阅读
生产者-消费者模式是一个经典的多线程设计模式,它为多线程间的协作提供了良好的解决方案。也经常有面试官会让手写一个生产者消费者,从代码细节可以看出你对多线程编程的熟练程度,今天我们来详细看一下如何写出一个生产者消费者模式,并且逐步对其优化争取做到高性能。
结构剖析
在生产者-消费者模式中,通常有两类线程,一类是生产者线程一类是消费者线程。生产者线程负责提交用户请求,消费者线程则负责处理生产者提交的任务。
最简单粗暴的做法就是生产者每提交一个任务,消费者就立即处理,直到所有任务处理完。但是这样直接通信很容易出现性能上的问题,消费者必须等待它的生产者提交到任务才能执行,就不能达到真正的并行。同时生产者和消费者之间存在依赖关系,在设计上耦合度非常高,这是不可取的。那么最好的做法就是加一个中间层作为通信桥梁。
生产者和消费者之间通过共享内存缓存区进行通信。多个生产者线程将任务提交给共享内存缓存区,消费者线程并不直接与生产者线程通信,而在共享内存缓冲区获取任务,并行地处理。其中内存缓冲区的主要功能是数据再多线程间的共享,同时还可以通过该缓存区,缓解生产者和消费者间的性能差。它是生产者消费者模式的核心组件,既能作为通信的桥梁,又能避免两者直接通信,从而将生产者和消费者进行解耦。生产者不需要消费者的存在,消费者也不需要知道生产者的存在。
由于内存缓冲区的存在,允许生产者和消费者在执行速度上有差异,无论哪一方速度超过另一方,缓冲区都可以得到缓解,确保系统正常运行。
生产者-消费者模式UML图如下:
实测
下面是我用IDEA的工具生成的UML图
其中Queue的选择最为关键,一般情况下应该会选择阻塞式队列,也就是BlockingQueue,但是它的并发效果并不好,所以我选择了ConcurrentLinkedQueue,笔者经过实测,生产和消费100万个数据,使用LinkedBlockingQueue耗时是38秒,并且在处理了防超量生产的情况下仍然产生了超量生产,其原因就是其并发效果不如后者,后者使用的是CAS无锁实现,而使用ConcurrentLinkedQueue耗时是27秒,无超量生产情况产生。其中ConcurrentLinkedQueue运行时占用内存最大时为546MB(jconsole测得)。
代码实现
Data类
package thread.producerConsumer;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.javase_learning</p> * Created in 19:11 2019/5/19 */public class Data { private final int data;
public Data(int data) { this.data = data; }
public int getData() { return data; }
@Override public String toString() { return "Data{" + "data=" + data + '}'; }}
生产者Producer
package thread.producerConsumer;
import java.util.concurrent.ConcurrentLinkedQueue;import java.util.concurrent.atomic.AtomicInteger;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.javase_learning</p> * Created in 19:11 2019/5/19 */public class Producer implements Runnable { // 缓冲队列 private ConcurrentLinkedQueue<Data> queue; // 每个生产者生产数据的数量 private int produceNum;
// 任务计数,同时模拟数据生产 private static AtomicInteger count = new AtomicInteger(0); // 需要生产的总任务数量,出处有耦合性,实际开发时应当处理。 private static AtomicInteger total = new AtomicInteger(Main.NUMS);
public Producer(ConcurrentLinkedQueue<Data> queue, int produceNum) { this.queue = queue; this.produceNum = produceNum; }
@Override public void run() { Data data = null; // 限定生产,防止超量生产 if (total.get() <= 0) { Thread.currentThread().interrupt(); } for (int i = 0; i < produceNum; i++) { // 再次检测,防止超量生产 if (total.get() > 0) { // 构造数据 data = new Data(count.incrementAndGet()); // 向缓冲队列中提交 if (!queue.offer(data)) { System.out.println("Filed to put data:" + data); } else { System.out.println("Successfully produced data:" + data); } // 需要生产的总量-1 total.decrementAndGet(); } else { Thread.currentThread().interrupt(); break; } } }}
消费者Consumer
package thread.producerConsumer;
import java.text.MessageFormat;import java.util.concurrent.ConcurrentLinkedQueue;import java.util.concurrent.CountDownLatch;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.javase_learning</p> * Created in 19:11 2019/5/19 */public class Consumer implements Runnable { private ConcurrentLinkedQueue<Data> queue; private CountDownLatch countDown;
public Consumer(ConcurrentLinkedQueue<Data> queue, CountDownLatch countDown) { this.queue = queue; this.countDown = countDown; }
@Override public void run() { Data data = queue.poll(); if (data != null) { // 计算+1 int res = data.getData() + 1; System.out.println("Data processing: " + MessageFormat.format("{0}+1={1}", data.getData(), res)); // 记录该线程任务已结束 countDown.countDown(); } }}
Main运行主类
package thread.producerConsumer;
import java.util.concurrent.*;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.javase_learning</p> * Created in 19:10 2019/5/19 */public class Main { public static final int NUMS = 1000000;
public static void main(String[] args) {
// 定义缓冲队列 ConcurrentLinkedQueue<Data> queue = new ConcurrentLinkedQueue<>(); // 创建生产者和消费者线程池,为防止冲爆内存,这里限定线程数 ExecutorService producerEs = Executors.newFixedThreadPool(100); ExecutorService consumerEs = Executors.newFixedThreadPool(100); // 定义计时器,等待消费者全部消费完 CountDownLatch countDown = new CountDownLatch(NUMS);
long start = System.currentTimeMillis(); long end = System.currentTimeMillis();
// 向线程池中提交 NUMS 个生产任务 for (int i = 0; i < NUMS; i++) { // 每个生产者以3倍的量生产,模拟生产者和消费者的速度差 producerEs.execute(new Producer(queue,3)); consumerEs.execute(new Consumer(queue, countDown)); } try { countDown.await(); end = System.currentTimeMillis(); producerEs.shutdown(); consumerEs.shutdown(); } catch (InterruptedException e) { e.printStackTrace(); }
System.out.println("Total time:" + (end - start)); }}
Disruptor
Disruptor是一个高效的无锁内存队列,它使用无锁的方式实现了一个有界环形队列,非常适合生产者-消费者模式。它相比于BlockingQueue具有更多的特点:
1. 同一个“事件”可以有多个消费者,消费者之间既可以并行处理,也可以相互依赖形成处理的先后次序(形成一个依赖图);
2. 预分配用于存储事件内容的内存空间;
3. 针对极高的性能目标而实现的极度优化和无锁的设计。
这里简单引用另一篇文章(https://www.jianshu.com/p/8473bbb556af)中的一部分Disruptor介绍
如其名,环形的缓冲区。曾经 RingBuffer 是 Disruptor 中的最主要的对象,但从3.0版本开始,其职责被简化为仅仅负责对通过 Disruptor 进行交换的数据(事件)进行存储和更新。在一些更高级的应用场景中,Ring Buffer 可以由用户的自定义实现来完全替代。
通过顺序递增的序号来编号管理通过其进行交换的数据(事件),对数据(事件)的处理过程总是沿着序号逐个递增处理。一个 Sequence 用于跟踪标识某个特定的事件处理者( RingBuffer/Consumer )的处理进度。虽然一个 AtomicLong 也可以用于标识进度,但定义 Sequence 来负责该问题还有另一个目的,那就是防止不同的 Sequence 之间的CPU缓存伪共享(Flase Sharing)问题。 (注:这是 Disruptor 实现高性能的关键点之一,网上关于伪共享问题的介绍已经汗牛充栋,在此不再赘述)。
Sequencer 是 Disruptor 的真正核心。此接口有两个实现类 SingleProducerSequencer、MultiProducerSequencer ,它们定义在生产者和消费者之间快速、正确地传递数据的并发算法。
用于保持对RingBuffer的 main published Sequence 和Consumer依赖的其它Consumer的 Sequence 的引用。 Sequence Barrier 还定义了决定 Consumer 是否还有可处理的事件的逻辑。
定义 Consumer 如何进行等待下一个事件的策略。 (注:Disruptor 定义了多种不同的策略,针对不同的场景,提供了不一样的性能表现)
在 Disruptor 的语义中,生产者和消费者之间进行交换的数据被称为事件(Event)。它不是一个被 Disruptor 定义的特定类型,而是由 Disruptor 的使用者定义并指定。
EventProcessor 持有特定消费者(Consumer)的 Sequence,并提供用于调用事件处理实现的事件循环(Event Loop)。
Disruptor 定义的事件处理接口,由用户实现,用于处理事件,是 Consumer 的真正实现。
即生产者,只是泛指调用 Disruptor 发布事件的用户代码,Disruptor 没有定义特定接口或类型。
Disruptor代码实现
笔者简单地写了一个用Disruptor实现的程序,结果惊人!同样是100万个数据,时间消耗只用了22秒,时间上和前面ConcurrentLinkedQueue差别不大,但有所提升,但是它的内存消耗居然最高时仅使用了60MB!!!在内存上节省了9倍的空间,可以看到这种内存复用的策略很有效果,在实际开发中很有帮助。
提示:Disruptor是第三方包,需要下载jar包导入项目或者使用Maven导入
数据类Data
package thread.producerConsumer.disruptor;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.disruptor.javase_learning</p> * Created in 21:57 2019/5/19 */public class Data { private long value;
public long getValue() { return value; }
public void setValue(long value) { this.value = value; }}
数据生产工厂
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.EventFactory;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.disruptor.javase_learning</p> * Created in 21:56 2019/5/19 */public class DataFactory implements EventFactory<Data> { @Override public Data newInstance() { return new Data(); }}
生产者Prodocer
消费者Consumer
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.WorkHandler;
import java.text.MessageFormat;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.disruptor.javase_learning</p> * Created in 21:51 2019/5/19 */public class Consumer implements WorkHandler<Data> { @Override public void onEvent(Data data) throws Exception { long res = data.getValue() + 1; System.out.println(MessageFormat.format("Data processing: {0} + 1 = {1}",data.getValue(),res)); }}
运行主类Main
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.RingBuffer;import com.lmax.disruptor.dsl.Disruptor;
import java.nio.ByteBuffer;import java.util.concurrent.ThreadFactory;
/** * @author beifengtz * <a href='http://www.beifengtz.com'>www.beifengtz.com</a> * <p>location: thread.producerConsumer.disruptor.javase_learning</p> * Created in 21:50 2019/5/19 */public class Main { private static final int NUMS = 10; private static final int SUM = 1000000;
public static void main(String[] args) { try { Thread.sleep(10000); } catch (InterruptedException e) { e.printStackTrace(); }
long start = System.currentTimeMillis(); DataFactory factory = new DataFactory(); int bufferSize = 1024; Disruptor<Data> disruptor = new Disruptor<>(factory, bufferSize, new ThreadFactory() { @Override public Thread newThread(Runnable r) { return new Thread(r); } });
// 构造100个消费者 Consumer[] consumers = new Consumer[NUMS]; for (int i = 0; i < NUMS; i++) { consumers[i] = new Consumer(); } disruptor.handleEventsWithWorkerPool(consumers); disruptor.start();
RingBuffer<Data> ringBuffer = disruptor.getRingBuffer(); Producer producer = new Producer(ringBuffer); ByteBuffer bb = ByteBuffer.allocate(8); for (long l = 0; l < SUM; l++) { bb.putLong(0, l); producer.pushData(bb); System.out.println("Successfully produced data: " + l); } long end = System.currentTimeMillis(); disruptor.shutdown(); System.out.println("Total time: " + (end - start)); }}