浅析JDK1.7 ConcurrentHashMap源码。

写在开篇

先贴张图看下ConcurrentHashMap JDK 1.7的结构:

CHM1.7.01
CHM1.7.01

先大体介绍一下:ConcurrentHashMap 是由Segment数组结构和HashEntry数组结构组成。

  • Segment是一种可重入锁ReentrantLock,在ConcurrentHashMap里扮演锁的角色,HashEntry则用于存储键值对数据
  • 一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,这与HashMap结构相似

简单来讲,就是ConcurrentHashMap比HashMap多了一次hash过程,第1次hash定位到Segment,第2次hash定位到HashEntry,然后链表搜索找到指定节点。

该种实现方式的缺点是hash过程比普通的HashMap要长,但是优点也很明显,在进行写操作时,只需锁住写元素所在的Segment即可,其他Segment无需加锁,提高了并发读写的效率。

Segment

先看下Segment的定义,是ConcurrentHashMap的一个静态内部类,继承了ReentrantLock。

1
static final class Segment<K,V> extends ReentrantLock implements Serializable {

接着看下其中几个重要的属性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
* 一个HashEntry数组table,HashEntry是链表的节点定义
*/
transient volatile HashEntry<K,V>[] table;

/**
* Segment中Entry的数量
*/
transient int count;

/**
* 阈值
*/
transient int threshold;

/**
* 负载因子
*/
final float loadFactor;

HashEntry

1
2
3
4
5
6
7
8
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;

// ...
}

由定义可知,value和next均为使用volatile修饰,一个线程对该Segment内部的某个链表节点HashEntry的value或下一个节点next修改能够对其他线程可见。

ConcurrentHashMap

类的继承关系

1
2
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V>
implements ConcurrentMap<K, V>, Serializable {

继承于抽象的AbstractMap,实现了ConcurrentMap,Serializable这两个接口。

几个重要的属性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**
* 段掩码
*/
final int segmentMask;

/**
* 段偏移量
*/
final int segmentShift;

/**
* ConcurrentHashMap中的桶
*/
final Segment<K,V>[] segments;

segmentShift和segmentMask,这两个全局变量在定位segment时的哈希算法里需要使用。

几个重要的默认常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/**
* 指的是整个ConcurrentHashMap默认的初始容量
*/
static final int DEFAULT_INITIAL_CAPACITY = 16;

/**
* ConcurrentHashMap容量的最大值
*/
static final int MAXIMUM_CAPACITY = 1 << 30;

/**
* Segment 数组不可以扩容,所以这个负载因子是给每个 Segment 内部使用的。
*/
static final float DEFAULT_LOAD_FACTOR = 0.75f;

/**
* 默认的并发数,即segments数组的大小,ConcurrentHashMap会使用大于等于该值的最小2幂指数作为实际并发度
*/
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

/**
* segment中table的最小容量
*/
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

/**
* segment数组最大的大小
*/
static final int MAX_SEGMENTS = 1 << 16;

ConcurrentHashMap的构造函数

  • Segment数组初始化后,不能再扩容
  • 只初始化了segment[0],其他位置仍然是 null
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
int sshift = 0;
int ssize = 1;
/**
* ssize:segments数组的大小
* 不能小于concurrencyLevel,默认为16
*/
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
// segmentShift和segmentMask将用于取segment的下标
/**
* 比如concurrencyLevel为17,那么ssize为32,即2^5;sshift为5
* segmentShift即27了,后面在取segment下标的时候,会无符号左移27位,也就是取高5位的时候,就是0-31,此时segment下标也是0-31,取模后对应着每个segment
*
* segmentMask就是2的n次方-1,这里是5,用于取模
*/
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
/**
* cap:Segment内部HashEntry数组的大小
* 最小为MIN_SEGMENT_TABLE_CAPACITY,默认为2
* 实际大小根据initialCapacity/ssize得到
* 即整体容量大小除以Segment数组的数量
* 得到每个Segment内部的table的大小
*/
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// 创建segments和segments[0]
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
// 往数组写入 segment[0]
UNSAFE.putOrderedObject(ss, SBASE, s0);
this.segments = ss;
}

put()

在ConcurrentHashMap中的put操作是没有加锁的,而在Segment中的put操作,通过ReentrantLock加锁。

首先通过key的hash确定segments数组的下标,即需要往哪个segment存放数据。确定好segment之后,则调用该segment的put方法,写到该segment内部的table数组的某个链表中。

先看下ConcurrentHashMap中的put():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
// 根据key的hash,确定具体的Segment
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
// 如果segments数组的该位置还没segment,初始化
if ((s = (Segment<K,V>)UNSAFE.getObject
(segments, (j << SSHIFT) + SBASE)) == null)
s = ensureSegment(j);
// 插入新值至槽s中
return s.put(key, hash, value, false);
}

看下ensureSegment()如何初始化槽:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 使用当前 segment[0] 处的数组长度和负载因子来初始化 segment[k]
Segment<K,V> proto = ss[0];
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
// 初始化 segment[k] 内部的数组
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
// 循环CAS赋值给Segment[]后退出
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}

接下来是Segment中的put(),首先获取lock锁,然后根据key的hash值,获取在segment内部的HashEntry数组table的下标,从而获取对应的链表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
/**
* tryLock:非阻塞获取lock
* scanAndLockForPut:该segment锁被其他线程持有了,则非阻塞重试3次,超过3次则阻塞等待锁。
*/
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
// 链表头结点
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
// table中已存在结点
if (e != null) {
K k;
// 已经存在,则更新value值
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
// 更新value时,也递增modCount,而在HashMap中是结构性修改才递增。
++modCount;
}
break;
}
e = e.next;
}
// 头插法新增结点
else {
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
// 扩容
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}

我们详细看一下scanAndLockForPut()是怎么实现的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1;
// 非阻塞自旋获取lock锁
while (!tryLock()) {
HashEntry<K,V> f;
if (retries < 0) {
if (e == null) {
if (node == null)
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
// MAX_SCAN_RETRIES为2,尝试3次后,则当前线程阻塞等待lock锁
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
// 如果链表被修改过,重置retries
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f;
retries = -1;
}
}
return node;
}

扩容

接下来看下Segment的rehash():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
// 新容量为旧容量的2倍
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
// 创建新数组
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
// 新掩码
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
// 元素在新数组中的位置
int idx = e.hash & sizeMask;
if (next == null) // 该位置只有一个元素
newTable[idx] = e;
else {
// e是链表表头
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
// for 循环找到一个 lastRun 结点,这个结点之后的所有元素是将要放到一起的
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
// 将lastRun及其之后的所有结点组成的这个链表放到 lastIdx这个位置
newTable[lastIdx] = lastRun;
// 下面的操作是处理lastRun之前的结点,看看分配在哪个链表中
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
// 头插法插入元素
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
// 头插法插入新结点
int nodeIndex = node.hash & sizeMask;
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
// 赋值新数组给table
table = newTable;
}

get()

get()是不用加锁的,通过使用UNSAFE的volatile版本的方法保证线程可见性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public V get(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
// 获取segment
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
// 遍历table,返回该key对应的value
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}

remove()

remove()和put()实现逻辑还是挺相似的。

先看下ConcurrentHashMap中的remove():

1
2
3
4
5
6
public V remove(Object key) {
int hash = hash(key);
// 找到对应的segment
Segment<K,V> s = segmentForHash(hash);
return s == null ? null : s.remove(key, hash, null);
}

接着继续看segment中的remove():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
final V remove(Object key, int hash, Object value) {
// 获取锁
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
// 获取table对应下标的列表的第一个元素
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
// 若将删除的是首结点,则将下一个Entry设置为首结点
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}

size()

size方法主要是计算当前hashmap中存放的元素的总个数,即累加各个segments的内部的哈希表table数组内的所有链表的所有链表节点的个数。

实现逻辑为:整个计算过程刚开始是不对segments加锁的,重复计算两次,如果前后两次hashmap都没有修改过,则直接返回计算结果,如果修改过了,则再加锁计算一次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
public int size() {
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
// 累加modCounts
long sum;
// 记录前一次累加的modCounts
long last = 0L;
// 尝试的次数
int retries = -1;
try {
for (;;) {
/**
* RETRIES_BEFORE_LOCK值为2
* retries++ == RETRIES_BEFORE_LOCK,表示已经是第三次了,故需要加锁
*/
if (retries++ == RETRIES_BEFORE_LOCK) {
// 每个segment都加锁,此时不能执行写操作了
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock(); // force creation
}
// sum重置为0
sum = 0L;
size = 0;
overflow = false;
// 遍历每个segment
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
// 累加各个segment的modCount,以便与上一次的modCount进行比较
sum += seg.modCount;
int c = seg.count;
// size+=c 计算ConcurrentHashMap中size的数量
if (c < 0 || (size += c) < 0)
overflow = true;
}
}
// 如果前后两次都相等,说明在这期间没有写的操作,可以直接返回
if (sum == last)
break;
last = sum;
}
} finally {
if (retries > RETRIES_BEFORE_LOCK) {
// 释放锁
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
return overflow ? Integer.MAX_VALUE : size;
}