你好,我是朱晔。作为课程的第一讲,我今天要和你聊聊使用并发工具类库相关的话题。

在代码审核讨论的时候,我们有时会听到有关线程安全和并发工具的一些片面的观点和结论,比如“把HashMap改为ConcurrentHashMap,就可以解决并发问题了呀”“要不我们试试无锁的CopyOnWriteArrayList吧,性能更好”。事实上,这些说法都不太准确。

的确,为了方便开发者进行多线程编程,现代编程语言会提供各种并发工具类。但如果我们没有充分了解它们的使用场景、解决的问题,以及最佳实践的话,盲目使用就可能会导致一些坑,小则损失性能,大则无法确保多线程情况下业务逻辑的正确性。

我需要先说明下,这里的并发工具类是指用来解决多线程环境下并发问题的工具类库。一般而言并发工具包括同步器和容器两大类,业务代码中使用并发容器的情况会多一些,我今天分享的例子也会侧重并发容器。

接下来,我们就看看在使用并发工具时,最常遇到哪些坑,以及如何解决、避免这些坑吧。

没有意识到线程重用导致用户信息错乱的Bug

之前有业务同学和我反馈,在生产上遇到一个诡异的问题,有时获取到的用户信息是别人的。查看代码后,我发现他使用了ThreadLocal来缓存获取到的用户信息。

我们知道,ThreadLocal适用于变量在线程间隔离,而在方法或类间共享的场景。如果用户信息的获取比较昂贵(比如从数据库查询用户信息),那么在ThreadLocal中缓存数据是比较合适的做法。但,这么做为什么会出现用户信息错乱的Bug呢?

我们看一个具体的案例吧。

使用Spring Boot创建一个Web应用程序,使用ThreadLocal存放一个Integer的值,来暂且代表需要在线程中保存的用户信息,这个值初始是null。在业务逻辑中,我先从ThreadLocal获取一次值,然后把外部传入的参数设置到ThreadLocal中,来模拟从当前上下文获取到用户信息的逻辑,随后再获取一次值,最后输出两次获得的值和线程名称。

private static final ThreadLocal<Integer> currentUser = ThreadLocal.withInitial(() -> null);


@GetMapping("wrong")
public Map wrong(@RequestParam("userId") Integer userId) {
    //设置用户信息之前先查询一次ThreadLocal中的用户信息
    String before  = Thread.currentThread().getName() + ":" + currentUser.get();
    //设置用户信息到ThreadLocal
    currentUser.set(userId);
    //设置用户信息之后再查询一次ThreadLocal中的用户信息
    String after  = Thread.currentThread().getName() + ":" + currentUser.get();
    //汇总输出两次查询结果
    Map result = new HashMap();
    result.put("before", before);
    result.put("after", after);
    return result;
}

按理说,在设置用户信息之前第一次获取的值始终应该是null,但我们要意识到,程序运行在Tomcat中,执行程序的线程是Tomcat的工作线程,而Tomcat的工作线程是基于线程池的。

顾名思义,线程池会重用固定的几个线程,一旦线程重用,那么很可能首次从ThreadLocal获取的值是之前其他用户的请求遗留的值。这时,ThreadLocal中的用户信息就是其他用户的信息。

为了更快地重现这个问题,我在配置文件中设置一下Tomcat的参数,把工作线程池最大线程数设置为1,这样始终是同一个线程在处理请求:

server.tomcat.max-threads=1

运行程序后先让用户1来请求接口,可以看到第一和第二次获取到用户ID分别是null和1,符合预期:

随后用户2来请求接口,这次就出现了Bug,第一和第二次获取到用户ID分别是1和2,显然第一次获取到了用户1的信息,原因就是Tomcat的线程池重用了线程。从图中可以看到,两次请求的线程都是同一个线程:http-nio-8080-exec-1。

这个例子告诉我们,在写业务代码时,首先要理解代码会跑在什么线程上:

理解了这个知识点后,我们修正这段代码的方案是,在代码的finally代码块中,显式清除ThreadLocal中的数据。这样一来,新的请求过来即使使用了之前的线程也不会获取到错误的用户信息了。修正后的代码如下:

@GetMapping("right")
public Map right(@RequestParam("userId") Integer userId) {
    String before  = Thread.currentThread().getName() + ":" + currentUser.get();
    currentUser.set(userId);
    try {
        String after = Thread.currentThread().getName() + ":" + currentUser.get();
        Map result = new HashMap();
        result.put("before", before);
        result.put("after", after);
        return result;
    } finally {
        //在finally代码块中删除ThreadLocal中的数据,确保数据不串
        currentUser.remove();
    }
}

重新运行程序可以验证,再也不会出现第一次查询用户信息查询到之前用户请求的Bug:

ThreadLocal是利用独占资源的方式,来解决线程安全问题,那如果我们确实需要有资源在线程之间共享,应该怎么办呢?这时,我们可能就需要用到线程安全的容器了。

使用了线程安全的并发工具,并不代表解决了所有线程安全问题

JDK 1.5后推出的ConcurrentHashMap,是一个高性能的线程安全的哈希表容器。“线程安全”这四个字特别容易让人误解,因为ConcurrentHashMap只能保证提供的原子性读写操作是线程安全的。

我在相当多的业务代码中看到过这个误区,比如下面这个场景。有一个含900个元素的Map,现在再补充100个元素进去,这个补充操作由10个线程并发进行。开发人员误以为使用了ConcurrentHashMap就不会有线程安全问题,于是不加思索地写出了下面的代码:在每一个线程的代码逻辑中先通过size方法拿到当前元素数量,计算ConcurrentHashMap目前还需要补充多少元素,并在日志中输出了这个值,然后通过putAll方法把缺少的元素添加进去。

为方便观察问题,我们输出了这个Map一开始和最后的元素个数。

//线程个数
private static int THREAD_COUNT = 10;
//总元素数量
private static int ITEM_COUNT = 1000;

//帮助方法,用来获得一个指定元素数量模拟数据的ConcurrentHashMap
private ConcurrentHashMap<String, Long> getData(int count) {
    return LongStream.rangeClosed(1, count)
            .boxed()
            .collect(Collectors.toConcurrentMap(i -> UUID.randomUUID().toString(), Function.identity(),
                    (o1, o2) -> o1, ConcurrentHashMap::new));
}

@GetMapping("wrong")
public String wrong() throws InterruptedException {
    ConcurrentHashMap<String, Long> concurrentHashMap = getData(ITEM_COUNT - 100);
    //初始900个元素
    log.info("init size:{}", concurrentHashMap.size());

    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);
    //使用线程池并发处理逻辑
    forkJoinPool.execute(() -> IntStream.rangeClosed(1, 10).parallel().forEach(i -> {
        //查询还需要补充多少个元素
        int gap = ITEM_COUNT - concurrentHashMap.size();
        log.info("gap size:{}", gap);
        //补充元素
        concurrentHashMap.putAll(getData(gap));
    }));
    //等待所有任务完成
    forkJoinPool.shutdown();
    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);
    //最后元素个数会是1000吗?
    log.info("finish size:{}", concurrentHashMap.size());
    return "OK";
}

访问接口后程序输出的日志内容如下:

从日志中可以看到:

针对这个场景,我们可以举一个形象的例子。ConcurrentHashMap就像是一个大篮子,现在这个篮子里有900个桔子,我们期望把这个篮子装满1000个桔子,也就是再装100个桔子。有10个工人来干这件事儿,大家先后到岗后会计算还需要补多少个桔子进去,最后把桔子装入篮子。

ConcurrentHashMap这个篮子本身,可以确保多个工人在装东西进去时,不会相互影响干扰,但无法确保工人A看到还需要装100个桔子但是还未装的时候,工人B就看不到篮子中的桔子数量。更值得注意的是,你往这个篮子装100个桔子的操作不是原子性的,在别人看来可能会有一个瞬间篮子里有964个桔子,还需要补36个桔子。

回到ConcurrentHashMap,我们需要注意ConcurrentHashMap对外提供的方法或能力的限制

代码的修改方案很简单,整段逻辑加锁即可:

@GetMapping("right")
public String right() throws InterruptedException {
    ConcurrentHashMap<String, Long> concurrentHashMap = getData(ITEM_COUNT - 100);
    log.info("init size:{}", concurrentHashMap.size());


    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);
    forkJoinPool.execute(() -> IntStream.rangeClosed(1, 10).parallel().forEach(i -> {
        //下面的这段复合逻辑需要锁一下这个ConcurrentHashMap
        synchronized (concurrentHashMap) {
            int gap = ITEM_COUNT - concurrentHashMap.size();
            log.info("gap size:{}", gap);
            concurrentHashMap.putAll(getData(gap));
        }
    }));
    forkJoinPool.shutdown();
    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);


    log.info("finish size:{}", concurrentHashMap.size());
    return "OK";
}

重新调用接口,程序的日志输出结果符合预期:

可以看到,只有一个线程查询到了需要补100个元素,其他9个线程查询到不需要补元素,最后Map大小为1000。

到了这里,你可能又要问了,使用ConcurrentHashMap全程加锁,还不如使用普通的HashMap呢。

其实不完全是这样。

ConcurrentHashMap提供了一些原子性的简单复合逻辑方法,用好这些方法就可以发挥其威力。这就引申出代码中常见的另一个问题:在使用一些类库提供的高级工具类时,开发人员可能还是按照旧的方式去使用这些新类,因为没有使用其特性,所以无法发挥其威力。

没有充分了解并发工具的特性,从而无法发挥其威力

我们来看一个使用Map来统计Key出现次数的场景吧,这个逻辑在业务代码中非常常见。

代码如下:

//循环次数
private static int LOOP_COUNT = 10000000;
//线程数量
private static int THREAD_COUNT = 10;
//元素数量
private static int ITEM_COUNT = 10;
private Map<String, Long> normaluse() throws InterruptedException {
    ConcurrentHashMap<String, Long> freqs = new ConcurrentHashMap<>(ITEM_COUNT);
    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);
    forkJoinPool.execute(() -> IntStream.rangeClosed(1, LOOP_COUNT).parallel().forEach(i -> {
        //获得一个随机的Key
        String key = "item" + ThreadLocalRandom.current().nextInt(ITEM_COUNT);
                synchronized (freqs) {      
                    if (freqs.containsKey(key)) {
                        //Key存在则+1
                        freqs.put(key, freqs.get(key) + 1);
                    } else {
                        //Key不存在则初始化为1
                        freqs.put(key, 1L);
                    }
                }
            }
    ));
    forkJoinPool.shutdown();
    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);
    return freqs;
}

我们吸取之前的教训,直接通过锁的方式锁住Map,然后做判断、读取现在的累计值、加1、保存累加后值的逻辑。这段代码在功能上没有问题,但无法充分发挥ConcurrentHashMap的威力,改进后的代码如下:

private Map<String, Long> gooduse() throws InterruptedException {
    ConcurrentHashMap<String, LongAdder> freqs = new ConcurrentHashMap<>(ITEM_COUNT);
    ForkJoinPool forkJoinPool = new ForkJoinPool(THREAD_COUNT);
    forkJoinPool.execute(() -> IntStream.rangeClosed(1, LOOP_COUNT).parallel().forEach(i -> {
        String key = "item" + ThreadLocalRandom.current().nextInt(ITEM_COUNT);
                //利用computeIfAbsent()方法来实例化LongAdder,然后利用LongAdder来进行线程安全计数
                freqs.computeIfAbsent(key, k -> new LongAdder()).increment();
            }
    ));
    forkJoinPool.shutdown();
    forkJoinPool.awaitTermination(1, TimeUnit.HOURS);
    //因为我们的Value是LongAdder而不是Long,所以需要做一次转换才能返回
    return freqs.entrySet().stream()
            .collect(Collectors.toMap(
                    e -> e.getKey(),
                    e -> e.getValue().longValue())
            );
}

在这段改进后的代码中,我们巧妙利用了下面两点:

这样在确保线程安全的情况下达到极致性能,把之前7行代码替换为了1行。

我们通过一个简单的测试比较一下修改前后两段代码的性能:

@GetMapping("good")
public String good() throws InterruptedException {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start("normaluse");
    Map<String, Long> normaluse = normaluse();
    stopWatch.stop();
    //校验元素数量
    Assert.isTrue(normaluse.size() == ITEM_COUNT, "normaluse size error");
    //校验累计总数    
    Assert.isTrue(normaluse.entrySet().stream()
                    .mapToLong(item -> item.getValue()).reduce(0, Long::sum) == LOOP_COUNT
            , "normaluse count error");
    stopWatch.start("gooduse");
    Map<String, Long> gooduse = gooduse();
    stopWatch.stop();
    Assert.isTrue(gooduse.size() == ITEM_COUNT, "gooduse size error");
    Assert.isTrue(gooduse.entrySet().stream()
                    .mapToLong(item -> item.getValue())
                    .reduce(0, Long::sum) == LOOP_COUNT
            , "gooduse count error");
    log.info(stopWatch.prettyPrint());
    return "OK";
}

这段测试代码并无特殊之处,使用StopWatch来测试两段代码的性能,最后跟了一个断言判断Map中元素的个数以及所有Value的和,是否符合预期来校验代码的正确性。测试结果如下:

可以看到,优化后的代码,相比使用锁来操作ConcurrentHashMap的方式,性能提升了10倍

你可能会问,computeIfAbsent为什么如此高效呢?

答案就在源码最核心的部分,也就是Java自带的Unsafe实现的CAS。它在虚拟机层面确保了写入数据的原子性,比加锁的效率高得多:

    static final <K,V> boolean casTabAt(Node<K,V>[] tab, int i,
                                        Node<K,V> c, Node<K,V> v) {
        return U.compareAndSetObject(tab, ((long)i << ASHIFT) + ABASE, c, v);
    }

像ConcurrentHashMap这样的高级并发工具的确提供了一些高级API,只有充分了解其特性才能最大化其威力,而不能因为其足够高级、酷炫盲目使用。

没有认清并发工具的使用场景,因而导致性能问题

除了ConcurrentHashMap这样通用的并发工具类之外,我们的工具包中还有些针对特殊场景实现的生面孔。一般来说,针对通用场景的通用解决方案,在所有场景下性能都还可以,属于“万金油”;而针对特殊场景的特殊实现,会有比通用解决方案更高的性能,但一定要在它针对的场景下使用,否则可能会产生性能问题甚至是Bug。

之前在排查一个生产性能问题时,我们发现一段简单的非数据库操作的业务逻辑,消耗了超出预期的时间,在修改数据时操作本地缓存比回写数据库慢许多。查看代码发现,开发同学使用了CopyOnWriteArrayList来缓存大量的数据,而数据变化又比较频繁。

CopyOnWrite是一个时髦的技术,不管是Linux还是Redis都会用到。在Java中,CopyOnWriteArrayList虽然是一个线程安全的ArrayList,但因为其实现方式是,每次修改数据时都会复制一份数据出来,所以有明显的适用场景,即读多写少或者说希望无锁读的场景。

如果我们要使用CopyOnWriteArrayList,那一定是因为场景需要而不是因为足够酷炫。如果读写比例均衡或者有大量写操作的话,使用CopyOnWriteArrayList的性能会非常糟糕。

我们写一段测试代码,来比较下使用CopyOnWriteArrayList和普通加锁方式ArrayList的读写性能吧。在这段代码中我们针对并发读和并发写分别写了一个测试方法,测试两者一定次数的写或读操作的耗时。

//测试并发写的性能
@GetMapping("write")
public Map testWrite() {
    List<Integer> copyOnWriteArrayList = new CopyOnWriteArrayList<>();
    List<Integer> synchronizedList = Collections.synchronizedList(new ArrayList<>());
    StopWatch stopWatch = new StopWatch();
    int loopCount = 100000;
    stopWatch.start("Write:copyOnWriteArrayList");
    //循环100000次并发往CopyOnWriteArrayList写入随机元素
    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> copyOnWriteArrayList.add(ThreadLocalRandom.current().nextInt(loopCount)));
    stopWatch.stop();
    stopWatch.start("Write:synchronizedList");
    //循环100000次并发往加锁的ArrayList写入随机元素
    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> synchronizedList.add(ThreadLocalRandom.current().nextInt(loopCount)));
    stopWatch.stop();
    log.info(stopWatch.prettyPrint());
    Map result = new HashMap();
    result.put("copyOnWriteArrayList", copyOnWriteArrayList.size());
    result.put("synchronizedList", synchronizedList.size());
    return result;
}

//帮助方法用来填充List
private void addAll(List<Integer> list) {
    list.addAll(IntStream.rangeClosed(1, 1000000).boxed().collect(Collectors.toList()));
}

//测试并发读的性能
@GetMapping("read")
public Map testRead() {
    //创建两个测试对象
    List<Integer> copyOnWriteArrayList = new CopyOnWriteArrayList<>();
    List<Integer> synchronizedList = Collections.synchronizedList(new ArrayList<>());
    //填充数据   
    addAll(copyOnWriteArrayList);
    addAll(synchronizedList);
    StopWatch stopWatch = new StopWatch();
    int loopCount = 1000000;
    int count = copyOnWriteArrayList.size();
    stopWatch.start("Read:copyOnWriteArrayList");
    //循环1000000次并发从CopyOnWriteArrayList随机查询元素
    IntStream.rangeClosed(1, loopCount).parallel().forEach(__ -> copyOnWriteArrayList.get(ThreadLocalRandom.current().nextInt(count)));
    stopWatch.stop();
    stopWatch.start("Read:synchronizedList");
    //循环1000000次并发从加锁的ArrayList随机查询元素
    IntStream.range(0, loopCount).parallel().forEach(__ -> synchronizedList.get(ThreadLocalRandom.current().nextInt(count)));
    stopWatch.stop();
    log.info(stopWatch.prettyPrint());
    Map result = new HashMap();
    result.put("copyOnWriteArrayList", copyOnWriteArrayList.size());
    result.put("synchronizedList", synchronizedList.size());
    return result;
}

运行程序可以看到,大量写的场景(10万次add操作),CopyOnWriteArray几乎比同步的ArrayList慢一百倍

而在大量读的场景下(100万次get操作),CopyOnWriteArray又比同步的ArrayList快五倍以上:

你可能会问,为何在大量写的场景下,CopyOnWriteArrayList会这么慢呢?

答案就在源码中。以add方法为例,每次add时,都会用Arrays.copyOf创建一个新数组,频繁add时内存的申请释放消耗会很大:

    /**
     * Appends the specified element to the end of this list.
     *
     * @param e element to be appended to this list
     * @return {@code true} (as specified by {@link Collection#add})
     */
    public boolean add(E e) {
        synchronized (lock) {
            Object[] elements = getArray();
            int len = elements.length;
            Object[] newElements = Arrays.copyOf(elements, len + 1);
            newElements[len] = e;
            setArray(newElements);
            return true;
        }
    }

重点回顾

今天,我主要与你分享了,开发人员使用并发工具来解决线程安全问题时容易犯的四类错。

一是,只知道使用并发工具,但并不清楚当前线程的来龙去脉,解决多线程问题却不了解线程。比如,使用ThreadLocal来缓存数据,以为ThreadLocal在线程之间做了隔离不会有线程安全问题,没想到线程重用导致数据串了。请务必记得,在业务逻辑结束之前清理ThreadLocal中的数据。

二是,误以为使用了并发工具就可以解决一切线程安全问题,期望通过把线程不安全的类替换为线程安全的类来一键解决问题。比如,认为使用了ConcurrentHashMap就可以解决线程安全问题,没对复合逻辑加锁导致业务逻辑错误。如果你希望在一整段业务逻辑中,对容器的操作都保持整体一致性的话,需要加锁处理。

三是,没有充分了解并发工具的特性,还是按照老方式使用新工具导致无法发挥其性能。比如,使用了ConcurrentHashMap,但没有充分利用其提供的基于CAS安全的方法,还是使用锁的方式来实现逻辑。你可以阅读一下ConcurrentHashMap的文档,看一下相关原子性操作API是否可以满足业务需求,如果可以则优先考虑使用。

四是,没有了解清楚工具的适用场景,在不合适的场景下使用了错误的工具导致性能更差。比如,没有理解CopyOnWriteArrayList的适用场景,把它用在了读写均衡或者大量写操作的场景下,导致性能问题。对于这种场景,你可以考虑是用普通的List。

其实,这四类坑之所以容易踩到,原因可以归结为,我们在使用并发工具的时候,并没有充分理解其可能存在的问题、适用场景等。所以最后,我还要和你分享两点建议

  1. 一定要认真阅读官方文档(比如Oracle JDK文档)。充分阅读官方文档,理解工具的适用场景及其API的用法,并做一些小实验。了解之后再去使用,就可以避免大部分坑。
  2. 如果你的代码运行在多线程环境下,那么就会有并发问题,并发问题不那么容易重现,可能需要使用压力测试模拟并发场景,来发现其中的Bug或性能问题。

今天用到的代码,我都放在了GitHub上,你可以点击这个链接查看。

思考与讨论

  1. 今天我们多次用到了ThreadLocalRandom,你觉得是否可以把它的实例设置到静态变量中,在多线程情况下重用呢?
  2. ConcurrentHashMap还提供了putIfAbsent方法,你能否通过查阅JDK文档,说说computeIfAbsent和putIfAbsent方法的区别?

你在使用并发工具时,还遇到过其他坑吗?我是朱晔,欢迎在评论区与我留言分享你的想法,也欢迎你把这篇文章分享给你的朋友或同事,一起交流。

评论