目录
概述
1. ThreadLocal
基本原理
使用示例
局限性
2. InheritableThreadLocal
基本原理
使用示例
局限性
3. TransmittableThreadLocal
基本原理
使用示例
核心机制
TransmittableThreadLocal的源码分析
核心代码示例
4. 使用框架提供的上下文传递功能
示例(Spring @Async)
总结
概述
在多线程编程中,我们常常需要在线程之间传递上下文信息。Java 提供了 ThreadLocal
和 InheritableThreadLocal
来帮助管理线程局部变量,但在某些场景下,如线程池和异步执行中,这些工具存在一些局限性。让我们详细探讨这些问题的发展历程,并介绍最终的解决方案。
1. ThreadLocal
基本原理
ThreadLocal
提供了一种机制,使每个线程都可以有自己独立的变量副本,从而避免线程之间的变量共享和竞争。每个线程都有自己的 ThreadLocalMap
,ThreadLocal
的变量存储在其中。
使用示例
public class ThreadLocalExample {private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();public static void main(String[] args) {ExecutorService executor = Executors.newFixedThreadPool(2);threadLocal.set("ValueA");System.out.println("主线程设置值为: ValueA");executor.submit(() -> System.out.println("任务1 ThreadLocal 值: " + threadLocal.get()));threadLocal.set("ValueB");System.out.println("主线程设置值为: ValueB");executor.submit(() -> System.out.println("任务2 ThreadLocal 值: " + threadLocal.get()));executor.shutdown();}
}
在这个示例中,主线程设置了 ThreadLocal
的值为 "ValueA",然后提交了一个任务给线程池。在任务提交后,主线程又将 ThreadLocal
的值设置为 "ValueB" 并提交了第二个任务。由于线程池中的线程会复用,两个任务可能会输出相同的值 "ValueB"。
局限性
- 线程池复用问题:在线程池中,线程会被重复使用。如果一个线程在一次任务中设置了
ThreadLocal
的值,那么该值可能会在后续任务中被误用,从而导致数据污染。 - 上下文丢失:
ThreadLocal
只在线程内有效,不能自动在父子线程之间传递数据。
2. InheritableThreadLocal
为了克服 ThreadLocal
不能在父子线程之间传递数据的问题,Java 引入了 InheritableThreadLocal
。
基本原理
InheritableThreadLocal
是 ThreadLocal
的一个子类,允许父线程的值自动传递给子线程。当创建一个新的子线程时,InheritableThreadLocal
会将父线程的值拷贝到子线程中。
使用示例
public class InheritableThreadLocalExample {private static final InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();public static void main(String[] args) {ExecutorService executor = Executors.newFixedThreadPool(2);inheritableThreadLocal.set("ValueA");System.out.println("主线程设置值为: ValueA");executor.submit(() -> System.out.println("任务1 InheritableThreadLocal 值: " + inheritableThreadLocal.get()));inheritableThreadLocal.set("ValueB");System.out.println("主线程设置值为: ValueB");executor.submit(() -> System.out.println("任务2 InheritableThreadLocal 值: " + inheritableThreadLocal.get()));executor.shutdown();}
}
在这个示例中,主线程设置 InheritableThreadLocal
的值为 "ValueA" 并提交第一个任务。然后,主线程将 InheritableThreadLocal
的值改为 "ValueB" 并提交第二个任务。然而,第二个任务可能仍会打印 "ValueA" 的值,因为线程池中的线程复用了之前的线程上下文。
局限性
- 线程池复用问题:与
ThreadLocal
相同,在线程池中,InheritableThreadLocal
也存在数据污染的问题。子线程不会继承父线程的最新值,而是第一次创建线程时的值。 - 上下文更新问题:如果父线程更新了
InheritableThreadLocal
的值,已经存在的子线程不会反映这些变化。
3. TransmittableThreadLocal
随着应用程序复杂度的增加,尤其是在使用线程池和异步编程时,简单的 ThreadLocal
和 InheritableThreadLocal
已经不能满足需求。TransmittableThreadLocal
(TTL) 由阿里巴巴开源,旨在解决这些问题。
基本原理
TransmittableThreadLocal
通过捕获和恢复上下文信息,并包装线程池和任务,确保在线程执行任务前后进行上下文传递和清理。
使用示例
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;public class TransmittableThreadLocalExample {private static final TransmittableThreadLocal<String> transmittableThreadLocal = new TransmittableThreadLocal<>();public static void main(String[] args) {ExecutorService executor = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(2));transmittableThreadLocal.set("值A");System.out.println("主线程设置值为: 值A");executeTasks(executor, "任务组1");transmittableThreadLocal.set("值B");System.out.println("主线程设置值为: 值B");executeTasks(executor, "任务组2");executor.shutdown();}private static void executeTasks(ExecutorService executor, String taskGroup) {Runnable task = () -> {String value = transmittableThreadLocal.get();System.out.println(taskGroup + " - TransmittableThreadLocal 值: " + value);if (!"值B".equals(value) && "任务组2".equals(taskGroup)) {System.out.println(taskGroup + " - 数据污染检测!预期值为: 值B,但实际值为: " + value);}};for (int i = 0; i < 5; i++) {executor.submit(task);}}
}
在这个示例中,TTL 确保了在线程池中每个任务执行时,能够正确获取到当前线程的上下文数据,而不会受到之前任务的影响。
核心机制
- 捕获上下文:在任务提交前,捕获当前线程的所有
TransmittableThreadLocal
数据。 - 恢复上下文:在任务执行时,恢复捕获的上下文数据,确保子线程能够继承父线程的上下文。
- 清理上下文:在任务执行完毕后,清理子线程的上下文数据,避免数据污染和内存泄漏。
TransmittableThreadLocal的源码分析
TTL 的核心实现主要在 TransmittableThreadLocal.Transmitter
类中进行,它负责捕获、恢复和清理上下文信息。
完整代码示例
// TransmittableThreadLocal 类继承自 InheritableThreadLocal,并实现了 TtlCopier 接口
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {private static final Logger logger = Logger.getLogger(TransmittableThreadLocal.class.getName());private final boolean disableIgnoreNullValueSemantics;// 一个 InheritableThreadLocal 变量,用于存储当前线程的所有 TransmittableThreadLocal 对象private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {// 初始化值protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {return new WeakHashMap<>();}// 复制父线程的值给子线程protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {return new WeakHashMap<>(parentValue);}};// 默认构造函数,初始化 disableIgnoreNullValueSemantics 为 falsepublic TransmittableThreadLocal() {this(false);}// 带参构造函数,可以设置 disableIgnoreNullValueSemantics 的值public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;}// 创建一个带初始值的 TransmittableThreadLocal 实例@NonNullpublic static <S> TransmittableThreadLocal<S> withInitial(@NonNull Supplier<? extends S> supplier) {if (supplier == null) {throw new NullPointerException("supplier is null");} else {return new SuppliedTransmittableThreadLocal<>(supplier, null, null);}}// 创建一个带初始值和复制器的 TransmittableThreadLocal 实例@ParametersAreNonnullByDefault@NonNullpublic static <S> TransmittableThreadLocal<S> withInitialAndCopier(Supplier<? extends S> supplier, TtlCopier<S> copierForChildValueAndCopy) {if (supplier == null) {throw new NullPointerException("supplier is null");} else if (copierForChildValueAndCopy == null) {throw new NullPointerException("ttl copier is null");} else {return new SuppliedTransmittableThreadLocal<>(supplier, copierForChildValueAndCopy, copierForChildValueAndCopy);}}// 创建一个带初始值和不同复制器的 TransmittableThreadLocal 实例@ParametersAreNonnullByDefault@NonNullpublic static <S> TransmittableThreadLocal<S> withInitialAndCopier(Supplier<? extends S> supplier, TtlCopier<S> copierForChildValue, TtlCopier<S> copierForCopy) {if (supplier == null) {throw new NullPointerException("supplier is null");} else if (copierForChildValue == null) {throw new NullPointerException("ttl copier for child value is null");} else if (copierForCopy == null) {throw new NullPointerException("ttl copier for copy value is null");} else {return new SuppliedTransmittableThreadLocal<>(supplier, copierForChildValue, copierForCopy);}}// 复制父值public T copy(T parentValue) {return parentValue;}// 任务执行前的钩子方法,子类可重写protected void beforeExecute() {}// 任务执行后的钩子方法,子类可重写protected void afterExecute() {}// 获取值,必要时添加到 holderpublic final T get() {T value = super.get();if (this.disableIgnoreNullValueSemantics || null != value) {this.addThisToHolder();}return value;}// 设置值,必要时添加到 holderpublic final void set(T value) {if (!this.disableIgnoreNullValueSemantics && null == value) {this.remove();} else {super.set(value);this.addThisToHolder();}}// 移除值,同时从 holder 中移除public final void remove() {this.removeThisFromHolder();super.remove();}private void superRemove() {super.remove();}// 复制当前值private T copyValue() {return this.copy(this.get());}// 将当前对象添加到 holderprivate void addThisToHolder() {if (!((WeakHashMap)holder.get()).containsKey(this)) {((WeakHashMap)holder.get()).put(this, null);}}// 将当前对象从 holder 中移除private void removeThisFromHolder() {((WeakHashMap)holder.get()).remove(this);}// 执行回调方法private static void doExecuteCallback(boolean isBefore) {WeakHashMap<TransmittableThreadLocal<Object>, ?> ttlInstances = new WeakHashMap<>((Map)holder.get());for (TransmittableThreadLocal<Object> threadLocal : ttlInstances.keySet()) {try {if (isBefore) {threadLocal.beforeExecute();} else {threadLocal.afterExecute();}} catch (Throwable t) {if (logger.isLoggable(Level.WARNING)) {logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t, t);}}}}// 打印调试信息static void dump(@Nullable String title) {if (title != null && title.length() > 0) {System.out.printf("Start TransmittableThreadLocal[%s] Dump...%n", title);} else {System.out.println("Start TransmittableThreadLocal Dump...");}for (TransmittableThreadLocal<Object> threadLocal : ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet()) {System.out.println(threadLocal.get());}System.out.println("TransmittableThreadLocal Dump end!");}static void dump() {dump(null);}// Transmitter 类,负责捕获和恢复上下文public static class Transmitter {private static volatile WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> threadLocalHolder = new WeakHashMap<>();private static final Object threadLocalHolderUpdateLock = new Object();private static final Object threadLocalClearMark = new Object();private static final TtlCopier<Object> shadowCopier = parentValue -> parentValue;// 捕获当前线程的上下文信息@NonNullpublic static Object capture() {return new Snapshot(captureTtlValues(), captureThreadLocalValues());}// 捕获 TTL 值private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();for (TransmittableThreadLocal<Object> threadLocal : ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet()) {ttl2Value.put(threadLocal, threadLocal.copyValue());}return ttl2Value;}// 捕获 ThreadLocal 值private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {ThreadLocal<Object> threadLocal = entry.getKey();TtlCopier<Object> copier = entry.getValue();threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));}return threadLocal2Value;}// 重新设置捕获的上下文信息@NonNullpublic static Object replay(@NonNull Object captured) {Snapshot capturedSnapshot = (Snapshot) captured;return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));}// 重新设置 TTL 值@NonNullprivate static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>();Iterator<TransmittableThreadLocal<Object>> iterator = ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet().iterator();while (iterator.hasNext()) {TransmittableThreadLocal<Object> threadLocal = iterator.next();backup.put(threadLocal, threadLocal.get());if (!captured.containsKey(threadLocal)) {iterator.remove();threadLocal.superRemove();}}setTtlValuesTo(captured);TransmittableThreadLocal.doExecuteCallback(true);return backup;}// 重新设置 ThreadLocal 值private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {HashMap<ThreadLocal<Object>, Object> backup = new HashMap<>();for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {ThreadLocal<Object> threadLocal = entry.getKey();backup.put(threadLocal, threadLocal.get());Object value = entry.getValue();if (value == threadLocalClearMark) {threadLocal.remove();} else {threadLocal.set(value);}}return backup;}// 清除上下文信息@NonNullpublic static Object clear() {HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {ThreadLocal<Object> threadLocal = entry.getKey();threadLocal2Value.put(threadLocal, threadLocalClearMark);}return replay(new Snapshot(ttl2Value, threadLocal2Value));}// 恢复上下文信息public static void restore(@NonNull Object backup) {Snapshot backupSnapshot = (Snapshot) backup;restoreTtlValues(backupSnapshot.ttl2Value);restoreThreadLocalValues(backupSnapshot.threadLocal2Value);}// 恢复 TTL 值private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {TransmittableThreadLocal.doExecuteCallback(false);Iterator<TransmittableThreadLocal<Object>> iterator = ((WeakHashMap<TransmittableThreadLocal<Object>, ?>) holder.get()).keySet().iterator();while (iterator.hasNext()) {TransmittableThreadLocal<Object> threadLocal = iterator.next();if (!backup.containsKey(threadLocal)) {iterator.remove();threadLocal.superRemove();}}setTtlValuesTo(backup);}// 设置 TTL 值private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {TransmittableThreadLocal<Object> threadLocal = entry.getKey();threadLocal.set(entry.getValue());}}// 恢复 ThreadLocal 值private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {ThreadLocal<Object> threadLocal = entry.getKey();threadLocal.set(entry.getValue());}}// 使用捕获的上下文信息执行 Supplierpublic static <R> R runSupplierWithCaptured(@NonNull Object captured, @NonNull Supplier<R> bizLogic) {Object backup = replay(captured);try {return bizLogic.get();} finally {restore(backup);}}// 清除上下文信息后执行 Supplierpublic static <R> R runSupplierWithClear(@NonNull Supplier<R> bizLogic) {Object backup = clear();try {return bizLogic.get();} finally {restore(backup);}}// 使用捕获的上下文信息执行 Callablepublic static <R> R runCallableWithCaptured(@NonNull Object captured, @NonNull Callable<R> bizLogic) throws Exception {Object backup = replay(captured);try {return bizLogic.call();} finally {restore(backup);}}// 清除上下文信息后执行 Callablepublic static <R> R runCallableWithClear(@NonNull Callable<R> bizLogic) throws Exception {Object backup = clear();try {return bizLogic.call();} finally {restore(backup);}}// 注册 ThreadLocalpublic static <T> boolean registerThreadLocal(@NonNull ThreadLocal<T> threadLocal, @NonNull TtlCopier<T> copier) {return registerThreadLocal(threadLocal, copier, false);}// 注册带有 ShadowCopier 的 ThreadLocalpublic static <T> boolean registerThreadLocalWithShadowCopier(@NonNull ThreadLocal<T> threadLocal) {return registerThreadLocal(threadLocal, shadowCopier, false);}// 注册 ThreadLocal,带有复制器和是否强制注册的选项public static <T> boolean registerThreadLocal(@NonNull ThreadLocal<T> threadLocal, @NonNull TtlCopier<T> copier, boolean force) {if (threadLocal instanceof TransmittableThreadLocal) {TransmittableThreadLocal.logger.warning("register a TransmittableThreadLocal instance, this is unnecessary!");return true;} else {synchronized (threadLocalHolderUpdateLock) {if (!force && threadLocalHolder.containsKey(threadLocal)) {return false;} else {WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> newHolder = new WeakHashMap<>(threadLocalHolder);newHolder.put(threadLocal, copier);threadLocalHolder = newHolder;return true;}}}}// 注册带有 ShadowCopier 的 ThreadLocal,并带有是否强制注册的选项public static <T> boolean registerThreadLocalWithShadowCopier(@NonNull ThreadLocal<T> threadLocal, boolean force) {return registerThreadLocal(threadLocal, shadowCopier, force);}// 取消注册 ThreadLocalpublic static <T> boolean unregisterThreadLocal(@NonNull ThreadLocal<T> threadLocal) {if (threadLocal instanceof TransmittableThreadLocal) {TransmittableThreadLocal.logger.warning("unregister a TransmittableThreadLocal instance, this is unnecessary!");return true;} else {synchronized (threadLocalHolderUpdateLock) {if (!threadLocalHolder.containsKey(threadLocal)) {return false;} else {WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> newHolder = new WeakHashMap<>(threadLocalHolder);newHolder.remove(threadLocal);threadLocalHolder = newHolder;return true;}}}}private Transmitter() {throw new InstantiationError("Must not instantiate this class");}// Snapshot 类,用于存储捕获的上下文信息private static class Snapshot {final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {this.ttl2Value = ttl2Value;this.threadLocal2Value = threadLocal2Value;}}}// SuppliedTransmittableThreadLocal 类,带有初始值和复制器的 TransmittableThreadLocal 实现private static final class SuppliedTransmittableThreadLocal<T> extends TransmittableThreadLocal<T> {private final Supplier<? extends T> supplier;private final TtlCopier<T> copierForChildValue;private final TtlCopier<T> copierForCopy;SuppliedTransmittableThreadLocal(Supplier<? extends T> supplier, TtlCopier<T> copierForChildValue, TtlCopier<T> copierForCopy) {if (supplier == null) {throw new NullPointerException("supplier is null");} else {this.supplier = supplier;this.copierForChildValue = copierForChildValue;this.copierForCopy = copierForCopy;}}protected T initialValue() {return this.supplier.get();}protected T childValue(T parentValue) {return this.copierForChildValue != null ? this.copierForChildValue.copy(parentValue) : super.childValue(parentValue);}public T copy(T parentValue) {return this.copierForCopy != null ? this.copierForCopy.copy(parentValue) : super.copy(parentValue);}}
}
4. 使用框架提供的上下文传递功能
许多现代框架提供了对线程局部变量和上下文传递的支持。例如,Spring 框架提供了 @Async
注解,可以在异步方法中自动传递上下文信息。
示例(Spring @Async
)
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;@Service
public class AsyncService {@Asyncpublic void asyncMethod(String value) {System.out.println("异步方法执行,传递的值为: " + value);}
}
配置类:
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;@Configuration
@EnableAsync
public class AsyncConfig {
}
调用方法:
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;@Component
public class AsyncCaller {@Autowiredprivate AsyncService asyncService;public void callAsyncMethod() {asyncService.asyncMethod("测试值");}
}
总结
在多线程编程中,ThreadLocal
和 `InheritableThreadLocal 解决了线程局部变量的问题,但在复杂的线程池和异步执行场景下,这些工具存在局限性。
TransmittableThreadLocal通过捕获和恢复上下文信息,并包装线程池和任务,确保上下文的正确传递和清理,是一种有效的解决方案。此外,现代框架提供的上下文传递功能(如 Spring 的
@Async`)也是解决上下文传递问题的有效方式。
选择适合的工具和方法,可以更好地管理上下文数据,确保系统的稳定性和可靠性。通过这些方式,我们可以在复杂的多线程环境中有效地传递和管理上下文信息,避免数据污染和内存泄漏问题。