Stream 的 reduce 與 collect


從一組數據依條件求得一個數,或將一組數據依條件收集至另一個容器,程式設計中不少地方都存在這類需求,使用迴圈解決這類需求,也是許多開發者最常採用的動作。舉例來說,求得一組人的男性平均年齡:

List<Person> persons = ...;
int sum = 0;
for(Person person : persons) {
    if(person.getGender() == Person.Gender.MALE) {
        sum += person.getAge();
    }
}
int average = sum / persons.size();

實際上,迴圈中進行的也是求得年齡加總,而若要求得一組人的男性最大年齡:

int max = 0;
for(Person person : persons) {
    if(person.getGender() == Person.Gender.MALE) {
        if(person.getAge() > max) {
            max = person.getAge();
        }
    }
}

實際上,你的程式中這類需求都存在著類似地流程結構,而你也不斷重複撰寫著類似結構,而且從閱讀程式碼角度來看,無法一眼察覺程式意圖,在JDK8中,可以改寫為:

int sum = persons.stream()
                 .filter(person -> person.getGender() == Person.Gender.MALE)
                 .mapToInt(Person::getAge)
                 .sum();
            
int average = (int) persons.stream()
                           .filter(person -> person.getGender() == Person.Gender.MALE)
                           .mapToInt(Person::getAge)
                           .average()
                           .getAsDouble();
            
int max = persons.stream()
                 .filter(person -> person.getGender() == Person.Gender.MALE)
                 .mapToInt(Person::getAge)
                 .max()
                 .getAsInt();

JDK8的IntStream提供了sum()average()max()min()等方法,那麼如果有其它的計算需求呢?觀察先前的迴圈結構,實際上都是將一組數據逐步取出削減,然而透過指定運算以取得結果的結構,JDK8將這個流程結構通用化,定義了reduce()方法來達到自訂運算需求。例如,以上三個流程,也可以使用reduce()重新撰寫如下:

int sum = persons.stream()
                 .filter(person -> person.getGender() == Person.Gender.MALE)
                 .mapToInt(Person::getAge)
                 .reduce((total, age) ->  total + age)
                 .getAsInt();

long males = persons.stream()
                .filter(person -> person.getGender() == Person.Gender.MALE)
                .count();

int average = persons.stream()
                     .filter(person -> person.getGender() == Person.Gender.MALE)
                     .mapToInt(Person::getAge)
                     .reduce((total, age) ->  total + age)
                     .getAsInt() / males;
           
int max = persons.stream()
                 .filter(person -> person.getGender() == Person.Gender.MALE)
                 .mapToInt(Person::getAge)
                 .reduce(0, (currentMax, age) -> age > currentMax ? age : currentMax);

reduce()的Lambda表示式,必須接受兩個引數,第一個引數為走訪該組數據上一元素後的運算結果,第二個引數為目前走訪元素,Lambda表示式本體就是你原先在迴圈中打算進行的運算;reduce()如果如上例中首兩個程式片段沒有指定初值,就會試著使用該組數據中第一個元素,作為第一次呼叫Lambda表示式時的第一個引數值,因為考量到數據組可能為空,因此reduce()不指定初值的版本,會傳回OptionalInt(非基本型態數據組,則會是Optional)。

那麼!如果你想將一組人的男性收集至另一個List呢?在persons.stream().filter(person -> person.getGender() == Person.Gender.MALE)之後,傳回的是Stream<Person>,因為filter()Stream的中介操作,不是最終操作,使用reduce()的話,在處理完新元素後,每次都會傳回新的計算結果,作為下一次Lambda表示式接受的第一個引數,顯然不適合用來收集物件。

你可以使用Streamcollect()方法,以將一組人的男性收集至另一個List的需求來說,最簡單的方式就是:

List<Person> males = persons.stream()
                            .filter(person -> person.getGender() == Person.Gender.MALE)
                            .collect(toList()); // toList() 是 java.util.stream.Collectors 的靜態方法

CollectorstoList()方法傳回的並不是List,而是java.util.stream.Collector實例,Collector主要的四個方法是:

  • suppiler():傳回Suppiler,定義收集結果的新容器如何建立
  • accumulator():傳回BiConsumer,定義如何使用結果容器收集物件
  • combiner():傳回BinaryOperator,定義若有兩個結果容器時,如何合併為一個結果容器
  • finisher():傳回Function,選擇性地定義如何將結果轉換為最後的結果容器
來看看Streamcollect()方法另一個版本,有助於瞭解Collector這幾個方法如何使用,以下的程式片段與上面的collect()範例結果是相同的:

List<Person> males = persons.stream()
                            .filter(person -> person.getGender() == Person.Gender.MALE)
                            .collect(
                                 () -> new ArrayList<>(),
                                 (maleLt, person) -> maleLt.add(person),
                                 (maleLt1, maleLt2) -> maleLt1.addAll(maleLt2)
                            );

collect()需要收集物件時,會使用第一個Lambda來取得容器物件,這相當於Collectorsuppiler()之作用,第二個Lambda定義了如何收集物件,也就是Collectoraccumulator()之作用,在使用具有平行處理能力的Stream時,有可能會使用多個容器對原數據組進行分而治之(Divide and conquer),當每個小任務完成時,該如何合併,就是第三個Lambda要定義的,喔!別忘了可以用方法參考,因此上面可以寫成以下比較簡潔:

List<Person> males = persons.stream()
                            .filter(person -> person.getGender() == Person.Gender.MALE)
                            .collect(ArrayList::new, ArrayList::add, ArrayList::addAll);

當然,使用這個版本的collect()需要處理比較多的細節,你可以先看看Collectors上提供了哪些Collector實作。舉例來說,如果想要依性別分組,那可以使用CollectorsgroupingBy()方法,告訴它要用哪個當作分組的鍵(Key),最後傳回的Map結果會以List作為值(Key):

Map<Person.Gender, List<Person>> males = persons.stream()
                  .collect(
                      groupingBy(Person::getGender));

有的方法也兼具另一種流暢風格,例如,想在依性別分組之後,取得分組下的姓名,那可以如下撰寫:

Map<Person.Gender, List<String>> males = persons.stream()
                  .collect(
                      groupingBy(Person::getGender,
                      mapping(Person::getName,
                  toList())));

例如,想在依性別分組之後,分別取得男女年齡加總,那可以如下撰寫:

Map<Person.Gender, Integer> males = persons.stream()
                  .collect(
                          groupingBy(Person::getGender,
                               reducing(0, Person::getAge, Integer::sum))
                  );

要求得各性別下平均年齡的話,Collectors也有個averagingInt()方法可以使用:

Map<Person.Gender, Double> males = persons.stream()
                  .collect(
                          groupingBy(Person::getGender,
                               averagingInt(Person::getAge))
                  );