上一篇中我们知道 InheritableThreadLocal 在线程复用场景下是无法进行 ThreadLocal 值传递的。TransmittableThreadLocal(TTL) 是 Alibaba 开源的,用于解决在使用线程池等会池化复用线程的组件情况下,提供 ThreadLocal 值的传递功能,解决异步执行时上下文传递的问题。TransmittableThreadLocal 需要配合 TTL 提供的 TtlExecutors、TtlRunnable 和 TtlCallable 使用,也可以使用 Java Agent 无侵入式实现线程池的传递。另外它继承自 InheritableThreadLocal。
@Test
public void testTtlRunnableTransmittableThreadLocalByThreadPool(){
TransmittableThreadLocal threadLocal = new TransmittableThreadLocal();
IntStream.range(0,10).forEach(i -> {
System.out.println(i);
threadLocal.set(i);
service.execute(TtlRunnable.get(() -> {
System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
}));
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
输出结果:
0
pool-1-thread-1:0
1
pool-1-thread-1:1
2
pool-1-thread-1:2
3
pool-1-thread-1:3
4
pool-1-thread-1:4
5
pool-1-thread-1:5
6
pool-1-thread-1:6
7
pool-1-thread-1:7
8
pool-1-thread-1:8
9
pool-1-thread-1:9
private ExecutorService service = Executors.newFixedThreadPool(1);
@Test
public void testTransmittableThreadLocalByTtlThreadPool(){
service = TtlExecutors.getTtlExecutorService(service);
TransmittableThreadLocal threadLocal = new TransmittableThreadLocal();
IntStream.range(0,10).forEach(i -> {
System.out.println(i);
threadLocal.set(i);
service.execute(() ->
System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get()
));
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
输出结果:
0
pool-1-thread-1:0
1
pool-1-thread-1:1
2
pool-1-thread-1:2
3
pool-1-thread-1:3
4
pool-1-thread-1:4
5
pool-1-thread-1:5
6
pool-1-thread-1:6
7
pool-1-thread-1:7
8
pool-1-thread-1:8
9
pool-1-thread-1:9
可以看出,在配合 TtlExecutors、TtlRunnable 和 TtlCallable 时,TransmittableThreadLocal 可以实现 InheritableThreadLocal 实现不了的效果——线程复用条件下的 ThreadLocal 变量传递。
// Note about holder:
// 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation),
// but it is used as *set*.
// 2. WeakHashMap support null value.
private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =
new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {
@Override
protected Map<TransmittableThreadLocal<?>, ?> initialValue() {
return new WeakHashMap<TransmittableThreadLocal<?>, Object>();
}
@Override
protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {
return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);
}
};
holder 是一个 InheritableThreadLocal 类型的变量,这里使用了一个 WeakHashMap 来存放 initialValue 和 childValue。
@Override
public final void set(T value) {
super.set(value);
// may set null to remove value
if (null == value) removeValue();
else addValue();
}
private void removeValue() {
holder.get().remove(this);
}
private void addValue() {
if (!holder.get().containsKey(this)) {
holder.get().put(this, null); // WeakHashMap supports null value.
}
}
holder.get()获取到的是每次添加值或删除值时都会操作 holder。holder.get()获取到的是一个 Key 为 TransmittableThreadLocal,值为 Object 的 Map。这里在 addValue 时 key 为 TransmittableThreadLocal,值为 null 是为了利用 WeakHashMap 的特性,在没有引用指向 this 时,jvm 会在需要的时候进行 gc。
@Override
public final T get() {
T value = super.get();
if (null != value) addValue();
return value;
}
主要还是利用父类的 get 方法,这里主要是添加了一个 holder 对 ThreadLocal 的管理。
private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
//相当于是做一个快照,放在AtomicReference中(原子引用)
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
@Override
public void run() {
Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
//进行上下文的备份
Object backup = replay(captured);
try {
runnable.run();
} finally {
//恢复备份
restore(backup);
}
}
我们继续看下 replay 和 restore 方法:
@Nonnull
public static Object replay(@Nonnull Object captured) {
//快照的TransmittableThreadLocal map
@SuppressWarnings("unchecked")
Map<TransmittableThreadLocal<?>, Object> capturedMap = (Map<TransmittableThreadLocal<?>, Object>) captured;
//用于备份的TransmittableThreadLocal map
Map<TransmittableThreadLocal<?>, Object> backup = new HashMap<TransmittableThreadLocal<?>, Object>();
for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();
iterator.hasNext(); ) {
Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();
TransmittableThreadLocal<?> threadLocal = next.getKey();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!capturedMap.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set values to captured TTL
setTtlValuesTo(capturedMap);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
public static void restore(@Nonnull Object backup) {
@SuppressWarnings("unchecked")
Map<TransmittableThreadLocal<?>, Object> backupMap = (Map<TransmittableThreadLocal<?>, Object>) backup;
// call afterExecute callback
doExecuteCallback(false);
for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();
iterator.hasNext(); ) {
Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();
TransmittableThreadLocal<?> threadLocal = next.getKey();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backupMap.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backupMap);
}
private static void setTtlValuesTo(@Nonnull Map<TransmittableThreadLocal<?>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : ttlValues.entrySet()) {
@SuppressWarnings("unchecked")
TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey();
threadLocal.set(entry.getValue());
}
}
private static void doExecuteCallback(boolean isBefore) {
for (Map.Entry<TransmittableThreadLocal<?>, ?> entry : holder.get().entrySet()) {
TransmittableThreadLocal<?> threadLocal = entry.getKey();
try {
if (isBefore) threadLocal.beforeExecute();
else threadLocal.afterExecute();
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t);
}
}
}
}
在真正地执行 run 方法前会选对之前线程的 TransmittableThreadLocal 进行备份,在执行完成后进行 restore。其中 beforeExecute 和 afterExecute 是执行之前和之后的回调方法。归纳起来主要有两步:
线程池执行时,执行了 ExecutorTtlWrapper 的 execute 方法,execute 方法中调用了 TtlRunnable.get(command) ,get 方法中创建了一个 TtlRunnable 对象返回了。有兴趣的可以自己去看。
log4j2 MDC:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>log4j2-ttl-thread-context-map</artifactId>
<version>1.2.0</version>
</dependency>