浅析JDK1.7 ConcurrentHashMap源码。
写在开篇
先贴张图看下ConcurrentHashMap JDK 1.7的结构:
先大体介绍一下:ConcurrentHashMap 是由Segment数组结构和HashEntry数组结构组成。
- Segment是一种可重入锁ReentrantLock,在ConcurrentHashMap里扮演锁的角色,HashEntry则用于存储键值对数据
- 一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,这与HashMap结构相似
简单来讲,就是ConcurrentHashMap比HashMap多了一次hash过程,第1次hash定位到Segment,第2次hash定位到HashEntry,然后链表搜索找到指定节点。
该种实现方式的缺点是hash过程比普通的HashMap要长,但是优点也很明显,在进行写操作时,只需锁住写元素所在的Segment即可,其他Segment无需加锁,提高了并发读写的效率。
Segment
先看下Segment的定义,是ConcurrentHashMap的一个静态内部类,继承了ReentrantLock。1
static final class Segment<K,V> extends ReentrantLock implements Serializable {
接着看下其中几个重要的属性:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19/**
* 一个HashEntry数组table,HashEntry是链表的节点定义
*/
transient volatile HashEntry<K,V>[] table;
/**
* Segment中Entry的数量
*/
transient int count;
/**
* 阈值
*/
transient int threshold;
/**
* 负载因子
*/
final float loadFactor;
HashEntry
1 | static final class HashEntry<K,V> { |
由定义可知,value和next均为使用volatile修饰,一个线程对该Segment内部的某个链表节点HashEntry的value或下一个节点next修改能够对其他线程可见。
ConcurrentHashMap
类的继承关系
1 | public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> |
继承于抽象的AbstractMap,实现了ConcurrentMap,Serializable这两个接口。
几个重要的属性
1 | /** |
segmentShift和segmentMask,这两个全局变量在定位segment时的哈希算法里需要使用。
几个重要的默认常量
1 | /** |
ConcurrentHashMap的构造函数
- Segment数组初始化后,不能再扩容
- 只初始化了segment[0],其他位置仍然是 null
1 | public ConcurrentHashMap(int initialCapacity, |
put()
在ConcurrentHashMap中的put操作是没有加锁的,而在Segment中的put操作,通过ReentrantLock加锁。
首先通过key的hash确定segments数组的下标,即需要往哪个segment存放数据。确定好segment之后,则调用该segment的put方法,写到该segment内部的table数组的某个链表中。
先看下ConcurrentHashMap中的put():1
2
3
4
5
6
7
8
9
10
11
12
13
14public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
// 根据key的hash,确定具体的Segment
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
// 如果segments数组的该位置还没segment,初始化
if ((s = (Segment<K,V>)UNSAFE.getObject
(segments, (j << SSHIFT) + SBASE)) == null)
s = ensureSegment(j);
// 插入新值至槽s中
return s.put(key, hash, value, false);
}
看下ensureSegment()如何初始化槽:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 使用当前 segment[0] 处的数组长度和负载因子来初始化 segment[k]
Segment<K,V> proto = ss[0];
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
// 初始化 segment[k] 内部的数组
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
// 循环CAS赋值给Segment[]后退出
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
接下来是Segment中的put(),首先获取lock锁,然后根据key的hash值,获取在segment内部的HashEntry数组table的下标,从而获取对应的链表:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53final V put(K key, int hash, V value, boolean onlyIfAbsent) {
/**
* tryLock:非阻塞获取lock
* scanAndLockForPut:该segment锁被其他线程持有了,则非阻塞重试3次,超过3次则阻塞等待锁。
*/
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
// 链表头结点
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
// table中已存在结点
if (e != null) {
K k;
// 已经存在,则更新value值
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
// 更新value时,也递增modCount,而在HashMap中是结构性修改才递增。
++modCount;
}
break;
}
e = e.next;
}
// 头插法新增结点
else {
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
// 扩容
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
我们详细看一下scanAndLockForPut()是怎么实现的:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1;
// 非阻塞自旋获取lock锁
while (!tryLock()) {
HashEntry<K,V> f;
if (retries < 0) {
if (e == null) {
if (node == null)
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
// MAX_SCAN_RETRIES为2,尝试3次后,则当前线程阻塞等待lock锁
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
// 如果链表被修改过,重置retries
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f;
retries = -1;
}
}
return node;
}
扩容
接下来看下Segment的rehash():1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
// 新容量为旧容量的2倍
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
// 创建新数组
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
// 新掩码
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
// 元素在新数组中的位置
int idx = e.hash & sizeMask;
if (next == null) // 该位置只有一个元素
newTable[idx] = e;
else {
// e是链表表头
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
// for 循环找到一个 lastRun 结点,这个结点之后的所有元素是将要放到一起的
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
// 将lastRun及其之后的所有结点组成的这个链表放到 lastIdx这个位置
newTable[lastIdx] = lastRun;
// 下面的操作是处理lastRun之前的结点,看看分配在哪个链表中
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
// 头插法插入元素
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
// 头插法插入新结点
int nodeIndex = node.hash & sizeMask;
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
// 赋值新数组给table
table = newTable;
}
get()
get()是不用加锁的,通过使用UNSAFE的volatile版本的方法保证线程可见性。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19public V get(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
// 获取segment
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
// 遍历table,返回该key对应的value
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
remove()
remove()和put()实现逻辑还是挺相似的。
先看下ConcurrentHashMap中的remove():1
2
3
4
5
6public V remove(Object key) {
int hash = hash(key);
// 找到对应的segment
Segment<K,V> s = segmentForHash(hash);
return s == null ? null : s.remove(key, hash, null);
}
接着继续看segment中的remove():1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37final V remove(Object key, int hash, Object value) {
// 获取锁
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
// 获取table对应下标的列表的第一个元素
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
// 若将删除的是首结点,则将下一个Entry设置为首结点
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}
size()
size方法主要是计算当前hashmap中存放的元素的总个数,即累加各个segments的内部的哈希表table数组内的所有链表的所有链表节点的个数。
实现逻辑为:整个计算过程刚开始是不对segments加锁的,重复计算两次,如果前后两次hashmap都没有修改过,则直接返回计算结果,如果修改过了,则再加锁计算一次。
1 | public int size() { |