我们从ThreadLocal
的英文注释看起:
This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its {@code get} or {@code set} method) has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).
翻译一下:ThreadLocal
提供一些线程本地(即线程自有)的变量,这些变量之间有一个不同点——在每个线程中存取一个变量(通过ThreadLocal
的get
或者set
方法)的时候,每个线程都会有自己的变量实例(独立初始化的实例)。ThreadLocal
的实例通常会设置成private static
类型,以便将一些状态和某个线程保持关联(比如用户编号或者事物编号)。
(第二句是个长难句,看了好一会才明白😅,their normal counterparts
指的也是这些变量,所以翻译成“这些变量之间有一个不同点”;own
和independently initialized copy of the variabl
是并列状语,都是形容its
,而its
指的是each thread
。)
如何实现“每个线程都有自己独立初始化的变量”呢,继续看注释给出的例子:
public class ThreadId { // Atomic integer containing the next thread ID to be assigned // private static final AtomicInteger nextId = new AtomicInteger(0); // Thread local variable containing each thread's ID private static final ThreadLocal<Integer> threadId = new ThreadLocal<>() { @Override protected Integer initialValue() { return nextId.getAndIncrement(); } }; // Returns the current thread's unique ID, assigning it if necessary public static int get() { return threadId.get(); } } 复制代码
这个例子给每个Thread
设置了独立的ID
,通过重写initialValue
方法实现变量的“独立初始化”,这个方法会在调用get
方法时触发(前提是没有预先自己调用set
方法设置变量),这个会在后面分析,先看下半部分注释:
Each thread holds an implicit reference to its copy of a thread-local variable as long as the thread is alive and the {@code ThreadLocal} instance is accessible; after a thread goes away, all of its copies of thread-local instances are subject to garbage collection (unless other references to these copies exist).
对于一个ThreadLocal
对象,每个线程在存活的时候都保存了一个该对象的隐式引用(implicit reference),并且这个ThreadLocal
对象是可以进行存取数据的。当线程死亡的时候,线程中所有对ThreadLocal
对象的引用都会被提交给垃圾回收(除非仍有其他对ThreadLocal
对象引用存在)
好了,对ThreadLocal
有了直观的认识后,我们来看看它的数据结构是怎样的。
对于上面例子的initialValue
方法,当我们在线程中第一次调用get
方法的时候会被触发:
public T get() { Thread t = Thread.currentThread(); // 从当前线程中获取 ThreadLocalMap ThreadLocalMap map = getMap(t); // 如果map不为空,则寻找对应的变量(如果已经调用过get、set,则map不为空) if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } // 如果为空,则设置预定义的初始化变量(一般是当前线程首次调用该ThreadLocal对象的get方法) return setInitialValue(); } ThreadLocalMap getMap(Thread t) { // 注意了,这里的threadLocals是Thread里面的变量 return t.threadLocals; } private T setInitialValue() { // 调用initialValue,生成预定义的变量 T value = initialValue(); Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); // 将这个变量保存到map中 if (map != null) map.set(this, value); else // 这里会去创建threadLocals对象 createMap(t, value); return value; } // 如果你需要预定义变量,就继承ThreadLocal并重写这个方法 protected T initialValue() { return null; } // 创建ThreadLocalMap void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } 复制代码
可以看到,get
方法会使用Thread
中的threadLocals
进行保存数据,也就是说,ThreadLocal
并没有把数据存在本身,而是放在了对应的线程上。我们来看看Thread
中的这个变量:
/* ThreadLocal values pertaining to this thread. This map is maintained * by the ThreadLocal class. */ ThreadLocal.ThreadLocalMap threadLocals = null; 复制代码
ThreadLocalMap
是什么呢,它的英文介绍是这样的:
ThreadLocalMap is a customized hash map suitable only for maintaining thread local values. No operations are exported outside of the ThreadLocal class. The class is package private to allow declaration of fields in class Thread. To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys. However, since reference queues are not used, stale entries are guaranteed to be removed only when the table starts running out of space.
大概意思是:ThreadLocalMap
是一个仅用来维护线程本地数据的自定义hash map
,仅能在ThreadLocal
内进行操作(它是package private
的静态内部类),它与Thread
同包,可以在其中声明ThreadLocalMap
变量。为了方便回收内存,这里的hash map
的key
使用了WeakReferences
进行引用。但是呢,当key
不再被其他地方引用的时候,脏数据只会在ThreadLocalMap
空间即将耗尽的时候进行移除(rehash的过程中移除)
我们暂时只关注它的数据结构。
观看源码后知道,这个所谓的hash map
其实是使用一个数组来保存数据的:
static class ThreadLocalMap { static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } private Entry[] table; } 复制代码
这个数组的名字叫table
,数组元素是Entry
,它是一个键值对,key
是使用了WeakReference
引用ThreadLocal
变量,value
就是我们保存的数据。所以,ThreadLocalMap
本质上是一个数组,数组的元素是一个个的键值对。WeakReference
就是前面所说的隐式引用(implicit reference),当ThreadLocal
对象没有被其他地方引用的时候,它就可以被正常回收掉了。
这里先说结论:ThreadLocal
对象将自己作为索引(key
),绑定索要储存的变量(value
),形成一个键值对(Entry
),存入线程中(Thread
的ThreadLocalMap
中),大致结构如下:
上图有两个ThreadLocal
对象,在三个线程中各自保存了三个value
,这里注意了,一个ThreadLocal
对象在一个线程中只能保存一个value
,正如前的面介绍所说。如何做到的呢,我们来分析一下ThreadLocal
与ThreadLocalMap
的存取数据相关的方法吧。
先来看ThreadLocalMap
的set
方法:
private void set(ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; // 计算数组下标(这里使用了一个特殊的数字,暂时可以不用理解) int i = key.threadLocalHashCode & (len-1); // 遍历数组,当遍历到数组中的空元素的时候才跳出循环,执行后面的代码 for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { // 从之前的讲解知道,e是弱引用,这里获取ThreadLocal对象 ThreadLocal<?> k = e.get(); // 如果k与key相同,则替换value,结束方法 if (k == key) { e.value = value; return; } // 如果k为空(被回收了),则替换key和value,结束方法 if (k == null) { replaceStaleEntry(key, value, i); return; } } // 如果上面的循环遍历到空的元素,则直接新建一个Entry tab[i] = new Entry(key, value); // 长度+1 int sz = ++size; // 清除脏数据,并且长度大于threshold的时候进行rehash(数组长度翻倍) // 这里threshold为数组长度的2/3,所以数组永远都有空的元素,上面的循环不至于变成死循环 if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); } 复制代码
在上面代码的遍历中看到,当传入的key
(ThreadLocal
对象)已经存在于数组中的时候,会去替换value
,并且从ThreadLocal
的set
方法中可以看到,传入的key
就是当前的ThreadLocal
对象:
public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) // 传入当前对象本身 map.set(this, value); else createMap(t, value); } 复制代码
这里可以看出对于一个ThreadLocal
对象,key
是固定的,是当前ThreadLocal
对象;而获取的当前线程可能不同,也就是存入的map
是属于不同线程的,而因为key
只有一个,所以一个ThreadLocal
对象在一个线程中只能保存一个value
,而在每个线程中都可以存一个value
,这样各个value
就独立在各个线程中了。
同样先看ThreadLocalMap
的getEntry
方法:
private Entry getEntry(ThreadLocal<?> key) { // 获取下标 int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else // 遍历数组寻找key,并回收发现的脏数据 return getEntryAfterMiss(key, i, e); } 复制代码
这个方法很简单,直接从table
中取出了相关数据,如果对应下标的数据不一致,就会去遍历table
寻找对应的key
。关于ThreadLocal
的get
方法在一开始已经说过,就不再赘述了。
好了,ThreadLocal
的结构大致了解了,接下来继续细看。
在set
方法中:
int i = key.threadLocalHashCode & (len-1); 复制代码
key
为当前ThreadLocal
对象,threadLocalHashCode
为ThreadLocal
中的变量,通过nextHashCode
方法获取:
/** * ThreadLocals rely on per-thread linear-probe hash maps attached * to each thread (Thread.threadLocals and * inheritableThreadLocals). The ThreadLocal objects act as keys, * searched via threadLocalHashCode. This is a custom hash code * (useful only within ThreadLocalMaps) that eliminates collisions * in the common case where consecutively constructed ThreadLocals * are used by the same threads, while remaining well-behaved in * less common cases. */ private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647; private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); } 复制代码
我们看到nextHashCode
是个静态的AtomicInteger
,每创建一个ThreadLocal
对象nextHashCode
都会增加固定的值,每个ThreadLocal
对象都能获得一个不一样的threadLocalHashCode
。
threadLocalHashCode
的注释说到了本质:ThreadLocal
依赖于每个线程中的线性hash map
,ThreadLocal
对象作为key
,通过threadLocalHashCode
得到自己的下标。当多个连续创建的ThreadLocal
对象都在同一个线程中保存数据的时候,这个自定义的hash code
能最大限度地排除哈希碰撞(碰撞就是产生了相同的下标)。
接下来我想试验一下,我们参考ThreadLocalMap
的构造方法设置初始数据:
/** * The initial capacity -- MUST be a power of two. */ private static final int INITIAL_CAPACITY = 16; ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { table = new Entry[INITIAL_CAPACITY]; int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); table[i] = new Entry(firstKey, firstValue); size = 1; setThreshold(INITIAL_CAPACITY); } 复制代码
table
的初始长度为16(数组长度要求为2的倍数,这样会方便计算,比如16-1的二进制为1111,方便&运算时取数字的低位,比如 1111010 & 1111 = 001010,取了低4位),如果创建了与table
长度相同数量的ThreadLocal
对象,并在同一个线程中都保存数据,它们的下标是怎样的?
private static final int INITIAL_CAPACITY = 16; private static final int HASH_INCREMENT = 0x61c88647; public static void main(String[] args) throws Exception { // 新建一个AtomicInteger,此时nextHashCode.get()为0 AtomicInteger nextHashCode = new AtomicInteger(); // 测试哈希值 int threadLocalHashCode; // 假设创建了16个ThreadLocal对象 for (int j = 0; j < INITIAL_CAPACITY; j++) { // 连续生成threadLocalHashCode,一个code代表一个ThreadLocal threadLocalHashCode = nextHashCode.getAndAdd(HASH_INCREMENT); System.out.println("threadLocalHashCode: " + threadLocalHashCode); // 计算每个ThreadLocal保存的index int index = threadLocalHashCode & (INITIAL_CAPACITY - 1); System.out.println("index: " + index); } // j 在 INITIAL_CAPACITY 之内,产生的 index 没有重复 } 复制代码
我们看看输出:
threadLocalHashCode: 0 index: 0 threadLocalHashCode: 1640531527 index: 7 threadLocalHashCode: -1013904242 index: 14 threadLocalHashCode: 626627285 index: 5 threadLocalHashCode: -2027808484 index: 12 threadLocalHashCode: -387276957 index: 3 threadLocalHashCode: 1253254570 index: 10 threadLocalHashCode: -1401181199 index: 1 threadLocalHashCode: 239350328 index: 8 threadLocalHashCode: 1879881855 index: 15 threadLocalHashCode: -774553914 index: 6 threadLocalHashCode: 865977613 index: 13 threadLocalHashCode: -1788458156 index: 4 threadLocalHashCode: -147926629 index: 11 threadLocalHashCode: 1492604898 index: 2 threadLocalHashCode: -1161830871 index: 9 复制代码
很神奇,在table
数组长度之内,产生的 index 没有重复,这就是没有哈希碰撞的效果。
这个很简单,主要为了方便ThreadLocal
对象回收。
但是,但是,但是要注意一种内存泄漏的情况:
ThreadLocal的原理是操作Thread内部的一个ThreadLocalMap,这个Map的Entry继承了WeakReference,设值完成后map中是(WeakReference,value)这样的数据结构。Java中的弱引用在内存不足的时候会被回收掉,回收之后变成(null,value)的形式,key被收回掉了。 如果这个线程执行完之后销毁,value也会被回收,这样也不会出现内存泄露。但如果是在线程池中,线程 执行完后不被回收,而是返回线程池中。此时Thread有个强引用 指向 ThreadLocalMap,ThreadLocalMap有强引用 指向 Entry,导致Entry中key为null的value无法被回收,一直存在内存中。在执行了ThreadLocal.set()方法之后一定要记得使用ThreadLocal.remove(),将不要的数据移除掉,避免内存泄漏。
相信大家看源码的时候也看到了:
// sThreadLocal.get() will return null unless you've called prepare(). static final ThreadLocal<Looper> sThreadLocal = new ThreadLocal<Looper>(); private static void prepare(boolean quitAllowed) { if (sThreadLocal.get() != null) { throw new RuntimeException("Only one Looper may be created per thread"); } sThreadLocal.set(new Looper(quitAllowed)); } /** * Return the Looper object associated with the current thread. Returns * null if the calling thread is not associated with a Looper. */ public static @Nullable Looper myLooper() { return sThreadLocal.get(); } 复制代码
如果明白了ThreadLocal
的原理,上面这段代码应该秒懂了吧~~😉
最后,来吃一波水果~~ 🍇🍈🍉🍊🍋🍌🍍🍎🍏🍐🍑🍒🍓