Suppose I am trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that enters the integer n
as input, prints something if n
is divisible by 10k, and then calls itself with n + 1
until n
reaches the maximum value of n
.
Omitting all the syntax noise caused by for-comprehensions, I want:
@annotation.tailrec def p(n: Int): Unit = { if (n % 10000 == 0) printLine("line") if (n > N) () else p(n + 1) }
Essentially, it will be a kind of "fizzbuzz."
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._ object Demo1 { // define operations of a little domain specific language sealed trait Lang[X] case class PrintLine(line: String) extends Lang[Unit] // define the domain specific language as the free monad of operations type Prog[X] = Free[Lang, X] import Free.{liftF, pure} // lift operations into the free monad def printLine(l: String): Prog[Unit] = liftF(PrintLine(l)) def ret: Prog[Unit] = Free.pure(()) // write a program that is just a loop that prints current index // after every few iteration steps val mod = 100000 val N = 1000000 // straightforward syntax: deadly slow, exits with OutOfMemoryError def p0(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- (if (i > N) ret else p0(i + 1)) } yield () // Same as above, but written out without `for` def p1(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () } } // Same as above, with `map` attached to recursive call def p2(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p2(i + 1).map{ ignore2 => () }) } // Same as above, but without the `map`; performs ok. def p3(i: Int): Prog[Unit] = { (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => if (i > N) ret else p3(i + 1) } } // Variation of the above; Ok. def p4(i: Int): Prog[Unit] = (for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) } yield ()).flatMap{ ignored2 => if (i > N) ret else p4(i + 1) } // try to use the variable returned by the last generator after yield, // hope that the final `map` is optimized away (it not optimized away...) def p5(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) stopHere <- (if (i > N) ret else p5(i + 1)) } yield stopHere // define an interpreter that translates the programs into Trampoline import scalaz.Trampoline type Exec[X] = Free.Trampoline[X] val interpreter = new (Lang ~> Exec) { def apply[A](cmd: Lang[A]): Exec[A] = cmd match { case PrintLine(l) => Trampoline.delay(println(l)) } } // try it out def main(args: Array[String]): Unit = { println("\n p0") p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p1") p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p2") p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p3") p3(0).foldMap(interpreter).run // ok println("\n p4") p4(0).foldMap(interpreter).run // ok println("\n p5") p5(0).foldMap(interpreter).run // OutOfMemory } }
Unfortunately, a simple translation ( p0
) seems to run with some O (N ^ 2) flaws and a crash with OutOfMemoryError. The problem is that the for
understanding adds map{x => ()}
after a recursive call to p0
, which causes the Free
monad to fill up all the memory with reminders to “finish” p0 'and then do nothing. "If I manually" I expand the understanding of for
and write out the last flatMap
explicitly (as in p3
and p4
), the problem disappears and everything runs smoothly. This, however, is an extremely fragile workaround: the program’s behavior changes dramatically if we just add map(id)
, and this map(id)
does not even appear in the code, because it is generated automatically using for
-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ it is repeatedly recommended to wrap recursive calls on suspend
. Here is an attempt with an instance of Applicative
and suspend
:
import scalaz._ // Essentially same as in `Demo1`, but this time with // an `Applicative` and an explicit `Suspend` in the // `for`-comprehension object Demo2 { sealed trait Lang[H] case class Const[H](h: H) extends Lang[H] case class PrintLine[H](line: String) extends Lang[H] implicit object Lang extends Applicative[Lang] { def point[A](a: => A): Lang[A] = Const(a) def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match { case Const(x) => { f match { case Const(ab) => Const(ab(x)) case _ => throw new Error } } case PrintLine(l) => PrintLine(l) } } type Prog[X] = Free[Lang, X] import Free.{liftF, pure} def printLine(l: String): Prog[Unit] = liftF(PrintLine(l)) def ret: Prog[Unit] = Free.pure(()) val mod = 100000 val N = 2000000 // try to suspend the entire second generator def p7(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- Free.suspend(if (i > N) ret else p7(i + 1)) } yield () // try to suspend the recursive call def p8(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- if (i > N) ret else Free.suspend(p8(i + 1)) } yield () import scalaz.Trampoline type Exec[X] = Free.Trampoline[X] val interpreter = new (Lang ~> Exec) { def apply[A](cmd: Lang[A]): Exec[A] = cmd match { case Const(x) => Trampoline.done(x) case PrintLine(l) => (Trampoline.delay(println(l))).asInstanceOf[Exec[A]] } } def main(args: Array[String]): Unit = { p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError p8(0).foldMap(interpreter).run // same... } }
The suspend
insert really didn't help: it's still slow and crashing with OutOfMemoryError
s.
Should I use suspend
otherwise?
Maybe there is some kind of purely syntactic tool that allows you to use for understanding without creating a map
at the end?
I would really appreciate it if someone could point out what I am doing wrong here and how to restore it.