Skip to content

Commit

Permalink
[gosrc2cpg] - Handling for lambda call node for invocation (joernio#3873
Browse files Browse the repository at this point in the history
)

1. Handling for lambda call node for its invocation along with few more
unit tests.
2. Data flow unit tests for the lambda.
  • Loading branch information
pandurangpatil authored Dec 1, 2023
1 parent 46806ac commit fa6393b
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trait AstForLambdaCreator(implicit withSchemaValidation: ValidationMode) { this:
.collectFirst({ case m: NewMethod if !m.fullName.endsWith(parserResult.filename) => m.fullName })
.getOrElse(fullyQualifiedPackage)
val fullName = s"$baseFullName.$lambdaName"
val (signature, methodReturn, params, genericTypeMethodMap) = generateLambdaSignature(
val (signature, returnTypeStr, methodReturn, params, genericTypeMethodMap) = generateLambdaSignature(
createParserNodeInfo(funcLiteral.json(ParserKeys.Type))
)
val methodNode_ = methodNode(funcLiteral, lambdaName, funcLiteral.code, fullName, Some(signature), relPathFileName)
Expand All @@ -40,14 +40,16 @@ trait AstForLambdaCreator(implicit withSchemaValidation: ValidationMode) { this:
case _ =>
methodNode_.astParentType(NodeTypes.METHOD).astParentFullName(baseFullName)
Ast.storeInDiffGraph(astForMethod, diffGraph)
val typeFullName = GoGlobal.lambdaSignatureToLambdaTypeMap.getOrDefault(signature, fullName)
// TODO: Create TypeDecl for lambda function for which we didnt find the type.
Seq(Ast(methodRefNode(funcLiteral, funcLiteral.code, fullName, typeFullName)))
// We need to create TypeDecl for every lambda function and set its inheritance with all the matching lambda types.
// val typeFullName = GoGlobal.lambdaSignatureToLambdaTypeMap.getOrDefault(signature, fullName)
GoGlobal.recordFullNameToReturnType(fullName, returnTypeStr, signature)
Seq(Ast(methodRefNode(funcLiteral, funcLiteral.code, fullName, fullName)))
}

private def generateLambdaSignature(
funcType: ParserNodeInfo
): (String, NewMethodReturn, Value, Map[String, List[String]]) = {
): (String, String, NewMethodReturn, Value, Map[String, List[String]]) = {
val genericTypeMethodMap: Map[String, List[String]] = Map()
val (returnTypeStr, returnTypeInfo) =
getReturnType(funcType.json, genericTypeMethodMap).headOption
Expand All @@ -57,6 +59,6 @@ trait AstForLambdaCreator(implicit withSchemaValidation: ValidationMode) { this:
val params = funcType.json(ParserKeys.Params)(ParserKeys.List)
val signature =
s"${XDefines.ClosurePrefix}(${parameterSignature(params, genericTypeMethodMap)})$returnTypeStr"
(signature, methodReturn, params, genericTypeMethodMap)
(signature, returnTypeStr, methodReturn, params, genericTypeMethodMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,19 @@ trait AstForMethodCallExpressionCreator(implicit withSchemaValidation: Validatio
.getOrDefault(methodFullName, (Defines.anyTypeName, s"$methodFullName()"))
val (signature, fullName, returnTypeFullName) =
Defines.builtinFunctions.getOrElse(methodName, (signatureCache, methodFullName, returnTypeFullNameCache))
(methodName, signature, fullName, returnTypeFullName, Seq.empty)
val lambdaOption = scope.lookupVariable(methodName)
val (postLambdaFullname, postLambdaSignature, postLambdaReturnTypeFullName) = lambdaOption match
case Some((_, lambdaTypeFullName)) =>
val (lambdaReturnTypeFullNameCache, lambdaSignatureCache) =
GoGlobal.methodFullNameReturnTypeMap
.getOrDefault(lambdaTypeFullName, (returnTypeFullName, signature))
if (lambdaSignatureCache == signature) then
// This means we didn't find the lambda signature in methodFullNameReturnTypeMap cache.
(fullName, lambdaSignatureCache, lambdaReturnTypeFullNameCache)
else (lambdaTypeFullName, lambdaSignatureCache, lambdaReturnTypeFullNameCache)
case _ =>
(fullName, signature, returnTypeFullName)
(methodName, postLambdaSignature, postLambdaFullname, postLambdaReturnTypeFullName, Seq.empty)
case Some(xnode) =>
xnode.node match
case Ident =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.joern.go2cpg.dataflow

import io.joern.dataflowengineoss.language.*
import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.semanticcpg.language.*

class AnonymousFuncDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) {
"Simple Lambda expression dataflow" should {
val cpg = code("""
|package main
|
|import "fmt"
|
|func main() {
| // Define a lambda function and assign it to a variable
| add := func(a, b int) int {
| println(a)
| return a + b
| }
|
| // Call the lambda function
| result := add(3, 5)
| fmt.Println("Result:", result) // Output: 8
|}
|""".stripMargin)
"work dataflow within lambda from parameter to println sink" in {
val source = cpg.identifier("a")
val sink = cpg.call("println")
sink.reachableByFlows(source).size shouldBe 1
}

"work dataflow from literal parameter passed to lambda invocation to println sink" in {
val source = cpg.literal("3")
val sink = cpg.call("println")
sink.reachableByFlows(source).size shouldBe 1
}

"work dataflow from literal parameter passed to lambda invocation to outside println sink" in {
val source = cpg.literal("3")
val sink = cpg.call("Println")
sink.reachableByFlows(source).size shouldBe 1
}
}

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package io.joern.go2cpg.dataflow
import io.shiftleft.semanticcpg.language._
import io.joern.dataflowengineoss.language._

import io.joern.dataflowengineoss.language.*
import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.semanticcpg.language.*

class ArrayDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) {
"Source to sink dataflow through arrays" should {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,19 @@ class AnonymousFuncTests extends GoCodeToCpgSuite {
cpg.methodRef.referencedMethod.fullName.l shouldBe List(s"main.main.${Defines.ClosurePrefix}0")
}

"reflectes into lhs side TypeFullName" in {
"reflects into lhs side TypeFullName" in {
cpg.local("add").typeFullName.l shouldBe List(s"main.main.${Defines.ClosurePrefix}0")
}

"have call node created for lambda invocation" in {
cpg.call("add").methodFullName.l shouldBe List(s"main.main.${Defines.ClosurePrefix}0")
}

"able to traverse from call node to callee" in {
cpg.call("add").callee.fullName.l shouldBe List(s"main.main.${Defines.ClosurePrefix}0")
}

// TODO: Add TypeDecl test
// TODO: Add unit test for lambda defined outside method and part of package directly
}
}

0 comments on commit fa6393b

Please sign in to comment.