Fight the Future

Java言語とJVM、そしてJavaエコシステム全般にまつわること

ScalaのAdvanced Exampleを写経する(2)-for comprehension

for文もJavaとは違ってるね。
Advanced Exampleになってコードが長い!
理解したり試したり、時間がかかります。
fors.scala | The Scala Programming Language

package sample.snippet

object Persons {

  var persons = List(
    new Person("Bob", 17),
    new Person("John", 40),
    new Person("Richerd", 68)
  )
  
  class Person(val name: String, val age:Int)
  
  /** Return an iterator over persons older than 20.
   */
  def olderThan20(xs: Seq[Person]): Iterator[String] = 
    olderThan20(xs.elements)
  
  def olderThan20(xs: Iterator[Person]): Iterator[String] = {
    
    /** The first expression is called a 'generator'
     *  and makes 'p' take values from 'xs'.
     *  The second expression is called a 'filter'
     *  and it is a boolean expression
     *  which selects only persons older than 20.
     */
    for (p <- xs if p.age > 20) yield p.name
  }
  
  /** Some functions over lists of numbers which demonstrate
   *  the use of comprehensions.
   */
  object Numeric {
    
    def divisors(n: Int): List[Int] = 
      for (i <- List.range(1, n+1) if n % i == 0) yield i
    
    def isPrime(n: Int) = divisors(n).length == 2
    
    /** Return pairs of numbers whose sum is prime. */
    def findNums(n: Int): Iterable[(Int, Int)] = {
      for(i <- 1 until n;
          j <- 1 until (i -1);
          if isPrime(i + j)) 
            yield (i, j)
    }
    
    def sum(xs: List[Double]): Double = 
      xs.foldLeft(0.0) { (x, y) => x + y }
    
    /** Return the sum of pairwise prodect of the two lists. */
    def scalProd(xs: List[Double], ys: List[Double]) = 
      sum(for((x, y) <- xs zip ys) yield x * y)
    
    def removeDuplicates[A](xs: List[A]): List[A] = {
      if (xs.isEmpty)
        xs
      else 
        xs.head :: removeDuplicates(for (x <- xs.tail if x != xs.head) yield x)
     }
    
  }
  
    def main(args: Array[String]) {
      print("Persons over 20:")
      olderThan20(persons) foreach {x => print(" " + x)}
      
      println
      
      import Numeric._
      
      println("divisors(34) = " + divisors(34))
      
      print("findNums(15) =")
      findNums(15) foreach { x => print(" " + x) }
      println
      
      val xs = List(3.5, 5.0, 4.5)
      println("average(" + xs + ") = " + sum(xs) / xs.length)
      
      val ys = List(2.0, 1.0, 3.0)
      println("scalProd(" + xs + "," + ys + ") = " + scalProd(xs, ys)) 
      
      val l = removeDuplicates(List(1,1,2,3,4,4,4,4,5,6,7,8,9))
      println(l)
    }
}

細かく見ていきます。

for (p <- xs if p.age > 20) yield p.name

yieldは実際にはmap()になるみたい。だからIterator[Person]がIterator[String]で名前が入ってるのに変換するってこと。

    def divisors(n: Int): List[Int] = 
      for (i <- List.range(1, n+1) if n % i == 0) yield i

    def isPrime(n: Int) = divisors(n).length == 2

何気に感動したんだけど、これ約数を見つける関数。書き方がちょっとかっこいい。
素数を見つけるのも約数が2つの数と。エレガントな感じ。

    /** Return pairs of numbers whose sum is prime. */
    def findNums(n: Int): Iterable[(Int, Int)] = {
      for(i <- 1 until n;
          j <- 1 until (i -1);
          if isPrime(i + j)) 
            yield (i, j)
    }

このfor文が新しいタイプ。これ2重ループなんよね。
iがあってjがあって。
(i, j)というのは複数の値を扱うとき使う。タプルという。
なので(i, j)という2つの値を関数の戻り値にできたりする。

    /** Return the sum of pairwise prodect of the two lists. */
    def scalProd(xs: List[Double], ys: List[Double]) = 
      sum(for((x, y) <- xs zip ys) yield x * y)

zipはListクラスのメソッドで。2つのリストを同輪で回すイメージ。
つまりxsの0番目とysの0番目、つぎは2つとも1番目。。。みたいな。
どっちかが長いとあまりは無視される。

    def removeDuplicates[A](xs: List[A]): List[A] = {
      if (xs.isEmpty)
        xs
      else 
        xs.head :: removeDuplicates(for (x <- xs.tail if x != xs.head) yield x)
     }

この関数がカルチャーショック的な感動だった。
僕ほとんどJavaしか知らないから。この考え方はJavaにはない(と思う)。
リストの重複を除くのに再帰を使う。関数型言語っぽくない?


たとえば「List(1 ,1 ,2 ,3 ,4 ,4 ,4 ,4 ,5 ,6 ,7 ,8 ,9)」を引数に渡すと、再帰でこういう引数になっていく。

xs: List = List(1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 8, 9)
xs: List = List(2, 3, 4, 4, 4, 4, 5, 6, 7, 8, 9)
xs: List = List(3, 4, 4, 4, 4, 5, 6, 7, 8, 9)
xs: List = List(4, 4, 4, 4, 5, 6, 7, 8, 9)
xs: List = List(5, 6, 7, 8, 9)
xs: List = List(6, 7, 8, 9)
xs: List = List(7, 8, 9)
xs: List = List(8, 9)
xs: List = List(9)
xs: List = List()

最後までいったら結合。「::」はリストに要素を結合する記法。
だから空のリストに9を結合して、次にList(9)に8を結合して。。。となっていく。


さて、実行結果。

Persons over 20: John Richerd
divisors(34) = List(1, 2, 17, 34)
findNums(15) = (4,1) (5,2) (6,1) (7,4) (8,3) (8,5) (9,2) (9,4) (10,1) (10,3) (10,7) (11,2) (11,6) (11,8) (12,1) (12,5) (12,7) (13,4) (13,6) (13,10) (14,3) (14,5) (14,9)
average(List(3.5, 5.0, 4.5)) = 4.333333333333333
scalProd(List(3.5, 5.0, 4.5),List(2.0, 1.0, 3.0)) = 25.5
List(1, 2, 3, 4, 5, 6, 7, 8, 9)