每个线程 Thread 内部有 ThreadLocalMap,Map 的键是 ThreadLocal<T>,值是泛型 T。
那么,根据线程 Thread 和 ThreadLocal<T> 才能唯一确定一个实际存储的值 T。
ThreadLocalMap 是 ThreadLocal 的静态内部类,是 ThreadLocal 的关键所在。ThreadLocalMap 是一种 HashMap,通过开放寻址法(每次向后移动一位)和环形数组解决哈希冲突,该类没有显式的继承 Map。
ThreadLocalMap 内部存储是一个 Entry[] table,这里的 Entry 是该类的静态内部类,没有向后的指针,也就是每个位置只能放置一个 Entry,而不像 HashMap 桶数组的每个位置可以放一个链表。
static class ThreadLocalMap { /** * The entries in this hash map extend WeakReference, using * its main ref field as the key (which is always a * ThreadLocal object). Note that null keys (i.e. entry.get() * == null) mean that the key is no longer referenced, so the * entry can be expunged from table. Such entries are referred to * as "stale entries" in the code that follows. */ // 注意到ThreadLocal是弱引用的,而value仍然是强引用 static class Entry extends WeakReference<ThreadLocal<?>> { /** The value associated with this ThreadLocal. */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } /** * The initial capacity -- MUST be a power of two. */ private static final int INITIAL_CAPACITY = 16; /** * The table, resized as necessary. * table.length MUST always be a power of two. */ private Entry[] table; /** * table中Entry非空的个数 */ private int size = 0; /** * The next size value at which to resize. */ private int threshold; // Default to 0 /** *阈值为长度的 2/3 */ private void setThreshold(int len) { threshold = len * 2 / 3; } /** * 增加后取模 */ private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); } /** * 减少后取模 */ private static int prevIndex(int i, int len) { return ((i - 1 >= 0) ? i - 1 : len - 1); } ... }
由于 Entry extends WeakReference<ThreadLocal<?>> ,所以 Entry get() 继承自 Reference,返回相应的泛型引用。在 Entry 内具体指的是返回对应的键 ThreadLocal k。
在后面清理时,称满足e != null && e.get() == null
的 Entry e 为 StaleEntry。
// WeakReference public class WeakReference<T> extends Reference<T> { /** * Creates a new weak reference that refers to the given object. The new * reference is not registered with any queue. * * @param referent object the new weak reference will refer to */ public WeakReference(T referent) { super(referent); } ... } // Reference public abstract class Reference<T> { public T get() { return this.referent; } ... }
以下摘自 https://zhuanlan.zhihu.com/p/158033837
看看官话,为什么要用弱引用。
To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys.
为了处理非常大和生命周期非常长的线程,哈希表使用弱引用作为 key。
ThreadLocal在没有外部对象强引用时如Thread,发生GC时弱引用Key会被回收,而Value是强引用不会回收,如果创建ThreadLocal的线程一直持续运行如线程池中的线程,那么这个Entry对象中的value就有可能一直得不到回收,发生内存泄露。
由于上述原因,下面不断使用 expungeStaleEntry 完成 value/entry 的清除。
有两种构造器,第一种是当前线程的 ThreadLocalMap threadLocals 为空,向里面放第一个 ThreadLocal 和值的情况;第二种是使用父线程的 ThreadLocalMap inheritableThreadLocals 来构造当前子线程的 inheritableThreadLocals。具体解释可以见 2.1 和 2.2。
/** * Construct a new map initially containing (firstKey, firstValue). * ThreadLocalMaps are constructed lazily, so we only create * one when we have at least one entry to put in it. */ ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { table = new Entry[INITIAL_CAPACITY]; int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); table[i] = new Entry(firstKey, firstValue); size = 1; setThreshold(INITIAL_CAPACITY); } /** * Construct a new map including all Inheritable ThreadLocals * from given parent map. Called only by createInheritedMap. * * @param parentMap the map associated with parent thread. */ private ThreadLocalMap(ThreadLocalMap parentMap) { Entry[] parentTable = parentMap.table; int len = parentTable.length; setThreshold(len); table = new Entry[len]; for (int j = 0; j < len; j++) { Entry e = parentTable[j]; if (e != null) { @SuppressWarnings("unchecked") ThreadLocal<Object> key = (ThreadLocal<Object>) e.get(); if (key != null) { // 放入的值是key.childValue(e.value) Object value = key.childValue(e.value); Entry c = new Entry(key, value); int h = key.threadLocalHashCode & (len - 1); // 出现冲突向后找,直到找到一个可以插入的位置 while (table[h] != null) h = nextIndex(h, len); table[h] = c; size++; } } } }
该方法相当于 HashMap 中 set 和 put,如果 map 中有对应的 key,替换 value;没有则插入。
在 for 循环中,从开始计算出的 i 向后查找直到 tab 在当前位置的 Entry 为 null 才退出。后续方法中的 for 退出逻辑均一致。
结果有三种情况
① 在 for 中,当前位置 key 已存在,替换相应的值即可。
② 在 for 中,当前位置 key 为 null,使用 replaceStaleEntry 实现删除 i 处的 Entry 和插入(key,value)的操作。
③ 退出 for,从 i 处检查 l o g 2 s z log_{2}sz log2sz 个元素,尝试清除其中的 StaleEntry,若未实现清除则检查 sz 是否超过阈值,是则执行 rehash。
private void set(ThreadLocal<?> key, Object value) { // We don't use a fast path as with get() because it is at // least as common to use set() to create new entries as // it is to replace existing ones, in which case, a fast // path would fail more often than not. Entry[] tab = table; int len = tab.length; // 找到key初始位置 int i = key.threadLocalHashCode & (len-1); // 执行循环向后查找,直到当前检查的e为null,退出 for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); // 找到对应的k,替换value即可 if (k == key) { e.value = value; return; } // 如果对应位置k为null,通过replaceStaleEntry 实现清除和插入的操作 // 注意,在该位置之后,可能存在和当前 key相同的 k,这个也要考虑到。 // 插入位置为i if (k == null) { replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); int sz = ++size; //由于set只能放入一个,如果前一个条件为true,则至少清除了一个, //加一减一相当于没变,不需要检查是否超过阈值以及 rehash if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }
对于 staleSlot
处的 staleEntry,执行的是放入 (key,value);对于其他的位置(如果存在)即 slotToExpunge
处的 staleEntry,执行清除。
该方法首先要做的就是将 (key,value) 正确放入 staleSlot,如果后续没有出现 key,直接在 staleSlot 处新建一个 Entry即可;如果出现了,进行相应的 Entry 交换和修改。另外, 作者还希望顺带检查一下除了 staleSlot,是否还有其他位置的 staleEntry,有的话执行清除操作。
slotToExpunge
表示找到的 staleEntry 的位置,初始为 staleSlot
,如果找到了其他位置也是 staleEntry,就修改slotToExpunge
。最后会判断 slotToExpunge
是否发生修改,如果修改,说明存在其他的 staleEntry,执行 expungeStaleEntry
和 cleanSomeSlots
完成相应的清除操作。可以用三元表达式来描述: slotToExpunge = 是否存在其他的staleEntry ? 最靠前的位置 : staleSlot
。
具体来说:
查找是否其他的存在 staleEntry,分为两步,第一步是向前查找,找到则修改;第二步是向后查找, 如果找到了新的 staleEntry,不要急着修改,会判断前面是否修改过了,修改过了则不修改,保证存在其他的 staleEntry 时一定存储最靠前的位置,对于后续清除更有帮助。
在第二步向后查找的时候,如果找到了相同的 key,则需要进行交换,保证 tab[staleSlot] 修改为 (key,value) 的 entry,且 tab[i] 变成了 staleEntry。对于不存在相同键的情况,放在退出循环后处理。
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; // Back up to check for prior stale entry in current run. // We clean out whole runs at a time to avoid continual // incremental rehashing due to garbage collector freeing // up refs in bunches (i.e., whenever the collector runs). int slotToExpunge = staleSlot; for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null) slotToExpunge = i; // Find either the key or trailing null slot of run, whichever // occurs first for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); // If we find key, then we need to swap it // with the stale entry to maintain hash table order. // The newly stale slot, or any other stale slot // encountered above it, can then be sent to expungeStaleEntry // to remove or rehash all of the other entries in run. // 这个if和后一个if可以交换顺序,交换后和我的描述一致,便于理解 if (k == key) { // 以下几步实现了staleSlot和i处Entry的互换, // 最后结果是staleSlot存储的e的键和值是传入的参数, // i处的Entry键为null,变成了 staleEntry。 e.value = value; tab[i] = tab[staleSlot]; tab[staleSlot] = e; // Start expunge at preceding stale entry if it exists if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } // If we didn't find stale entry on backward scan, the // first stale entry seen while scanning for key is the // first still present in the run. if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // If key not found, put new entry in stale slot tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // If there are any other stale entries in run, expunge them if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); }
该方法将 staleSlot 处清空,然后顺带着向后检查,如果后面是 staleEntry,清空;如果后面是非 staleEntry 的,重新计算位置,进行相应的移动。返回满足 tab[i] 为 null 的 i。
在重新计算的代码中,旧位置为 i,新位置为 h。① len 改变,则 h 会和 i 不同; ② len 不变,但是 i 并不是上一次计算出的位置,而是计算出的位置又向后移动了一些,这样 h 和 i 也可以不一样。
private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; // expunge entry at staleSlot // 清空 staleSlot tab[staleSlot].value = null; tab[staleSlot] = null; size--; // Rehash until we encounter null Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); // 清除其他的 staleEntry if (k == null) { e.value = null; tab[i] = null; size--; } else { // 重新计算,如果不同,则进行相应的移动。 int h = k.threadLocalHashCode & (len - 1); if (h != i) { tab[i] = null; // Unlike Knuth 6.4 Algorithm R, we must scan until // null because multiple entries could have been stale. while (tab[h] != null) h = nextIndex(h, len); tab[h] = e; } } } return i; }
从 i 向后依次检查 l o g 2 n log_{2}n log2n个元素,如果是 staleEntry 则清除。清除至少一个返回 true,否则是 false。
private boolean cleanSomeSlots(int i, int n) { boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; }
执行 expungeStaleEntries 来检查并清除全部 staleEntry,而后判断 size 是否达到数组一半长度,是执行扩容 resize。
计算可得 3 / 4 ∗ t h r e s h o l d = 3 / 4 ∗ 2 / 3 ∗ l e n = 1 / 2 ∗ l e n 3/4 * threshold = 3/4*2/3*len=1/2*len 3/4∗threshold=3/4∗2/3∗len=1/2∗len
private void rehash() { expungeStaleEntries(); // Use lower threshold for doubling to avoid hysteresis // 达到一半长度执行扩容 if (size >= threshold - threshold / 4) resize(); }
检查并清除全部的 StaleEntry。
private void expungeStaleEntries() { Entry[] tab = table; int len = tab.length; for (int j = 0; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null) expungeStaleEntry(j); } }
Map 中的 table 扩容为原来的两倍。
private void resize() { Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (int j = 0; j < oldLen; ++j) { Entry e = oldTab[j]; if (e != null) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1); while (newTab[h] != null) h = nextIndex(h, newLen); newTab[h] = e; count++; } } } setThreshold(newLen); size = count; table = newTab; }
getEntry 只比较计算出的位置 i 是否满足条件,其他情况交给 getEntryAfterMiss
getEntryAfterMiss 在循环中处理三种情况,一个是 key 相等,直接返回;第二是 k 为 null,说明是 staleEntry,执行删除;第三种则移动位置。 如果最后没有找到,则返回 null。
private Entry getEntry(ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else return getEntryAfterMiss(key, i, e); } private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null; }
在 remove 中,不断查找,找到同样的 key 后,删除,并执行 expungeStaleEntry 来清除该位置。
clear 是 Reference 内部的方法,将 referent 置为 null,在这里指的是 e 的键 ThreadLocal。
private void remove(ThreadLocal<?> key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { if (e.get() == key) { e.clear(); expungeStaleEntry(i); return; } } } // Reference public void clear() { this.referent = null; }
其中 childValue 方法在 ThreadLocal 内抛出异常,在其子类 InheritableThreadLocal 直接返回输入。
该方法只用于线程的 inheritableThreadLocals,用于构建子线程的 inheritableThreadLocals
父线程的 ThreadLocalMap 来构造当前子线程的 inheritableThreadLocals。
本文不介绍 InheritableThreadLocal,需要的可自行查看 https://www.cnblogs.com/hama1993/p/10400265.html
// ThreadLocal T childValue(T parentValue) { throw new UnsupportedOperationException(); } // InheritableThreadLocal public class InheritableThreadLocal<T> extends ThreadLocal<T> { /** * Computes the child's initial value for this inheritable thread-local * variable as a function of the parent's value at the time the child * thread is created. This method is called from within the parent * thread before the child is started. * <p> * This method merely returns its input argument, and should be overridden * if a different behavior is desired. * * @param parentValue the parent thread's value * @return the child thread's initial value */ // 返回输入 protected T childValue(T parentValue) { return parentValue; } }
每次一个新的 ThreadLocal 产生,threadLocalHashCode = 上一个 ThreadLocal 的 threadLocalHashCode + HASH_INCREMENT。这样做可以在 len 为2的幂时,key.threadLocalHashCode & (len - 1)
取得尽可能均匀。
也就是说,这里的 hashCode 与 key 的值无关,与 key 是第几个 ThreadLocal 有关。
// ThreadLocal 类 /** * ThreadLocals rely on per-thread linear-probe hash maps attached * to each thread (Thread.threadLocals and * inheritableThreadLocals). The ThreadLocal objects act as keys, * searched via threadLocalHashCode. This is a custom hash code * (useful only within ThreadLocalMaps) that eliminates collisions * in the common case where consecutively constructed ThreadLocals * are used by the same threads, while remaining well-behaved in * less common cases. */ private final int threadLocalHashCode = nextHashCode(); /** * The next hash code to be given out. Updated atomically. Starts at * zero. */ // 以下三个均为static private static AtomicInteger nextHashCode = new AtomicInteger(); /** * The difference between successively generated hash codes - turns * implicit sequential thread-local IDs into near-optimally spread * multiplicative hash values for power-of-two-sized tables. */ private static final int HASH_INCREMENT = 0x61c88647; /** * Returns the next hash code. */ private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); }
由于其他的 ThreadLocal 的方法通常是当前线程 Thread.currentThread()
内部的 ThreadLocal.ThreadLocalMap threadLocals 执行相应方法所做的操作,这里不再细述,仅举几例供参考。
//初始值为null,通常使用时需要重写 protected T initialValue() { return null; } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } private T setInitialValue() { T value = initialValue(); Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); return value; } // 调用ThreadLocalMap的getEntry // 如果不存在则设置初值 public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); }
void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } // 调用ThreadLocalMap的set public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); } // 调用ThreadLocalMap的remove public void remove() { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null) m.remove(this); }
ThreadLocal 在使用的时候,要求设置为 static 的,防止在多个线程内初始化导致算出的 hashCode 不同。
在线程内可以显式的调用 set 将对应的 ThreadLocal 插入;如果没有显式调用 set ,等到调用 get 时将 initialValue 给插入,该种情况下需要重写 initialValue,否则插入的值为 null。
public class ThreadLocalTest { // 设置为 static private static ThreadLocal<String> sThreadLocal = new ThreadLocal<String>(){ @Override protected String initialValue() { return "该线程使用 initialValue"; } }; public static void main(String args[]) { // 调用set,以set为准 sThreadLocal.set("这是在主线程中"); System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get()); //线程a new Thread(()->{ // 调用set,以set为准 sThreadLocal.set("这是在线程a中"); System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get()); }).start(); //线程b new Thread(()->{ // 直接get,则调用 initialValue。 System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get()) ;},"线程b").start(); } }
结果
线程名字:main---这是在主线程中 线程名字:Thread-0---这是在线程a中 线程名字:线程b---该线程使用 initialValue