有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

java如何使用groupBy创建一个值为BigDecimal字段平均值的映射?

我有以下课程:

final public class Person {

 private final String name;
 private final String state;
 private final BigDecimal salary;

 public Person(String name, String state, BigDecimal salary) {
    this.name = name;
    this.state = state;
    this.salary = salary;
 }

 //getters omitted for brevity...
}

我想创建一个地图,列出各州的平均工资。如何使用Java8流来实现这一点?我试图在groupBy上使用下游收集器,但未能以优雅的方式做到这一点

我做了以下工作,但看起来很可怕:

Stream.of(p1,p2,p3,p4).collect(groupingBy(Person::getState, mapping(d -> d.getSalary(), toList())))
.forEach((state,wageList) -> {
        System.out.print(state+"-> ");
        final BigDecimal[] wagesArray = wageList.stream()
                .map(bd -> new BigDecimal[]{bd, BigDecimal.ONE})
                .reduce((a, b) -> new BigDecimal[]{a[0].add(b[0]), a[1].add(BigDecimal.ONE)})
                .get();
        System.out.println(wagesArray[0].divide(wagesArray[1])
                                        .setScale(2, RoundingMode.CEILING));
    });

有更好的办法吗


共 (4) 个答案

  1. # 1 楼答案

    下面是一个仅使用BigDecimal算法的完整示例,展示了如何实现自定义收集器

    import java.math.BigDecimal;
    import java.util.Collections;
    import java.util.Map;
    import java.util.Set;
    import java.util.function.BiConsumer;
    import java.util.function.BinaryOperator;
    import java.util.function.Function;
    import java.util.function.Supplier;
    import java.util.stream.Collector;
    import java.util.stream.Collectors;
    import java.util.stream.Stream;
    
    public final class Person {
    
        private final String name;
        private final String state;
        private final BigDecimal salary;
    
        public Person(String name, String state, BigDecimal salary) {
            this.name = name;
            this.state = state;
            this.salary = salary;
        }
    
        public String getName() {
            return name;
        }
    
        public String getState() {
            return state;
        }
    
        public BigDecimal getSalary() {
            return salary;
        }
    
        public static void main(String[] args) {
            Person p1 = new Person("John", "NY", new BigDecimal("2000"));
            Person p2 = new Person("Jack", "NY", new BigDecimal("3000"));
            Person p3 = new Person("Jane", "GA", new BigDecimal("1500"));
            Person p4 = new Person("Jackie", "GA", new BigDecimal("2500"));
    
            Map<String, BigDecimal> result =
                Stream.of(p1, p2, p3, p4).collect(
                    Collectors.groupingBy(Person::getState,
                                          Collectors.mapping(Person::getSalary,
                                                             new AveragingCollector())));
            System.out.println("result = " + result);
    
        }
    
        private static class AveragingCollector implements Collector<BigDecimal, IntermediateResult, BigDecimal> {
            @Override
            public Supplier<IntermediateResult> supplier() {
                return IntermediateResult::new;
            }
    
            @Override
            public BiConsumer<IntermediateResult, BigDecimal> accumulator() {
                return IntermediateResult::add;
            }
    
            @Override
            public BinaryOperator<IntermediateResult> combiner() {
                return IntermediateResult::combine;
            }
    
            @Override
            public Function<IntermediateResult, BigDecimal> finisher() {
                return IntermediateResult::finish
            }
    
            @Override
            public Set<Characteristics> characteristics() {
                return Collections.emptySet();
            }
        }
    
        private static class IntermediateResult {
            private int count = 0;
            private BigDecimal sum = BigDecimal.ZERO;
    
            IntermediateResult() {
            }
    
            void add(BigDecimal value) {
                this.sum = this.sum.add(value);
                this.count++;
            }
    
            IntermediateResult combine(IntermediateResult r) {
                this.sum = this.sum.add(r.sum);
                this.count += r.count;
                return this;
            }
    
            BigDecimal finish() {
                return sum.divide(BigDecimal.valueOf(count), 2, BigDecimal.ROUND_HALF_UP);
            }
        }
    }
    

    如果您接受将BigDecimal值转换为double(这对于平均工资来说是完全可以接受的,IMHO),那么您可以使用

    Map<String, Double> result2 =
                Stream.of(p1, p2, p3, p4).collect(
                    Collectors.groupingBy(Person::getState,
                                          Collectors.mapping(Person::getSalary,
                                                             Collectors.averagingDouble(BigDecimal::doubleValue))));
    
  2. # 2 楼答案

    如果您需要BigDecimal精度,并且不介意额外的迭代,您可以这样做:

    static Map<String, BigDecimal> averageByState(List<Person> persons) {
        // collect sums
        Map<String, BigDecimal> sumByState = persons.stream()
                .collect(groupingBy(
                        Person::getState,
                        HashMap::new,
                        mapping(Person::getSalary, reducing(BigDecimal.ZERO, BigDecimal::add))));
    
        // collect counts
        Map<String, Long> countByState = persons.stream()
                .collect(groupingBy(Person::getState, counting()));
    
        // merge
        sumByState.replaceAll((state, sum) -> sum.divide(BigDecimal.valueOf(countByState.get(state))));
        return sumByState;
    }
    
  3. # 3 楼答案

    这里有一种方法。尽管我仍然相信还有其他更好的方法

    Map<String, Double> collect = Arrays.asList(p, p1, p2, p3, p4)
                .stream()
                .collect(groupingBy(k -> k.getState(), mapping(v -> v.salary, toList())))
                .entrySet()
                .stream()
                .collect(toMap(k -> k.getKey(), v -> v.getValue().stream().mapToDouble(BigDecimal::doubleValue).average().getAsDouble()));
    
  4. # 4 楼答案

    注意:这不是最好的解决方案。请阅读下面的(编辑部分)以了解更好的方法

    您可以使用^{}将每个州的工资累积到ArrayList中,然后使用finisher函数计算每个州的平均值:

    Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
        Collectors.groupingBy(Person::getState,
            Collectors.mapping(Person::getSalary,
                Collector.<BigDecimal, List<BigDecimal>, BigDecimal>of(
                    ArrayList::new, // create accumulator
                    List::add,      // add to accumulator
                    (l1, l2) -> {   // combine two partial accumulators
                        l1.addAll(l2);
                        return l1;
                    },
                    l -> l.stream() // finish with a reduction that returns average
                        .reduce(BigDecimal.ZERO, BigDecimal::add)
                        .divide(BigDecimal.valueOf(l.size()))))));
    

    这种方法不会失去精度,因为计算每个状态的平均值的每个操作都是在BigDecimal实例上执行的

    编辑:正如@Holger在他的评论中指出的,这不是解决问题的最佳方法,因为根本不需要使用ArrayList存储所有BigDecimal实例来计算平均值。相反,使用BigDecimal来累加部分和,使用long来累加计数就足够了(这是@JB Nizet在他的回答中遵循的方法,这里我修改了一些小细节)

    以下是一个考虑到这些因素的修改版本:

    private static class Acc {
        BigDecimal sum = BigDecimal.ZERO;
        long count = 0;
    
        void add(BigDecimal v) {
            sum = sum.add(v);
            count++;
        }
    
        Acc merge(Acc acc) {
            sum = sum.add(acc.sum);
            count += acc.count;
            return this;
        }
    
        BigDecimal avg() {
            return sum.divide(BigDecimal.valueOf(count));
        }
    }
    

    Acc是用于累积部分结果(部分和和和计数)的类

    现在,我们可以在这个类中使用Collector.of

    Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
        Collectors.groupingBy(Person::getState,
            Collectors.mapping(Person::getSalary,
                Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg))));
    

    或者更好的是,我们可以声明一个助手方法,其中Acc类是本地类:

    public static Collector<BigDecimal, ?, BigDecimal> averagingBigDecimal() {
        class Acc { // local class, lives only inside this method :P
            BigDecimal sum = BigDecimal.ZERO;
            long count = 0;
    
            void add(BigDecimal value) {
                sum = sum.add(value);
                count++;
            }
    
            Acc merge(Acc acc) {
                sum = sum.add(acc.sum);
                count += acc.count;
                return this;
            }
    
            BigDecimal avg() {
                return sum.divide(BigDecimal.valueOf(count));
            }
        }
        return Collector.of(Acc::new, Acc::add, Acc::merge, Acc::avg);
    }
    

    现在可以按如下方式使用此方法:

    Map<String, BigDecimal> salariesByState = Stream.of(p1, p2, p3, p4).collect(
        Collectors.groupingBy(Person::getState,
            Collectors.mapping(Person::getSalary, averagingBigDecimal())));