前言
ThreadLocal 是 Java 语言中的一个类,可以使用它为每个线程存储数据。这些数据只能被当前线程访问,而其他线程无法访问。这个类可以用于避免多次传递、线程间数据隔离、事务操作等场景。
本次源码分析基于 JDK 21.0.1。
ThreadLocal 使用简介
基本操作
使用 ThreadLocal 时,可以将数据存储在一个特殊的对象中,这个对象会被自动关联到当前线程。例如,可以使用以下代码创建一个 ThreadLocal 对象,其中存储了一个整数值:
ThreadLocal<Integer> threadLocalValue = new ThreadLocal<>();
threadLocalValue.set(1);
Integer result = threadLocalValue.get();
如果想要在创建 ThreadLocal 对象时就设置初始值,可以使用 withInitial()
方法,并通过 lambda 表达式传入一个 Supplier 对象,例如:
ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 1);
如果想要删除 ThreadLocal 中的值,可以调用 remove()
方法。例如:
threadLocal.remove();
多线程下 ThreadLocal 使用
以下代码演示了 ThreadLocal 的使用,代码首先创建了 NUM_THREADS
个线程,然后在每个线程内创建了 ThreadLocal。随后,每个线程分别对线程私有的 ThreadLocal 自增 NUM_THREADS
次,并对共享的 sharedValue
自增 NUM_THREADS
次。
import java.util.concurrent.atomic.AtomicInteger;public class Main {private static final int NUM_THREADS = 3;private static final int NUM_INCREMENTS = 5;public static void main(String[] args) {AtomicInteger sharedValue = new AtomicInteger(0);for (int i = 0; i < NUM_THREADS; i++) {new Thread(() -> {ThreadLocal<Integer> threadLocalValue = ThreadLocal.withInitial(() -> 0);for (int j = 0; j < NUM_INCREMENTS; j++) {int localValue = threadLocalValue.get();localValue++;threadLocalValue.set(localValue);int currentValue = sharedValue.get();currentValue++;sharedValue.set(currentValue);}System.out.println("Thread " + Thread.currentThread().getId() + ": Thread-local value = " + threadLocalValue.get() + ", Shared value = " + sharedValue.get());}).start();}}
}
ThreadLocal 源码解析
初始化
使用无参构造器时仅创建一个空的 ThreadLocal 对象:
public ThreadLocal() {}
使用 withInitial
设置 ThreadLocal 初值时,返回的是 SuppliedThreadLocal 类型:
// supplier 为传入的 lambda 表达式public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {// 创建并返回了一个 SuppliedThreadLocalreturn new SuppliedThreadLocal<>(supplier);}
传入的 Supplier 定义如下:
@FunctionalInterface
public interface Supplier<T> {T get();
}
其中 SuppliedThreadLocal 是 ThreadLocal 的静态内部类,它继承了 ThreadLocal 并重写了 initialValue()
方法:
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {private final Supplier<? extends T> supplier;// 将赋初值的 lambda 表达式设置为 supplier 成员变量SuppliedThreadLocal(Supplier<? extends T> supplier) {this.supplier = Objects.requireNonNull(supplier);}@Overrideprotected T initialValue() {return supplier.get();}}
后续第一次调用 get()
时,会调用 SuppliedThreadLocal 重写的 initialValue()
方法,该方法调用了传入的 Supplier 表达式返回 ThreadLocal 初值。
set()
set()
用于设置 ThreadLocal 的值,其实现如下。
public void set(T value) {// 为了设置 ThreadLocal 的值,传入了当前线程set(Thread.currentThread(), value);if (TRACE_VTHREAD_LOCALS) {dumpStackIfVirtualThread();}}private void set(Thread t, T value) {// 获取和当前 ThreadLocal 关联的哈希表ThreadLocalMap map = getMap(t);if (map == ThreadLocalMap.NOT_SUPPORTED) {throw new UnsupportedOperationException();}if (map != null) {// map 已经初始化,则直接设置值map.set(this, value);} else {// lazy 初始化 ThreadLocalMapcreateMap(t, value);}}
首先看 getMap(t)
,它获取了和当前 ThreadLocal 关联的哈希表:
ThreadLocalMap getMap(Thread t) {// 从线程对象获取 ThreadLocalMap,由此可以看出每个对象一个 ThreadLocalMapreturn t.threadLocals;}
t.threadLocals
是 Thread 对象的成员,其类型为 ThreadLocal.ThreadLocalMap
:
public class Thread {...ThreadLocal.ThreadLocalMap threadLocals;...
}
ThreadLocalMap 是 ThreadLocal 类的内部类,它用于存储线程本地变量。 ThreadLocalMap 是 Thread 对象的成员变量,这说明每个线程都有一个 ThreadLocalMap 对象,而 ThreadLocalMap 保存了当前线程拥有的所有 ThreadLocal 对象和对应的变量副本。
回到set()
方法,由 set()
方法可以看出 ThreadLocalMap 是延迟到第一次使用的时候创建的。创建 ThreadLocalMap 的代码如下:
void createMap(Thread t, T firstValue) {// 创建 ThreadLocal 并将关联的线程和赋予的值传入t.threadLocals = new ThreadLocalMap(this, firstValue);}
ThreadLocalMap 是一个专门保存 ThreadLocal 的哈希表,其构造器的实现如下:
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {// 创建哈希表的底层数组table = new Entry[INITIAL_CAPACITY];// 哈希值取余定位int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);// 创建一个 entry 放到槽位中table[i] = new Entry(firstKey, firstValue);size = 1;// 设置哈希表扩容的大小门槛,为总容量的 2/3setThreshold(INITIAL_CAPACITY);}private void setThreshold(int len) {threshold = len * 2 / 3;}
get()
get()
方法用于获取 ThreadLocal 的值,其实现如下。
public T get() {// 根据 Thread 获取值return get(Thread.currentThread());}// 1. 根据 Thread 获取 ThreadLocalMap// 2. 从 ThreadLocalMap 获取 entry,并将 entry 的 value 作为结果返回// 3. 如果 map 为 null,说明未初始化,调用 setInitialValue 进行初始化private T get(Thread t) {ThreadLocalMap map = getMap(t);if (map != null) {ThreadLocalMap.Entry e = map.getEntry(this);if (e != null) {@SuppressWarnings("unchecked")T result = (T) e.value;return result;}}return setInitialValue(t);}private T setInitialValue(Thread t) {// 如果使用无参构造器,返回的是 null// 如果使用了 ThreadLocal.withInitial 创建 ThreadLocal,返回的是 lambda 表达式的结果T value = initialValue();// 获取 ThreadLocalMap,如果是第一次访问则进行初始化ThreadLocalMap map = getMap(t);if (map != null) {map.set(this, value);} else {createMap(t, value);}if (this instanceof TerminatingThreadLocal<?> ttl) {TerminatingThreadLocal.register(ttl);}if (TRACE_VTHREAD_LOCALS) {dumpStackIfVirtualThread();}return value;}
总结
ThreadLocal 可以用于保存线程私有的数据,其源码具有下关键点:
- ThreadLocalMap 的创建是懒加载的;
- ThreadLocal 的实现是通过将一个 ThreadLocalMap 作为 Thread 对象的成员实现的;
- 各个线程的全部 ThreadLocal 都保存在 ThreadLocalMap 中。