forked from scala/scala-java8-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
WrapFnGen.scala
321 lines (285 loc) · 12.7 KB
/
WrapFnGen.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/
object WrapFnGen {
val copyright =
s"""
|/*
| * Copyright EPFL and Lightbend, Inc.
| * This file auto-generated by WrapFnGen.scala. Do not modify directly.
| */
|""".stripMargin
val packaging = "package scala.compat.java8"
import scala.tools.nsc._
val settings = new Settings(msg => sys.error(msg))
settings.usejavacp.value = true
val compiler = new Global(settings)
val run = new compiler.Run
import compiler._, definitions._
implicit class IndentMe(v: Vector[String]) {
def indent: Vector[String] = v.map(" " + _)
}
implicit class FlattenMe(v: Vector[Vector[String]]) {
def mkVec(join: String = ""): Vector[String] = {
val vb = Vector.newBuilder[String]
var first = true
v.foreach{ vi =>
if (!first) vb += join
first = false
vb ++= vi
}
vb.result()
}
}
implicit class DoubleFlattenMe(v: Vector[Vector[Vector[String]]]) {
def mkVecVec(join: String = ""): Vector[String] = {
val vb = Vector.newBuilder[String]
var first = true
v.foreach{ vi =>
if (!first) { vb += join; vb += join }
first = false
var ifirst = true
vi.foreach{ vj =>
if (!ifirst) vb += join
ifirst = false
vb ++= vj
}
}
vb.result()
}
}
implicit class SplitMyLinesAndStuff(s: String) {
def toVec = s.linesIterator.toVector
def nonBlank = s.trim.length > 0
}
implicit class TreeToText(t: Tree) {
def text = showCode(t).replace("$", "").linesIterator.toVector
}
case class Prioritized(lines: Vector[String], priority: Int) {
def withPriority(i: Int) = copy(priority = i)
}
case class SamConversionCode(
base: String,
wrappedAsScala: Vector[String],
asScalaAnyVal: Vector[String],
implicitToScala: Vector[String],
asScalaDef: Vector[String],
wrappedAsJava: Vector[String],
asJavaAnyVal: Vector[String],
implicitToJava: Prioritized,
asJavaDef: Vector[String]
) {
def impls: Vector[Vector[String]] = Vector(wrappedAsScala, asScalaAnyVal, wrappedAsJava, asJavaAnyVal)
def defs: Vector[Vector[String]] = Vector(asScalaDef, asJavaDef)
def withPriority(i: Int): SamConversionCode = copy(implicitToJava = implicitToJava.withPriority(i))
}
object SamConversionCode {
def apply(scc: SamConversionCode*): (Vector[String], Vector[Vector[String]]) = {
val sccDepthSet = scc.map(_.implicitToJava.priority).toSet
val codes =
{
if (sccDepthSet != (0 to sccDepthSet.max).toSet) {
val sccDepthMap = sccDepthSet.toList.sorted.zipWithIndex.toMap
scc.map(x => x.withPriority(sccDepthMap(x.implicitToJava.priority)))
}
else scc
}.toVector.sortBy(_.base)
def priorityName(n: Int, pure: Boolean = false): String = {
val pre =
if (n <= 0)
if (pure) "FunctionConverters"
else s"package object ${priorityName(n, pure = true)}"
else
if (pure) s"Priority${n}FunctionConverters"
else s"trait ${priorityName(n, pure = true)}"
if (!pure && n < (sccDepthSet.size-1)) s"$pre extends ${priorityName(n+1, pure = true)}" else pre
}
val impls =
"package functionConverterImpls {" +: {
codes.map(_.impls).mkVecVec().indent
} :+ "}"
val traits = codes.filter(_.implicitToJava.priority > 0).groupBy(_.implicitToJava.priority).toVector.sortBy(- _._1).map{ case (k,vs) =>
s"${priorityName(k)} {" +:
s" import functionConverterImpls._" +:
s" " +:
vs.map(_.implicitToJava.lines).mkVec().indent :+
s"}"
}
val explicitDefs = codes.map(_.defs).mkVecVec()
val packageObj =
s"${priorityName(0)} {" +:
s" import functionConverterImpls._" +:
s" " +:
{
explicitDefs.indent ++
Vector.fill(3)(" ") ++
codes.filter(_.implicitToJava.priority == 0).map(_.implicitToJava.lines).mkVec().indent ++
Vector.fill(3)(" ") ++
codes.map(_.implicitToScala).mkVec().indent
} :+ "}"
(impls, traits :+ packageObj)
}
}
private def buildWrappersViaReflection: Seq[SamConversionCode] = {
val pack: Symbol = rootMirror.getPackageIfDefined("java.util.function")
case class Jfn(iface: Symbol, sam: Symbol) {
lazy val genericCount = iface.typeParams.length
lazy val name = sam.name.toTermName
lazy val title = iface.name.encoded
lazy val params = sam.info.params
lazy val sig = sam typeSignatureIn iface.info
lazy val pTypes = sig.params.map(_.info)
lazy val rType = sig.resultType
def arity = params.length
}
val sams = pack.info.decls.
map(d => (d, d.typeSignature.members.filter(_.isAbstract).toList)).
collect{ case (d, m :: Nil) if d.isAbstract => Jfn(d, m) }
def generate(jfn: Jfn): SamConversionCode = {
def mkRef(tp: Type): Tree = if (tp.typeSymbol.isTypeParameter) Ident(tp.typeSymbol.name.toTypeName) else tq"$tp"
// Types for the Java SAM and the corresponding Scala function, plus all type parameters
val scalaType = gen.mkAttributedRef(FunctionClass(jfn.arity))
val javaType = gen.mkAttributedRef(jfn.iface)
val tnParams: List[TypeName] = jfn.iface.typeParams.map(_.name.toTypeName)
val tdParams: List[TypeDef] = tnParams.map(TypeDef(NoMods, _, Nil, EmptyTree))
val javaTargs: List[Tree] = tdParams.map(_.name).map(Ident(_))
val scalaTargs: List[Tree] = jfn.pTypes.map(mkRef) :+ mkRef(jfn.rType)
// Conversion wrappers have three or four components that we need to name
// (1) The wrapper class that wraps a Java SAM as Scala function, or vice versa (ClassN)
// (2) A value class that provides .asJava or .asScala to request the conversion (ValCN)
// (3) A name for an explicit conversion method (DefN)
// (4) An implicit conversion method name (ImpN) that invokes the value class
// Names for Java conversions to Scala
val j2sClassN = TypeName("FromJava" + jfn.title)
val j2sValCN = TypeName("Rich" + jfn.title + "As" + scalaType.name.encoded)
val j2sDefN = TermName("asScalaFrom" + jfn.title)
val j2sImpN = TermName("enrichAsScalaFrom" + jfn.title)
// Names for Scala conversions to Java
val s2jClassN = TypeName("AsJava" + jfn.title)
val s2jValCN = TypeName("Rich" + scalaType.name.encoded + "As" + jfn.title)
val s2jDefN = TermName("asJava" + jfn.title)
val s2jImpN = TermName("enrichAsJava" + jfn.title)
// Argument lists for the function / SAM
val vParams = (jfn.params zip jfn.pTypes).map{ case (p,t) =>
ValDef(NoMods, p.name.toTermName, if (t.typeSymbol.isTypeParameter) Ident(t.typeSymbol.name) else gen.mkAttributedRef(t.typeSymbol), EmptyTree)
}
val vParamRefs = vParams.map(_.name).map(Ident(_))
val j2sClassTree =
q"""class $j2sClassN[..$tdParams](jf: $javaType[..$javaTargs]) extends $scalaType[..$scalaTargs] {
def apply(..$vParams) = jf.${jfn.name}(..$vParamRefs)
}"""
val j2sValCTree =
q"""class $j2sValCN[..$tdParams](private val underlying: $javaType[..$javaTargs]) extends AnyVal {
@inline def asScala: $scalaType[..$scalaTargs] = new $j2sClassN[..$tnParams](underlying)
}"""
val j2sDefTree =
q"""@inline def $j2sDefN[..$tdParams](jf: $javaType[..$javaTargs]): $scalaType[..$scalaTargs] = new $j2sClassN[..$tnParams](jf)"""
val j2sImpTree =
q"""@inline implicit def $j2sImpN[..$tdParams](jf: $javaType[..$javaTargs]): $j2sValCN[..$tnParams] = new $j2sValCN[..$tnParams](jf)"""
val s2jClassTree =
q"""class $s2jClassN[..$tdParams](sf: $scalaType[..$scalaTargs]) extends $javaType[..$javaTargs] {
def ${jfn.name}(..$vParams) = sf.apply(..$vParamRefs)
}"""
val s2jValCTree =
q"""class $s2jValCN[..$tdParams](private val underlying: $scalaType[..$scalaTargs]) extends AnyVal {
@inline def asJava: $javaType[..$javaTargs] = new $s2jClassN[..$tnParams](underlying)
}"""
val s2jDefTree =
q"""@inline def $s2jDefN[..$tdParams](sf: $scalaType[..$scalaTargs]): $javaType[..$javaTargs] = new $s2jClassN[..$tnParams](sf)"""
// This is especially tricky because functions are contravariant in their arguments
// Need to prevent e.g. Any => String from "downcasting" itself to Int => String; we want the more exact conversion
if (jfn.title == "IntFunction") {
println("jfn.pTypes: " + jfn.pTypes)
println("jfn.pTypes.forall(! _.isFinalType): " + jfn.pTypes.forall(! _.isFinalType))
println("jfn.sig: " + jfn.sig)
println("jfn.sam.typeSignature: " + jfn.sam.typeSignature)
println("jfn.sig == jfn.sam.typeSignature: " + (jfn.sig == jfn.sam.typeSignature))
}
val s2jImpTree: (Tree, Int) =
if (jfn.pTypes.forall(! _.isFinalType) && jfn.sig == jfn.sam.typeSignature)
(
q"""@inline implicit def $s2jImpN[..$tdParams](sf: $scalaType[..$scalaTargs]): $s2jValCN[..$tnParams] = new $s2jValCN[..$tnParams](sf)""",
tdParams.length
)
else {
// Some types are not generic or are re-used; we had better catch those.
// Making up new type names, so switch everything to TypeName or TypeDef
// Instead of foo[A](f: (Int, A) => Long): Fuu[A] = new Foo[A](f)
// we want foo[X, A](f: (X, A) => Long)(implicit evX: Int =:= X): Fuu[A] = new Foo[A](f.asInstanceOf[(Int, A) => Long])
// Instead of bar[A](f: A => A): Brr[A] = new Foo[A](f)
// we want bar[A, B](f: A => B)(implicit evB: A =:= B): Brr[A] = new Foo[A](f.asInstanceOf[A => B])
val An = "A(\\d+)".r
val numberedA = collection.mutable.Set.empty[Int]
val evidences = collection.mutable.ArrayBuffer.empty[(TypeName, TypeName)]
numberedA ++= scalaTargs.map(_.toString).collect{ case An(digits) if (digits.length < 10) => digits.toInt }
val scalafnTnames = (jfn.pTypes :+ jfn.rType).zipWithIndex.map{
case (pt, i) if (i < jfn.pTypes.length && pt.isFinalType) || (!pt.isFinalType && jfn.pTypes.take(i).exists(_ == pt)) =>
val j = Iterator.from(i).dropWhile(numberedA).next()
val genericName = TypeName(s"A$j")
numberedA += j
evidences += ((genericName, pt.typeSymbol.name.toTypeName))
genericName
case (pt, _) => pt.typeSymbol.name.toTypeName
}
val scalafnTdefs = scalafnTnames.
map(TypeDef(NoMods, _, Nil, EmptyTree)).
dropRight(if (jfn.rType.isFinalType) 1 else 0)
val evs = evidences.map{ case (generic, specific) => ValDef(NoMods, TermName("ev"+generic.toString), tq"$generic =:= $specific", EmptyTree) }
val tree =
q"""@inline implicit def $s2jImpN[..$scalafnTdefs](sf: $scalaType[..$scalafnTnames])(implicit ..$evs): $s2jValCN[..$tnParams] =
new $s2jValCN[..$tnParams](sf.asInstanceOf[$scalaType[..$scalaTargs]])
"""
(tree, tdParams.length)
}
SamConversionCode(
base = jfn.title,
wrappedAsScala = j2sClassTree.text,
asScalaAnyVal = j2sValCTree.text,
implicitToScala = j2sImpTree.text,
asScalaDef = j2sDefTree.text,
wrappedAsJava = s2jClassTree.text,
asJavaAnyVal = s2jValCTree.text,
implicitToJava = s2jImpTree match { case (t,d) => Prioritized(t.text, d) },
asJavaDef = s2jDefTree.text
)
}
sams.toSeq.map(generate)
}
lazy val converterContents =
s"""
|$copyright
|
|$packaging
|
|import language.implicitConversions
|
|
|""".stripMargin +
(SamConversionCode(buildWrappersViaReflection: _*) match {
case (impls, defs) => impls.mkString("\n") + "\n\n\n\n" + defs.map(_.mkString("\n")).mkString("\n\n\n\n")
})
def sameText(f: java.io.File, text: String): Boolean = {
val x = scala.io.Source.fromFile(f)
val lines = try { x.getLines.toVector } finally { x.close }
lines.iterator.filter(_.nonBlank) == text.linesIterator.filter(_.nonBlank)
}
def write(f: java.io.File, text: String): Unit = {
if (!f.exists || !sameText(f, text)) {
val p = new java.io.PrintWriter(f)
try { p.println(text) }
finally { p.close() }
}
}
def main(args: Array[String]): Unit = {
val names = args.iterator.map(x => new java.io.File(x))
write(names.next(), converterContents)
}
}