你的Scala编译得还不够慢

Jun 10, 2021

见习魔法师

Scala被人诟病的缺点之一是Scala的编译速度非常慢,虽然你可以在编译的时候补补番什么的,但是想要马上拿到结果的话可能令人焦虑。其实编译慢的原因有很多,比如类型推断是一个耗时的工作,虽然Scala用的是局部类型推断,ast中的每个元素在被构造的时候必须知道类型,但是本身也是比较耗时的。同时,Scala会生成更多的方法和类,比如伴生对象就会导致class喜加一,一个文件可能含多个类等等。Scala的implicit search会在整个scope中搜索变量,也是耗时的工作,诸如此类导致了Scala的编译很慢。

但是,我觉得,Scala编译还能更慢一点,因此本文讨论如何让Scala的编译更慢,从而增加划水时间,可以在Scala代码里面写一些类型体操,宏,内联函数什么的让Scala编译器多忙活一阵。

Implicits! 更多的 Implicits!

Implicit (或者given, using)可以增加编译器的工作量。除了必要的implicits以外,要注意多多增加用于证明的implicits,比如erased terms。在代码中加入更多的隐式转换,更多的DSL,更多的extension methods,都可以达到减慢速度的效果。

我们拿最经典的例子来说明:

sealed trait State
final class On extends State
final class Off extends State

@implicitNotFound("State must be on")
class IsOn[S <: State]
object IsOn:
  given IsOn[On] = new IsOn[On]

@implicitNotFound("State must be off")
class IsOff[S <: State]
object IsOff:
  given IsOff[Off] = new IsOff[Off]

class Light[S <: State] private ():
  def turnedOn(using IsOff[S]) = new Light[On]
  def turnedOff(using IsOn[S]) = new Light[Off]
object Light extends Light[Off]

Light.turnedOn.turnedOn // Error: State must be off

在灯被点亮后,灯的State变成On,此时要再开灯,需要一个StateOff的evidence,然而这个时候是不存在的,因此编译失败。 因为现在erased terms还在实验阶段,所以scalac就会生成所有的class,并且在编译器爽做implicit search。

还可以看一看之前写的Scala3 Contextual Abstractions里面的contextual functions,里面的做法都能增加一堆的implicit search。

类型体操!更多的类型体操!

事实上,Scala的类型体操有些还是需要implicit的支持的,这也从侧面说明implicit的强大。 首先,我们使用Scala3新出的match type,来增加难度。

type Elem[X] = X match
  case String => Char
  case Array[t] => t
  case Seq[t] => t
def firstElem[X](xs: X): Elem[X] = xs match
  case s: String => s charAt 0
  case x: Array[?] => x(0)
  case x: Seq[?] => x.head

你可能会问,区区这种东西,typescript也能做到,甚至能裁剪字符串,修改interface结构等高级操作,那么Scala3区别在哪儿呢?

在ts中我们可以这么写:

type Elem<T> = 
  T extends string ? string : T extends (infer U)[] ? U : never

首先,这个三目运算符十分不友好,看看type-challenges里面的题解就知道这个语法在类型复杂了之后有多难懂。

和typescript这种妖艳贱货不一样的是,Scala可以用compile-time-operations将类型信息带到运行时(classtag等),或者直接在inline函数中处理,或者直接生成计算后的值。

在加入implicits之后,Scala的编译时操作就比typescript更加强悍了。

import scala.compiletime._
import compiletime.ops.string.`+`
import compiletime.ops.int.{%, <, <=}
import compiletime.ops.any.==
import compiletime.ops.boolean._
import scala.annotation.implicitNotFound

@implicitNotFound("Unexpected case: ${T}")
class Expected[T <: Boolean] private ()
object Expected:
  given Expected[true] = new Expected[true]

这里我们定义了一个Expected类,这个类接受了一个类型参数,必须是Boolean的子类型。 同时构造器是私有的,意味着无法在外界构造实例,同时给出了仅有的一个实例,即类型参数为true的时候Expected类才有意义。

你可能会问了,Boolean在Scala中是final的,为何会有子类型呢?事实上,Scala中任何可以作为字面量类型的类型都是final的,因此这里的子类型只能代表相应的字面量类型,没有任何歧义。

我们就可以提出第一个需求了:一个range必须保证开始小于等于结束,如何在编译期确保这件事情?

Scala3提供了inline关键字,和Scala2的注解不同的是,inline不再是建议,而是命令,命令编译器必须将代码内联,同时做必要的编译期计算。还提供了constValue[T]。如果T是一个字面量类型,则返回该值,否则编译错误。此时我们就可以如此定义range函数了:

inline def range[S <: Int, T <: Int](using Expected[S <= T]) = constValue[S].to(constValue[T])

range[2, 4] // 2 to 4
range[4, 2] // Unexpected case: (4 : Int) < (2 : Int)

<=这里是一个infix type,S <= T其实等价于<=[S, T],这个类型的全名是scala.compiletime.ops.int.<=,作用是在编译期计算constValue[S] <= constValue[T]的值。 如果觉得数字传入参数有点奇怪,也可以这么定义:

def range[S <: Int & Singleton, T <: Int & Singleton](s: S, t: T)(using Expected[S <= T]) = s.to(t)
注意只有Int & SingletonScala才会选择将参数推断为字面类型,否则永远推断为Int,此时就无法比较了。

typescript的类型挑战有一道题:Join String,用Scala我们也可以在编译时写出这样的函数:

inline def joinStr[D <: String, Xs <: Tuple]: String = 
  inline erasedValue[Xs] match
    case _: EmptyTuple => ""
    case _: (t *: EmptyTuple) => constValue[t & String]
    case _: (t *: ts) => constValue[t & String] + constValue[D] + joinStr[D, ts]

val s = joinStr["+", ("1", "2", "3")] // "1+2+3"

除了match type之外也要多多使用HigherKindTypes,多多使用像shapeless, cats, scalaz这样的库。如果能多多使用形如这样的代码,你就离成功更近了一步:

infix type ~>[F[_], G[_]] = [t] => F[t] => G[t]
type Id[T] = T
val wrap: Id ~> Option = [t] => Option(_: t)
val unwrap: Option ~> Id = [t] => (op: Option[t]) => op.get

宏!更多的宏!

Scala3的宏允许了编译器在编译时执行项目中的部分代码,从而生成新的代码,所以宏就有这样的限制:

  • 宏的实现必须是一个静态方法

  • 宏的调用和实现不能在同一个文件中

宏的展开发生在typer阶段,这个时候还没有做代码的优化,可以拿到几乎是原本的AST。当然和Scala2的macro不同的是,Scala3提供了一种不关心AST,只关心类型的抽象,即Expr[T]。同时你可以使用反射模组来获取AST。

第一个宏

import scala.quoted._
object Macros:
  def evil: Unit = ${ evilImpl }
  def evilImpl(using Quotes): Expr[Unit] = 
    System.exit(0)
    '{ () }
def balabala =
  Macros.evil
  println(114514)

可以看见宏的基本特征是:一个内联函数和其实现,因为宏展开后可能不是函数调用了,所以这个函数必须是inline的,会接受一些参数,但是他的函数体一定是${ ...impl }调用一个实现函数,传入的所有参数必须quote一遍,形如'expr

宏的实现就是接受一系列Expr作为参数,返回一个Expr作为返回值,同时macro的实现必须接受scala.quoted.Quotes作为隐式参数,如果生成的表达式还用到了类型参数的类型,则还必须传入scala.quoted.Type[T]作为隐式参数。

Expr[T]是表达式的一种表示,表示代表的term的类型为T,为了将两者转化,Scala提供了两个操作:quote和splice。

quote 形如 'expr,将T变为Expr[T]

splice 形如 $expr,将Expr[T]变为T

因此对任意表达式e:

${'{e}} = e
'{${e}} = e

说了这么多,这个宏干了啥呢?这个宏相当恶心,在编译balabala函数的时候,Scalac执行了evilImpl函数,然后Scalac退出,什么都没有了。也就是说在展开这个函数的时候,编译器退出了,这个时候编译时间就是无限大了,我们成功达成了终极目标。 这也说明了在宏的实现内不能出现含全局副作用的代码,会有奇怪的行为。

上文提到的joinStr也可以不用类型体操实现,可以写宏,你可能会觉得更加如鱼得水一点,因为Expr[T]提供了一个扩展方法valueOrError,如果编译器的值是已知的,可以获取值,否则抛出错误。 当然还有value,方法,返回一个Option。

def joinImpl(delim: Expr[String], xs: Expr[Seq[String]])(using Quotes): Expr[String] =
  import quotes.reflect.report
  val d = delim.valueOrError
  val strs = xs match
    case Varargs(ys) => ys.map(_.valueOrError)
    case _ => xs.valueOrError
  val result = strs.mkString(d)
  Expr(result)

inline def joinStr(inline delim: String)(inline xs: String*): String = 
  ${ joinImpl('delim, 'xs) }

Varargs是可以判断传进来的表达式是不是一个可变参数,如果是可变参数,如果是,将Expr[Seq[T]]解构为Seq[Expr[T]]。然后我们在编译期把string连接的结果算出来,然后直接替换为字符串字面量。

val sep = joinStr(",")("1", "2", "3") // Compiles to
val sep = "1,2,3"

我们知道,for在Scala中是语法糖,左右其实是调用map, flatMap, withFilter等方法,在强大的同时引入了一定的运行时开销。那我们可以写一个最原始的C风格的for来解决这个问题:

def forEachImpl[T](t: Expr[Array[T]], fn: Expr[T => Unit])(using Quotes, Type[T]): Expr[Unit] =
  '{
    val xs = $t
    var i = 0
    val n = xs.length
    while i < n do
      ${Expr.betaReduce('{$fn(xs(i))})}
      i += 1
  }

extension[T] (arr: Array[T])
  inline def forEach(inline cons: T => Unit): Unit =
    ${ forEachImpl('arr, 'cons) }

val arr = Array(1, 2, 3)
arr.forEach(println) // Expands to
val xs = arr
var i = 0
val n = xs.length
while i < n do
  println(xs(i))
  i += 1

宏也可以改变闭包语义,因为相当于直接插入到展开位置。根据这个特性我们可以实现C++中的引用。

case class Ref[@specialized(Specializable.BestOfBreed) T] (get: () => T, set: T => Unit):
  def :=(v: T) = set(v)
  def value: T = get()
  override def toString = value.toString

object Ref:
  given[T]: Conversion[Ref[T], T] = _.value

  import scala.quoted.*
  inline def apply[T](inline v: T): Ref[T] =
    ${ make('v) }
  def make[T](v: Expr[T])(using Quotes, Type[T]): Expr[Ref[T]] =
    import quotes.reflect.{Ref => ReflectRef, *}
    v.asTerm match {
      case Inlined(_, _, term @ (Ident(_) | Select(_, _))) =>
      '{
        Ref(() => $v, nv => {
          ${ Assign(term, 'nv.asTerm).asExpr }
        })
      }
      case _ =>
        report.throwError(s"${v.show} is not an assignable value.", v)
    }

var x = 0
val ref = Ref(x)

ref := 1
assert(x == 1)
assert(ref * 2 == 2)

可以看见,Ref宏的参数也是inline的,所以可以能发现原本的AST到底是不是可以赋值的AST(C++中等同的概念是左值)。而且在JVM中,栈上的值是无法引用到闭包里的,所以这里Scala还把x转换为了scala.runtime.IntRef, 如此多隐含的开销够编译器喝一壶的了。

Mirror!更多的Mirror!

Mirror 会在生成以下类的时候一起生成:

  • case classes, case objects

  • enum classes, enums

  • 子类只有case classes, case objects的sealed trait 或者 sealed class

Mirror包含了类型的元素信息,包括类型和名称,你可以在编译期直接constValue召唤出来。因此你可以直接分析类的结构,可以做一些typeclass的derivation。这也是在Scala里面模仿代数数据类型的方法。

在Scala3中,任何class除了extends之外,现在可以加入derives了。那么,什么是derives呢?可以认为是做了这样的事情:

trait HistoryTrend[T] // A typeclass
object HistoryTrend:
  def derived[T]: HistoryTrend[T] = ??? // or given[T]: HistoryTrend[T] = ???

trait SelfStruggle
class OnesFate extends SelfStruggle derives HistoryTrend

那么derives其实是语法糖,其实在OnesFate类的伴生对象里定义了HistoryTrend[OnesFate]的一个given对象。

class OnesFate extends SelfStruggle
object OnesFate:
  given HistoryTrend[OnesFate] = HistoryTrend.derived

derived方法的实现有很多,甚至可以用宏实现。总之只要生成typeclass相应的instance即可。可以看见,typeclass实现了很多我们一直想要做到的事情,引入了implicit,使用了宏,还要编译器处理Mirror instances,简直是增加编译时间的大杀器。

在具有Mirror的时候,你就可以不用宏实现typeclass的派生。比如,大名鼎鼎的Monoid。其定义如下:

存在这样的集合M,以及二元运算f: M \times M \to M,满足:

  • 结合律:\forall a, b, c \in M, f(a, f(b, c)) = f(f(a, b), c)

  • 单位元:\exists e \in M s.t. \forall a \in M, f(e, a) = f(a, e) = a

trait Monoid[A]:
  def empty: A
  def combine(x: A, y: A): A
  extension (x: A) def |+| (y: A): A = combine(x, y)

给出一些基本的Monoid, 比如Numeric,String,Boolean

object Monoid:
  def zero[A](using m: Monoid[A]) = m.empty
  given Monoid[Unit] with
    def empty: Unit = ()
    def combine(x: Unit, y: Unit): Unit = ()

  given Monoid[Boolean] with
    def empty: Boolean = false
    def combine(x: Boolean, y: Boolean): Boolean = x || y

  given[T](using num: Numeric[T]): Monoid[T] with
    def empty: T = num.zero
    def combine(x: T, y: T): T = num.plus(x, y)

  given Monoid[String] with
    def empty: String = ""
    def combine(x: String, y: String): String = x + y

在拥有一个Mirror和其包含的所有子类的Moniod实例后,该怎么办呢?我们先来看看Mirror是如何工作的:

case class ISB(i: Int, s: String, b: Boolean)

因为ISB是一个case class,所以会生成Mirror instance,而且ISB是一个Product class,因为ISB的值域是Int * String * Boolean,即这三个类型的笛卡尔积的值域。 对这样的class,我们可以认为会产生这样的代码:

case class ISB(i: Int, s: String, b: Boolean) extends Product:
  def productElement(n: Int): any = n match
    case 0 => i
    case 1 => s
    case 2 => b
    case _ => throw new IndexOutOfBoundsException(n)
  def productArity: Int = 3
  def productPrefix = "ISB"
import scala.deriving._
object ISB extends Mirror.Product:
  type MirroredType = ISB
  type MirroredMonoType = ISB
  type MirroredLabel <: String = "ISB"
  type MirroredElemLabels <: Tuple = ("i", "s", "b")
  type MirroredElemTypes <: Tuple = (Int, String, Boolean)

  def fromProduct(p: scala.Product): MirroredMonoType = 
    new ISB(p.productElement(0).asInstanceOf[Int], p.productElement(1).asInstanceOf[String], p.productElement(2).asInstanceOf[Boolean])

可以发现,这么简单的一行代码,编译器会帮你产生那么多代码,可见为了达成本文的目标,在你的Scala项目中一定要多多使用case class。 还有很多随着case class生成的方法,这里省略,仅关心和Mirror有关的元素。我们发现了思路: MirrorElemeTypes(Int, String, Boolean),如果生成(Monoid[Int], Monoid[String], Monoid[Boolean])的实例,就可以使用这些实例生成empty和combine方法了。 我们先假定我们拥有了这样的实例,那么就可以这么生成:

object Monoid:
  // Previous defs...
  import scala.deriving._
  inline def genProductMonoid[A](m: Mirror.ProductOf[A], inst: IArray[Monoid[?]]): Monoid[A] = new Monoid[A]:
    extension[T](t: T) def iterate = t.asInstanceOf[Product].productIterator
    def empty: A = m.fromProduct(Tuple.fromIArray(inst.map(_.empty)))
    def combine(x: A, y: A): A = 
      val arr = inst.zip(x.iterate).zip(y.iterate).map {
        case ((m, x), y) => m.combine(x.asInstanceOf, y.asInstanceOf)
      }
      m.fromProduct(Tuple.fromIArray(arr))

可以看见我们的做法是把元素转为一个Tuple,然后使用fromProduct方法构造这样的类,因为Tuple实现了Product。获取类里面的元素也是利用了case class实现了Product接口,所以可以直接利用这个生成元素, 然后再让里面的元素通过Monoid实例处理,就能实现Monoid的组合。

那我们要如何将Monoid召唤出来呢?考虑到MirroredElemLabels是一个Tuple,而Scala3对Tuple的解构和List类似,(Int, String, Boolean)其实等价于Int *: String *: Boolean *: EmptyTuple, 所以可以写一个递归的算法来实现:

inline def summonMonoids[T <: Tuple]: List[Monoid[?]] =
  inline erasedValue[T] match
    case _: EmptyTuple => Nil
    case _: (t *: ts) => summonInline[Monoid[t]] :: summonMonoids[ts]

这里erasedValue[T]是你没有T的实例,却想要对T的类型做模式匹配的时候可以使用的方法,仅能在inline函数中做inline match。 在拥有了这些东西之后,我们就可以编写最重要的derived函数了,因为derives语法糖就是调用这个函数。

object Monoid:
  // Previous defs...
  inline def derived[A](using m: Mirror.ProductOf[A]) = 
    val insts = IArray.from(summonMonoids[m.MirroredElemTypes])
    genProductMonoid(m, insts)

此时就可以大胆为ISB写derives了:

case class ISB(i: Int, s: String, b: Boolean) derives Monoid

val yj = ISB(114, "yj", false)
val snpi = ISB(400, "snpi", true)
val yjsnpi = ISB(514, "yjsnpi", true)
assert(yj |+| snpi == yjsnpi)

注意typeclass是可以组合的,即我们可以:

case class Homo(age: Int, data: ISB) derives Monoid

val hYj = Homo(21, yj)
val hSnpi = Homo(3, snpi)
val hYjsnpi = Homo(24, yjsnpi)
assert(hYj |+| hSnpi == hYjsnpi)

总结

其实让Scala编译变慢的方法十分简单,多用Scala特有的特性就行了。 这些特性其实道理是一致的,就是让编译器帮你做了你本该做的活。事实上,虽然你的编译时间增加了,你实际上消耗的时间却减少了,因为编译器对你正确性的证明做的比人快。