专栏首页Java技术大杂烩ThreadLocal 源码解析

ThreadLocal 源码解析

本文将从以下几个方面介绍 前言 栗子 类图 ThreadLocal源码分析 ThreadLocalMap 源码分析 ThreadLocal 可能会导致内存泄漏

前言

ThreadLocal 顾名思义就是在每个线程内部都会存储只有当前线程才能访问的变量的一个副本,然后当前线程修改了该副本的值后而不会影响其他线程的值,各个变量之间相互不影响。

当我们需要共享一个变量,而该变量又不是线程安全的时候,可以使用 ThreadLocal 来复制该变量的一个副本; 再比如使用 SimpleDateFormat 的时候,由于 SimpleDateFormat 不是线程安全的,所以当把它定义为属性的时候,有可能会出现问题;此时可以使用 ThreadLocal 来进行包装 SimpleDateFormat 等等。

栗子

首先看一个不使用 ThreadLocal 的简单不成熟栗子,每个线程都要修改共享变量 i的值

    private int i = 0;

    private void createThread() throws InterruptedException {
        Thread thread = new Thread(() -> {
            i = 0;
            System.out.println(Thread.currentThread().getName() + " : " +  i);
            i+=10;
            System.out.println(Thread.currentThread().getName() + " : " +  i);
        });
        thread.start();
        thread.join();
    }

    public static void main(String[] args) throws InterruptedException {
        Main m = new Main();
        for (int j = 0; j < 5; j++) {
            m.createThread();
        }
    }

输出:
Thread-0 : 0
Thread-0 : 10
Thread-1 : 0
Thread-1 : 10
Thread-2 : 0
Thread-2 : 10
Thread-3 : 0
Thread-3 : 10
Thread-4 : 0
Thread-4 : 10

在每个线程修改该共享变量的值之前,都需要重置该变量的值,之后才会进行修改,这样结果才会符合我们的预期。

接下来看下使用 ThreadLocal 是来实现的:

    private int i = 0;
    // 为每个线程创建变量 i 的副本
    private ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> i);

    private void createThread2() throws InterruptedException {
        Thread thread = new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + " : " + threadLocal.get());
            threadLocal.set(threadLocal.get() + 10);
            System.out.println(Thread.currentThread().getName() + " : " + threadLocal.get());
        });
        thread.start();
        thread.join();
    }

    public static void main(String[] args) throws InterruptedException {
        Main m = new Main();
        for (int j = 0; j < 5; j++) {
            m.createThread2();
        }
    }
输出:
Thread-0 : 0
Thread-0 : 10
Thread-1 : 0
Thread-1 : 10
Thread-2 : 0
Thread-2 : 10
Thread-3 : 0
Thread-3 : 10
Thread-4 : 0
Thread-4 : 10

可以看到,使用 ThreadLocal 同样实现上述的效果,但是不需要再每个线程执行之前重置该共享变量了。

注:使用 join() 方法为了让线程顺序执行,线程1执行完了线程2再执行

源码分析

接下来看下 ThreadLocal 的一个实现

类图

先来看下它的一个类图

从该类图中,可以看到,ThreadLocal 并没有实现任何的类,也没有实现任何的接口,它只有两个内部类,ThreadLocalMapSuppliedThreadLocalThreadLocalMap类中还有一个 Entry 内部类,可以看到,类结构是很简单的。SuppliedThreadLocal 只是为了实现 Java 8 的函数式编程(Lambda表达式),可以忽略。

关于 Java 8 的 Lambda 可以参考 Lambda表达式 : https://my.oschina.net/mengyuankan/blog/1575424 Java 8 中的流--Stream: https://my.oschina.net/mengyuankan/blog/1575565

ThreadLoal 方法

返回值

方法名

描述

T

get

返回当前线程本地变量的值

protected T

initialValue()

初始化当前线程本地变量的值,默认为null,一般需要重写该方法

void

remove()

删除不再使用的 ThreadLocal

void

set(T value)

设置当前线程本地变量的值

ThreadLocal

withInitial(Supplier supplier)

使用Lambda表达式设置初始值,和 initialValue() 作用是一样的

ThreadLocal 的方法使用都比较简单,接下来就看看它们是怎么实现的,

ThreadLocal

public class ThreadLocal<T> {
    // 哈希值
    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode = new AtomicInteger();

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    //防止哈希冲突
    private static final int HASH_INCREMENT = 0x61c88647;

    // 当前线程的本地变量的初始值,默认为null,一般需要重写该方法
    protected T initialValue() {
        return null;
    }

    // Lambda 方式设置初始值
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        return new SuppliedThreadLocal<>(supplier);
    }

    // 构造方法
    public ThreadLocal() {
    }

    // 获取 ThreadLocalMap 
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    // 根据线程,和变量值创建 ThreadLocalMap
    // 每个线程都在自己的 ThreadLocalMap
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

上述是 ThreadLocal 的一些辅助的方法,主要方法 set , get 方法主要是在 ThreadLocalMap 中实现,所以需要在下面结合 ThreadLocalMap 中来说:

ThreadLocalMap

首先,ThreadLocalMap 是一个自定义的哈希映射,仅仅是用来维护线程本地变量的值,ThreadLocalMap 使用 WeakReferences 作为键,为了能够及时的GC.

关于 WeakReferences ,可以参考 java虚拟机之初探:https://my.oschina.net/mengyuankan/blog/1825562

ThreadLocalMap 的定义

    static class ThreadLocalMap {

        // 内部类,有两个属性:ThreadLocal 和 Object
        // ThreadLocal:作为key,当key==ull(即entry.get()== null)表示不再引用该键,因此可以从表中删除
        // Object:本地变量的值
        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        // Entry数组的初始容量,为16,必须为2的幂
        private static final int INITIAL_CAPACITY = 16;

        // Entry数组,可重置大小,数组的长度必须为2的幂
        private Entry[] table;

        // Entry数组中元素的个数
        private int size = 0;

        //Entry扩容的阈值,默认为0
        private int threshold; 

        //设置阈值,为 len 的三分之二
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        // Entry数组的下一个索引
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        // Entry数组的上一个索引
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

     ....方法.........
    }

从上述定义的属性和类可以看到,ThreadLocalMap 主要使用数组来实现的,数组的每一项是一个 Entry 对象,Entry 对象中会持有当前线程的引用和当前线程所绑定的变量值。结构如下所示:

接下来看下 ThreadLocalMap 方法的实现,在该部分中,需要结合 ThreadLocal 的方法一起来看,

get() 方法

// 返回当前线程所绑定的本地变量值,如果当前线程为null,则返回setInitialValue()方法中的值
public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取ThreadLocalMap 
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 在 ThreadLocalMap 中获取当前线程对应的Entry,Entry 中存储了当前线程所绑定的本地变量的值
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            // 获取当前线程所绑定的本地变量的值,并返回
            T result = (T)e.value;
            return result;
        }
    }
    // 如果当前线程没有的 ThreadLocalMap 中,则返回 setInitialValue 中的值
    return setInitialValue();
}

get() 方法不要有以下几步:

  1. 获取当前线程
  2. 获取线程内的 ThreadLocalMap,如果map已经存在,则以当前的ThreadLocal为键,获取Entry对象,并从从Entry中取出值
  3. 如果 map 不存在,则调用setInitialValue方法执行初始化

现在,来看下如何从 ThreadLocalMap中获取当前线程所对应的 Entry 对象:

getEntry() 方法如下

private Entry getEntry(ThreadLocal<?> key) {
    // 获取对应线程的hashcode
    // 计算 Entry数组的索引
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 如果该索引处的Entry对象刚好等于key,则直接返回
    if (e != null && e.get() == key)
        return e;
    else
    // 如果上述条件不满足,则进入 getEntryAfterMiss 方法
        return getEntryAfterMiss(key, i, e);
}

getEntryAfterMiss()方法

该方法主要作用是,当在当前的索引中找不到对应的 Entry 对象时执行,在该方法内部,主要是在 Entry 数组中循环查找对应key,如果key为空,则进行清理操作

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    // 当前 Entry数组
    Entry[] tab = table;
    int len = tab.length;
    // 如果当前Entry数组 i 对应的位置的Entry对象不为空 
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 如果key等于Entry数组 i 对应的位置的Entry对象,则直接返回
        if (k == key)
            return e;
        if (k == null)
            // 如果 Entry 数组 i 对应的位置的 Entry 对象为空,则删除该 Entry 对象,resize Entry数组
            expungeStaleEntry(i);
        else
           // 否则,获取 Entry 数组的下一个索引位置,继续查找
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

expungeStaleEntry()方法

当在 Entry 数组中对应的位置不存在任何引用的时候,进行 Entry 数组的清理操作,resize Entry 数组:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 把当前索引对应位置的对象设置为null
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--; // Entry数组大小减1

    // Rehash
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

在执行完上述方法后,get() 方法就会得到一个 Entry 对象,之后返回该对象的value,该value就是当前线程所绑定的本地变量的值。

在上面所说的 get() 方法中,如果 ThreadLocalMap 不存在,则执行 setInitialValue 进行初始化,下面看下setInitialValue:

setInitialValue()方法

private T setInitialValue() {
    // 调用 initialValue 方法,该方法默认返回null,一般需要重写该方法
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    // 如果 ThreadLocalMap 已存在,则设置初值为 initialValue 方法的返回值
    if (map != null)
        map.set(this, value);
    else
        // 如果 ThreadLocalMap 不存在,则创建
        createMap(t, value);
    return value;
}

ThreadLocalMap.set()方法

ThreadLocalMap 的 set 方法,主要用来设置其对应的值:

private void set(ThreadLocal<?> key, Object value) {
    // 当前的Entry数组
    Entry[] tab = table;
    int len = tab.length;
    // 数组索引
    int i = key.threadLocalHashCode & (len-1);
    // 遍历 Entry 数组
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 如果在 Entry 找到,则设置value
        if (k == key) {
            e.value = value;
            return;
        }
        // 如果当前的ThreadLocal为空,则调用replaceStaleEntry来更换这个key为空的Entry
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 如果在 Entry 数组中没有找到对应的key ,则创建,插入到数组中
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 清理 Entry 数组中为null的项,且如果数组大小大于等于我们设置的阈值,则rehash数组
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

cleanSomeSlots 方法里面还是会调用上面所说的 expungeStaleEntry 方法进行清理 Entry数组为null的项。

rehash()方法

如果当Entry数组的大小大于等于设置的阈值的话,Entry数组就需要进行扩容操作:

private void rehash() {
    // 清空Entry数组
    expungeStaleEntries();
    // 如果 数组大小大于等于 阈值的 3/4,则扩容
    if (size >= threshold - threshold / 4)
        // 扩容
        resize();
}

resize()方法

把 Entry数组的容量扩大为原来的 2 倍:

private void resize() {
    // 旧的数组
    Entry[] oldTab = table;
    // 旧数组的长度
    int oldLen = oldTab.length;
    // 新的数组的长度为旧的的2倍
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    // 复制数据
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 重新计算数组的索引值
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

ThreadLocal的 set() 方法

ThreadLocal 的 set 方法用来设置当前线程所绑定的变量的值,它的实现和setInitialValue差不多:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    // 如果 ThreadLocalMap  存在,则设置值
    if (map != null)
        map.set(this, value);
    else
        // 如果 ThreadLocalMap  不存在则创建
        createMap(t, value);
}

ThreadLocal 的remove() 方法

 public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null)
         // 调用 ThreadLocalMap 的 remove 方法
         m.remove(this);
 }

 private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算索引
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

以上就是 ThreadLocal 的一个实现过程。

ThreadLocal 可能会导致内存泄漏

从上面的代码中可以看到,ThreadLocalMap 使用 ThreadLocal 的弱引用作为key,如果一个 ThreadLocal 没有外部强引用来引用它,那么系统 GC 的时候,这个ThreadLocal 就会被回收,这样一来,ThreadLocalMap 中就会出现 key 为 null 的 Entry,就没有办法访问这些key为null的Entry的value,如果当前线程再迟迟不结束的话,这些key为null的Entry的value永远无法回收,造成内存泄漏。在 ThreadLocal 中 的 get, set 和remove 方法中,都对 Entry 的key进行的null的判断,如果为null,则会 expungeStaleEntry 进行清理操作;

所以,在线程中使用完 ThreadLocal 变量后,要记得及时remove掉。

本文分享自微信公众号 - Java技术大杂烩(tsmyk0715),作者:TSMYK

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-05-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Redis 数据结构-字典源码分析

    字典这种数据结构并不是 Redis 那几种基本数据结构,但是 hash , sets 和 sorted sets 这几种数据结构在底层都是使用字典来实现的(并不...

    Java技术大杂烩
  • Mybatis 解析配置文件的源码解析

    使用过Mybatis 的都知道,Mybatis 有个配置文件,用来配置数据源,别名,一些全局的设置如开启缓存之类的, 在Mybatis 在初始化的时候,会加载该...

    Java技术大杂烩
  • Spring bean 创建过程源码解析

    在上一篇文章 Spring 中 bean 注册的源码解析 中分析了 Spring 中 bean 的注册过程,就是把配置文件中配置的 bean 的信息加载到内...

    Java技术大杂烩
  • ThreadLocal和InheritableThreadLocal深入分析

      通过ThreadLocal和InheritableThreadLocal,我们能够很方便的设计出线程安全的类。JDK底层是如何做到的呢?ThreadLoca...

    良辰美景TT
  • HashMap实现原理分析

    HashMap实现原理分析 HashMap主要是用数组来存储数据的,我们都知道它会对key进行哈希运算,哈系运算会有重复的哈希值,对于哈希值的冲突,HashMa...

    xiangzhihong
  • java进阶|HashTable源码分析和理解

    键值对集合Map,HashTable都是我们常用的,但是随着多线程环境以及代码的普及,ConcurrentHashMap这样的并发集合也常用了起来,今天...

    后端Coder
  • 学习笔记:Hashtable和HashMap

    学了这么些天的基础知识发现自己还是个门外汗,难怪自己一直混的不怎么样。但这样的恶补不知道有没有用,是不是过段时间这些知识又忘了呢?这些知识平时的工作好像都是随拿...

    用户1105954
  • Salesforce市值破万亿!爱因斯坦AI平台将大显身手,首席科学家却离职

    想当年,Salesforce首次亮相的时候,公司市值才刚过10亿美金,一眨眼现在公司的市值已经到1694亿美元,折合人民币超万亿元。

    新智元
  • 利用数据可视化和相关历史背景分析在COVID-19影响下美国股市暴跌

    如果你不知道过去两个月美国股市的下跌,那么你要么是个象牙塔里的大学生,要么是个既没有任何投资又没有要还一辈子的助学贷款的低级工人。不管怎样,不管你是否有没有在股...

    deephub
  • ThreadLocal企业中真实应用

    SimpleDateFormat(下面简称sdf)类内部有一个Calendar对象引用,它用来储存和这个sdf相关的日期信息,例如sdf.parse(dateS...

    公众号 IT老哥

扫码关注云+社区

领取腾讯云代金券