ThreadLocal分析

本文代码基于Java8

前言

ThreadLocal 的官方API解释为:

该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其 get 或 set 方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。

  1. ThreadLocal 提供了一种访问某个变量的特殊方式:访问到的变量属于当前线程,即保证每个线程的变量不一样,而同一个线程在任何地方拿到的变量都是一致的,这就是所谓的线程隔离。

  2. 如果要使用 ThreadLocal ,通常定义为 private static 类型,最好是定义为 private static final 类型。

ThreadLocal的作用是提供线程内的局部变量,这种变量在线程的生命周期内起作用,减少同一个线程内多个函数或者组件之间一些公共变量的传递的复杂度。

ThreadLocal 类结构

ThreadLocal 类结构如下:

ThreadLocalMapSuppliedThreadLocalThreadLocal 内部类,且 SuppliedThreadLocal 继承自 ThreadLocalEntryThreadLocalMap 内部类。

1
2
3
4
5
6
7
8
9
10
11
public class ThreadLocal<T> {
// 下一个哈希值
private final int threadLocalHashCode = nextHashCode();
// 下一个要给出的哈希值,原子更新,从 0 开始。
private static AtomicInteger nextHashCode = new AtomicInteger();
// 连续生成的散列码之间的差异:将隐式顺序线程局部id转换为近似最优分布的乘法散列值,以获得两个大小表的幂。
private static final int HASH_INCREMENT = 0x61c88647;
// 构造函数
public ThreadLocal() {
}
}

ThreadLocalMap 实现

ThreadLocalMap 是一个定制的散列映射,仅适用于维护线程本地值。该类是包私有的,允许在 Thread 类中声明字段。在 ThreadLocal 类之外不做任何操作。为了帮助处理占用内存大和存活时间长的用法,哈希表 Entry 使用弱引用作为键。但是,由于不使用引用队列,因此只有当表开始耗尽空间时,才保证删除过时的条目。

每个线程可能有多个 ThreadLocal,同一线程的各个ThreadLocal 存放于同一个 ThreadLocalMap

1
2
3
4
5
6
7
8
static class Entry extends WeakReference<ThreadLocal<?>> {
/** 和当前 ThreadLocal 有关 */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

ThreadLocalMap 具体实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
static class ThreadLocalMap {
// 哈希表 Entry
static class Entry extends WeakReference<ThreadLocal<?>> {
/** 该值和当前 ThreadLocal 对象有关 */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
// 哈希表初始容量,必须是 2 的指数
private static final int INITIAL_CAPACITY = 16;
// Entry表,需要的时候会调整表大小,表大小一定是 2 的指数。
private Entry[] table;
// 表中 Entry 数量
private int size = 0;
// Entry表大小调整的阈值,默认是 0
private int threshold; // Default to 0
// 设置阈值为表长度的 2/3
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
// 获取下一个 index,i+1<len 则返回 i + 1,否则返回 0
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 获取上一个 index,i - 1 >= 0 则返回 i-1,否则返回 len-1
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

/** ThreadLocalMap 构造函数 */
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 创建大小为 INITIAL_CAPACITY 的 Entry表
table = new Entry[INITIAL_CAPACITY];
// 根据 firstKey 的哈希值和初始容量获取元素 index
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 设置表的元素
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置表大小调整阈值
setThreshold(INITIAL_CAPACITY);
}

/** 根据 parentMap 构造 ThreadLocalMap */
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

/** 根据 key 获取 Entry */
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
// 在其直接哈希槽中找不到密钥时使用的GetEntry方法。
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

/** ThreadLocalMap set 方法 */
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

/** 从 ThreadLocalMap 中移除 key */
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
// 用指定 key 的项替换 set 操作期间遇到的过时项
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// Start expunge at preceding stale entry if it exists
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

// 通过重新清除位于过时槽和下一个空槽之间的任何可能发生冲突的条目来删除过时的条目
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// expunge entry at staleSlot
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
// 试探性地扫描一些单元格以查找过时的条目。
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
/** 重新哈希 */
private void rehash() {
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}

/** 调整 Entry 表大小,大小变为原来的 2 倍 */
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}

/** 删除表中所有过时的项 */
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}

SuppliedThreadLocal 实现

SuppliedThreadLocal 是 JDK8 新增的内部类,只是扩展了 ThreadLocal 的初始化值的方法而已,允许使用 JDK8 新增的 Lambda 表达式赋值。需要注意的是,函数式接口 Supplier 不允许为 null。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
// Supplier容器
private final Supplier<? extends T> supplier;
// SuppliedThreadLocal 构造方法
SuppliedThreadLocal(Supplier<? extends T> supplier) {
// supplier == null 会抛出空指针异常
this.supplier = Objects.requireNonNull(supplier);
}

@Override
protected T initialValue() {
// 调用get()方法,此时会调用对象的构造方法,即获得到真正对象
// 每次get都会调用构造方法,即获取的对象不同
return supplier.get();
}
}

ThreadLocal 的基本方法

ThreadLocal 的基本方法包括取值、设置初始值、赋值、移除等

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// 取值
public T get() {
Thread t = Thread.currentThread();
// 根据当前线程获取本地线程 Map
ThreadLocalMap map = getMap(t);
if (map != null) { // 本地线程Map不为空
// 获取当前对象对应的本地线程 Map 的 Entry 对象
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) { // Entry 对象不为空
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 设置初始值后返回初始值
return setInitialValue();
}
// 设置初始值
private T setInitialValue() {
// 继承 ThreadLocal 后重写,参考 SuppliedThreadLocal
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
// 本地线程 Map 不存在则创建
createMap(t, value);
return value;
}
// 创建本地线程 Map
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
// 赋值
public void set(T value) {
Thread t = Thread.currentThread();
// 获取当前对象对应的本地线程 Map
ThreadLocalMap map = getMap(t);
if (map != null)
// 覆盖
map.set(this, value);
else
// 本地线程 Map 不存在则创建
createMap(t, value);
}
// 移除
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
// 创建线程局部变量。变量的初始值是通过调用{@code supplier}上的{@code get}方法来确定的。
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

使用场景

变量有局部的还有全局的,局部变量没什么好说的,一涉及到全局,那自然就会出现多线程的安全问题,要保证多线程安全访问,不出现脏读脏写,那就要涉及到线程同步了。而 ThreadLocal 相当于提供了介于局部变量与全局变量中间的这样一种线程内部的全局变量。

当我们只想在本身的线程内使用的变量,可以用 ThreadLocal 来实现,并且这些变量是和线程的生命周期密切相关的,线程结束,变量也就销毁了。 ThreadLocal 不是为了解决线程间的共享变量问题的,如果是多线程都需要访问的数据,那需要用全局变量加同步机制。

  1. 线程中处理一个非常复杂的业务,可能方法有很多,那么,使用 ThreadLocal 可以代替一些参数的显式传递;
  2. 在一些多线程的情况下,如果用线程同步的方式,当并发比较高的时候会影响性能,可以改为 ThreadLocal 的方式,例如高性能序列化框架 Kyro 就要用 ThreadLocal 来保证高性能和线程安全;
  3. 线程内上下文管理器、数据库连接等可以用到 ThreadLocal ;
  4. 用来存储用户 Session。Session 的特性很适合 ThreadLocal ,因为 Session 之前当前会话周期内有效,会话结束便销毁。

内存泄漏问题

ThreadLocal 的不正确使用会导致内存泄漏。实际上 ThreadLocalMap 中使用的 key 为 ThreadLocal 的弱引用,弱引用的特点是,如果这个对象只存在弱引用,那么在下一次垃圾回收的时候必然会被清理掉。JVM 利用调用 remove、get、set 方法的时候,会清除线程 ThreadLocalMap 里所有 key 为 null 的 value,回收弱引用。

所以如果 ThreadLocal 没有被外部强引用的情况下,在垃圾回收的时候会被清理掉的,这样一来 ThreadLocalMap 中使用这个 ThreadLocal 的 key 也会被清理掉。但是,value 是强引用,不会被清理,这样一来就会出现 key 为 null 的 value。

当使用静态 ThreadLocal 的时候,延长 ThreadLocal 的生命周期,那也可能导致内存泄漏。因为,静态变量在类未加载的时候,它就已经加载,当线程结束的时候,静态变量不一定会回收。

ThreadLocal 出现内存泄漏条件:

  1. ThreadLocal 引用被设置为 null,且后面没有 set、get、remove 操作。
  2. 线程一直运行,不停止。(线程池)
  3. 触发了垃圾回收。(Minor GC或Full GC)

如何避免内存泄漏:

  1. ThreadLocal 声明为 private finalprivatefinal 尽可能不让他人修改变更引用,最好不要声明为静态的。
  2. ThreadLocal 使用后务必调用 remove 方法。最简单有效的方法是使用后将其移除。

总结

ThreadLocalMap 并不是为了解决线程安全问题,而是提供了一种将实例绑定到当前线程的机制,类似于隔离的效果。每个线程维护一个 ThreadLocalMap 的映射表,映射表的 key 是 ThreadLocal 实例本身,value 是要存储的副本变量。ThreadLocal 实例本身并不存储值,它只是提供一个在当前线程中找到副本值的 key。

ThreadLocal 设计的初衷是为了解决多线程编程中的资源共享问题。对比 synchronizedsynchronized 采取的是“以时间换空间”的策略,本质上是对关键资源上锁,让大家排队操作。而 ThreadLocal 采取的是“以空间换时间”的思路,为每个使用该变量的线程提供独立的变量副本,在本线程内部,它相当于一个“全局变量”,可以保证本线程任何时间操纵的都是同一个对象。

ThreadLocal 类最重要的一个概念是,其原理是通过一个 ThreadLocal 的静态内部类 ThreadLocalMap 实现,但是实际中,ThreadLocal 不保存 ThreadLocalMap,而是有每个 Thread 内部维护 ThreadLocal.ThreadLocalMap threadLocals 一份数据结构。