CountDownLatch解析

本文代码基于Java8

前言

CountDownLatch ,英文翻译为倒计时锁存器,是一个同步辅助类,在完成一组正在其他线程中执行的操作之前,它允许一个或多个线程一直等待。也是基于 AQS,它是 AQS 的共享功能的一个实现。

它主要用来保证完成某个任务的先决条件满足,是一个同步工具类,用来协调多个线程之间的同步。这个工具通常用来控制线程等待,它可以让某一个线程等待直到倒计时结束,再开始执行。

  • 确保某个计算在其需要的所有资源都被初始化之后才继续执行;
  • 确保某个服务在其依赖的所有其他服务都已经启动之后才启动;
  • 等待直到某个操作所有参与者都准备就绪再继续执行。

CountDownLatch 有一个正数计数器,countDown() 方法对计数器做减操作,await() 方法等待计数器达到0。所有 await 的线程都会阻塞直到计数器为0或者等待线程中断或者超时。

CountDownLatch 类结构

其中SyncCountDownLatch 的内部类,Sync 继承自 AbstractQueuedSynchronizer 。使用 AQS state 表示 count 计数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 设置初始计数
Sync(int count) {
setState(count);
}
// 获取当前计数
int getCount() {
return getState();
}
// 重写 AQS 的 tryAcquireShared 方法,-1 表示 count>0,可以获取。
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// // count-1,如果 count 变为0,则唤醒所有。
for (;;) {
// 获取当前状态,为0表示未锁,不用释放。
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
// 利用 CAS 来更新 state 的状态,这里可能有并发,所以这也是用死循环更新的原因
// c为期望值,nextc为更新值。
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}

CountDownLatch 构造方法

使用给定的 count 构造 CountDownLatch,count 表示线程通过 await 前必须要执行的次数,count 不能小于0。

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

CountDownLatch 是一次性的,计数器的值只能在构造方法中初始化一次,之后没有任何机制再次对其设置值,当CountDownLatch 使用完毕后,它不能再次被使用。

CountDownLatch 线程等待方法

await() 是通过轮询 state 的状态来判断所有的任务是否都完成。

无限等待

让当前线程等待直到 count 减数为0,除非线程被中断。如果 count 为0,线程将立即返回,不再阻塞等待。
如果当前计数大于零,则出于线程调度目的,当前线程将禁用,并处于休眠状态,直到发生以下两种情况之一:

  1. countDown() 方法调用使得 count 减数为0;
  2. 当前线程被中断 (如果被中断将会抛出 InterruptedException 异常)。
1
2
3
4
public void await() throws InterruptedException {
// 参考 AQS 的 acquireSharedInterruptibly() 方法
sync.acquireSharedInterruptibly(1);
}

超时等待

使当前线程处理等待状态直到 count 减为0或者等待超时。如果当前count是0,则线程立即返回true。
如果当前计数大于零,则出于线程调度目的,当前线程将禁用,并处于休眠状态,直到发生以下三种情况之一:

  1. countDown() 方法调用使得 count 减数为0;
  2. 当前线程被中断 (如果被中断将会抛出 InterruptedException 异常);
  3. 等待超时。

如果等待超时但是 count>0,则返回 false。如果超时时间小于或等于零,方法将不会等待。

1
2
3
4
5
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
// 参考 AQS 的 tryAcquireSharedNanos() 方法
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

CountDownLatch 其他方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// 倒计时,将会将count-1
public void countDown() {
// 参考 AQS 的 releaseShared() 方法
sync.releaseShared(1);
}
// AQS 的 releaseShared() 方法
public final boolean releaseShared(int arg) {
// 参考 Sync 重写的 tryReleaseShared(int releases) 方法
if (tryReleaseShared(arg)) {
// 唤醒主线程,因为如果 state 不等于0的话,主线程一直是阻塞的。
doReleaseShared();
return true;
}
return false;
}

// 获取当前计数
public long getCount() {
return sync.getCount();
}

// 返回标识锁及其状态的字符串
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}

AQS 的 doReleaseShared() 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) { // 至少有两个节点
int ws = h.waitStatus;
if (ws == Node.SIGNAL) { // 后继节点需要唤醒
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h); // 唤醒后继节点
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

使用示例

官方示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Driver { // ...
void main() throws InterruptedException {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(N);
for (int i = 0; i < N; ++i) // create and start threads
new Thread(new Worker(startSignal, doneSignal)).start();
doSomethingElse(); // don't let run yet
// 所有worker线程继续执行
startSignal.countDown(); // let all threads proceed
doSomethingElse();
// 允许driver等待直到所有的worker都完成
doneSignal.await(); // wait for all to finish
}
}
class Worker implements Runnable {
private final CountDownLatch startSignal;
private final CountDownLatch doneSignal;
Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
this.startSignal = startSignal;
this.doneSignal = doneSignal;
}
public void run() {
try {
startSignal.await(); // 开始信号阻止任何worker直到driver准备好
doWork();
doneSignal.countDown(); // 完成信号,计数减一
} catch (InterruptedException ex) {} // return;
}
void doWork() { ... }
}

火箭发射示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class CountDownLatchDemo implements Runnable{
static final CountDownLatch latch = new CountDownLatch(10);
static final CountDownLatchDemo demo = new CountDownLatchDemo();

@Override
public void run() {
// 模拟检查任务
try {
Thread.sleep(new Random().nextInt(10) * 1000);
System.out.println("check complete");
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
//计数减一
//放在finally避免任务执行过程出现异常,导致countDown()不能被执行
latch.countDown();
}
}

public static void main(String[] args) throws InterruptedException {
ExecutorService exec = Executors.newFixedThreadPool(10);
for (int i=0; i<10; i++){
exec.submit(demo);
}

// 等待检查
latch.await();

// 发射火箭
System.out.println("Fire!");
// 关闭线程池
exec.shutdown();
}
}

总结

CountDownLatch 主要用来保证完成某个任务的先决条件满足,是一个同步工具类,用来协调多个线程之间的同步。这个工具通常用来控制线程等待,它可以让某一个线程等待直到倒计时结束,再开始执行。