Skip to content

Commit

Permalink
[rubysrc2cpg] Comprehensive follow up to joernio#3708's fix (joernio#…
Browse files Browse the repository at this point in the history
…3714)

Generalized the solution from PR joernio#3708 by applying it to other call argument construction.

Fixes joernio#3693
  • Loading branch information
DavidBakerEffendi authored Oct 4, 2023
1 parent 37c83f0 commit 662bd23
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class AstCreator(
case ctx: MethodDefinitionPrimaryContext => astForMethodDefinitionContext(ctx.methodDefinition())
case ctx: ProcDefinitionPrimaryContext => astForProcDefinitionContext(ctx.procDefinition())
case ctx: YieldWithOptionalArgumentPrimaryContext =>
Seq(astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments())))
astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments()))
case ctx: IfExpressionPrimaryContext => Seq(astForIfExpression(ctx.ifExpression()))
case ctx: UnlessExpressionPrimaryContext => Seq(astForUnlessExpression(ctx.unlessExpression()))
case ctx: CaseExpressionPrimaryContext => astForCaseExpressionPrimaryContext(ctx)
Expand All @@ -224,8 +224,8 @@ class AstCreator(
case ctx: RegexInterpolationPrimaryContext =>
astForRegexInterpolationPrimaryContext(ctx.regexInterpolation)
case ctx: QuotedRegexInterpolationPrimaryContext => astForQuotedRegexInterpolation(ctx.quotedRegexInterpolation)
case ctx: IsDefinedPrimaryContext => Seq(astForIsDefinedPrimaryExpression(ctx))
case ctx: SuperExpressionPrimaryContext => Seq(astForSuperExpression(ctx))
case ctx: IsDefinedPrimaryContext => astForIsDefinedPrimaryExpression(ctx)
case ctx: SuperExpressionPrimaryContext => astForSuperExpression(ctx)
case ctx: IndexingExpressionPrimaryContext => astForIndexingExpressionPrimaryContext(ctx)
case ctx: MethodOnlyIdentifierPrimaryContext => astForMethodOnlyIdentifier(ctx.methodOnlyIdentifier())
case ctx: InvocationWithBlockOnlyPrimaryContext => astForInvocationWithBlockOnlyPrimaryContext(ctx)
Expand Down Expand Up @@ -279,9 +279,10 @@ class AstCreator(
.asScala
.flatMap(astForExpressionContext)
.toSeq
val splatAsts = astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand())
val callNode = createOpCall(ctx.COMMA, Operators.arrayInitializer, text(ctx))
Seq(callAst(callNode, expAsts ++ splatAsts))
val splatAsts = astForExpressionOrCommand(ctx.splattingArgument().expressionOrCommand())
val callNode = createOpCall(ctx.COMMA, Operators.arrayInitializer, text(ctx))
val (argAsts, otherAsts) = (expAsts ++ splatAsts).partitionExprAst
otherAsts :+ callAst(callNode, argAsts)
case ctx: AssociationsOnlyIndexingArgumentsContext =>
astForAssociationsContext(ctx.associations())
case ctx: RubyParser.SplattingOnlyIndexingArgumentsContext =>
Expand Down Expand Up @@ -309,11 +310,11 @@ class AstCreator(
ctx.AMPDOT()
}

val argsAst = if (ctx.argumentsWithParentheses() != null) {
astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses())
} else {
Seq()
}
val (argsAst, otherAst) = (if (ctx.argumentsWithParentheses() != null) {
astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses())
} else {
Seq()
}).partitionExprAst

if (hasBlockStmt) {
val blockName = methodNameAst.head.nodes.head
Expand Down Expand Up @@ -361,7 +362,7 @@ class AstCreator(
if (callNode.name == "call" && ctx.primary().isInstanceOf[ProcDefinitionPrimaryContext]) {
// this is a proc.call
val baseCallNode = primaryAst.head.nodes.head.asInstanceOf[NewCall]
Seq(callAst(baseCallNode, argsAst))
otherAst :+ callAst(baseCallNode, argsAst)
} else {
callNode
.code(text(ctx))
Expand All @@ -375,9 +376,9 @@ class AstCreator(
.methodFullName(methodNode.fullName)
.typeFullName(Defines.Any)
blockMethods.addOne(primaryAst.head)
Seq(callAst(callNode, Seq(Ast(methodRefNode)) ++ argsAst))
otherAst :+ callAst(callNode, Seq(Ast(methodRefNode)) ++ argsAst)
case _ =>
Seq(callAst(callNode, argsAst, primaryAst.headOption))
otherAst :+ callAst(callNode, argsAst, primaryAst.headOption)
}
}
}
Expand Down Expand Up @@ -418,7 +419,8 @@ class AstCreator(
.code(text(ctx))
.lineNumber(ctx.COLON2.lineNumber)
.columnNumber(ctx.COLON2.columnNumber)
Seq(callAst(callNode, baseAst ++ blocksAst))
val (argsAst, otherAst) = (baseAst ++ blocksAst).partitionExprAst
otherAst :+ callAst(callNode, argsAst)
}

private def astForChainedScopedConstantReferencePrimaryContext(
Expand Down Expand Up @@ -736,8 +738,8 @@ class AstCreator(
val operatorText = getOperatorName(terminalNode.getSymbol)
val expressions = ctx.expression.asScala

val callArgs =
Option(ctx.keyword) match {
val (callArgs, otherAst) =
(Option(ctx.keyword) match {
case Some(ctxKeyword) =>
val expr1Ast = astForCallNode(ctx, ctxKeyword.getText)
val expr2Asts = astForExpressionContext(expressions.head)
Expand All @@ -746,10 +748,10 @@ class AstCreator(
val expr1Asts = astForExpressionContext(expressions.head)
val expr2Asts = expressions.lift(1).flatMap(astForExpressionContext)
expr1Asts ++ expr2Asts
}
}).partitionExprAst

val callNode = createOpCall(terminalNode, operatorText, text(ctx))
Seq(callAst(callNode, callArgs))
otherAst ++ Seq(callAst(callNode, callArgs))
}

private def astForAssociationsContext(ctx: AssociationsContext): Seq[Ast] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As

}

implicit class AstIterExt(a: Iterable[Ast]) {

/** Partitions a sequence of Ast objects into those with roots start with an expression, and those that don't.
* @return
* a tuple of sequences where the first has expression roots and the second does not.
*/
def partitionExprAst: (Seq[Ast], Seq[Ast]) = {
val (as, bs) = a.partition(_.root match
case Some(_: ExpressionNew) => true
case _ => false
)
(as.toSeq, bs.toSeq)
}

}

}

object RubyOperators {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{
AstNodeNew,
NewCall,
NewIdentifier,
NewMethod,
NewType,
NewTypeDecl
}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -98,23 +91,19 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
): Seq[Ast] = {
val (argsAst, otherAst) = arguments
.flatMap(astForExpressionContext)
.partition(_.root match
case Some(_: NewMethod) => false
case Some(_: NewTypeDecl) => false
case _ => true
)
.partitionExprAst
val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH)
otherAst.toSeq :+ callAst(call, argsAst.toList)
otherAst :+ callAst(call, argsAst.toList)
}

protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression()))

// TODO: Maybe merge (in RubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression?
protected def astForIsDefinedPrimaryExpression(ctx: IsDefinedPrimaryContext): Ast = {
val argsAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val call = callNode(ctx, text(ctx), RubyOperators.defined, RubyOperators.defined, DispatchTypes.STATIC_DISPATCH)
callAst(call, argsAst.toList)
protected def astForIsDefinedPrimaryExpression(ctx: IsDefinedPrimaryContext): Seq[Ast] = {
val (argsAst, otherAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionExprAst
val call = callNode(ctx, text(ctx), RubyOperators.defined, RubyOperators.defined, DispatchTypes.STATIC_DISPATCH)
otherAst :+ callAst(call, argsAst.toList)
}

protected def astForLiteralPrimaryExpression(ctx: LiteralPrimaryContext): Seq[Ast] = ctx.literal() match {
Expand Down Expand Up @@ -169,10 +158,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case ctx: VariableIdentifierOnlySingleLeftHandSideContext =>
Seq(astForVariableIdentifierHelper(ctx.variableIdentifier, true))
case ctx: PrimaryInsideBracketsSingleLeftHandSideContext =>
val primaryAsts = astForPrimaryContext(ctx.primary)
val argsAsts = astForArguments(ctx.arguments)
val indexAccessCall = createOpCall(ctx.LBRACK, Operators.indexAccess, text(ctx))
Seq(callAst(indexAccessCall, primaryAsts ++ argsAsts))
val primaryAsts = astForPrimaryContext(ctx.primary)
val (argsAsts, otherAst) = astForArguments(ctx.arguments).partitionExprAst
val indexAccessCall = createOpCall(ctx.LBRACK, Operators.indexAccess, text(ctx))
otherAst :+ callAst(indexAccessCall, primaryAsts ++ argsAsts)
case ctx: XdotySingleLeftHandSideContext =>
// TODO handle obj.foo=arg being interpreted as obj.foo(arg) here.
val xAsts = astForPrimaryContext(ctx.primary)
Expand Down Expand Up @@ -220,8 +209,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
* This is multiple RHS packed into a single LHS. That is, packing left hand side.
* This is as good as multiple RHS packed into an array and put into a single LHS
*/
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
Seq(callAst(opCallNode, leftAst ++ packedRHS))
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
val (argsAst, otherAst) = (leftAst ++ packedRHS).partitionExprAst
otherAst :+ callAst(opCallNode, argsAst)
}
} else {
Seq(callAst(opCallNode, leftAst ++ rightAst))
Expand Down Expand Up @@ -271,8 +261,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case Some(node) if node.name == "Array" => Operators.arrayInitializer
case _ => Operators.indexAccess

val callNode = createOpCall(ctx.LBRACK, operator, text(ctx))
Seq(callAst(callNode, lhsExpressionAst ++ rhsExpressionAst))
val callNode = createOpCall(ctx.LBRACK, operator, text(ctx))
val (argsAst, otherAst) = (lhsExpressionAst ++ rhsExpressionAst).partitionExprAst
otherAst :+ callAst(callNode, argsAst)

}

Expand All @@ -288,7 +279,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
.typeFullName(Defines.Any)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.code(if (wrapInBrackets) s"[$code]" else code)
Seq(callAst(callNode, astsToConcat))
val (argsAst, otherAst) = astsToConcat.partitionExprAst
otherAst :+ callAst(callNode, argsAst)
}

def astForStringInterpolationContext(ctx: InterpolatedStringExpressionContext): Seq[Ast] = {
Expand Down Expand Up @@ -377,7 +369,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)

protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = {
protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Seq[Ast] = {
val argsAst = Option(ctx.argumentsWithParentheses()) match
case Some(ctxArgs) => astForArgumentsWithParenthesesContext(ctxArgs)
case None => Seq()
Expand All @@ -387,17 +379,20 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
// TODO: Handle the optional block.
// NOTE: `super` is quite complicated semantically speaking. We'll need
// to revisit how to represent them.
protected def astForSuperCall(ctx: ParserRuleContext, arguments: Seq[Ast]): Ast = {
protected def astForSuperCall(ctx: ParserRuleContext, arguments: Seq[Ast]): Seq[Ast] = {
val call =
callNode(ctx, text(ctx), RubyOperators.superKeyword, RubyOperators.superKeyword, DispatchTypes.STATIC_DISPATCH)
callAst(call, arguments.toList)

val (argsAst, otherAst) = arguments.partitionExprAst
otherAst :+ callAst(call, argsAst)
}

protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = {
protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Seq[Ast] = {
val args = argumentsCtx.map(astForArguments).getOrElse(Seq())
val call =
callNode(ctx, text(ctx), Defines.UNRESOLVED_YIELD, Defines.UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH)
callAst(call, args)
val (argsAst, otherAst) = args.partitionExprAst
otherAst :+ callAst(call, argsAst)
}

protected def astForUntilExpression(ctx: UntilExpressionContext): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,53 +208,54 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V
// TODO: return Ast instead of Seq[Ast]
protected def astForExpressionOrCommand(ctx: ExpressionOrCommandContext): Seq[Ast] = ctx match {
case ctx: InvocationExpressionOrCommandContext => astForInvocationExpressionOrCommandContext(ctx)
case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx))
case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx))
case ctx: NotExpressionOrCommandContext => astForNotKeywordExpressionOrCommand(ctx)
case ctx: OrAndExpressionOrCommandContext => astForOrAndExpressionOrCommand(ctx)
case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression())
case _ =>
logger.error(s"astForExpressionOrCommand() $relativeFilename, ${text(ctx)} All contexts mismatched.")
Seq(Ast())
}

private def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = {
val exprOrCommandAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val call = callNode(ctx, text(ctx), Operators.not, Operators.not, DispatchTypes.STATIC_DISPATCH)
callAst(call, exprOrCommandAst)
private def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Seq[Ast] = {
val exprOrCommandAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val call = callNode(ctx, text(ctx), Operators.not, Operators.not, DispatchTypes.STATIC_DISPATCH)
val (argsAst, otherAst) = exprOrCommandAst.partitionExprAst
otherAst :+ callAst(call, argsAst)
}

private def astForOrAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = ctx.op.getType match {
private def astForOrAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = ctx.op.getType match {
case OR => astForOrExpressionOrCommand(ctx)
case AND => astForAndExpressionOrCommand(ctx)
}

private def astForOrExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = {
val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand)
val call = callNode(ctx, text(ctx), Operators.or, Operators.or, DispatchTypes.STATIC_DISPATCH)
callAst(call, argsAst.toList)
private def astForOrExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = {
val args = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand)
val call = callNode(ctx, text(ctx), Operators.or, Operators.or, DispatchTypes.STATIC_DISPATCH)
val (argsAst, otherAst) = args.partitionExprAst
otherAst :+ callAst(call, argsAst)
}

private def astForAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Ast = {
val argsAst = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand)
val call = callNode(ctx, text(ctx), Operators.and, Operators.and, DispatchTypes.STATIC_DISPATCH)
callAst(call, argsAst.toList)
private def astForAndExpressionOrCommand(ctx: OrAndExpressionOrCommandContext): Seq[Ast] = {
val args = ctx.expressionOrCommand().asScala.flatMap(astForExpressionOrCommand)
val call = callNode(ctx, text(ctx), Operators.and, Operators.and, DispatchTypes.STATIC_DISPATCH)
val (argsAst, otherAst) = args.partitionExprAst
otherAst :+ callAst(call, argsAst)
}

private def astForSuperCommand(ctx: SuperCommandContext): Ast =
private def astForSuperCommand(ctx: SuperCommandContext): Seq[Ast] =
astForSuperCall(ctx, astForArguments(ctx.argumentsWithoutParentheses().arguments()))

private def astForYieldCommand(ctx: YieldCommandContext): Ast =
private def astForYieldCommand(ctx: YieldCommandContext): Seq[Ast] =
astForYieldCall(ctx, Option(ctx.argumentsWithoutParentheses().arguments()))

private def astForSimpleMethodCommand(ctx: SimpleMethodCommandContext): Seq[Ast] = {
val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx))
methodIdentifierAsts.headOption.foreach(methodNameAsIdentifierStack.push)
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())

/* get args without the method def in it */
val argAstsWithoutMethods = argsAsts.filterNot(_.root.exists(_.isInstanceOf[NewMethod]))
val args = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val (argsAst, _) = args.partitionExprAst

/* isolate methods from the original args and create identifier ASTs from it */
val methodDefAsts = argsAsts.filter(_.root.exists(_.isInstanceOf[NewMethod]))
val methodDefAsts = args.filter(_.root.exists(_.isInstanceOf[NewMethod]))
val methodToIdentifierAsts = methodDefAsts.flatMap {
_.nodes.collectFirst { case methodNode: NewMethod =>
Ast(
Expand All @@ -277,17 +278,17 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V
if (callNodes.size == 1) {
val callNode = callNodes.head
if (callNode.name == "require" || callNode.name == "load") {
resolveRequireOrLoadPath(argsAsts, callNode)
resolveRequireOrLoadPath(args, callNode)
} else if (callNode.name == "require_relative") {
resolveRelativePath(filename, argsAsts, callNode)
resolveRelativePath(filename, args, callNode)
} else if (prefixMethods.contains(callNode.name)) {
/* we remove the method definition AST from argument and add its corresponding identifier form */
Seq(callAst(callNode, argAstsWithoutMethods ++ methodToIdentifierAsts))
Seq(callAst(callNode, argsAst ++ methodToIdentifierAsts))
} else {
Seq(callAst(callNode, argsAsts))
Seq(callAst(callNode, argsAst))
}
} else {
argsAsts
args
}
}

Expand Down Expand Up @@ -325,8 +326,8 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V
methodRefNode(ctx, s"def ${methodNode.name}(...)", methodNode.fullName, Defines.Any)

protected def astForCommand(ctx: CommandContext): Seq[Ast] = ctx match {
case ctx: YieldCommandContext => Seq(astForYieldCommand(ctx))
case ctx: SuperCommandContext => Seq(astForSuperCommand(ctx))
case ctx: YieldCommandContext => astForYieldCommand(ctx)
case ctx: SuperCommandContext => astForSuperCommand(ctx)
case ctx: SimpleMethodCommandContext => astForSimpleMethodCommand(ctx)
case ctx: MemberAccessCommandContext => astForMemberAccessCommand(ctx)
}
Expand Down
Loading

0 comments on commit 662bd23

Please sign in to comment.