From 575d3d87cffb535e320b1951b303dfcccfc20275 Mon Sep 17 00:00:00 2001 From: Alexander Ioffe Date: Wed, 20 Dec 2023 19:26:38 -0500 Subject: [PATCH 1/3] carry position across ast --- .../context/VerifyFreeVariables.scala | 20 ++- .../io/getquill/quotation/Liftables.scala | 9 +- .../scala/io/getquill/quotation/Parsing.scala | 38 +++-- .../io/getquill/quotation/Unliftables.scala | 8 +- .../io/getquill/util/MacroContextExt.scala | 6 + .../quotation/FreeVariablesSpec.scala | 34 ++--- .../main/scala/io/getquill/MirrorIdiom.scala | 2 +- .../scala/io/getquill/MirrorSqlDialect.scala | 3 +- .../scala/io/getquill/OracleDialect.scala | 4 +- .../src/main/scala/io/getquill/ast/Ast.scala | 45 ++++-- .../main/scala/io/getquill/ast/AstOps.scala | 7 - .../io/getquill/norm/NormalizeReturning.scala | 2 +- .../norm/capture/AvoidAliasConflict.scala | 48 +++---- .../io/getquill/quotation/FreeVariables.scala | 50 ++++--- .../main/scala/io/getquill/sql/SqlQuery.scala | 54 ++++---- .../io/getquill/sql/idiom/SqlIdiom.scala | 6 +- .../getquill/sql/idiom/VerifySqlQuery.scala | 23 +-- .../sql/norm/HideTopLevelFilterAlias.scala | 4 +- .../sql/norm/RemoveUnusedSelects.scala | 4 +- .../sql/norm/SelectPropertyProtractor.scala | 10 +- .../scala/io/getquill/util/Messages.scala | 2 +- .../main/scala/io/getquill/util/Text.scala | 131 ++++++++++++++++++ .../context/orientdb/OrientDBQuerySpec.scala | 4 +- 23 files changed, 349 insertions(+), 165 deletions(-) create mode 100644 quill-engine/src/main/scala/io/getquill/util/Text.scala diff --git a/quill-core/src/main/scala/io/getquill/context/VerifyFreeVariables.scala b/quill-core/src/main/scala/io/getquill/context/VerifyFreeVariables.scala index dd5db73c14..1d4034ba16 100644 --- a/quill-core/src/main/scala/io/getquill/context/VerifyFreeVariables.scala +++ b/quill-core/src/main/scala/io/getquill/context/VerifyFreeVariables.scala @@ -2,14 +2,28 @@ package io.getquill.context import scala.reflect.macros.whitebox.{Context => MacroContext} import io.getquill.quotation.FreeVariables -import io.getquill.ast.Ast +import io.getquill.ast.{Ast, Ident, Pos} import io.getquill.util.MacroContextExt._ object VerifyFreeVariables { - def apply(c: MacroContext)(ast: Ast): Ast = + def apply(c: MacroContext)(ast: Ast): Ast = { + import c.universe.{Ident => _, _} + FreeVariables.verify(ast) match { case Right(ast) => ast - case Left(msg) => c.fail(msg) + case Left(err) => + err.freeVars match { + // we we have a single position from the encosing context in the same file we can actually fail + // at the right position and point the compiler to that location since we can modify the position + // by the `point` info that we have from our position + case List(Ident.WithPos(_, Pos.Real(fileName, _, _, point, _))) if (c.enclosingPosition.source.path == fileName) => + c.failAtPoint(err.msgNoPos, point) + + case _ => + c.fail(err.msg) + } + } + } } diff --git a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala index 7aae532825..c4e39d64ed 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala @@ -124,6 +124,11 @@ trait Liftables extends QuatLiftable { case Visibility.Hidden => q"$pack.Visibility.Hidden" } + implicit val positionLiftable: Liftable[Pos] = Liftable[Pos] { + case Pos.Real(a, b, c, d, e) => q"$pack.Pos.Real($a, $b, $c, $d, $e)" + case Pos.Synthetic => q"$pack.Pos.Synthetic" + } + implicit val queryLiftable: Liftable[Query] = Liftable[Query] { case Entity.Opinionated(a, b, quat, renameable) => q"$pack.Entity.Opinionated($a, $b, $quat, $renameable)" case Filter(a, b, c) => q"$pack.Filter($a, $b, $c)" @@ -206,8 +211,8 @@ trait Liftables extends QuatLiftable { case CaseClass(n, a) => q"$pack.CaseClass($n, $a)" } - implicit val identLiftable: Liftable[Ident] = Liftable[Ident] { case Ident(a, quat) => - q"$pack.Ident($a, $quat)" + implicit val identLiftable: Liftable[Ident] = Liftable[Ident] { case Ident.Opinionated(a, quat, vis, pos) => + q"$pack.Ident.Opinionated($a, $quat, $vis, $pos)" } implicit val externalIdentLiftable: Liftable[ExternalIdent] = Liftable[ExternalIdent] { case ExternalIdent(a, quat) => q"$pack.ExternalIdent($a, $quat)" diff --git a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala index d8b8d822c3..872c236297 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala @@ -27,7 +27,7 @@ trait Parsing extends ValueComputation with QuatMaking with MacroUtilBase { // Variables that need to be sanitized out in various places due to internal conflicts with the way // macros hard handled in MetaDsl - private[getquill] val dangerousVariables: Set[IdentName] = Set(IdentName("v")) + private[getquill] val dangerousVariables: Set[Ident] = Set(Ident.trivial("v")) case class Parser[T](p: PartialFunction[Tree, T])(implicit ct: ClassTag[T]) { @@ -81,14 +81,14 @@ trait Parsing extends ValueComputation with QuatMaking with MacroUtilBase { case q"{..$exprs}" if exprs.size > 1 => Block(exprs.map(astParser(_))) } - val valParser: Parser[Val] = Parser[Val] { case q"val $name: $typ = $body" => + val valParser: Parser[Val] = Parser[Val] { case wholeExpr @ q"val $name: $typ = $body" => // for some reason inferQuat(typ.tpe) causes a compile hang in scala.reflect.internal val bodyAst = astParser(body) - Val(ident(name, bodyAst.quat), bodyAst) + Val(ident(name, bodyAst.quat, wholeExpr.pos), bodyAst) } - val patMatchValParser: Parser[Val] = Parser[Val] { case q"$mods val $name: $typ = ${patMatchParser(value)}" => - Val(ident(name, inferQuat(q"$typ".tpe)), value) + val patMatchValParser: Parser[Val] = Parser[Val] { case wholeExpr @ q"$mods val $name: $typ = ${patMatchParser(value)}" => + Val(ident(name, inferQuat(q"$typ".tpe), wholeExpr.pos), value) } val patMatchParser: Parser[Ast] = Parser[Ast] { case q"$expr match { case ($fields) => $body }" => @@ -462,16 +462,28 @@ trait Parsing extends ValueComputation with QuatMaking with MacroUtilBase { val identParser: Parser[Ident] = Parser[Ident] { // TODO Check to see that all these conditions work case t: ValDef => - identClean(Ident(t.name.decodedName.toString, inferQuat(t.symbol.typeSignature))) - case id @ c.universe.Ident(TermName(name)) => identClean(Ident(name, inferQuat(id.symbol.typeSignature))) - case t @ q"$cls.this.$i" => identClean(Ident(i.decodedName.toString, inferQuat(t.symbol.typeSignature))) + identClean(t.name.decodedName.toString, inferQuat(t.symbol.typeSignature), t.pos) + case id @ c.universe.Ident(TermName(name)) => + identClean(name, inferQuat(id.symbol.typeSignature), id.pos) + case t @ q"$cls.this.$i" => + identClean(i.decodedName.toString, inferQuat(t.symbol.typeSignature), t.pos) case t @ c.universe.Bind(TermName(name), c.universe.Ident(termNames.WILDCARD)) => - identClean( - Ident(name, inferQuat(t.symbol.typeSignature)) - ) // TODO Verify Quat what is the type of this thing? In what cases does it happen? Do we need to do something more clever with the tree and get a TypeRef? + // TODO Verify Quat what is the type of this thing? In what cases does it happen? Do we need to do something more clever with the tree and get a TypeRef? + identClean(name, inferQuat(t.symbol.typeSignature), t.pos) } - private def identClean(x: Ident): Ident = x.copy(name = x.name.replace("$", "")) - private def ident(x: TermName, quat: Quat): Ident = identClean(Ident(x.decodedName.toString, quat)) + private def identClean(name: String, quat: Quat, pos: Position): Ident = + Ident.Opinionated( + name.replace("$", ""), + quat, + Visibility.Visible, + if (pos != NoPosition) + Pos.Real(pos.source.path, pos.line, pos.column, pos.point, 0) + else + Pos.Synthetic + ) + + private def ident(x: TermName, quat: Quat, pos: Position): Ident = + identClean(x.decodedName.toString, quat, pos) /** * In order to guarantee consistent behavior across multiple databases, we diff --git a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala index d194f69f40..6df314cb45 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala @@ -155,6 +155,11 @@ trait Unliftables extends QuatUnliftable { case q"$pack.Visibility.Hidden" => Visibility.Hidden } + implicit val positionUnliftable: Unliftable[Pos] = Unliftable[Pos] { + case q"$pack.Pos.Real.apply(${file: String}, ${row: Int}, ${column: Int}, ${point: Int}, ${width: Int})" => Pos.Real(file, row, column, point, width) + case q"$pack.Pos.Synthetic" => Pos.Synthetic + } + implicit val propertyUnliftable: Unliftable[Property] = Unliftable[Property] { case q"$pack.Property.apply(${a: Ast}, ${b: String})" => Property(a, b) case q"$pack.Property.Opinionated.apply(${a: Ast}, ${b: String}, ${renameable: Renameable}, ${visibility: Visibility})" => @@ -209,7 +214,8 @@ trait Unliftables extends QuatUnliftable { } implicit val identUnliftable: Unliftable[Ident] = Unliftable[Ident] { - case q"$pack.Ident.apply(${a: String}, ${quat: Quat})" => Ident(a, quat) + case q"$pack.Ident.Opinionated.apply(${a: String}, ${quat: Quat}, ${vis: Visibility}, ${pos: Pos})" => + Ident.Opinionated(a, quat, vis, pos) } implicit val externalIdentUnliftable: Unliftable[ExternalIdent] = Unliftable[ExternalIdent] { case q"$pack.ExternalIdent.apply(${a: String}, ${quat: Quat})" => ExternalIdent(a, quat) diff --git a/quill-core/src/main/scala/io/getquill/util/MacroContextExt.scala b/quill-core/src/main/scala/io/getquill/util/MacroContextExt.scala index a5b2bb4ee1..2d7d1a8292 100644 --- a/quill-core/src/main/scala/io/getquill/util/MacroContextExt.scala +++ b/quill-core/src/main/scala/io/getquill/util/MacroContextExt.scala @@ -5,6 +5,7 @@ import io.getquill.util.IndentUtil._ import io.getquill.util.Messages.{debugEnabled, errorPrefix, prettyPrint} import io.getquill.quat.VerifyNoBranches +import scala.reflect.api.Position import scala.reflect.macros.blackbox.{Context => MacroContext} object MacroContextExt { @@ -16,6 +17,11 @@ object MacroContextExt { def error(msg: String): Unit = c.error(c.enclosingPosition, if (errorPrefix) s"[quill] $msg" else msg) + def failAtPoint(msg: String, point: Int): Nothing = { + val errorPos = c.enclosingPosition.withPoint(point) + c.abort(errorPos, if (errorPrefix) s"[quill] $msg" else msg) + } + def fail(msg: String): Nothing = c.abort(c.enclosingPosition, if (errorPrefix) s"[quill] $msg" else msg) diff --git a/quill-core/src/test/scala/io/getquill/quotation/FreeVariablesSpec.scala b/quill-core/src/test/scala/io/getquill/quotation/FreeVariablesSpec.scala index 546000ef78..9f6200f935 100644 --- a/quill-core/src/test/scala/io/getquill/quotation/FreeVariablesSpec.scala +++ b/quill-core/src/test/scala/io/getquill/quotation/FreeVariablesSpec.scala @@ -1,12 +1,12 @@ package io.getquill.quotation -import io.getquill.ast.IdentName import io.getquill.base.Spec import io.getquill.MirrorContexts.testContext.implicitOrd import io.getquill.MirrorContexts.testContext.qr1 import io.getquill.MirrorContexts.testContext.qr2 import io.getquill.MirrorContexts.testContext.quote import io.getquill.MirrorContexts.testContext.unquote +import io.getquill.ast.Ident class FreeVariablesSpec extends Spec { @@ -16,35 +16,35 @@ class FreeVariablesSpec extends Spec { "detects references to values outside of the quotation (free variables)" - { "ident" in { val q = quote(s) - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "function" in { val q = quote { (a: String) => s } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "filter" in { val q = quote { qr1.filter(_.s == s) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "map" in { val q = quote { qr1.map(_ => s) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "flatMap" in { val q = quote { qr1.map(_ => s).flatMap(_ => qr2) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "concatMap" in { val a = Seq(1, 2) @@ -52,41 +52,41 @@ class FreeVariablesSpec extends Spec { quote { qr1.concatMap(_ => a).flatMap(_ => qr2) } - FreeVariables(q.ast) mustEqual Set(IdentName("a")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("a")) } "sortBy" in { val q = quote { qr1.sortBy(_ => s) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "groupBy" in { val q = quote { qr1.groupBy(_ => s) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "take" in { val q = quote { qr1.take(i) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "conditional outer join" in { val q = quote { qr1.leftJoin(qr2).on((a, b) => a.s == s) } - FreeVariables(q.ast) mustEqual Set(IdentName("s")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("s")) } "assignment" in { val q = quote { qr1.insert(_.i -> i) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "join" in { val i = 1 @@ -95,7 +95,7 @@ class FreeVariablesSpec extends Spec { .join(qr2.filter(_.i == i)) .on((t1, t2) => t1.i == t2.i) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "option operators" - { "map" in { @@ -103,14 +103,14 @@ class FreeVariablesSpec extends Spec { val q = quote { qr1.map(_.o.map(_ == i)) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "forall" in { val i = 1 val q = quote { qr1.filter(_.o.forall(_ == i)) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "exists" in { @@ -118,14 +118,14 @@ class FreeVariablesSpec extends Spec { val q = quote { qr1.filter(_.o.exists(_ == i)) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } "contains" in { val i = 1 val q = quote { qr1.filter(_.o.contains(i)) } - FreeVariables(q.ast) mustEqual Set(IdentName("i")) + FreeVariables(q.ast) mustEqual Set(Ident.trivial("i")) } } } diff --git a/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala b/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala index 2921d846df..d7b0ac510d 100644 --- a/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala @@ -270,7 +270,7 @@ trait MirrorIdiomBase extends Idiom { } implicit final val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] { - case Ident.Opinionated(name, _, visibility) => + case Ident.Opinionated(name, _, visibility, _) => stmt"${bracketIfHidden(name, visibility).token}" } diff --git a/quill-engine/src/main/scala/io/getquill/MirrorSqlDialect.scala b/quill-engine/src/main/scala/io/getquill/MirrorSqlDialect.scala index aadf948409..96647e0b67 100644 --- a/quill-engine/src/main/scala/io/getquill/MirrorSqlDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/MirrorSqlDialect.scala @@ -1,5 +1,6 @@ package io.getquill +import io.getquill.ast.Ident import io.getquill.context.sql.idiom.{ConcatSupport, QuestionMarkBindVariables, SqlIdiom} import io.getquill.context._ import io.getquill.norm.ProductAggregationToken @@ -56,7 +57,7 @@ object MirrorSqlDialect extends MirrorSqlDialect { trait StrategizeElements extends SqlIdiom with QuestionMarkBindVariables with ConcatSupport with CanReturnField { override def tokenizeIdentName(strategy: NamingStrategy, name: String): String = strategy.default(name) - override def tokenizeTableAlias(strategy: NamingStrategy, table: String): String = strategy.default(table) + override def tokenizeTableAlias(strategy: NamingStrategy, table: Ident): String = strategy.default(table.name) override def tokenizeColumnAlias(strategy: NamingStrategy, column: String): String = strategy.default(column) override def tokenizeFixedColumn(strategy: NamingStrategy, column: String): String = strategy.default(column) override def prepareForProbing(string: String) = string diff --git a/quill-engine/src/main/scala/io/getquill/OracleDialect.scala b/quill-engine/src/main/scala/io/getquill/OracleDialect.scala index b957110c6f..f670bfd5a4 100644 --- a/quill-engine/src/main/scala/io/getquill/OracleDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/OracleDialect.scala @@ -82,8 +82,8 @@ trait OracleDialect override protected def tokenizeColumnAlias(strategy: NamingStrategy, column: String): String = tokenizeEscapeUnderscores(strategy, column, None) - override protected def tokenizeTableAlias(strategy: NamingStrategy, column: String): String = - tokenizeEscapeUnderscores(strategy, column, None) + override protected def tokenizeTableAlias(strategy: NamingStrategy, tableName: Ident): String = + tokenizeEscapeUnderscores(strategy, tableName.name, None) private def tokenizeEscapeUnderscores( strategy: NamingStrategy, diff --git a/quill-engine/src/main/scala/io/getquill/ast/Ast.scala b/quill-engine/src/main/scala/io/getquill/ast/Ast.scala index f47ec6a82c..a0b1b98efd 100644 --- a/quill-engine/src/main/scala/io/getquill/ast/Ast.scala +++ b/quill-engine/src/main/scala/io/getquill/ast/Ast.scala @@ -318,7 +318,7 @@ final case class Function(params: List[Ident], body: Ast) extends Ast { override def bestQuat: Quat = body.bestQuat } -final class Ident private (val name: String)(theQuat: => Quat)(val visibility: Visibility) extends Terminal with Ast { +final class Ident private (val name: String)(theQuat: => Quat)(val visibility: Visibility, val pos: Pos) extends Terminal with Ast { override lazy val quat: Quat = theQuat override def bestQuat: Quat = quat @@ -333,11 +333,11 @@ final class Ident private (val name: String)(theQuat: => Quat)(val visibility: V override def hashCode: Int = id.hashCode() override def withQuat(quat: => Quat): Ident = - Ident.Opinionated(this.name, quat, this.visibility) + Ident.Opinionated(this.name, quat, this.visibility, this.pos) // need to define a copy which will propagate current value of visibility into the copy def copy(name: String = this.name, quat: => Quat = this.quat): Ident = - Ident.Opinionated(name, quat, this.visibility) + Ident.Opinionated(name, quat, this.visibility, this.pos) } /** @@ -356,14 +356,23 @@ final class Ident private (val name: String)(theQuat: => Quat)(val visibility: V */ object Ident { private final case class Id(name: String) - def apply(name: String, quat: => Quat = Quat.Value) = new Ident(name)(quat)(Visibility.Visible) + def apply(name: String, quat: => Quat = Quat.Value) = new Ident(name)(quat)(Visibility.Visible, Pos.Synthetic) def unapply(p: Ident): Option[(String, Quat)] = Some((p.name, p.quat)) + // Represents an identifier used for temporary purposes (e.g. comparison to symbols) and or for various + // operational reasons. For example VerifySqlQuery needs Ident.trivial("*") to check if there are any + // star-operators that have been created within the query. + def trivial(name: String) = Ident(name, Quat.Unknown) + + object WithPos { + def unapply(id: Ident): Option[(String, Pos)] = Some((id.name, id.pos)) + } + object Opinionated { - def apply(name: String, quatNew: => Quat, visibilityNew: Visibility) = - new Ident(name)(quatNew)(visibilityNew) - def unapply(p: Ident): Option[(String, Quat, Visibility)] = - Some((p.name, p.quat, p.visibility)) + def apply(name: String, quatNew: => Quat, visibilityNew: Visibility, posNew: Pos) = + new Ident(name)(quatNew)(visibilityNew, posNew) + def unapply(p: Ident): Option[(String, Quat, Visibility, Pos)] = + Some((p.name, p.quat, p.visibility, p.pos)) } } @@ -416,6 +425,26 @@ sealed trait OpinionValues[T <: Opinion[T]] { def neutral: T } +sealed trait Pos extends Opinion[Pos] { + def print: String +} +object Pos extends OpinionValues[Pos] { + // Identifier was introduced by the Quill compilation phases + // Notably, the 'point' field is used specifically so that we can offset the scala-compiler to the right Position + // in the VerifyFreeVariables functionality. Scala derives line/column from the SourceFile and `point` data + // so for the sake of scala macros, only `point` is needed. The files `line` and `column` are used if when we build + // up error messages with variable positions and do not have the compiler to help us (e.g. if there are multiple + // places with error locations (e.g. multiple free variables have been found or the error is being thrown at runtime). + case class Real(fileName: String, line: Int, column: Int, point: Int, width: Int = 0) extends Pos { + override def print: String = s"${fileName}:${line}:${column}" + } + case object Synthetic extends Pos { + override def print: String = "" + } + + override def neutral: Pos = Synthetic +} + sealed trait Visibility extends Opinion[Visibility] object Visibility extends OpinionValues[Visibility] { case object Visible extends Visibility with Opinion[Visibility] diff --git a/quill-engine/src/main/scala/io/getquill/ast/AstOps.scala b/quill-engine/src/main/scala/io/getquill/ast/AstOps.scala index 39b5df4c15..dfd505d872 100644 --- a/quill-engine/src/main/scala/io/getquill/ast/AstOps.scala +++ b/quill-engine/src/main/scala/io/getquill/ast/AstOps.scala @@ -1,13 +1,6 @@ package io.getquill.ast -// Represents an Ident without a Quat -case class IdentName(name: String) - object Implicits { - implicit final class IdentOps(private val id: Ident) extends AnyVal { - def idName: IdentName = IdentName(id.name) - } - implicit final class AstOpsExt(private val body: Ast) extends AnyVal { def +||+(other: Ast): BinaryOperation = BinaryOperation(body, BooleanOperator.`||`, other) def +&&+(other: Ast): BinaryOperation = BinaryOperation(body, BooleanOperator.`&&`, other) diff --git a/quill-engine/src/main/scala/io/getquill/norm/NormalizeReturning.scala b/quill-engine/src/main/scala/io/getquill/norm/NormalizeReturning.scala index e35f5fa50b..2edc6673a9 100644 --- a/quill-engine/src/main/scala/io/getquill/norm/NormalizeReturning.scala +++ b/quill-engine/src/main/scala/io/getquill/norm/NormalizeReturning.scala @@ -47,7 +47,7 @@ class NormalizeReturning(normalize: Normalize) { */ private def dealiasBody(body: Ast, alias: Ident): Ast = Transform(body) { case q: Query => - AvoidAliasConflict.sanitizeQuery(q, Set(alias.idName), normalize) + AvoidAliasConflict.sanitizeQuery(q, Set(alias), normalize) } private def apply(e: Action, body: Ast, returningIdent: Ident): Action = e match { diff --git a/quill-engine/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala b/quill-engine/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala index 0ce9cee323..8b1a0607d8 100644 --- a/quill-engine/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala +++ b/quill-engine/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala @@ -44,8 +44,8 @@ import scala.collection.immutable.Set * called once from the top-level inside `SqlNormalize` at the very end of the * transformation pipeline. */ -private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: Boolean, traceConfig: TraceConfig) - extends StatefulTransformer[Set[IdentName]] { +private[getquill] case class AvoidAliasConflict(state: Set[Ident], detemp: Boolean, traceConfig: TraceConfig) + extends StatefulTransformer[Set[Ident]] { val interp = new Interpolator(TraceType.AvoidAliasConflict, traceConfig, 3) import interp._ @@ -82,7 +82,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B private def recurseAndApply[T <: Query]( elem: T - )(ext: T => (Ast, Ident, Ast))(f: (Ast, Ident, Ast) => T): (T, StatefulTransformer[Set[IdentName]]) = + )(ext: T => (Ast, Ident, Ast))(f: (Ast, Ident, Ast) => T): (T, StatefulTransformer[Set[Ident]]) = trace"Uncapture RecurseAndApply $elem ".andReturnIf { val (newElem, newTrans) = super.apply(elem) val ((query, alias, body), state) = @@ -94,10 +94,10 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B BetaReduction(body, alias -> fresh) }(pr => pr != body) - (f(query, fresh, pr), AvoidAliasConflict(state + fresh.idName, detemp, traceConfig)) + (f(query, fresh, pr), AvoidAliasConflict(state + fresh, detemp, traceConfig)) }(_._1 != elem) - private def applyBodies[T <: Query](pairs: List[(Ident, Ast)]): (List[(Ident, Ast)], List[IdentName]) = + private def applyBodies[T <: Query](pairs: List[(Ident, Ast)]): (List[(Ident, Ast)], List[Ident]) = trace"Uncapture ApplyBodies $pairs ".andReturnIf { val newPairs = pairs.map { case (alias, body) => @@ -109,11 +109,11 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B (fresh, newBody) } - val newIdNames = newPairs.map(_._1.idName) + val newIdNames = newPairs.map(_._1) (newPairs, newIdNames) }(_._1 != pairs) - override def apply(qq: Query): (Query, StatefulTransformer[Set[IdentName]]) = + override def apply(qq: Query): (Query, StatefulTransformer[Set[Ident]]) = trace"Uncapture $qq ".andReturnIf { qq match { @@ -181,14 +181,14 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B val (ar, art) = apply(a) val (br, brt) = art.apply(b) val freshA = freshIdent(iA, brt.state) - val freshB = freshIdent(iB, brt.state + freshA.idName) + val freshB = freshIdent(iB, brt.state + freshA) val or = trace"Uncapturing Join: Replace $iA -> $freshA, $iB -> $freshB".andReturnIf { BetaReduction(o, iA -> freshA, iB -> freshB) }(_ != o) val (orr, orrt) = trace"Uncapturing Join: Recurse with state: ${brt.state} + $freshA + $freshB".andReturnIf { - AvoidAliasConflict(brt.state + freshA.idName + freshB.idName, detemp, traceConfig)(or) + AvoidAliasConflict(brt.state + freshA + freshB, detemp, traceConfig)(or) }(_._1 != or) (Join(t, ar, br, freshA, freshB, orr), orrt) @@ -204,7 +204,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B }(_ != o) val (orr, orrt) = trace"Uncapturing FlatJoin: Recurse with state: ${art.state} + $freshA".andReturnIf { - AvoidAliasConflict(art.state + freshA.idName, detemp, traceConfig)(or) + AvoidAliasConflict(art.state + freshA, detemp, traceConfig)(or) }(_._1 != or) (FlatJoin(t, ar, freshA, orr), orrt) @@ -216,7 +216,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B } }(_._1 != qq) - private def apply[Q](x: Ident, p: Ast)(f: (Ident, Ast) => Q): (Q, StatefulTransformer[Set[IdentName]]) = + private def apply[Q](x: Ident, p: Ast)(f: (Ident, Ast) => Q): (Q, StatefulTransformer[Set[Ident]]) = trace"Uncapture Apply ($x, $p)".andReturnIf { val fresh = freshIdent(x) val pr = @@ -225,7 +225,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B }(_ != p) val (prr, t) = trace"Uncapture Apply Recurse".andReturnIf { - AvoidAliasConflict(state + fresh.idName, detemp, traceConfig)(pr) + AvoidAliasConflict(state + fresh, detemp, traceConfig)(pr) }(_._1 != pr) (f(fresh, prr), t) @@ -236,7 +236,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B * variable and the make sure it does not conflict with any other variables of * outer clauses in the AST (freshIdent does that part). */ - private def freshIdent(x: Ident, state: Set[IdentName] = state): Ident = + private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = x match { case TemporaryIdent(tid) if (detemp) => dedupeIdent(Ident("x", tid.quat), state) @@ -244,15 +244,15 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B dedupeIdent(x, state) } - private def dedupeIdent(x: Ident, state: Set[IdentName] = state): Ident = { + private def dedupeIdent(x: Ident, state: Set[Ident] = state): Ident = { def loop(x: Ident, n: Int): Ident = { val fresh = Ident(s"${x.name}$n", x.quat) - if (!state.contains(fresh.idName)) + if (!state.contains(fresh)) fresh else loop(x, n + 1) } - if (!state.contains(x.idName)) + if (!state.contains(x)) x else loop(x, 1) @@ -275,7 +275,7 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B case ((body, state, newParams), param) => { val fresh = freshIdent(param) val pr = BetaReduction(body, param -> fresh) - val (prr, t) = AvoidAliasConflict(state + fresh.idName, false, traceConfig)(pr) + val (prr, t) = AvoidAliasConflict(state + fresh, false, traceConfig)(pr) (prr, t.state, newParams :+ fresh) } } @@ -285,14 +285,14 @@ private[getquill] case class AvoidAliasConflict(state: Set[IdentName], detemp: B private def applyForeach(f: Foreach): Foreach = { val fresh = freshIdent(f.alias) val pr = BetaReduction(f.body, f.alias -> fresh) - val (prr, _) = AvoidAliasConflict(state + fresh.idName, false, traceConfig)(pr) + val (prr, _) = AvoidAliasConflict(state + fresh, false, traceConfig)(pr) Foreach(f.query, fresh, prr) } } private[getquill] class AvoidAliasConflictApply(traceConfig: TraceConfig) { def apply(q: Query, detemp: Boolean = false): Query = - AvoidAliasConflict(Set[IdentName](), detemp, traceConfig)(q) match { + AvoidAliasConflict(Set[Ident](), detemp, traceConfig)(q) match { case (q, _) => q } } @@ -300,12 +300,12 @@ private[getquill] class AvoidAliasConflictApply(traceConfig: TraceConfig) { private[getquill] object AvoidAliasConflict { def Ast(q: Ast, detemp: Boolean = false, traceConfig: TraceConfig): Ast = - new AvoidAliasConflict(Set[IdentName](), detemp, traceConfig)(q) match { + new AvoidAliasConflict(Set[Ident](), detemp, traceConfig)(q) match { case (q, _) => q } def apply(q: Query, detemp: Boolean = false, traceConfig: TraceConfig): Query = - AvoidAliasConflict(Set[IdentName](), detemp, traceConfig)(q) match { + AvoidAliasConflict(Set[Ident](), detemp, traceConfig)(q) match { case (q, _) => q } @@ -314,14 +314,14 @@ private[getquill] object AvoidAliasConflict { * function. Do this by walking through the function's subtree and * transforming and queries encountered. */ - def sanitizeVariables(f: Function, dangerousVariables: Set[IdentName], traceConfig: TraceConfig): Function = + def sanitizeVariables(f: Function, dangerousVariables: Set[Ident], traceConfig: TraceConfig): Function = AvoidAliasConflict(dangerousVariables, false, traceConfig).applyFunction(f) /** Same is `sanitizeVariables` but for Foreach * */ - def sanitizeVariables(f: Foreach, dangerousVariables: Set[IdentName], traceConfig: TraceConfig): Foreach = + def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident], traceConfig: TraceConfig): Foreach = AvoidAliasConflict(dangerousVariables, false, traceConfig).applyForeach(f) - def sanitizeQuery(q: Query, dangerousVariables: Set[IdentName], normalize: Normalize): Query = + def sanitizeQuery(q: Query, dangerousVariables: Set[Ident], normalize: Normalize): Query = AvoidAliasConflict(dangerousVariables, false, normalize.traceConf).apply(q) match { // Propagate aliasing changes to the rest of the query case (q, _) => normalize(q) diff --git a/quill-engine/src/main/scala/io/getquill/quotation/FreeVariables.scala b/quill-engine/src/main/scala/io/getquill/quotation/FreeVariables.scala index 4d640d6de2..781e95f141 100644 --- a/quill-engine/src/main/scala/io/getquill/quotation/FreeVariables.scala +++ b/quill-engine/src/main/scala/io/getquill/quotation/FreeVariables.scala @@ -1,19 +1,20 @@ package io.getquill.quotation import io.getquill.ast._ -import io.getquill.ast.Implicits._ +import io.getquill.util.Text + import collection.immutable.Set -case class State(seen: Set[IdentName], free: Set[IdentName]) +case class State(seen: Set[Ident], free: Set[Ident]) case class FreeVariables(state: State) extends StatefulTransformer[State] { override def apply(ast: Ast): (Ast, StatefulTransformer[State]) = ast match { - case ident: Ident if (!state.seen.contains(ident.idName)) => - (ident, FreeVariables(State(state.seen, state.free + ident.idName))) + case ident: Ident if (!state.seen.contains(ident)) => + (ident, FreeVariables(State(state.seen, state.free + ident))) case f @ Function(params, body) => - val (_, t) = FreeVariables(State(state.seen ++ params.map(_.idName), state.free))(body) + val (_, t) = FreeVariables(State(state.seen ++ params, state.free))(body) (f, FreeVariables(State(state.seen, state.free ++ t.state.free))) case q @ Foreach(a, b, c) => (q, free(a, b, c)) @@ -48,7 +49,7 @@ case class FreeVariables(state: State) extends StatefulTransformer[State] { override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) = e match { case Assignment(a, b, c) => - val t = FreeVariables(State(state.seen + a.idName, state.free)) + val t = FreeVariables(State(state.seen + a, state.free)) val (bt, btt) = t(b) val (ct, ctt) = t(c) (Assignment(a, bt, ct), FreeVariables(State(state.seen, state.free ++ btt.state.free ++ ctt.state.free))) @@ -57,7 +58,7 @@ case class FreeVariables(state: State) extends StatefulTransformer[State] { override def apply(e: AssignmentDual): (AssignmentDual, StatefulTransformer[State]) = e match { case AssignmentDual(a1, a2, b, c) => - val t = FreeVariables(State(state.seen + a1.idName + a2.idName, state.free)) + val t = FreeVariables(State(state.seen + a1 + a2, state.free)) val (bt, btt) = t(b) val (ct, ctt) = t(c) ( @@ -97,44 +98,41 @@ case class FreeVariables(state: State) extends StatefulTransformer[State] { case q @ Join(t, a, b, iA, iB, on) => val (_, freeA) = apply(a) val (_, freeB) = apply(b) - val (_, freeOn) = FreeVariables(State(state.seen + iA.idName + iB.idName, Set.empty))(on) + val (_, freeOn) = FreeVariables(State(state.seen + iA + iB, Set.empty))(on) (q, FreeVariables(State(state.seen, state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free))) case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll | _: Aggregation | _: Distinct | _: Nested => super.apply(query) } - private def free(a: Ast, ident: Ident, c: Ast): FreeVariables = - free(a, ident.idName, c) - - private def free(a: Ast, ident: IdentName, c: Ast) = { + private def free(a: Ast, ident: Ident, c: Ast) = { val (_, ta) = apply(a) val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c) FreeVariables(State(state.seen, state.free ++ ta.state.free ++ tc.state.free)) } } +case class FreeVariableError(freeVars: List[Ident]) extends Exception { + lazy val msg = Text.FreeVariablesExitError(freeVars, true) + // For compile-time flows where the position is already passed down the the compiler do + // showing it again would just cause confusion + lazy val msgNoPos = Text.FreeVariablesExitError(freeVars, false) + + override def getMessage: String = msg +} + object FreeVariables { - def apply(ast: Ast): Set[IdentName] = + def apply(ast: Ast): Set[Ident] = new FreeVariables(State(Set.empty, Set.empty))(ast) match { case (_, transformer) => transformer.state.free } - def verify(ast: Ast): Either[String, Ast] = + def verify(ast: Ast): Either[FreeVariableError, Ast] = apply(ast) match { case free if free.isEmpty => Right(ast) case free => - val firstVar = free.headOption.map(_.name).getOrElse("someVar") - Left( - s""" - |Found the following variables: ${free.map(_.name).toList} that seem to originate outside of a `quote {...}` or `run {...}` block. - |Quotes and run blocks cannot use values outside their scope directly (with the exception of inline expressions in Scala 3). - |In order to use runtime values in a quotation, you need to lift them, so instead - |of this `$firstVar` do this: `lift($firstVar)`. - |Here is a more complete example: - |Instead of this: `def byName(n: String) = quote(query[Person].filter(_.name == n))` - | Do this: `def byName(n: String) = quote(query[Person].filter(_.name == lift(n)))` - """.stripMargin - ) + val error = + FreeVariableError(free.toList) + Left(error) } } diff --git a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala index 61cfc9b85c..1970ae57b4 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala @@ -12,13 +12,13 @@ import io.getquill.sql.Common.ContainsImpurities final case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) sealed trait FromContext { def quat: Quat } -final case class TableContext(entity: Entity, alias: String) extends FromContext { +final case class TableContext(entity: Entity, alias: Ident) extends FromContext { override def quat: Quat = entity.quat } -final case class QueryContext(query: SqlQuery, alias: String) extends FromContext { +final case class QueryContext(query: SqlQuery, alias: Ident) extends FromContext { override def quat: Quat = query.quat } -final case class InfixContext(infix: Infix, alias: String) extends FromContext { override def quat: Quat = infix.quat } +final case class InfixContext(infix: Infix, alias: Ident) extends FromContext { override def quat: Quat = infix.quat } final case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast) extends FromContext { override def quat: Quat = Quat.Tuple(a.quat, b.quat) } @@ -133,15 +133,15 @@ class SqlQueryApply(traceConfig: TraceConfig) { } case TakeDropFlatten(q, limit, offset) => trace"Construct SqlQuery from: TakeDropFlatten" andReturn { - flatten(q, "x").copy(limit = limit, offset = offset)(q.quat) + flatten(q, Ident("x", q.quat)).copy(limit = limit, offset = offset)(q.quat) } case q: Query => trace"Construct SqlQuery from: Query" andReturn { - flatten(q, "x") + flatten(q, Ident("x", q.quat)) } case infix: Infix => trace"Construct SqlQuery from: Infix" andReturn { - flatten(infix, "x") + flatten(infix, Ident("x", infix.quat)) } case other => trace"Construct SqlQuery from: other" andReturn { @@ -149,7 +149,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { } } - private def flatten(query: Ast, alias: String): FlattenSqlQuery = + private def flatten(query: Ast, alias: Ident): FlattenSqlQuery = trace"Flattening: ${query}" andReturn { val (sources, finalFlatMapBody) = flattenContexts(query) flatten(sources, finalFlatMapBody, alias, nestNextMap = false) @@ -163,7 +163,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { val cc = CaseClassMake.fromQuat(flatJoin.quat)(name) flattenContexts(FlatMap(q, id, Map(flatJoin, alias, cc))) } - case FlatMap(q @ (_: Query | _: Infix), Ident(alias, _), p: Query) => + case FlatMap(q @ (_: Query | _: Infix), alias: Ident, p: Query) => trace"Flattening Flatmap with Query" andReturn { val source = this.source(q, alias) val (nestedContexts, finalFlatMapBody) = flattenContexts(p) @@ -182,13 +182,13 @@ class SqlQueryApply(traceConfig: TraceConfig) { private def flatten( sources: List[FromContext], finalFlatMapBody: Ast, - alias: String, + alias: Ident, nestNextMap: Boolean ): FlattenSqlQuery = { - def select(alias: String, quat: Quat): List[SelectValue] = SelectValue(Ident(alias, quat), None) :: Nil + def select(alias: Ident, quat: Quat): List[SelectValue] = SelectValue(alias, None) :: Nil - def base(q: Ast, alias: String, nestNextMap: Boolean): FlattenSqlQuery = + def base(q: Ast, alias: Ident, nestNextMap: Boolean): FlattenSqlQuery = trace"Computing Base (nestingMaps=${nestNextMap}) for Query: $q" andReturn { def nest(ctx: FromContext): FlattenSqlQuery = trace"Computing FlattenSqlQuery for: $ctx" andReturn { FlattenSqlQuery(from = sources :+ ctx, select = select(alias, q.quat))(q.quat) @@ -218,7 +218,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { case Join(tpe, a, b, iA, iB, on) => trace"base| Collecting join aliases $q" andReturn { val ctx = source(q, alias) - def aliases(ctx: FromContext): List[(String, Quat)] = + def aliases(ctx: FromContext): List[(Ident, Quat)] = ctx match { case TableContext(_, alias) => (alias, ctx.quat) :: Nil case QueryContext(_, alias) => (alias, ctx.quat) :: Nil @@ -226,7 +226,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b) case FlatJoinContext(_, a, _) => aliases(a) } - val collectedAliases = aliases(ctx).map { case (a, quat) => Ident(a, quat) } + val collectedAliases = aliases(ctx).map { case (a, quat) => a } val select = Tuple(collectedAliases) FlattenSqlQuery( from = ctx :: Nil, @@ -247,7 +247,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { trace"Flattening (alias = $alias) sources $sources from $finalFlatMapBody" andReturn { finalFlatMapBody match { - case ConcatMap(q, Ident(alias, _), p) => + case ConcatMap(q, alias: Ident, p) => trace"Flattening| ConcatMap" andReturn { FlattenSqlQuery( from = source(q, alias) :: Nil, @@ -268,7 +268,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { // Map(GroupBy(people,p=>p.name),a:(_1:name,_2:people) => p:(_1,MAX(_2.map(_.age))) // more concretely: // Map(GroupBy(q:people,x:p,g:p.name),a:(_1:name,_2:people), p:(_1,MAX(_2.map(_.age))) - case Map(GroupBy(q, x @ Ident(alias, _), g), a, p) => + case Map(GroupBy(q, x: Ident, g), a, p) => trace"Flattening| Map(GroupBy)" andReturn { // In the case that we have a map-to a Product before a GroupBy, we need to have a sub-nesting first. @@ -293,7 +293,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { // We fixed this particular case by the `case Map(_, _, ContainsImpurities()) =>` clause which will nest the map clause first leading to the correct query // but there are other potential cases that are not covered by that. As a fallback we need to forcibly nest the inner clause of this groupBy // if it is a map. That is what the `nestNextMap=true` argument does to `base` does - val b = base(q, alias, nestNextMap = true) + val b = base(q, x, nestNextMap = true) // use ExpandSelection logic to break down OrderBy clause // In the case that GroupBy(people,p=>p) make it into: GroupBy(people,p=> List(p.name,p.age) /*return this*/ ) @@ -324,9 +324,9 @@ class SqlQueryApply(traceConfig: TraceConfig) { // GroupByMap(people,p=>p.name)(a:Person => p:(a.name,MAX(a.age))) // more concretely: // GroupBy(q:people,x:p,g:p.name)(a:Person, p:(a.name,MAX(a.age))) - case GroupByMap(q, x @ Ident(alias, _), g, a, p) => + case GroupByMap(q, x: Ident, g, a, p) => trace"Flattening| GroupByMap" andReturn { - val b = base(q, alias, nestNextMap = true) + val b = base(q, x, nestNextMap = true) // Same as ExpandSelection in Map(GroupBy) val flatGroupByAsts = new ExpandSelection(b.from).ofSubselect(List(SelectValue(g))).map(_.ast) val groupByClause = @@ -344,7 +344,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { b.copy(groupBy = Some(groupByClause), select = this.selectValues(realiasedSelect))(quat) } - case Map(q, Ident(alias, _), p) => + case Map(q, alias: Ident, p) => val b = base(q, alias, nestNextMap = false) val agg = b.select.collect { case s @ SelectValue(_: Aggregation, _, _) => s @@ -359,7 +359,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { select = selectValues(p) )(quat) - case Filter(q, Ident(alias, _), p) => + case Filter(q, alias: Ident, p) => // If it's a filter, pass on the value of nestNextMap in case there is a future map we need to nest val b = base(q, alias, nestNextMap) // If the filter body uses the filter alias, make sure it matches one of the aliases in the fromContexts @@ -377,7 +377,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { select = select(alias, quat) )(quat) - case SortBy(q, Ident(alias, _), p, o) => + case SortBy(q, alias: Ident, p, o) => val b = base(q, alias, nestNextMap = false) val criteria = orderByCriteria(p, o, b.from) // If the sortBy body uses the filter alias, make sure it matches one of the aliases in the fromContexts @@ -453,7 +453,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { trace"Flattening| Distinct" andReturn b.copy(distinct = DistinctKind.Distinct)(quat) - case DistinctOn(q, Ident(alias, _), fields) => + case DistinctOn(q, alias: Ident, fields) => val distinctList = fields match { case Tuple(values) => values @@ -496,12 +496,12 @@ class SqlQueryApply(traceConfig: TraceConfig) { case _ => SelectValue(ast) :: Nil } - private def source(ast: Ast, alias: String): FromContext = + private def source(ast: Ast, alias: Ident): FromContext = ast match { case entity: Entity => TableContext(entity, alias) case infix: Infix => InfixContext(infix, alias) - case Join(t, a, b, ia, ib, on) => JoinContext(t, source(a, ia.name), source(b, ib.name), on) - case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on) + case Join(t, a, b, ia, ib, on) => JoinContext(t, source(a, ia), source(b, ib), on) + case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia), on) case Nested(q) => QueryContext(apply(q), alias) case other => QueryContext(apply(other), alias) } @@ -518,7 +518,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { case _ => fail(s"Invalid order by criteria $ast") } - private def collectAliases(contexts: List[FromContext]): List[String] = + private def collectAliases(contexts: List[FromContext]): List[Ident] = contexts.flatMap { case c: TableContext => List(c.alias) case c: QueryContext => List(c.alias) @@ -527,7 +527,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { case FlatJoinContext(_, from, _) => collectAliases(List(from)) } - private def collectTableAliases(contexts: List[FromContext]): List[String] = + private def collectTableAliases(contexts: List[FromContext]): List[Ident] = contexts.flatMap { case c: TableContext => List(c.alias) case _: QueryContext => List.empty diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala index 9f6d4af794..da8afd94c6 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala @@ -277,8 +277,8 @@ trait SqlIdiom extends Idiom { protected def tokenizeFixedColumn(strategy: NamingStrategy, column: String): String = column - protected def tokenizeTableAlias(strategy: NamingStrategy, table: String): String = - table + protected def tokenizeTableAlias(strategy: NamingStrategy, table: Ident): String = + table.name protected def tokenizeIdentName(strategy: NamingStrategy, name: String): String = name @@ -498,7 +498,7 @@ trait SqlIdiom extends Idiom { stmt"${actionAlias.map(alias => stmt"${scopedTokenizer(alias)}.").getOrElse(emptyStatement)}${TokenizeProperty(name, prefix, strategy, renameable)}" // In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info. - case (Ident.Opinionated(_, _, Hidden), prefix) => + case (Ident.Opinionated(_, _, Hidden, _), prefix) => stmt"${TokenizeProperty(name, prefix, strategy, renameable)}" // The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala index 3c33064475..bb138d0628 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala @@ -4,22 +4,11 @@ import io.getquill.ast._ import io.getquill.context.sql._ import io.getquill.quotation.FreeVariables import io.getquill.quat.Quat +import io.getquill.util.Text case class Error(free: List[Ident], ast: Ast) case class InvalidSqlQuery(errors: List[Error]) { - override def toString = { - val allVars = errors.flatMap(_.free).distinct - val firstVar = errors.headOption.flatMap(_.free.headOption).getOrElse("someVar") - s""" - |When synthesizing Joins, Quill found some variables that could not be traced back to their - |origin: ${allVars.map(_.name)}. Typically this happens when there are some flatMapped - |clauses that are missing data once they are flattened. - |Sometimes this is the result of a internal error in Quill. If that is the case, please - |reach out on our discord channel https://discord.gg/2ccFBr4 and/or file an issue - |on https://github.com/zio/zio-quill. - |""".stripMargin + - errors.map(error => s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'.").mkString(",\n") - } + override def toString = Text.JoinSynthesisError(errors) } object VerifySqlQuery { @@ -36,7 +25,7 @@ object VerifySqlQuery { private def verifyFlatJoins(q: FlattenSqlQuery) = { - def loop(l: List[FromContext], available: Set[String]): Set[String] = + def loop(l: List[FromContext], available: Set[Ident]): Set[Ident] = l.foldLeft(available) { case (av, TableContext(_, alias)) => Set(alias) case (av, InfixContext(_, alias)) => Set(alias) @@ -45,7 +34,7 @@ object VerifySqlQuery { av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av) case (av, FlatJoinContext(_, a, on)) => val nav = av ++ loop(a :: Nil, av) - val free = FreeVariables(on).map(_.name) + val free = FreeVariables(on) val invalid = free -- nav require( invalid.isEmpty, @@ -61,7 +50,7 @@ object VerifySqlQuery { verifyFlatJoins(query) - val aliases = query.from.flatMap(this.aliases).map(IdentName(_)) :+ IdentName("*") :+ IdentName("?") + val aliases = query.from.flatMap(this.aliases) :+ Ident.trivial("*") :+ Ident.trivial("?") def verifyAst(ast: Ast) = { val freeVariables = @@ -111,7 +100,7 @@ object VerifySqlQuery { } } - private def aliases(s: FromContext): List[String] = + private def aliases(s: FromContext): List[Ident] = s match { case s: TableContext => List(s.alias) case s: QueryContext => List(s.alias) diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/HideTopLevelFilterAlias.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/HideTopLevelFilterAlias.scala index 5c1c02387c..3ce0e0af16 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/HideTopLevelFilterAlias.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/HideTopLevelFilterAlias.scala @@ -20,13 +20,13 @@ import io.getquill.norm.BetaReduction // but with output causes the alias becomes OUTPUT so it can be different in those cases. object HideTopLevelFilterAlias extends StatelessTransformer { def hideAlias(alias: Ident, in: Ast) = { - val newAlias = Ident.Opinionated(alias.name, alias.quat, Visibility.Hidden) + val newAlias = Ident.Opinionated(alias.name, alias.quat, Visibility.Hidden, alias.pos) (newAlias, BetaReduction(in, alias -> newAlias)) } def hideAssignmentAlias(assignment: Assignment) = { val alias = assignment.alias - val newAlias = Ident.Opinionated(alias.name, alias.quat, Visibility.Hidden) + val newAlias = Ident.Opinionated(alias.name, alias.quat, Visibility.Hidden, alias.pos) val newValue = BetaReduction(assignment.value, alias -> newAlias) val newProperty = BetaReduction(assignment.property, alias -> newAlias) val newAssignment = Assignment(newAlias, newProperty, newValue) diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveUnusedSelects.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveUnusedSelects.scala index dcdb7adc25..a78335e8a5 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveUnusedSelects.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveUnusedSelects.scala @@ -98,8 +98,8 @@ object RemoveUnusedSelects { case _: TableContext | _: InfixContext => (s, new mutable.LinkedHashSet[Property]()) } - private def references(alias: String, asts: List[Ast]) = - LinkedHashSet.empty ++ (References(State(Ident(alias, Quat.Value), Nil))(asts)(_.apply)._2.state.references) + private def references(alias: Ident, asts: List[Ast]) = + LinkedHashSet.empty ++ (References(State(alias, Nil))(asts)(_.apply)._2.state.references) } case class State(ident: Ident, references: List[Property]) diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/SelectPropertyProtractor.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/SelectPropertyProtractor.scala index 7439d9d4dd..434df3dfd4 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/SelectPropertyProtractor.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/SelectPropertyProtractor.scala @@ -52,20 +52,20 @@ case class InContext(from: List[FromContext]) { def contextReferenceType(ast: Ast) = { val references = collectTableAliases(from) ast match { - case Ident(v, _) => references.get(v) - case PropertyMatryoshka(Ident(v, _), _, _) => references.get(v) - case _ => None + case id: Ident => references.get(id) + case PropertyMatryoshka(id: Ident, _, _) => references.get(id) + case _ => None } } - private def collectTableAliases(contexts: List[FromContext]): Map[String, InContextType] = + private def collectTableAliases(contexts: List[FromContext]): Map[Ident, InContextType] = contexts.map { case c: TableContext => Map(c.alias -> InTableContext) case c: QueryContext => Map(c.alias -> InQueryContext) case c: InfixContext => Map(c.alias -> InInfixContext) case JoinContext(_, a, b, _) => collectTableAliases(List(a)) ++ collectTableAliases(List(b)) case FlatJoinContext(_, from, _) => collectTableAliases(List(from)) - }.foldLeft(Map[String, InContextType]())(_ ++ _) + }.foldLeft(Map[Ident, InContextType]())(_ ++ _) } object InContext { sealed trait InContextType diff --git a/quill-engine/src/main/scala/io/getquill/util/Messages.scala b/quill-engine/src/main/scala/io/getquill/util/Messages.scala index 32b7098bba..5b112c10c3 100644 --- a/quill-engine/src/main/scala/io/getquill/util/Messages.scala +++ b/quill-engine/src/main/scala/io/getquill/util/Messages.scala @@ -132,7 +132,7 @@ object Messages { // implicit val t = new EnableTrace { override type Trace = TraceType.Normalizations :: HNil } // Otherwise it would have to be override type Trace = TraceType.Normalizations.type - // Specifically for situations where what needs to be printed is a type of warning to the user as opposed to an expansion + // Specifically for situations where what needs to be printed is a type of warning to the user as opposOed to an expansion // This kind of trace is always on by default and does not need to be enabled by the user. sealed trait Warning extends TraceType { val value = "warning" } sealed trait SqlNormalizations extends TraceType { val value = "sql" } diff --git a/quill-engine/src/main/scala/io/getquill/util/Text.scala b/quill-engine/src/main/scala/io/getquill/util/Text.scala new file mode 100644 index 0000000000..9d5f325f65 --- /dev/null +++ b/quill-engine/src/main/scala/io/getquill/util/Text.scala @@ -0,0 +1,131 @@ +package io.getquill.util + +import io.getquill.ast.{Ast, Ident, Pos} +import io.getquill.context.sql.idiom.Error + +// Intentionally put all comments in 1st line. Want this to be a place +// where example code can be put +// format: off + +object Text { + +implicit class StringExt(str: String) { + def trimLeft = str.dropWhile(_.isWhitespace) +} + +implicit class SeqExt[T](seq: Seq[T]) { + def sortByVariant[R](pf: PartialFunction[T, R])(implicit ord: Ordering[R]) = { + seq.filter(v => pf.isDefinedAt(v)).sortBy(v => pf(v)) + } +} + + +def JoinSynthesisError(errors: List[Error]) = + errors match { + case List(Error(List(id), ast)) => joinSynthesisErrorSingle(id, ast) + case _ => joinSynthesisErrorMulti(errors) + } + + +// make a special printout case for a single error since this is what happens 95% of the time +private def joinSynthesisErrorSingle(id: Ident, ast: Ast) = +s""" +When synthesizing Joins, Quill found a variable that could not be traced back to its origin: ${id.name} +originally at: ${id.pos} + +with the following faulty expression: +${ast} + +${joinSynthesisExplanation} +""".trimLeft + + +private def joinSynthesisErrorMulti(errors: List[Error]) = { +val allVars = errors.flatMap(_.free).distinct +val firstVar = errors.headOption.flatMap(_.free.headOption).getOrElse("someVar") + +def printError(error: Error) = +s""" +====== Faulty Expression ====== +${error.ast} +Variables: +${error.free.map(id => s"${id.name} - ${id.pos.print}").mkString("\n")} +""".trimLeft + +s""" +When synthesizing Joins, Quill found some variables that could not be traced back to their +origin: ${allVars.map(_.name)}. + +${joinSynthesisExplanation} +""".trimLeft + +errors.map(printError(_)).mkString(",\n") +} + + +private lazy val joinSynthesisExplanation = +s""" +Typically this happens when there are some flatMapped +clauses that are missing data once they are flattened. +Sometimes this is the result of a internal error in Quill. If that is the case, please +reach out on our discord channel https://discord.gg/2ccFBr4 and/or file an issue +on https://github.com/zio/zio-quill. +""".trimLeft + +// ========================================= FreeVariablesExitError ========================================= + +def FreeVariablesExitError(freeVars: Seq[Ident], showPos: Boolean = true): String = + if (freeVars.size == 1) + freeVariablesSingle(freeVars.head, showPos) + else + freeVariablesMulti(freeVars) + +// Most of the time there are free varaibles it's just one so make a specific message optimizing for that +// if we're in a compile-time flow we don't need to show the position because it will be conveyed directly to the compiler. +private def freeVariablesSingle(freeVar: Ident, showPos: Boolean) = +s""" +Found the following variable: ${freeVar} that originates outside of a `quote {...}` or `run {...}` block. +${if (showPos) s"Here: ${freeVar.pos.print}\n" else ""} +${freeVariablesExplanation(freeVar.name)} +""".trimLeft + +private def freeVariablesMulti(freeVarsUnordered: Seq[Ident]) = { + val knowPosVars = + freeVarsUnordered.sortByVariant { + case value @ Ident.WithPos(_, Pos.Real(file, line, col, _, _)) => (file, line, col) + } + val unknownPosVars = + freeVarsUnordered.sortByVariant { + case Ident.WithPos(name, Pos.Synthetic) => name + } + val allVars = knowPosVars ++ unknownPosVars + val free = allVars.map(_.name) + val firstVar = free.headOption.getOrElse("x") + val locations = + if (knowPosVars.nonEmpty) { + knowPosVars.map(v => s" ${v.name} - ${v.pos.print}").mkString("\n") + "\n" + } else + "" + +s""" +Found the following variables: ${free} that seem to originate outside of a `quote {...}` or `run {...}` block. +${locations} +${freeVariablesExplanation(firstVar)} +""".trimLeft + + +} + +private def freeVariablesExplanation(varExample: String) = +s""" +Quotes and run blocks cannot use values outside their scope directly (with the exception of inline expressions in Scala 3). +In order to use runtime values in a quotation, you need to lift them, so instead +of this `$varExample` do this: `lift($varExample)`. +Here is a more complete example: +Instead of this: `def byName(n: String) = quote(query[Person].filter(_.name == n))` + Do this: `def byName(n: String) = quote(query[Person].filter(_.name == lift(n)))` +} +""".trimLeft + + +} +// format: on diff --git a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala index 731e5b56b3..b36237d318 100644 --- a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala +++ b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala @@ -215,10 +215,10 @@ class OrientDBQuerySpec extends Spec { t.token(e.copy(select = List(x, x), distinct = DistinctKind.Distinct)(Quat.Value)) ).getMessage mustBe "OrientDB DISTINCT with multiple columns is not supported" - val tb = TableContext(Entity("tb", Nil, QEP), "x1") + val tb = TableContext(Entity("tb", Nil, QEP), Ident.trivial("x1")) t.token(e.copy(from = List(tb, tb))(Quat.Value)) mustBe stmt"SELECT * FROM tb" - val jn = FlatJoinContext(InnerJoin, tb.copy(alias = "x2"), Ident("x")) + val jn = FlatJoinContext(InnerJoin, tb.copy(alias = Ident.trivial("x2")), Ident("x")) intercept[IllegalStateException](t.token(e.copy(from = List(tb, jn))(Quat.Value))) t.token( From abb7e08664fe46127858351e4f6bd2c5a493b9f0 Mon Sep 17 00:00:00 2001 From: Alexander Ioffe Date: Wed, 3 Jul 2024 18:59:58 -0400 Subject: [PATCH 2/3] Replace SheathLeafClauses with simpler SqlQuery-oriented transforms --- .../main/scala/io/getquill/sql/SqlQuery.scala | 37 ++++++++- .../io/getquill/sql/idiom/SqlIdiom.scala | 6 +- .../io/getquill/sql/norm/ExpandDistinct.scala | 15 ++-- .../getquill/sql/norm/RemoveExtraAlias.scala | 80 ++++++++++++++++++- .../io/getquill/sql/norm/SqlNormalize.scala | 4 +- 5 files changed, 128 insertions(+), 14 deletions(-) diff --git a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala index 1970ae57b4..5e15f4bf44 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala @@ -11,14 +11,27 @@ import io.getquill.sql.Common.ContainsImpurities final case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) -sealed trait FromContext { def quat: Quat } +sealed trait FromContext { + def quat: Quat + def mapAst(f: Ast => Ast): FromContext = this match { + case c: TableContext => c + case QueryContext(query, alias) => QueryContext(query.mapAst(f), alias) + case c: InfixContext => c.mapAsts(f) + case JoinContext(t, a, b, on) => JoinContext(t, a.mapAst(f), b.mapAst(f), f(on)) + case FlatJoinContext(t, a, on) => FlatJoinContext(t, a.mapAst(f), f(on)) + } +} final case class TableContext(entity: Entity, alias: Ident) extends FromContext { override def quat: Quat = entity.quat } final case class QueryContext(query: SqlQuery, alias: Ident) extends FromContext { override def quat: Quat = query.quat } -final case class InfixContext(infix: Infix, alias: Ident) extends FromContext { override def quat: Quat = infix.quat } +final case class InfixContext(infix: Infix, alias: Ident) extends FromContext { + override def quat: Quat = infix.quat + def mapAsts(f: Ast => Ast): InfixContext = + copy(infix = infix.copy(params = infix.params.map(f))) +} final case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast) extends FromContext { override def quat: Quat = Quat.Tuple(a.quat, b.quat) } @@ -29,6 +42,16 @@ final case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast) extends F sealed trait SqlQuery { def quat: Quat + def mapAst(f: Ast => Ast): SqlQuery = + this match { + case flatten: FlattenSqlQuery => + flatten.mapAsts(f) + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(a.mapAst(f), op, b.mapAst(f))(quat) + case UnaryOperationSqlQuery(op, q) => + UnaryOperationSqlQuery(op, q.mapAst(f))(quat) + } + override def toString: String = { import io.getquill.MirrorSqlDialect._ import io.getquill.idiom.StatementInterpolator._ @@ -83,6 +106,16 @@ final case class FlattenSqlQuery( )(quatType: Quat) extends SqlQuery { override def quat: Quat = quatType + + def mapAsts(f: Ast => Ast): FlattenSqlQuery = + copy( + where = where.map(f), + groupBy = groupBy.map(f), + orderBy = orderBy.map(o => o.copy(ast = f(o.ast))), + limit = limit.map(f), + offset = offset.map(f), + select = select.map(s => s.copy(ast = f(s.ast))) + )(quatType) } object TakeDropFlatten { diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala index da8afd94c6..9eb9c13290 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala @@ -17,7 +17,7 @@ import io.getquill.norm.ConcatBehavior.AnsiConcat import io.getquill.norm.EqualityBehavior.AnsiEquality import io.getquill.norm.{ConcatBehavior, EqualityBehavior, ExpandReturning, NormalizeCaching, ProductAggregationToken} import io.getquill.quat.Quat -import io.getquill.sql.norm.{HideTopLevelFilterAlias, NormalizeFilteredActionAliases, RemoveExtraAlias, RemoveUnusedSelects} +import io.getquill.sql.norm.{HideTopLevelFilterAlias, NormalizeFilteredActionAliases, RemoveExtraAlias, RemoveUnusedSelects, ValueizeSingleLeafSelects} import io.getquill.util.{Interleave, Interpolator, Messages, TraceConfig} import io.getquill.util.Messages.{TraceType, fail, trace} @@ -82,7 +82,9 @@ trait SqlIdiom extends Idiom { val sql = querifyAst(q, idiomContext.traceConfig) trace"SQL: ${sql}".andLog() VerifySqlQuery(sql).map(fail) - val expanded = ExpandNestedQueries(sql, topLevelQuat) + val valueized = ValueizeSingleLeafSelects(naming)(sql, topLevelQuat) + trace"Valueized SQL: ${valueized}".andLog() + val expanded = ExpandNestedQueries(valueized, topLevelQuat) trace"Expanded SQL: ${expanded}".andLog() val refined = if (Messages.pruneColumns) RemoveUnusedSelects(expanded) else expanded trace"Filtered SQL (only used selects): ${refined}".andLog() diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/ExpandDistinct.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/ExpandDistinct.scala index 818c11bfd0..24b4b83402 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/ExpandDistinct.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/ExpandDistinct.scala @@ -77,12 +77,15 @@ class ExpandDistinct(traceConfig: TraceConfig) { // Problems with distinct were first discovered in #1032. Basically, unless // the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places // such as "ORDER BY tableIdent" etc... - case Distinct(Map(q, x, p)) => - val newMap = Map(q, x, Tuple(List(p))) - val newQuat = Quat.Tuple(valueQuat(p.quat)) // force quat recomputation for perf purposes - val newIdent = Ident(x.name, newQuat) - trace"ExpandDistinct Distinct(Map(other))" andReturn - Map(Distinct(newMap), newIdent, Property(newIdent, "_1")) + + // TODO EXPERIMENTING WITH THIS CLAUSE, TRY TO DISABLE`` + + // case Distinct(Map(q, x, p)) => + // val newMap = Map(q, x, Tuple(List(p))) + // val newQuat = Quat.Tuple(valueQuat(p.quat)) // force quat recomputation for perf purposes + // val newIdent = Ident(x.name, newQuat) + // trace"ExpandDistinct Distinct(Map(other))" andReturn + // Map(Distinct(newMap), newIdent, Property(newIdent, "_1")) } } } diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveExtraAlias.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveExtraAlias.scala index 95d4f4cfd0..22729e967b 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveExtraAlias.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/RemoveExtraAlias.scala @@ -1,8 +1,84 @@ package io.getquill.sql.norm import io.getquill.NamingStrategy -import io.getquill.ast.{Property, Renameable} -import io.getquill.context.sql.{FlattenSqlQuery, SelectValue} +import io.getquill.ast.Ast.LeafQuat +import io.getquill.ast.{Ast, CollectAst, Ident, Property, Renameable} +import io.getquill.context.sql.{FlatJoinContext, FlattenSqlQuery, FromContext, InfixContext, JoinContext, QueryContext, SelectValue, TableContext} +import io.getquill.norm.{BetaReduction, TypeBehavior} +import io.getquill.quat.Quat + +// If we run this right after SqlQuery we know that in every place with a single select-value it is a leaf clause e.g. `SELECT x FROM (SELECT p.name from Person p)) AS x` +// in that case we know that SelectValue(x) is a leaf clause that we should expand into a `x.value`. +// MAKE SURE THIS RUNS BEFORE ExpandNestedQueries otherwise it will be incorrect, it should only run for single-selects from atomic values, +// if the ExpandNestedQueries ran it could be a single field that is coming from a case class e.g. case class MySingleValue(stuff: Int) that is being selected from +case class ValueizeSingleLeafSelects(strategy: NamingStrategy) extends StatelessQueryTransformer { + protected def productize(ast: Ident) = + Ident(ast.name, Quat.Product("", "value" -> Quat.Value)) + + protected def valueize(ast: Ident) = + Property(productize(ast), "value") + + // Turn every `SELECT primitive-x` into a `SELECT case-class-x.primitive-value` + override protected def expandNested(q: FlattenSqlQuery, level: QueryLevel): FlattenSqlQuery = { + // get the alises before we transform (i.e. Valueize) the contexts inside turning the leaf-quat alises into product-quat alises + val leafValuedFroms = collectAliases(q.from).filter(!_.quat.isProduct) + // now transform the inner clauses + val from = q.from.map(expandContext(_)) + + def containsAlias(ast: Ast): Boolean = + CollectAst.byType[Ident](ast).exists(id => leafValuedFroms.contains(id)) + + // If there is one single select clause that has a primitive (i.e. Leaf) quat then we can alias it to "value" + // This is the case of `SELECT primitive FROM (SELECT p.age from Person p) AS primitive` + // where we turn it into `SELECT p.name AS value FROM Person p` + def aliasSelects(selectValues: List[SelectValue]) = + selectValues match { + case List(sv @ SelectValue(LeafQuat(ast), _, _)) => List(sv.copy(alias = Some("value"))) + case other => other + } + + val valuizedQuery = + q.copy(from = from)(q.quat).mapAsts { ast => + if (containsAlias(ast)) { + val reductions = CollectAst.byType[Ident](ast).filter(id => leafValuedFroms.contains(id)).map(id => id -> valueize(id)) + BetaReduction(ast, TypeBehavior.ReplaceWithReduction, reductions: _*) + } else { + ast + } + } + + valuizedQuery.copy(select = aliasSelects(valuizedQuery.select))(q.quat) + } + + // Turn every `FROM primitive-x` into a `FROM case-class(x.primitive)` + override protected def expandContext(s: FromContext): FromContext = + super.expandContext(s) match { + case QueryContext(query, LeafQuat(id: Ident)) => + QueryContext(query, productize(id)) + case other => + other + } + + // protected def expandContext(s: FromContext): FromContext = + // s match { + // case QueryContext(q, alias) => + // QueryContext(apply(q, QueryLevel.Inner), alias) + // case JoinContext(t, a, b, on) => + // JoinContext(t, expandContext(a), expandContext(b), on) + // case FlatJoinContext(t, a, on) => + // FlatJoinContext(t, expandContext(a), on) + // case _: TableContext | _: InfixContext => s + // } + + private def collectAliases(contexts: List[FromContext]): List[Ident] = + contexts.flatMap { + case c: TableContext => List(c.alias) + case c: QueryContext => List(c.alias) + case c: InfixContext => List(c.alias) + case JoinContext(_, a, b, _) => collectAliases(List(a)) ++ collectAliases(List(b)) + case FlatJoinContext(_, from, _) => collectAliases(List(from)) + } +} /** * Remove aliases at the top level of the AST since they are not needed (quill diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/SqlNormalize.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/SqlNormalize.scala index 9d1622f686..529f14b715 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/SqlNormalize.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/SqlNormalize.scala @@ -63,8 +63,8 @@ class SqlNormalize( .andThen(demarcate("ExpandJoin")) .andThen(ExpandMappedInfix.apply _) .andThen(demarcate("ExpandMappedInfix")) - .andThen(SheathLeafClausesPhase.apply _) - .andThen(demarcate("SheathLeaves")) + // .andThen(SheathLeafClausesPhase.apply _) + // .andThen(demarcate("SheathLeaves")) .andThen { ast => // In the final stage of normalization, change all temporary aliases into // shorter ones of the form x[0-9]+. From 6292a404215c470fa0de7260cd66a1f885e57ad6 Mon Sep 17 00:00:00 2001 From: Alexander Ioffe Date: Thu, 11 Jul 2024 19:59:07 -0400 Subject: [PATCH 3/3] Better verify message, fix contains-checking in SqlQuery case Filter --- .../main/scala/io/getquill/sql/SqlQuery.scala | 4 ++-- .../scala/io/getquill/sql/idiom/SqlIdiom.scala | 2 +- .../io/getquill/sql/idiom/VerifySqlQuery.scala | 16 ++++++++++------ .../src/main/scala/io/getquill/util/Text.scala | 10 +++++++--- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala index 5e15f4bf44..e582bf468d 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala @@ -397,7 +397,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { val b = base(q, alias, nestNextMap) // If the filter body uses the filter alias, make sure it matches one of the aliases in the fromContexts if ( - b.where.isEmpty && (!CollectAst.byType[Ident](p).map(_.name).contains(alias) || collectAliases(b.from) + b.where.isEmpty && (!CollectAst.byType[Ident](p).contains(alias) || collectAliases(b.from) .contains(alias)) ) trace"Flattening| Filter(Ident) [Simple]" andReturn @@ -415,7 +415,7 @@ class SqlQueryApply(traceConfig: TraceConfig) { val criteria = orderByCriteria(p, o, b.from) // If the sortBy body uses the filter alias, make sure it matches one of the aliases in the fromContexts if ( - b.orderBy.isEmpty && (!CollectAst.byType[Ident](p).map(_.name).contains(alias) || collectAliases(b.from) + b.orderBy.isEmpty && (!CollectAst.byType[Ident](p).contains(alias) || collectAliases(b.from) .contains(alias)) ) trace"Flattening| SortBy(Ident) [Simple]" andReturn diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala index 9eb9c13290..af07e0ff59 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala @@ -81,7 +81,7 @@ trait SqlIdiom extends Idiom { case q: Query => val sql = querifyAst(q, idiomContext.traceConfig) trace"SQL: ${sql}".andLog() - VerifySqlQuery(sql).map(fail) + VerifySqlQuery(sql).verifyOrFail().map(fail) val valueized = ValueizeSingleLeafSelects(naming)(sql, topLevelQuat) trace"Valueized SQL: ${valueized}".andLog() val expanded = ExpandNestedQueries(valueized, topLevelQuat) diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala index bb138d0628..619e3b2398 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/VerifySqlQuery.scala @@ -7,14 +7,14 @@ import io.getquill.quat.Quat import io.getquill.util.Text case class Error(free: List[Ident], ast: Ast) -case class InvalidSqlQuery(errors: List[Error]) { - override def toString = Text.JoinSynthesisError(errors) +case class InvalidSqlQuery(errors: List[Error], query: SqlQuery) { + override def toString = Text.JoinSynthesisError(errors, query) } -object VerifySqlQuery { +class VerifySqlQuery(originalQuery: SqlQuery) { - def apply(query: SqlQuery): Option[String] = - verify(query).map(_.toString) + def verifyOrFail(): Option[String] = + verify(originalQuery).map(_.toString) private def verify(query: SqlQuery): Option[InvalidSqlQuery] = query match { @@ -96,7 +96,7 @@ object VerifySqlQuery { (freeVariableErrors ++ nestedErrors) match { case Nil => None - case errors => Some(InvalidSqlQuery(errors)) + case errors => Some(InvalidSqlQuery(errors, originalQuery)) } } @@ -139,3 +139,7 @@ object VerifySqlQuery { }) } } + +object VerifySqlQuery { + def apply(query: SqlQuery) = new VerifySqlQuery(query) +} diff --git a/quill-engine/src/main/scala/io/getquill/util/Text.scala b/quill-engine/src/main/scala/io/getquill/util/Text.scala index 9d5f325f65..43616643b6 100644 --- a/quill-engine/src/main/scala/io/getquill/util/Text.scala +++ b/quill-engine/src/main/scala/io/getquill/util/Text.scala @@ -1,6 +1,7 @@ package io.getquill.util import io.getquill.ast.{Ast, Ident, Pos} +import io.getquill.context.sql.SqlQuery import io.getquill.context.sql.idiom.Error // Intentionally put all comments in 1st line. Want this to be a place @@ -20,15 +21,15 @@ implicit class SeqExt[T](seq: Seq[T]) { } -def JoinSynthesisError(errors: List[Error]) = +def JoinSynthesisError(errors: List[Error], query: SqlQuery) = errors match { - case List(Error(List(id), ast)) => joinSynthesisErrorSingle(id, ast) + case List(Error(List(id), ast)) => joinSynthesisErrorSingle(id, ast, query) case _ => joinSynthesisErrorMulti(errors) } // make a special printout case for a single error since this is what happens 95% of the time -private def joinSynthesisErrorSingle(id: Ident, ast: Ast) = +private def joinSynthesisErrorSingle(id: Ident, ast: Ast, query: SqlQuery) = s""" When synthesizing Joins, Quill found a variable that could not be traced back to its origin: ${id.name} originally at: ${id.pos} @@ -36,6 +37,9 @@ originally at: ${id.pos} with the following faulty expression: ${ast} +in the query: +${query} + ${joinSynthesisExplanation} """.trimLeft