diff --git a/jacodb-api-jvm/src/main/kotlin/org/jacodb/api/jvm/cfg/JcRawInst.kt b/jacodb-api-jvm/src/main/kotlin/org/jacodb/api/jvm/cfg/JcRawInst.kt index ebd098759..de038957e 100644 --- a/jacodb-api-jvm/src/main/kotlin/org/jacodb/api/jvm/cfg/JcRawInst.kt +++ b/jacodb-api-jvm/src/main/kotlin/org/jacodb/api/jvm/cfg/JcRawInst.kt @@ -862,6 +862,12 @@ data class JcRawArgument( } } +enum class LocalVarKind { + UNKNOWN, + ORIGINAL, + NAMED_LOCAL, +} + /** * @param name isn't considered in `equals` and `hashcode` */ @@ -869,6 +875,7 @@ data class JcRawLocalVar( val index: Int, override val name: String, override val typeName: TypeName, + val kind: LocalVarKind = LocalVarKind.UNKNOWN, ) : JcRawLocal { override fun toString(): String = name diff --git a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/JcInstListBuilder.kt b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/JcInstListBuilder.kt index 0672a8b64..1c3dd68d6 100644 --- a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/JcInstListBuilder.kt +++ b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/JcInstListBuilder.kt @@ -75,7 +75,7 @@ class JcInstListBuilder(val method: JcMethod,val instList: JcInstList inst.lhv.let { unprocessedLhv -> if (unprocessedLhv is JcRawLocalVar && unprocessedLhv.typeName == UNINIT_THIS) { convertedLocalVars.getOrPut(unprocessedLhv) { - JcRawLocalVar(unprocessedLhv.index, unprocessedLhv.name, inst.rhv.typeName) + JcRawLocalVar(unprocessedLhv.index, unprocessedLhv.name, inst.rhv.typeName, unprocessedLhv.kind) } } else { unprocessedLhv diff --git a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/RawInstListBuilder.kt b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/RawInstListBuilder.kt index e06f6aabd..3248ebd6c 100644 --- a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/RawInstListBuilder.kt +++ b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/RawInstListBuilder.kt @@ -96,6 +96,7 @@ import org.jacodb.api.jvm.cfg.JcRawUshrExpr import org.jacodb.api.jvm.cfg.JcRawValue import org.jacodb.api.jvm.cfg.JcRawVirtualCallExpr import org.jacodb.api.jvm.cfg.JcRawXorExpr +import org.jacodb.api.jvm.cfg.LocalVarKind import org.jacodb.impl.cfg.util.CLASS_CLASS import org.jacodb.impl.cfg.util.ExprMapper import org.jacodb.impl.cfg.util.METHOD_HANDLES_CLASS @@ -365,6 +366,10 @@ class RawInstListBuilder( private val localMergeAssignments = mutableListOf() private val stackMergeAssignments = mutableListOf() + private val registerRelationTree = identityMap() + private val isRegisterNamed = identityMap() + private val registerToLocalName = identityMap() + private var argCounter = 0 private var generatedLocalVarsCounter = 0 @@ -372,8 +377,10 @@ class RawInstListBuilder( buildGraph() buildInstructions() + markOriginalAssigns() buildRequiredAssignments() buildRequiredGotos() + fillRegisterNames() val generatedInstructions = mutableListOf() instructionLists.forEach { insnList -> insnList?.let { generatedInstructions += it } } @@ -390,6 +397,102 @@ class RawInstListBuilder( return Simplifier().simplify(localsNormalizedInstructionList) } + private fun isNamed(localVar: JcRawLocalVar): Boolean { + return localVar.name[0] != LOCAL_VAR_START_CHARACTER + } + + private fun findRelationRoot(start: Int): Int { + var curId = start + while (registerRelationTree.contains(curId)) { + curId = registerRelationTree[curId] ?: break + } + return curId + } + + private fun uniteRegisters(fst: JcRawLocalVar, snd: JcRawLocalVar) { + isRegisterNamed[fst.index] = isNamed(fst) + isRegisterNamed[snd.index] = isNamed(snd) + val fstRoot = findRelationRoot(fst.index) + val sndRoot = findRelationRoot(snd.index) + if (fstRoot == sndRoot || + (isRegisterNamed.getOrDefault(fstRoot, false) && isRegisterNamed.getOrDefault(sndRoot, false))) + return + if (isNamed(fst) || registerToLocalName.contains(fstRoot)) { + registerRelationTree[sndRoot] = fstRoot + return + } + registerRelationTree[fstRoot] = sndRoot + } + + private fun findRegisterName(localVar: JcRawLocalVar): String? { + var curRegId = localVar.index + while (registerRelationTree.contains(curRegId)) { + curRegId = registerRelationTree[curRegId] ?: break + } + if (registerToLocalName.contains(curRegId)) { + return registerToLocalName[curRegId] + } + return null + } + + private fun changeAssigns(assignChanger: (JcRawAssignInst) -> JcRawAssignInst?) { + instructionLists.forEach { insnList -> + insnList?.indices?.forEach { i -> + val insn = insnList[i] + if (insn is JcRawAssignInst) { + val newAssign = assignChanger(insn) + if (newAssign != null) + insnList[i] = newAssign + } + } + } + } + + private fun putRegisterNameForAssign(assignInst: JcRawAssignInst): JcRawAssignInst? { + if (assignInst.lhv !is JcRawLocalVar) { + return null + } + val lhv = assignInst.lhv as JcRawLocalVar + if (isNamed(lhv)) { + return null + } + val name = findRegisterName(lhv) ?: return null + val newLhv = lhv.copy(name = name) + return JcRawAssignInst( + assignInst.owner, + newLhv, + assignInst.rhv + ) + } + + private fun fillRegisterNames() { + changeAssigns { putRegisterNameForAssign(it) } + } + + private fun markAssignAsOriginal(assignInst: JcRawAssignInst): JcRawAssignInst? { + if (assignInst.lhv !is JcRawLocalVar) { + return null + } + val newLhv = (assignInst.lhv as JcRawLocalVar).copy(kind = LocalVarKind.ORIGINAL) + return JcRawAssignInst( + assignInst.owner, + newLhv, + assignInst.rhv + ) + } + + private fun markOriginalAssigns() { + changeAssigns { markAssignAsOriginal(it) } + } + + private fun createRawAssign(owner: JcMethod, lhv: JcRawValue, rhv: JcRawExpr): JcRawAssignInst { + if (lhv is JcRawLocalVar && rhv is JcRawLocalVar) { + if (lhv.kind != LocalVarKind.ORIGINAL && lhv.kind != LocalVarKind.NAMED_LOCAL) + uniteRegisters(lhv, rhv) + } + return JcRawAssignInst(owner, lhv, rhv) + } + private fun MutableList.ensureFirstInstIsLineNumber() { val firstLineNumber = indexOfFirst { it is JcRawLineNumberInst } if (firstLineNumber == -1 || firstLineNumber == 0) return @@ -532,7 +635,7 @@ class RawInstListBuilder( value: JcRawSimpleValue, expr: JcRawSimpleValue ) { - val assignment = JcRawAssignInst(method, value, expr) + val assignment = createRawAssign(method, value, expr) if (insn.isTerminateInst) { insnList.addInst(assignment, insnList.lastIndex) } else if (insn.isBranchingInst) { @@ -547,7 +650,7 @@ class RawInstListBuilder( insnList.addInst(assignment, branchInstIdx) if (branchInst.key.dependsOn(value)) { val freshVar = generateFreshLocalVar(branchInst.key.typeName) - insnList.addInst(JcRawAssignInst(method, freshVar, branchInst.key), index = 0) + insnList.addInst(createRawAssign(method, freshVar, branchInst.key), index = 0) insnList[branchInstIdx + 2] = JcRawSwitchInst( owner = branchInst.owner, key = freshVar, @@ -561,7 +664,7 @@ class RawInstListBuilder( insnList.addInst(assignment, branchInstIdx) if (branchInst.condition.dependsOn(value)) { val freshVar = generateFreshLocalVar(value.typeName) - insnList.addInst(JcRawAssignInst(method, freshVar, value), index = 0) + insnList.addInst(createRawAssign(method, freshVar, value), index = 0) val updatedCondition = branchInst.condition.replace(fromValue = value, toValue = freshVar) insnList[branchInstIdx + 2] = JcRawIfInst( @@ -670,7 +773,7 @@ class RawInstListBuilder( val value = frame.findLocal(variable) ?: return@Array null if (value is JcRawLocalVar && value.typeName != type && type !in blackListForTypeRefinement) { - JcRawLocalVar(value.index, value.name, type).also { newLocal -> + JcRawLocalVar(value.index, value.name, type, value.kind).also { newLocal -> localTypeRefinement[value] = newLocal } } else { @@ -684,7 +787,7 @@ class RawInstListBuilder( .map { (index, value) -> val type = stackTypes[index] if (value is JcRawLocalVar && value.typeName != type && type !in blackListForTypeRefinement) { - JcRawLocalVar(value.index, value.name, type).also { newLocal -> + JcRawLocalVar(value.index, value.name, type, value.kind).also { newLocal -> localTypeRefinement[value] = newLocal } } else value @@ -775,8 +878,8 @@ class RawInstListBuilder( expr: JcRawSimpleValue, insn: AbstractInsnNode, ): JcRawAssignInst? { + val infoFromLocalVars = findLocalVariableWithInstruction(variable, insn) val oldVar = currentFrame.findLocal(variable)?.let { - val infoFromLocalVars = findLocalVariableWithInstruction(variable, insn) val isArg = variable < argCounter && infoFromLocalVars != null && infoFromLocalVars.start == firstLabelOrNull if (expr.typeName.isPrimitive.xor(it.typeName.isPrimitive) @@ -788,26 +891,37 @@ class RawInstListBuilder( it } } + + fun createAssignWithNextDeclared(type: TypeName): JcRawAssignInst { + val assignment = nextRegisterDeclaredVariable(type, variable, insn) + currentFrame = currentFrame.putLocal(variable, assignment) + return createRawAssign(method, assignment, expr) + } + return if (oldVar != null) { if (oldVar is JcRawArgument) { currentFrame = currentFrame.putLocal(variable, expr) null } else if (oldVar.typeName == expr.typeName || (expr is JcRawNullConstant && !oldVar.typeName.isPrimitive)) { if (!currentFrame.valueUsedInStack(oldVar) && !currentFrame.valueUsedInLocals(oldVar, variable)) { - // optimization: old variable has no other usages. So we can overwrite it - JcRawAssignInst(method, oldVar, expr) + // reuse register if the same variable is used; otherwise, create a new register + // it helps to track down which registers represent which variables from the original code + if (oldVar is JcRawLocalVar && infoFromLocalVars?.name == registerToLocalName[oldVar.index]) { + createRawAssign(method, oldVar, expr) + } + else { + createAssignWithNextDeclared(expr.typeName) + } } else { currentFrame = currentFrame.putLocal(variable, expr) null } } else { - val assignment = nextRegisterDeclaredVariable(expr.typeName, variable, insn) - currentFrame = currentFrame.putLocal(variable, assignment) - JcRawAssignInst(method, assignment, expr) + createAssignWithNextDeclared(expr.typeName) } } else { // We have to get type if rhv expression is NULL - val typeOfNewAssigment = + val typeOfNewAssignment = if (expr.typeName.typeName == PredefinedPrimitives.Null) { findLocalVariableWithInstruction(variable, insn) ?.desc?.typeNameFromJvmName() @@ -816,10 +930,7 @@ class RawInstListBuilder( } else { expr.typeName } - val newLocal = nextRegisterDeclaredVariable(typeOfNewAssigment, variable, insn) - val result = JcRawAssignInst(method, newLocal, expr) - currentFrame = currentFrame.putLocal(variable, newLocal) - result + createAssignWithNextDeclared(typeOfNewAssignment) } } @@ -881,7 +992,8 @@ class RawInstListBuilder( val lvName = lvNode?.name?.takeIf { keepLocalVariableNames } if (lvName != null) { - nextVar = nextVar.copy(name = lvName) + nextVar = nextVar.copy(name = lvName, kind = LocalVarKind.NAMED_LOCAL) + registerToLocalName[nextVar.index] = lvName } val declaredTypeName = lvNode?.desc?.typeNameFromJvmName() @@ -1050,7 +1162,7 @@ class RawInstListBuilder( val read = JcRawArrayAccess(arrayRef, index, arrayRef.typeName.elementType()) val assignment = nextRegister(read.typeName) - addInstruction(insn, JcRawAssignInst(method, assignment, read)) + addInstruction(insn, createRawAssign(method, assignment, read)) push(assignment) } @@ -1059,7 +1171,7 @@ class RawInstListBuilder( val index = pop() val arrayRef = pop() addInstruction( - insn, JcRawAssignInst( + insn, createRawAssign( method, JcRawArrayAccess(arrayRef, index, arrayRef.typeName.elementType()), value @@ -1204,7 +1316,7 @@ class RawInstListBuilder( else -> error("Unknown binary opcode: $opcode") } val assignment = nextRegister(resolvedType) - addInstruction(insn, JcRawAssignInst(method, assignment, expr)) + addInstruction(insn, createRawAssign(method, assignment, expr)) push(assignment) } @@ -1234,7 +1346,7 @@ class RawInstListBuilder( else -> error("Unknown unary opcode $opcode") } val assignment = nextRegister(expr.typeName) - addInstruction(insn, JcRawAssignInst(method, assignment, expr)) + addInstruction(insn, createRawAssign(method, assignment, expr)) push(assignment) } @@ -1251,7 +1363,7 @@ class RawInstListBuilder( else -> error("Unknown cast opcode $opcode") } val assignment = nextRegister(targetType) - addInstruction(insn, JcRawAssignInst(method, assignment, JcRawCastExpr(targetType, operand))) + addInstruction(insn, createRawAssign(method, assignment, JcRawCastExpr(targetType, operand))) push(assignment) } @@ -1265,7 +1377,7 @@ class RawInstListBuilder( else -> error("Unknown cmp opcode $opcode") } val assignment = nextRegister(PredefinedPrimitives.Int.typeName()) - addInstruction(insn, JcRawAssignInst(method, assignment, expr)) + addInstruction(insn, createRawAssign(method, assignment, expr)) push(assignment) } @@ -1306,7 +1418,7 @@ class RawInstListBuilder( Opcodes.GETFIELD -> { val assignment = nextRegister(fieldType) val field = JcRawFieldRef(frame.pop(), declaringClass, fieldName, fieldType) - addInstruction(insnNode, JcRawAssignInst(method, assignment, field)) + addInstruction(insnNode, createRawAssign(method, assignment, field)) frame.push(assignment) } @@ -1314,20 +1426,20 @@ class RawInstListBuilder( val value = frame.pop() val instance = frame.pop() val fieldRef = JcRawFieldRef(instance, declaringClass, fieldName, fieldType) - addInstruction(insnNode, JcRawAssignInst(method, fieldRef, value)) + addInstruction(insnNode, createRawAssign(method, fieldRef, value)) } Opcodes.GETSTATIC -> { val assignment = nextRegister(fieldType) val field = JcRawFieldRef(declaringClass, fieldName, fieldType) - addInstruction(insnNode, JcRawAssignInst(method, assignment, field)) + addInstruction(insnNode, createRawAssign(method, assignment, field)) frame.push(assignment) } Opcodes.PUTSTATIC -> { val value = frame.pop() val fieldRef = JcRawFieldRef(declaringClass, fieldName, fieldType) - addInstruction(insnNode, JcRawAssignInst(method, fieldRef, value)) + addInstruction(insnNode, createRawAssign(method, fieldRef, value)) } } } @@ -1368,7 +1480,7 @@ class RawInstListBuilder( val resolvedType = resolveType(local.typeName, rhv.typeName) val expr = JcRawAddExpr(resolvedType, local, rhv) val assignment = nextRegister(resolvedType) - addInstruction(insnNode, JcRawAssignInst(method, assignment, expr)) + addInstruction(insnNode, createRawAssign(method, assignment, expr)) val localAssign = local(variable, assignment, insnNode) if (localAssign != null) { @@ -1384,7 +1496,7 @@ class RawInstListBuilder( Opcodes.NEWARRAY -> { val expr = JcRawNewArrayExpr(operand.toPrimitiveType().asArray(), pop()) val assignment = nextRegister(expr.typeName) - addInstruction(insnNode, JcRawAssignInst(method, assignment, expr)) + addInstruction(insnNode, createRawAssign(method, assignment, expr)) push(assignment) } @@ -1443,7 +1555,7 @@ class RawInstListBuilder( addInstruction(insnNode, JcRawCallInst(method, expr)) } else { val result = nextRegister(Type.getReturnType(desc).descriptor.typeNameFromJvmName()) - addInstruction(insnNode, JcRawAssignInst(method, result, expr)) + addInstruction(insnNode, createRawAssign(method, result, expr)) push(result) } } @@ -1743,7 +1855,7 @@ class RawInstListBuilder( is Type -> { val assignment = nextRegister(CLASS_CLASS.typeNameFromJvmName()) addInstruction( - insnNode, JcRawAssignInst( + insnNode, createRawAssign( method, assignment, when (cst.sort) { @@ -1763,7 +1875,7 @@ class RawInstListBuilder( is Handle -> { val assignment = nextRegister(CLASS_CLASS.typeNameFromJvmName()) addInstruction( - insnNode, JcRawAssignInst( + insnNode, createRawAssign( method, assignment, ldcValue(cst) @@ -1794,7 +1906,7 @@ class RawInstListBuilder( else -> { val lookupAssignment = nextRegister(METHOD_HANDLES_LOOKUP_CLASS.typeNameFromJvmName()) addInstruction( - insnNode, JcRawAssignInst( + insnNode, createRawAssign( method, lookupAssignment, JcRawStaticCallExpr( @@ -1820,7 +1932,7 @@ class RawInstListBuilder( ) } } - addInstruction(insnNode, JcRawAssignInst(method, assignment, methodCall)) + addInstruction(insnNode, createRawAssign(method, assignment, methodCall)) push(assignment) } @@ -1897,7 +2009,7 @@ class RawInstListBuilder( addInstruction(insnNode, JcRawCallInst(method, expr)) } else { val result = nextRegister(Type.getReturnType(insnNode.desc).descriptor.typeNameFromJvmName()) - addInstruction(insnNode, JcRawAssignInst(method, result, expr)) + addInstruction(insnNode, createRawAssign(method, result, expr)) push(result) } } @@ -1909,7 +2021,7 @@ class RawInstListBuilder( } val expr = JcRawNewArrayExpr(insnNode.desc.typeNameFromJvmName(), dimensions.reversed()) val assignment = nextRegister(expr.typeName) - addInstruction(insnNode, JcRawAssignInst(method, assignment, expr)) + addInstruction(insnNode, createRawAssign(method, assignment, expr)) push(assignment) } @@ -1927,7 +2039,7 @@ class RawInstListBuilder( when (insnNode.opcode) { Opcodes.NEW -> { val assignment = nextRegister(type) - addInstruction(insnNode, JcRawAssignInst(method, assignment, JcRawNewExpr(type))) + addInstruction(insnNode, createRawAssign(method, assignment, JcRawNewExpr(type))) push(assignment) } @@ -1935,7 +2047,7 @@ class RawInstListBuilder( val length = pop() val assignment = nextRegister(type.asArray()) addInstruction( - insnNode, JcRawAssignInst( + insnNode, createRawAssign( method, assignment, JcRawNewArrayExpr(type.asArray(), length) @@ -1946,14 +2058,14 @@ class RawInstListBuilder( Opcodes.CHECKCAST -> { val assignment = nextRegister(type) - addInstruction(insnNode, JcRawAssignInst(method, assignment, JcRawCastExpr(type, pop()))) + addInstruction(insnNode, createRawAssign(method, assignment, JcRawCastExpr(type, pop()))) push(assignment) } Opcodes.INSTANCEOF -> { val assignment = nextRegister(PredefinedPrimitives.Boolean.typeName()) addInstruction( - insnNode, JcRawAssignInst( + insnNode, createRawAssign( method, assignment, JcRawInstanceOfExpr(PredefinedPrimitives.Boolean.typeName(), pop(), type) diff --git a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/Simplifier.kt b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/Simplifier.kt index cf73da2d2..1b40a583b 100644 --- a/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/Simplifier.kt +++ b/jacodb-core/src/main/kotlin/org/jacodb/impl/cfg/Simplifier.kt @@ -294,7 +294,7 @@ internal class Simplifier { } val replacement = types.filterValues { it.size > 1 } .mapValues { - JcRawLocalVar(it.key.index, it.key.name, it.key.typeName) + JcRawLocalVar(it.key.index, it.key.name, it.key.typeName, it.key.kind) } return instList.map(ExprMapper(replacement.toMap())) }