前几天学习了AQS源码为了加深印象今天来基于AQS自己实现一个锁
之前我们学习了AQS的源码,了解到了自定义AQS需要实现重写一系列函数,还需要定义原子变量state的含义。
下文我们自己实现一个锁,定义state为0表示锁没有被线程持有,state为1表示锁已经被某一个线程持有,由于是不可重入锁,所以不需要记录持有锁的线程获取锁的次数,另外,我们自定义的锁支持条件变量,因为我们要实现生产者——消费者模型
class NonReentrantLock implements Lock, Serializable { //实现AQS private static class Sync extends AbstractQueuedSynchronizer { //是否锁被占有 @Override protected boolean isHeldExclusively() { return getState() == 1; } //如果state为0,则尝试获取锁 @Override protected boolean tryAcquire(int acquires) { assert acquires == 1; if (compareAndSetState(0, 1)) { //CAS成功则将当前线程设置获取到锁 setExclusiveOwnerThread(Thread.currentThread()); return true; } return false; } //尝试释放锁,将state改为1 @Override protected boolean tryRelease(int releases) { assert releases == 1; if (getState() == 0) throw new IllegalMonitorStateException(); setExclusiveOwnerThread(null); setState(0); return true; } //提供条件变量接口 Condition newCondition() { return new ConditionObject(); } } //创建一个Sync来做具体的工作 private final Sync sync = new Sync(); @Override public void lock() { sync.acquire(1); } @Override public boolean tryLock() { return sync.tryAcquire(1); } @Override public void unlock() { sync.release(1); } @Override public Condition newCondition() { return sync.newCondition(); } public boolean isLocked() { return sync.isHeldExclusively(); } @Override public void lockInterruptibly() throws InterruptedException { sync.acquireInterruptibly(1); } @Override public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { return sync.tryAcquireNanos(1, unit.toNanos(time)); } }
在如上代码中,NonReentrantLock定义了一个内部类Sync用来实现具体的锁的操作,Sync则继承了AQS,由于我们实现的是独占模式的锁,所以Sync重写了tryAcquire、tryRelease和isHeldExclusively三个方法,另外Sync提供了newCondition这个方法用来支持条件变量
public class AQSDemo { final static NonReentrantLock lock = new NonReentrantLock(); final static Condition notFull = lock.newCondition(); final static Condition notEmpty = lock.newCondition(); final static Queue<String> queue = new LinkedBlockingQueue<>(); final static int queueSize = 10; public static void main(String[] args) { Thread producer = new Thread(() -> { lock.lock(); try { //如果队列满了则等待 while (queue.size() == queueSize) { notEmpty.await(); } //添加元素到队列 queue.add("ele"); //唤醒消费者线程 notFull.signalAll(); } catch (Exception e) { e.printStackTrace(); } finally { lock.unlock(); } }); Thread consumer = new Thread(() -> { lock.lock(); try { //如果队列满了则等待 while (queue.size() == 0) { notFull.await(); } //消费队列 queue.poll(); //唤醒生产线程 notEmpty.signalAll(); } catch (Exception e) { e.printStackTrace(); } finally { lock.unlock(); } }); producer.start(); consumer.start(); } }
如上代码首先创建了一个NonReentrantLock的一个对象Lock,然后调用lock.newCondition创建了两个条件变量,用来进行生产者和消费者线程之间的同步。
在main函数里面,首先创建了生产者线程,在线程内部先调用lock.lock()获取独占锁,然后判断当前队列是否已经满了,如果满了掉用notEmpty.await()阻塞挂起当前线程。需要注意的是,这里使用while而不是if是为了避免虚假唤醒,如果队列不满则直接向队列里面添加元素,然后调用notFull.signalAll()唤醒所有因为消费元素而被阻塞的消费线程,最后释放获取的锁。