Skip to content

Commit

Permalink
Add support for record pattern matching (#5185)
Browse files Browse the repository at this point in the history
Add support for record pattern matching
  • Loading branch information
johannescoetzee authored Jan 2, 2025
1 parent c879793 commit 717cd8b
Show file tree
Hide file tree
Showing 10 changed files with 1,051 additions and 756 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait AstForNameExpressionsCreator { this: AstCreator =>

case SimpleVariable(ScopePatternVariable(localNode, typePatternExpr)) =>
scope.enclosingMethod.flatMap(_.getPatternVariableInfo(typePatternExpr)) match {
case Some(PatternVariableInfo(typePatternExpr, _, initializerAst, _, false)) =>
case Some(PatternVariableInfo(typePatternExpr, _, initializerAst, _, false, _)) =>
scope.enclosingMethod.foreach(_.registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr))
initializerAst
case _ =>
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ trait AstForSimpleExpressionsCreator { this: AstCreator =>
.typePatternExprsExposedToChild(expr.getRight)
.asScala
.flatMap(pattern => scope.enclosingMethod.flatMap(_.getPatternVariableInfo(pattern)))
.foreach { case PatternVariableInfo(pattern, local, _, _, _) =>
.foreach { case PatternVariableInfo(pattern, local, _, _, _, _) =>
scope.enclosingBlock.foreach(_.addPatternLocal(local, pattern))
}

Expand Down Expand Up @@ -318,7 +318,7 @@ trait AstForSimpleExpressionsCreator { this: AstCreator =>
val lhsAst = astsForExpression(expr.getExpression, ExpectedType.empty).head
expr.getPattern.toScala
.map { patternExpression =>
astForInstanceOfWithPattern(expr.getExpression, lhsAst, patternExpression)
instanceOfAstForPattern(patternExpression, lhsAst)
}
.getOrElse {
val booleanTypeFullName = Some(TypeConstants.Boolean)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,22 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>

val selectorNode = selectorAst.root.get

val selectorMustBeIdentifier = stmt.getEntries.asScala.flatMap(_.getLabels.asScala).exists(_.isPatternExpr)
val selectorMustBeIdentifierOrFieldAccess =
stmt.getEntries.asScala.flatMap(_.getLabels.asScala).exists(_.isPatternExpr)

val (selectorInitializer, selectorIdentifier, selectorRefsTo) = if (selectorMustBeIdentifier) {
val (init, ident, refs) = astIdentifierAndRefsForPatternLhs(stmt.getSelector, selectorAst)
(init, Option(ident), refs)
val (initializerAst, referenceAst) = if (selectorMustBeIdentifierOrFieldAccess) {
val initAndRefAsts = initAndRefAstsForPatternInitializer(stmt.getSelector, selectorAst)
(initAndRefAsts.get, Option(initAndRefAsts.get))
} else {
(selectorAst, None, None)
(selectorAst, None)
}

val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry(_, selectorIdentifier, selectorRefsTo))
val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry(_, referenceAst))

val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts)

Ast(switchNode)
.withChild(selectorInitializer)
.withChild(initializerAst)
.withChild(switchBodyAst)
.withConditionEdge(switchNode, selectorNode)
}
Expand Down Expand Up @@ -355,11 +356,7 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>
(defaultAst ++ explicitLabelAsts).toList
}

private def astForSwitchEntry(
entry: SwitchEntry,
selectorIdentifier: Option[NewIdentifier],
selectorRefsTo: Option[NewVariableNode]
): Seq[Ast] = {
private def astForSwitchEntry(entry: SwitchEntry, selectorReferenceAst: Option[Ast]): Seq[Ast] = {
// Fallthrough to/from a pattern is a compile error, so an entry can only have a pattern label if that is
// the only label
val labels = entry.getLabels.asScala.toList
Expand All @@ -368,8 +365,8 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>
val entryContext = new SwitchEntryContext(entry, new CombinedTypeSolver())

val instanceOfAst = labels.lastOption.collect { case patternExpr: PatternExpr =>
selectorIdentifier.map { selector =>
astForInstanceOfWithPattern(patternExpr, Ast(selector), patternExpr)
selectorReferenceAst.map { selectorAst =>
instanceOfAstForPattern(patternExpr, selectorAst)
}
}.flatten

Expand All @@ -382,7 +379,7 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>
scope.addLocalsForPatternsToEnclosingBlock(patternsExposedToBody)
val patternAstsToAdd = patternsExposedToBody
.flatMap(typePattern => scope.enclosingMethod.get.getPatternVariableInfo(typePattern))
.flatMap { case PatternVariableInfo(typePatternExpr, patternLocal, initializerAst, _, _) =>
.flatMap { case PatternVariableInfo(typePatternExpr, patternLocal, initializerAst, _, _, _) =>
scope.enclosingMethod.get.registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr)
scope.enclosingMethod.get.registerPatternVariableLocalToBeAddedToGraph(typePatternExpr)
Ast(patternLocal) :: initializerAst :: Nil
Expand Down Expand Up @@ -413,7 +410,7 @@ trait AstForSimpleStatementsCreator { this: AstCreator =>
instanceOfAst
.map { instanceOfAst =>
val ifNode = controlStructureNode(entry, ControlStructureTypes.IF, s"if (${instanceOfAst.rootCodeOrEmpty})")
labelAsts :+ Ast(ifNode).withChild(instanceOfAst).withChild(statementsAst)
labelAsts :+ controlStructureAst(ifNode, Option(instanceOfAst), statementsAst :: Nil)
}
.getOrElse(labelAsts :+ statementsAst)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@ trait AstForStatementsCreator extends AstForSimpleStatementsCreator with AstForF

patternSet.asScala
.flatMap(patternExpr => scope.enclosingMethod.flatMap(_.getPatternVariableInfo(patternExpr)))
.toArray
.sortBy(_.index)
.foreach {
case PatternVariableInfo(pattern, variableLocal, _, _, true) =>
case PatternVariableInfo(pattern, variableLocal, _, _, true, _) =>
scope.enclosingMethod.foreach(_.registerPatternVariableLocalToBeAddedToGraph(pattern))
astsAddedBeforeStmt.addOne(Ast(variableLocal))

case PatternVariableInfo(pattern, variableLocal, initializer, _, false) =>
case PatternVariableInfo(pattern, variableLocal, initializer, _, false, _) =>
if (patternsIntroducedByStmt.contains(pattern)) {
if (patternsIntroducedToBody.contains(pattern) || patternsIntroducedToElse.contains(pattern)) {
astsAddedBeforeStmt.addOne(Ast(variableLocal))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ case class PatternVariableInfo(
typeVariableLocal: NewLocal,
typeVariableInitializer: Ast,
localAddedToAst: Boolean = false,
initializerAddedToAst: Boolean = false
initializerAddedToAst: Boolean = false,
index: Int
)

object JavaScopeElement {
Expand Down Expand Up @@ -113,6 +114,8 @@ object JavaScopeElement {
private val temporaryLocals = mutable.ListBuffer[NewLocal]()
private val patternVariableInfoIdentityMap: mutable.Map[TypePatternExpr, PatternVariableInfo] =
new util.IdentityHashMap[TypePatternExpr, PatternVariableInfo]().asScala
// The insertion order should be preserved to ensure stable results when getting unadded variable asts
private var patternVariableIndex = 0

def addParameter(parameter: NewMethodParameterIn): Unit = {
addVariableToScope(ScopeParameter(parameter))
Expand All @@ -131,29 +134,31 @@ object JavaScopeElement {
): Unit = {
patternVariableInfoIdentityMap.put(
typePatternExpr,
PatternVariableInfo(typePatternExpr, typeVariableLocal, typeVariableInitializer)
PatternVariableInfo(typePatternExpr, typeVariableLocal, typeVariableInitializer, index = patternVariableIndex)
)
patternVariableIndex += 1
}

def getPatternVariableInfo(typePatternExpr: TypePatternExpr): Option[PatternVariableInfo] = {
patternVariableInfoIdentityMap.get(typePatternExpr)
}

def registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr: TypePatternExpr): Unit = {
patternVariableInfoIdentityMap.get(typePatternExpr).foreach { case patternVariableInfo =>
patternVariableInfoIdentityMap.put(typePatternExpr, patternVariableInfo.copy(initializerAddedToAst = true))
patternVariableInfoIdentityMap.get(typePatternExpr).foreach { patternVariableInfo =>
patternVariableInfoIdentityMap
.put(typePatternExpr, patternVariableInfo.copy(initializerAddedToAst = true))
}
}

def registerPatternVariableLocalToBeAddedToGraph(typePatternExpr: TypePatternExpr): Unit = {
patternVariableInfoIdentityMap.get(typePatternExpr).foreach { case patternVariableInfo =>
patternVariableInfoIdentityMap.get(typePatternExpr).foreach { patternVariableInfo =>
patternVariableInfoIdentityMap.put(typePatternExpr, patternVariableInfo.copy(localAddedToAst = true))
}
}

def getUnaddedPatternVariableAstsAndMarkAdded(): List[Ast] = {
val result = mutable.ListBuffer[Ast]()
patternVariableInfoIdentityMap.values.foreach { patternInfo =>
patternVariableInfoIdentityMap.values.toArray.sortBy(_.index).foreach { patternInfo =>
if (!patternInfo.localAddedToAst) {
result.addOne(Ast(patternInfo.typeVariableLocal))
registerPatternVariableLocalToBeAddedToGraph(patternInfo.typePatternExpr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class Scope(implicit val withSchemaValidation: ValidationMode, val disableTypeFa

def addLocalsForPatternsToEnclosingBlock(patterns: List[TypePatternExpr]): Unit = {
patterns.flatMap(enclosingMethod.get.getPatternVariableInfo(_)).foreach {
case PatternVariableInfo(typePatternExpr, variableLocal, _, _, _) =>
case PatternVariableInfo(typePatternExpr, variableLocal, _, _, _, _) =>
enclosingBlock.get.addPatternLocal(variableLocal, typePatternExpr)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.github.javaparser.resolution.logic.InferenceVariableType
import com.github.javaparser.resolution.model.typesystem.{LazyType, NullType}
import com.github.javaparser.resolution.types.*
import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap
import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserRecordDeclaration
import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.{TypeConstants, TypeNameConstants}
import io.joern.x2cpg.datastructures.Global
import org.slf4j.LoggerFactory
Expand Down
Loading

0 comments on commit 717cd8b

Please sign in to comment.