Free ~> Trampoline: recursive program crash with OutOfMemoryError

Suppose I am trying to implement a very simple domain specific language with only one operation:

printLine(line) 

Then I want to write a program that enters the integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1 until n reaches the maximum value of n .

Omitting all the syntax noise caused by for-comprehensions, I want:

 @annotation.tailrec def p(n: Int): Unit = { if (n % 10000 == 0) printLine("line") if (n > N) () else p(n + 1) } 

Essentially, it will be a kind of "fizzbuzz."

Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:

 import scalaz._ object Demo1 { // define operations of a little domain specific language sealed trait Lang[X] case class PrintLine(line: String) extends Lang[Unit] // define the domain specific language as the free monad of operations type Prog[X] = Free[Lang, X] import Free.{liftF, pure} // lift operations into the free monad def printLine(l: String): Prog[Unit] = liftF(PrintLine(l)) def ret: Prog[Unit] = Free.pure(()) // write a program that is just a loop that prints current index // after every few iteration steps val mod = 100000 val N = 1000000 // straightforward syntax: deadly slow, exits with OutOfMemoryError def p0(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- (if (i > N) ret else p0(i + 1)) } yield () // Same as above, but written out without `for` def p1(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () } } // Same as above, with `map` attached to recursive call def p2(i: Int): Prog[Unit] = (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => (if (i > N) ret else p2(i + 1).map{ ignore2 => () }) } // Same as above, but without the `map`; performs ok. def p3(i: Int): Prog[Unit] = { (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ ignore1 => if (i > N) ret else p3(i + 1) } } // Variation of the above; Ok. def p4(i: Int): Prog[Unit] = (for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) } yield ()).flatMap{ ignored2 => if (i > N) ret else p4(i + 1) } // try to use the variable returned by the last generator after yield, // hope that the final `map` is optimized away (it not optimized away...) def p5(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) stopHere <- (if (i > N) ret else p5(i + 1)) } yield stopHere // define an interpreter that translates the programs into Trampoline import scalaz.Trampoline type Exec[X] = Free.Trampoline[X] val interpreter = new (Lang ~> Exec) { def apply[A](cmd: Lang[A]): Exec[A] = cmd match { case PrintLine(l) => Trampoline.delay(println(l)) } } // try it out def main(args: Array[String]): Unit = { println("\n p0") p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p1") p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p2") p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError println("\n p3") p3(0).foldMap(interpreter).run // ok println("\n p4") p4(0).foldMap(interpreter).run // ok println("\n p5") p5(0).foldMap(interpreter).run // OutOfMemory } } 

Unfortunately, a simple translation ( p0 ) seems to run with some O (N ^ 2) flaws and a crash with OutOfMemoryError. The problem is that the for understanding adds map{x => ()} after a recursive call to p0 , which causes the Free monad to fill up all the memory with reminders to “finish” p0 'and then do nothing. "If I manually" I expand the understanding of for and write out the last flatMap explicitly (as in p3 and p4 ), the problem disappears and everything runs smoothly. This, however, is an extremely fragile workaround: the program’s behavior changes dramatically if we just add map(id) , and this map(id) does not even appear in the code, because it is generated automatically using for -comprehension.

In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ it is repeatedly recommended to wrap recursive calls on suspend . Here is an attempt with an instance of Applicative and suspend :

 import scalaz._ // Essentially same as in `Demo1`, but this time with // an `Applicative` and an explicit `Suspend` in the // `for`-comprehension object Demo2 { sealed trait Lang[H] case class Const[H](h: H) extends Lang[H] case class PrintLine[H](line: String) extends Lang[H] implicit object Lang extends Applicative[Lang] { def point[A](a: => A): Lang[A] = Const(a) def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match { case Const(x) => { f match { case Const(ab) => Const(ab(x)) case _ => throw new Error } } case PrintLine(l) => PrintLine(l) } } type Prog[X] = Free[Lang, X] import Free.{liftF, pure} def printLine(l: String): Prog[Unit] = liftF(PrintLine(l)) def ret: Prog[Unit] = Free.pure(()) val mod = 100000 val N = 2000000 // try to suspend the entire second generator def p7(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- Free.suspend(if (i > N) ret else p7(i + 1)) } yield () // try to suspend the recursive call def p8(i: Int): Prog[Unit] = for { _ <- (if (i % mod == 0) printLine("i = " + i) else ret) _ <- if (i > N) ret else Free.suspend(p8(i + 1)) } yield () import scalaz.Trampoline type Exec[X] = Free.Trampoline[X] val interpreter = new (Lang ~> Exec) { def apply[A](cmd: Lang[A]): Exec[A] = cmd match { case Const(x) => Trampoline.done(x) case PrintLine(l) => (Trampoline.delay(println(l))).asInstanceOf[Exec[A]] } } def main(args: Array[String]): Unit = { p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError p8(0).foldMap(interpreter).run // same... } } 

The suspend insert really didn't help: it's still slow and crashing with OutOfMemoryError s.

Should I use suspend otherwise?

Maybe there is some kind of purely syntactic tool that allows you to use for understanding without creating a map at the end?

I would really appreciate it if someone could point out what I am doing wrong here and how to restore it.

+6
source share
1 answer

This extra map , added by the Scala compiler, moves recursion from tail position to tailless position. The free monad still makes this stack safe, but the complexity of the space becomes O (N) instead of O (1). (In particular, this is not O (N 2 ).)

Is it possible to optimize scalac so that map gives a separate question (of which I do not know the answer).

I will try to illustrate what happens when interpreting p1 compared to p3 . (I will ignore the Trampoline translation, which is redundant (see below).)

p3 (i.e. without additional map )

Let me use the following shorthand:

 def cont(i: Int): Unit => Prg[Unit] = ignore1 => if (i > N) ret else p3(i + 1) 

Now p3(0) interpreted as follows

 p3(0) printLine("i = " + 0) flatMap cont(0) // side-effect: println("i = 0") cont(0) p3(1) ret flatMap cont(1) cont(1) p3(2) ret flatMap cont(2) cont(2) 

etc. You see that the amount of memory needed at any point does not exceed some constant upper bound.

p1 (i.e. with optional map )

I will use the following abbreviations:

 def cont(i: Int): Unit => Prg[Unit] = ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () } def cpu: Unit => Prg[Unit] = // constant pure unit ignore => Free.pure(()) 

Now p1(0) interpreted as follows:

 p1(0) printLine("i = " + 0) flatMap cont(0) // side-effect: println("i = 0") cont(0) p1(1) map { ignore2 => () } // Free.map is implemented via flatMap p1(1) flatMap cpu (ret flatMap cont(1)) flatMap cpu cont(1) flatMap cpu (p1(2) map { ignore2 => () }) flatMap cpu (p1(2) flatMap cpu) flatMap cpu ((ret flatMap cont(2)) flatMap cpu) flatMap cpu (cont(2) flatMap cpu) flatMap cpu ((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu ((p1(3) flatMap cpu) flatMap cpu) flatMap cpu (((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu 

etc. You can see that memory consumption is linearly dependent on N We just moved the estimate from the stack to the heap.

Remove: To keep Free memory friendly, save the recursion in the "tail position", that is, on the right side of flatMap (or map ).

In addition to this: Transfer to Trampoline not required, since Free already a trampoline. You can directly interpret in Id and use foldMapRec for stack-safe interpretation:

 val idInterpreter = new (Lang ~> Id) { def apply[A](cmd: Lang[A]): Id[A] = cmd match { case PrintLine(l) => println(l) } } p0(0).foldMapRec(idInterpreter) 

This will recover a small portion of your memory (but will not fix the problem).

+3
source

Source: https://habr.com/ru/post/1013128/


All Articles