Even More Lightweight Monadic Regions

Posted by Jonathan Immanuel Brachthäuser on January 17, 2019 · 18 mins read

Delimiting the lifetime of a resource to a particular scope is a common problem. In this post, I revisit “Lightweight Monadic Regions” (Kiselyov and Shan, 2008) and show how to generalize the ST-trick to nest monadic regions in a new way.

TLDR; The core idea is to index a monadic type by the set of regions it requires to be “alive”. This is not a new idea, but typically the set is represented by some typelevel list. Region subtyping then requires implicit evidence that one typelevel list is a sublist of another one. As a more lightweight approach, I propose to use Scala’s support for intersection types to represent the set of regions. Every region can be uniquely identified by its (path-dependent) singleton type and nesting of regions translates to the intersection of these singleton types. The nice thing about intersection types is that region subtyping now is just normal (contravariant) subtyping.

As it turns out, the idea of using a (contravariant) intersection type to track regions is not completely new and has been used by Parreaux et al. (2018) to guarantee macro hygiene.

This post is only about type / region safety. That is, we won’t actually implement functional local heaps or the like.

The code of this post and more variants of it can be found in this Scastie.

The ST-Trick

Launchbury & Sabry (1997) introduce a typed version of the ST monad. They use rank-2 types to describe regions in which resources (like references) are valid.

The following example uses the ST-monad to create mutable reference cells Ref.

def prog[S]: ST[Int, S] = for {
  r <- Ref[S](0)
  x <- r.get
  _ <- r.put(x + 1)
  y <- r.get
} yield y

Let’s look at the interface of the ST-monad translated to Scala and specialized to references that can only store Ints:

trait ST[+R, Scope] {
  def flatMap[B](f: R => ST[B, Scope]): ST[B, Scope]
  def map[B](f: R => B): ST[B, Scope]
}
trait Ref[Scope] {
  def get: ST[Int, Scope]
  def put(value: Int): ST[Unit, Scope]
}
def Ref[Scope](value: Int): ST[Ref[Scope], Scope]

One essential requirement for the ST-trick is that programs which use ST have to be fully parametric in Scope. This is captured in the type Program:

trait Program[+R] {
  def apply[Scope]: ST[R, Scope]
}
def runST[R](r: Program[R]): R

The type Program encodes a rank-2 type. Our example program prog has the right type and is parametric enough, so we can call runST:

runST { prog }

Remark

In fact, since SAMs don’t work for rank-2 types like Program we actually need to write

runST { new Program[Int] { def apply[S] = prog[S] } }

However, for conciseness, I sometimes pretend that we can write runSt { prog }.

The idea behind the ST-trick now is the following: since programs have to work for all scopes (notice the type parameter Scope on Program.apply) references introduced in one scope cannot be used in a different scope. For example, we can try to leak a reference like:

var leak: Ref[_] = _
runST {
  new Program[Unit] {
    def apply[S] = for { r <- Ref[S](0); _ = leak = r } yield ()
  }
}

However, the type variable S is not in scope outside of runST and thus we can only use an existential Ref[_] to express “it is a reference for some scope which don’t know”. References thus can escape but can never be used again outside of the call to runST. That is, the following does not type check:

runST {
  new Program[Int] { def apply[S] = leak.get }
   //                               ^^^^^^^^
   //                               found:    ST[Int, _]
   //                               required: ST[Int, S]
}

The existential type cannot be unified with S! The rank-2 type thus ensures that we can free the allocated resources after executing runST since they never can be used in any other call to runST.

Remark

In the setting of delimited control, Dybvig et al. (2007) use the same trick to prevent prompts from escaping the region of one run.

Using Singleton Types

The same region safety can be achieved with a slightly different encoding using Scala’s support for singleton types. We can express the interfaces of ST and Ref exactly as before, but use a different definition for Program:

trait Scope {}
trait Program[+R] {
  def apply(scope: Scope): ST[R, scope.type]
}

This way, the rank-2 universal quantification effectively moves from the type-level to the term-level.

Using ST now looks like

def prog(scope: Scope): ST[Int, scope.type] = for {
  r <- Ref[scope.type](0)
  x <- r.get
  _ <- r.put(x + 1)
  y <- r.get
} yield y

runST[Int] { prog }

Additionally, we can improve type inference a bit by making Ref a member of Scope:

trait Scope { scope =>
  trait Ref {
    def get: ST[Int, scope.type]
    def put(value: Int): ST[Unit, scope.type]
  }
  def Ref(value: Int): ST[Ref, scope.type]
}

The user program then changes to

def prog(scope: Scope): ST[Int, scope.type] = for {
  r <- /*>>>*/ scope.Ref(0) /*<<<*/
  x <- r.get
  _ <- r.put(x + 1)
  y <- r.get
} yield y

Generalizing to Multiple Regions

We are now ready to generalize the previous encodings to multiple, potentially nested regions:

def prog(scope1: Scope, scope2: Scope): ST[Int, scope1.type & scope2.type] =
  for {
    r <- scope1.Ref(0)
    s <- scope2.Ref(0)
    x <- r.get
    _ <- s.put(x + 1)
    y <- s.get
  } yield y

Here, we use intersection types to express that a program uses references of multiple nested scopes. Every call to scoped introduces one scope and removes the corresponding scope.type from the intersection. To run the program, we define type Global = Any for the top level scope (without any references). We then introduce two nested scopes with scoped to finally run the computation with runST:

runST[Int] {
  scoped[Int, Global] { scope1 =>
    scoped[Int, scope1.type] { scope2 =>
      prog(scope1, scope2)                // ST[Int, scope1.type & scope2.type]
    } // free all resources in scope2.    // ST[Int, scope1.type]
  }   // free all resources in scope1.    // ST[Int, Global]
}     // done.                            // Int

It might look a bit strange to have an intersection type of singleton types. In fact, most of those types don’t have inhabitants. However, we just use the intersection as a set of scope labels. It is a phantom type and does not have any operational relevance.

Here are the definitions of ST, Scope and Program:

trait ST[+R, -Scope] {
  def flatMap[B, S](f: R => ST[B, S]): ST[B, Scope & S]
  def map[B](f: R => B): ST[B, Scope]
}
trait Scope { scope =>
  trait Ref {
    def get: ST[Int, scope.type]
    def put(value: Int): ST[Unit, scope.type]
  }
  def Ref(value: Int): ST[Ref, scope.type]
}
trait Program[+R, -S] {
  def apply(scope: Scope): ST[R, scope.type & S]
}
def scoped[R, S](r: Program[R, S]): ST[R, S]
def runST[R](r: ST[R, Global]): R

In fact, since Dotty supports dependent function types, we can write program as

type Program[+R, -S] = (scope: Scope) => ST[R, scope.type & S]

The final version of the library is as follows:

trait Tests extends Regions {

  // this example also shows transfer from one nested region to another
  val ex1 = scoped[Int, Global] { s1 => for {
    r1 <- s1.Ref(1)
    r3  <- scoped[s1.Ref[Int], s1.Region] { s2 => for {
      r2 <- s2.Ref(0)
      x  <- r1.get
      _  <- r2 put (x + 1)
      // here we extend the lifetime of r2 to scope 1
      r3 <- r2 transfer s1
    // as an example, try returning r2 here. This won't typecheck.
    } yield r3 }
    y <- r3.get
    _ <- r1.put(y)
    z <- r1.get
  } yield z }

  val ex1res: Int = runST { ex1 }
}

trait Regions {

  type Global = Any
  trait ST[+R, -Scope] {
    def flatMap[B, S](f: R => ST[B, S]): ST[B, Scope & S]
    def map[B](f: R => B): ST[B, Scope]
  }
  trait Scope {
    type Region
    trait Ref[T] {
      def get: ST[T, Region]
      def put(value: T): ST[Unit, Region]

      // the effect type here asserts that both regions are
      // properly nested for the transfer
      def transfer(to: Scope): ST[to.Ref[T], Region & to.Region]
    }
    def Ref[T](value: T): ST[Ref[T], Region]
  }
  type Program[+R, -S] = (s: Scope) => ST[R, s.Region & S]

  def scoped[R, S](r: Program[R, S]): ST[R, S]
  def runST[R](r: ST[R, Global]): R
}