浅析JDK1.7 ConcurrentHashMap源码。
写在开篇
先贴张图看下ConcurrentHashMap JDK 1.7的结构: 
                
先大体介绍一下:ConcurrentHashMap 是由Segment数组结构和HashEntry数组结构组成。
- Segment是一种可重入锁ReentrantLock,在ConcurrentHashMap里扮演锁的角色,HashEntry则用于存储键值对数据
- 一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,这与HashMap结构相似
简单来讲,就是ConcurrentHashMap比HashMap多了一次hash过程,第1次hash定位到Segment,第2次hash定位到HashEntry,然后链表搜索找到指定节点。
该种实现方式的缺点是hash过程比普通的HashMap要长,但是优点也很明显,在进行写操作时,只需锁住写元素所在的Segment即可,其他Segment无需加锁,提高了并发读写的效率。
Segment
先看下Segment的定义,是ConcurrentHashMap的一个静态内部类,继承了ReentrantLock。1
static final class Segment<K,V> extends ReentrantLock implements Serializable {
接着看下其中几个重要的属性:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19/**
 * 一个HashEntry数组table,HashEntry是链表的节点定义
 */
transient volatile HashEntry<K,V>[] table;
/**
 * Segment中Entry的数量
 */
transient int count;
/**
 * 阈值
 */
transient int threshold;
/**
 * 负载因子
 */
final float loadFactor;
HashEntry
| 1 | static final class HashEntry<K,V> { | 
由定义可知,value和next均为使用volatile修饰,一个线程对该Segment内部的某个链表节点HashEntry的value或下一个节点next修改能够对其他线程可见。
ConcurrentHashMap
类的继承关系
| 1 | public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> | 
继承于抽象的AbstractMap,实现了ConcurrentMap,Serializable这两个接口。
几个重要的属性
| 1 | /** | 
segmentShift和segmentMask,这两个全局变量在定位segment时的哈希算法里需要使用。
几个重要的默认常量
| 1 | /** | 
ConcurrentHashMap的构造函数
- Segment数组初始化后,不能再扩容
- 只初始化了segment[0],其他位置仍然是 null
| 1 | public ConcurrentHashMap(int initialCapacity, | 
put()
在ConcurrentHashMap中的put操作是没有加锁的,而在Segment中的put操作,通过ReentrantLock加锁。
首先通过key的hash确定segments数组的下标,即需要往哪个segment存放数据。确定好segment之后,则调用该segment的put方法,写到该segment内部的table数组的某个链表中。
先看下ConcurrentHashMap中的put():1
2
3
4
5
6
7
8
9
10
11
12
13
14public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
        throw new NullPointerException();
    // 根据key的hash,确定具体的Segment
    int hash = hash(key);
    int j = (hash >>> segmentShift) & segmentMask;
    // 如果segments数组的该位置还没segment,初始化
    if ((s = (Segment<K,V>)UNSAFE.getObject          
         (segments, (j << SSHIFT) + SBASE)) == null) 
        s = ensureSegment(j);
    // 插入新值至槽s中
    return s.put(key, hash, value, false);
}
看下ensureSegment()如何初始化槽:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
    	// 使用当前 segment[0] 处的数组长度和负载因子来初始化 segment[k]
        Segment<K,V> proto = ss[0]; 
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        // 初始化 segment[k] 内部的数组
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                // 循环CAS赋值给Segment[]后退出
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}
接下来是Segment中的put(),首先获取lock锁,然后根据key的hash值,获取在segment内部的HashEntry数组table的下标,从而获取对应的链表:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    /**
	 * tryLock:非阻塞获取lock
	 * scanAndLockForPut:该segment锁被其他线程持有了,则非阻塞重试3次,超过3次则阻塞等待锁。
	 */
    HashEntry<K,V> node = tryLock() ? null :
    scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        int index = (tab.length - 1) & hash;
        // 链表头结点
        HashEntry<K,V> first = entryAt(tab, index);
        for (HashEntry<K,V> e = first;;) {
            // table中已存在结点
            if (e != null) {
                K k;
                // 已经存在,则更新value值
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    if (!onlyIfAbsent) {
                        e.value = value;
                        // 更新value时,也递增modCount,而在HashMap中是结构性修改才递增。
                        ++modCount;
                    }
                    break;
                }
                e = e.next;
            }
            // 头插法新增结点
            else {
                if (node != null)
                    node.setNext(first);
                else
                    node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                // 扩容
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    rehash(node);
                else
                    setEntryAt(tab, index, node);
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        unlock();
    }
    return oldValue;
}
我们详细看一下scanAndLockForPut()是怎么实现的:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1; 
    // 非阻塞自旋获取lock锁
    while (!tryLock()) {
        HashEntry<K,V> f; 
        if (retries < 0) {
            if (e == null) {
                if (node == null) 
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            else if (key.equals(e.key))
                retries = 0;
            else
                e = e.next;
        }
        // MAX_SCAN_RETRIES为2,尝试3次后,则当前线程阻塞等待lock锁
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();
            break;
        }
        // 如果链表被修改过,重置retries
        else if ((retries & 1) == 0 &&
                 (f = entryForHash(this, hash)) != first) {
            e = first = f; 
            retries = -1;
        }
    }
    return node;
}
扩容
接下来看下Segment的rehash():1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    // 新容量为旧容量的2倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    // 创建新数组
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    // 新掩码
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 元素在新数组中的位置
            int idx = e.hash & sizeMask;
            if (next == null) // 该位置只有一个元素
                newTable[idx] = e;
            else {
            	// e是链表表头
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                // for 循环找到一个 lastRun 结点,这个结点之后的所有元素是将要放到一起的
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                // 将lastRun及其之后的所有结点组成的这个链表放到 lastIdx这个位置
                newTable[lastIdx] = lastRun;
                // 下面的操作是处理lastRun之前的结点,看看分配在哪个链表中
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    // 头插法插入元素
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 头插法插入新结点
    int nodeIndex = node.hash & sizeMask; 
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    // 赋值新数组给table
    table = newTable;
}
get()
get()是不用加锁的,通过使用UNSAFE的volatile版本的方法保证线程可见性。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19public V get(Object key) {
    Segment<K,V> s; 
    HashEntry<K,V>[] tab;
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    // 获取segment
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        // 遍历table,返回该key对应的value
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
             (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}
remove()
remove()和put()实现逻辑还是挺相似的。
先看下ConcurrentHashMap中的remove():1
2
3
4
5
6public V remove(Object key) {
    int hash = hash(key);
    // 找到对应的segment
    Segment<K,V> s = segmentForHash(hash);
    return s == null ? null : s.remove(key, hash, null);
}
接着继续看segment中的remove():1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37final V remove(Object key, int hash, Object value) {
    // 获取锁
    if (!tryLock())
        scanAndLock(key, hash);
    V oldValue = null;
    try {
        HashEntry<K,V>[] tab = table;
        int index = (tab.length - 1) & hash;
        // 获取table对应下标的列表的第一个元素
        HashEntry<K,V> e = entryAt(tab, index);
        HashEntry<K,V> pred = null;
        while (e != null) {
            K k;
            HashEntry<K,V> next = e.next;
            if ((k = e.key) == key ||
                (e.hash == hash && key.equals(k))) {
                V v = e.value;
                if (value == null || value == v || value.equals(v)) {
                    // 若将删除的是首结点,则将下一个Entry设置为首结点
                    if (pred == null)
                        setEntryAt(tab, index, next);
                    else
                        pred.setNext(next);
                    ++modCount;
                    --count;
                    oldValue = v;
                }
                break;
            }
            pred = e;
            e = next;
        }
    } finally {
        unlock();
    }
    return oldValue;
}
size()
size方法主要是计算当前hashmap中存放的元素的总个数,即累加各个segments的内部的哈希表table数组内的所有链表的所有链表节点的个数。
实现逻辑为:整个计算过程刚开始是不对segments加锁的,重复计算两次,如果前后两次hashmap都没有修改过,则直接返回计算结果,如果修改过了,则再加锁计算一次。
| 1 | public int size() { | 
