From cc04ae79642a6486eed160bd14e6c5f6333b973c Mon Sep 17 00:00:00 2001 From: Guillermo Orellana Date: Thu, 7 Nov 2024 14:52:45 +0100 Subject: [PATCH 1/3] make Transaction and Message immutable also tries to cleanup a bit the serialisation logic --- libs.versions.toml | 2 + solana-kotlin/build.gradle.kts | 1 + .../solana/domain/core/AccountKeysList.kt | 62 ++---- .../solana/domain/core/CompiledInstruction.kt | 36 ---- .../avianlabs/solana/domain/core/Message.kt | 191 ++++-------------- .../solana/domain/core/SerializeMessage.kt | 119 +++++++++++ .../solana/domain/core/SignedTransaction.kt | 64 ++++++ .../solana/domain/core/Transaction.kt | 99 +++++---- .../solana/domain/core/TransactionBuilder.kt | 23 --- .../domain/core/TransactionInstruction.kt | 4 + .../solana/methods/sendTransaction.kt | 4 +- .../solana/methods/simulateTransaction.kt | 4 +- .../solana/vendor/ShortvecEncoding.kt | 6 +- .../solana/domain/core/AccountKeysListTest.kt | 14 +- .../solana/domain/core/MessageTest.kt | 3 +- .../solana/domain/core/TransactionTest.kt | 40 ++++ .../solana/domain/crypto/CryptoEngineTest.kt | 6 +- ...emProgramTest.kt => RPCIntegrationTest.kt} | 69 +++++-- 18 files changed, 408 insertions(+), 339 deletions(-) delete mode 100644 solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/CompiledInstruction.kt create mode 100644 solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt create mode 100644 solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt delete mode 100644 solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionBuilder.kt create mode 100644 solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/TransactionTest.kt rename solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/{SystemProgramTest.kt => RPCIntegrationTest.kt} (68%) diff --git a/libs.versions.toml b/libs.versions.toml index 038a3f1..9afe7a1 100644 --- a/libs.versions.toml +++ b/libs.versions.toml @@ -20,6 +20,7 @@ junitPioneer = "2.3.0" kermit = "2.0.4" khash = "1.1.3" kotlinxCoroutines = "1.9.0" +kotlinLogging = "7.0.0" ktor = "3.0.1" okhttp = "4.12.0" okio = "3.9.1" @@ -34,6 +35,7 @@ coroutinesCore = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", ver coroutinesJdk8 = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-jdk8", version.ref = "kotlinxCoroutines" } kermit = { module = "co.touchlab:kermit", version.ref = "kermit" } kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin" } +kotlinLogging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" } ktorClientContentNegotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" } ktorClientCio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } ktorClientCore = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } diff --git a/solana-kotlin/build.gradle.kts b/solana-kotlin/build.gradle.kts index be68396..380dcc6 100644 --- a/solana-kotlin/build.gradle.kts +++ b/solana-kotlin/build.gradle.kts @@ -62,6 +62,7 @@ kotlin { implementation(libs.kermit) implementation(libs.okio) implementation(libs.skie.configurationAnnotations) + implementation(libs.kotlinLogging) } } val commonTest by getting { diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/AccountKeysList.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/AccountKeysList.kt index 5a3d81f..e77f4a2 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/AccountKeysList.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/AccountKeysList.kt @@ -1,48 +1,30 @@ package net.avianlabs.solana.domain.core -public class AccountKeysList { - private val accounts: LinkedHashMap = LinkedHashMap() - - public fun add(accountMeta: AccountMeta) { - val key = accountMeta.publicKey.toString() - val existing = accounts[key] - if (existing != null) { - accounts[key] = existing.copy( - isSigner = accountMeta.isSigner || existing.isSigner, - isWritable = accountMeta.isWritable || existing.isWritable, +internal fun List.normalize(): List = groupBy { it.publicKey } + .mapValues { (_, metas) -> + metas.reduce { acc, meta -> + AccountMeta( + publicKey = acc.publicKey, + isSigner = acc.isSigner || meta.isSigner, + isWritable = acc.isWritable || meta.isWritable, ) - } else { - accounts[key] = accountMeta - } - } - - public fun addAll(metas: Collection) { - for (meta in metas) { - add(meta) } } + .values + .sortedWith(metaComparator) + .toList() - public val list: ArrayList - get() { - val accountKeysList = ArrayList(accounts.values) - accountKeysList.sortWith(metaComparator) - return accountKeysList - } - - public companion object { - private val metaComparator = Comparator { am1, am2 -> - // first sort by signer, then writable - if (am1.isSigner && !am2.isSigner) { - -1 - } else if (!am1.isSigner && am2.isSigner) { - 1 - } else if (am1.isWritable && !am2.isWritable) { - -1 - } else if (!am1.isWritable && am2.isWritable) { - 1 - } else { - 0 - } - } +private val metaComparator = Comparator { am1, am2 -> + // first sort by signer, then writable + if (am1.isSigner && !am2.isSigner) { + -1 + } else if (!am1.isSigner && am2.isSigner) { + 1 + } else if (am1.isWritable && !am2.isWritable) { + -1 + } else if (!am1.isWritable && am2.isWritable) { + 1 + } else { + 0 } } diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/CompiledInstruction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/CompiledInstruction.kt deleted file mode 100644 index 7a40354..0000000 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/CompiledInstruction.kt +++ /dev/null @@ -1,36 +0,0 @@ -package net.avianlabs.solana.domain.core - -public data class CompiledInstruction( - /** - * Index into the transaction keys array indicating the program account that executes this instruction - */ - val programIdIndex: Byte = 0, - - /** - * Ordered indices into the transaction keys array indicating which accounts to pass to the program - */ - val accounts: List, - - /** - * The program input data - */ - val data: ByteArray, -) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other == null || this::class != other::class) return false - - other as CompiledInstruction - - if (programIdIndex != other.programIdIndex) return false - if (accounts != other.accounts) return false - return data.contentEquals(other.data) - } - - override fun hashCode(): Int { - var result = programIdIndex.toInt() - result = 31 * result + accounts.hashCode() - result = 31 * result + data.contentHashCode() - return result - } -} diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt index e2478ff..d1772d6 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt @@ -1,173 +1,52 @@ package net.avianlabs.solana.domain.core -import net.avianlabs.solana.tweetnacl.TweetNaCl import net.avianlabs.solana.tweetnacl.ed25519.PublicKey -import net.avianlabs.solana.vendor.ShortvecEncoding -import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 -import okio.Buffer -public class Message( - public var feePayer: PublicKey? = null, - public var recentBlockHash: String? = null, - accountKeys: AccountKeysList = AccountKeysList(), - instructions: List = emptyList(), +public class Message internal constructor( + public val feePayer: PublicKey?, + public val recentBlockHash: String?, + public val accountKeys: List, + public val instructions: List, ) { - private val _accountKeys: AccountKeysList = accountKeys - private val _instructions: MutableList = instructions.toMutableList() - - - public val accountKeys: List - get() = _accountKeys.list - - public val instructions: List - get() = _instructions - - private class MessageHeader { - var numRequiredSignatures: Byte = 0 - var numReadonlySignedAccounts: Byte = 0 - var numReadonlyUnsignedAccounts: Byte = 0 - fun toByteArray(): ByteArray { - return byteArrayOf( - numRequiredSignatures, - numReadonlySignedAccounts, - numReadonlyUnsignedAccounts - ) - } + override fun toString(): String = + "Message(feePayer=$feePayer, recentBlockHash=$recentBlockHash, accountKeys=$accountKeys, instructions=$instructions)" - override fun toString(): String { - return "numRequiredSignatures: $numRequiredSignatures, numReadOnlySignedAccounts: $numReadonlySignedAccounts, numReadOnlyUnsignedAccounts: $numReadonlyUnsignedAccounts" - } + public fun newBuilder(): Builder = Builder( + feePayer = feePayer, + recentBlockHash = recentBlockHash, + accountKeys = accountKeys.toMutableList(), + instructions = instructions.toMutableList(), + ) - companion object { - const val HEADER_LENGTH = 3 + public class Builder internal constructor( + private var feePayer: PublicKey?, + private var recentBlockHash: String?, + private var accountKeys: MutableList, + private var instructions: MutableList, + ) { + public constructor() : this(null, null, mutableListOf(), mutableListOf()) - fun fromByteArray(bytes: ByteArray): MessageHeader { - val header = MessageHeader() - header.numRequiredSignatures = bytes[0] - header.numReadonlySignedAccounts = bytes[1] - header.numReadonlyUnsignedAccounts = bytes[2] - return header - } + public fun setFeePayer(feePayer: PublicKey): Builder { + this.feePayer = feePayer + return this } - } - - private class CompiledInstruction { - var programIdIndex: Byte = 0 - lateinit var keyIndicesCount: ByteArray - lateinit var keyIndices: ByteArray - lateinit var dataLength: ByteArray - lateinit var data: ByteArray - - // 1 = programIdIndex length - val length: Int - get() =// 1 = programIdIndex length - 1 + keyIndicesCount.size + keyIndices.size + dataLength.size + data.size - } - public fun addInstruction(instruction: TransactionInstruction): Message { - _accountKeys.addAll(instruction.keys) - _accountKeys.add(AccountMeta(instruction.programId, false, false)) - _instructions.add(instruction) - return this - } - - public fun serialize(): ByteArray { - requireNotNull(recentBlockHash) { "recentBlockhash required" } - require(_instructions.size != 0) { "No instructions provided" } - val messageHeader = MessageHeader() - val keysList = compileAccountKeys() - val accountKeysSize = keysList.size - val accountAddressesLength = ShortvecEncoding.encodeLength(accountKeysSize) - var compiledInstructionsLength = 0 - val compiledInstructions: MutableList = ArrayList() - for (instruction in _instructions) { - val keysSize = instruction.keys.size - val keyIndices = ByteArray(keysSize) - for (i in 0 until keysSize) { - keyIndices[i] = findAccountIndex(keysList, instruction.keys[i].publicKey).toByte() - } - val compiledInstruction = CompiledInstruction() - compiledInstruction.programIdIndex = - findAccountIndex(keysList, instruction.programId).toByte() - compiledInstruction.keyIndicesCount = ShortvecEncoding.encodeLength(keysSize) - compiledInstruction.keyIndices = keyIndices - compiledInstruction.dataLength = ShortvecEncoding.encodeLength(instruction.data.count()) - compiledInstruction.data = instruction.data - compiledInstructions.add(compiledInstruction) - compiledInstructionsLength += compiledInstruction.length - } - val instructionsLength = ShortvecEncoding.encodeLength(compiledInstructions.size) - val accountsKeyBufferSize = accountKeysSize * TweetNaCl.Signature.PUBLIC_KEY_BYTES - val bufferSize = - (MessageHeader.HEADER_LENGTH + RECENT_BLOCK_HASH_LENGTH + accountAddressesLength.size - + accountsKeyBufferSize + instructionsLength.size - + compiledInstructionsLength) - val out = Buffer() - val accountKeysBuff = Buffer() - for (accountMeta in keysList) { - accountKeysBuff.write(accountMeta.publicKey.toByteArray()) - if (accountMeta.isSigner) { - messageHeader.numRequiredSignatures = - (messageHeader.numRequiredSignatures.plus(1)).toByte() - if (!accountMeta.isWritable) { - messageHeader.numReadonlySignedAccounts = - (messageHeader.numReadonlySignedAccounts.plus(1)).toByte() - } - } else { - if (!accountMeta.isWritable) { - messageHeader.numReadonlyUnsignedAccounts = - (messageHeader.numReadonlyUnsignedAccounts.plus(1)).toByte() - } - } + public fun setRecentBlockHash(recentBlockHash: String): Builder { + this.recentBlockHash = recentBlockHash + return this } - out.write(messageHeader.toByteArray()) - out.write(accountAddressesLength) - out.write(accountKeysBuff, accountsKeyBufferSize.toLong()) - out.write(recentBlockHash!!.decodeBase58()) - out.write(instructionsLength) - for (compiledInstruction in compiledInstructions) { - out.writeByte(compiledInstruction.programIdIndex.toInt()) - out.write(compiledInstruction.keyIndicesCount) - out.write(compiledInstruction.keyIndices) - out.write(compiledInstruction.dataLength) - out.write(compiledInstruction.data) - } - return out.readByteArray(bufferSize.toLong()) - } - private fun compileAccountKeys(): List { - val keysList: MutableList = _accountKeys.list - val newList: MutableList = ArrayList() - try { - val feePayerIndex = findAccountIndex(keysList, feePayer!!) - val feePayerMeta = keysList[feePayerIndex] - newList.add(AccountMeta(feePayerMeta.publicKey, true, true)) - keysList.removeAt(feePayerIndex) - } catch (e: RuntimeException) { // Fee payer not yet in list - newList.add(AccountMeta(feePayer!!, true, true)) + public fun addInstruction(instruction: TransactionInstruction): Builder { + accountKeys.addAll( + instruction.keys + + AccountMeta(instruction.programId, isSigner = false, isWritable = false) + ) + instructions += instruction + return this } - newList.addAll(keysList) - return newList - } - private fun findAccountIndex(accountMetaList: List, key: PublicKey): Int { - for (i in accountMetaList.indices) { - if (accountMetaList[i].publicKey.equals(key)) { - return i - } - } - throw RuntimeException("unable to find account index") + public fun build(): Message = + Message(feePayer, recentBlockHash, accountKeys.normalize(), instructions) } - - override fun toString(): String = - """Message( - | header: not set, - | accountKeys: [${_accountKeys.list.joinToString()}], - | recentBlockhash: $recentBlockHash, - | instructions: [${_instructions.joinToString()}] - |)""".trimMargin() - } - -private const val RECENT_BLOCK_HASH_LENGTH = 32 diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt new file mode 100644 index 0000000..85e0de2 --- /dev/null +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt @@ -0,0 +1,119 @@ +package net.avianlabs.solana.domain.core + +import net.avianlabs.solana.tweetnacl.TweetNaCl +import net.avianlabs.solana.tweetnacl.ed25519.PublicKey +import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 +import net.avianlabs.solana.vendor.ShortVecLength +import net.avianlabs.solana.vendor.ShortvecEncoding +import okio.Buffer + +private const val RECENT_BLOCK_HASH_LENGTH = 32 + +private class Header private constructor( + val numRequiredSignatures: Byte, + val numReadonlySignedAccounts: Byte, + val numReadonlyUnsignedAccounts: Byte, +) { + + constructor(accountKeys: List) : this( + numRequiredSignatures = accountKeys.count { it.isSigner }.toByte(), + numReadonlySignedAccounts = accountKeys.count { it.isSigner && !it.isWritable }.toByte(), + numReadonlyUnsignedAccounts = accountKeys.count { !it.isSigner && !it.isWritable }.toByte(), + ) + + fun toByteArray(): ByteArray = byteArrayOf( + numRequiredSignatures, + numReadonlySignedAccounts, + numReadonlyUnsignedAccounts + ) + + override fun toString(): String = + "numRequiredSignatures: $numRequiredSignatures, " + + "numReadOnlySignedAccounts: $numReadonlySignedAccounts, " + + "numReadOnlyUnsignedAccounts: $numReadonlyUnsignedAccounts" + + companion object { + const val HEADER_LENGTH = 3 + + fun fromByteArray(bytes: ByteArray): Header = Header( + numRequiredSignatures = bytes[0], + numReadonlySignedAccounts = bytes[1], + numReadonlyUnsignedAccounts = bytes[2], + ) + } +} + +private class CompiledInstruction( + val programIdIndex: Byte, + val keyIndicesLength: ShortVecLength, + val keyIndices: ByteArray, + val dataLength: ShortVecLength, + val data: ByteArray, +) { + + val bytes: Int + get() = + 1 + // programIdIndex is one byte + keyIndicesLength.size + keyIndices.size + dataLength.size + data.size +} + +private fun List.findAccountIndex(key: PublicKey): Int = + indexOfFirst { it.publicKey == key } + .takeIf { it != -1 } ?: error("Account $key not found") + +private fun Message.compileAccountKeys(feePayer: PublicKey): List = + // ensure fee payer is the first account + listOf(AccountMeta(feePayer, isSigner = true, isWritable = true)) + + (accountKeys.filter { it.publicKey != feePayer }) + +public fun Message.serialize(): ByteArray { + requireNotNull(feePayer) { "feePayer required" } + requireNotNull(recentBlockHash) { "recentBlockhash required" } + require(instructions.isNotEmpty()) { "No instructions provided" } + val keysList = compileAccountKeys(feePayer) + val accountKeysSize = keysList.size + val accountAddressesLength = ShortvecEncoding.encodeLength(accountKeysSize) + val compiledInstructions = instructions.map { instruction -> + val keysSize = instruction.keys.size + val keyIndices = ByteArray(keysSize) + for (i in 0 until keysSize) { + keyIndices[i] = keysList.findAccountIndex(instruction.keys[i].publicKey).toByte() + } + CompiledInstruction( + programIdIndex = keysList.findAccountIndex(instruction.programId).toByte(), + keyIndicesLength = ShortvecEncoding.encodeLength(keysSize), + keyIndices = keyIndices, + dataLength = ShortvecEncoding.encodeLength(instruction.data.count()), + data = instruction.data, + ) + } + val accountsKeyBufferSize = accountKeysSize * TweetNaCl.Signature.PUBLIC_KEY_BYTES + val instructionsLength = ShortvecEncoding.encodeLength(compiledInstructions.size) + val compiledInstructionsBytes = compiledInstructions.sumOf { it.bytes } + val bufferSize = + (Header.HEADER_LENGTH + RECENT_BLOCK_HASH_LENGTH + accountAddressesLength.size + + accountsKeyBufferSize + instructionsLength.size + + compiledInstructionsBytes) + + val accountKeysBytes = keysList.map { accountMeta -> + accountMeta.publicKey.toByteArray() + }.reduce { acc, bytes -> acc + bytes } + + val messageHeader = Header(keysList) + + val buffer = Buffer().apply { + write(messageHeader.toByteArray()) + write(accountAddressesLength) + write(accountKeysBytes) + write(recentBlockHash.decodeBase58()) + write(instructionsLength) + for (compiledInstruction in compiledInstructions) { + writeByte(compiledInstruction.programIdIndex.toInt()) + write(compiledInstruction.keyIndicesLength) + write(compiledInstruction.keyIndices) + write(compiledInstruction.dataLength) + write(compiledInstruction.data) + } + } + return buffer.readByteArray(bufferSize.toLong()) +} diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt new file mode 100644 index 0000000..ddd9d7b --- /dev/null +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt @@ -0,0 +1,64 @@ +package net.avianlabs.solana.domain.core + +import io.github.oshai.kotlinlogging.KotlinLogging +import net.avianlabs.solana.tweetnacl.TweetNaCl +import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 +import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String +import net.avianlabs.solana.vendor.ShortvecEncoding +import okio.Buffer + +private val logger = KotlinLogging.logger {} + +public class SignedTransaction( + public val originalMessage: Message, + public val signedMessage: ByteArray, + public val signatures: List, +) : Transaction(originalMessage) { + + public override fun sign(signers: List): SignedTransaction = SignedTransaction( + originalMessage = originalMessage, + signedMessage = signedMessage, + signatures = signatures + signers.map { signer -> + TweetNaCl.Signature.sign(signedMessage, signer.secretKey).encodeToBase58String() + } + ).also { + val signatureSet = signatures.toSet() + if (signatureSet.size != signatures.size) { + logger.warn { "Duplicate signatures detected" } + } + } + + override fun toString(): String = + "SignedTransaction(message=${originalMessage}, signatures=$signatures)" + + public fun serialize(): ByteArray { + val signaturesSize = signatures.size + val signaturesLength = ShortvecEncoding.encodeLength(signaturesSize) + val bufferSize = + signaturesLength.size + + signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES + + signedMessage.size + val out = Buffer() + out.write(signaturesLength) + for (signature in signatures) { + val rawSignature = signature.decodeBase58() + out.write(rawSignature) + } + out.write(signedMessage) + return out.readByteArray(bufferSize.toLong()) + } + + public fun validate(): Boolean { + val message = signedMessage + val messageLength = message.size + val signaturesSize = signatures.size + val signaturesLength = ShortvecEncoding.encodeLength(signaturesSize) + val signaturesSizeBytes = signaturesLength.size + val signatureSize = TweetNaCl.Signature.SIGNATURE_BYTES + val signatureSizeBytes = ShortvecEncoding.encodeLength(signatureSize) + val signatureSizeBytesLength = signatureSizeBytes.size + val expectedSize = + messageLength + signaturesSizeBytes + signaturesSize * (signatureSize + signatureSizeBytesLength) + return expectedSize == serialize().size + } +} diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt index 1da293c..0688b43 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt @@ -1,74 +1,69 @@ package net.avianlabs.solana.domain.core +import io.github.oshai.kotlinlogging.KotlinLogging import net.avianlabs.solana.tweetnacl.TweetNaCl import net.avianlabs.solana.tweetnacl.ed25519.PublicKey -import net.avianlabs.solana.vendor.ShortvecEncoding -import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String -import okio.Buffer -public class Transaction( - public val message: Message = Message(), - private val _signatures: MutableList = mutableListOf() +private val logger = KotlinLogging.logger {} + +public open class Transaction internal constructor( + public val message: Message, ) { - public val signatures: List - get() = _signatures + public fun sign(signer: Signer): SignedTransaction = sign(listOf(signer)) - private lateinit var serializedMessage: ByteArray + public open fun sign(signers: List): SignedTransaction { + val message = when (message.feePayer) { + // fee payer is the first signer by default + null -> message.newBuilder() + .setFeePayer(signers.first().publicKey) + .build() - public fun addInstruction(instruction: TransactionInstruction): Transaction { - message.addInstruction(instruction) - return this - } + else -> message + } - public fun setRecentBlockHash(recentBlockhash: String): Transaction { - message.recentBlockHash = recentBlockhash - return this - } + val serializedMessage = message + .serialize() - public fun setFeePayer(feePayer: PublicKey?): Transaction { - message.feePayer = feePayer - return this + val signatures = signers.map { signer -> + TweetNaCl.Signature.sign(serializedMessage, signer.secretKey).encodeToBase58String() + } + + val signatureSet = signatures.toSet() + if (signatureSet.size != signatures.size) { + logger.warn { "Duplicate signatures detected" } + } + + return SignedTransaction( + originalMessage = message, + signedMessage = serializedMessage, + signatures = signatures, + ) } - public fun sign(signer: Signer): Transaction = sign(listOf(signer)) + override fun toString(): String = "Transaction(message=$message)" + + public class Builder internal constructor( + private var messageBuilder: Message.Builder, + ) { + public constructor() : this(Message.Builder()) - public fun sign(signers: List): Transaction { - require(signers.isNotEmpty()) { "No signers" } - // Fee payer defaults to first signer if not set - message.feePayer ?: let { - message.feePayer = signers[0].publicKey + public fun addInstruction(instruction: TransactionInstruction): Builder { + messageBuilder.addInstruction(instruction) + return this } - serializedMessage = message.serialize() - for (signer in signers) { - _signatures.add( - TweetNaCl.Signature.sign(serializedMessage, signer.secretKey).encodeToBase58String() - ) + + public fun setRecentBlockHash(recentBlockHash: String): Builder { + messageBuilder.setRecentBlockHash(recentBlockHash) + return this } - return this - } - public fun serialize(): ByteArray { - val signaturesSize = signatures.size - val signaturesLength = ShortvecEncoding.encodeLength(signaturesSize) - val bufferSize = - signaturesLength.size + signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES + serializedMessage.size - val out = Buffer() - out.write(signaturesLength) - for (signature in signatures) { - val rawSignature = signature.decodeBase58() - out.write(rawSignature) + public fun setFeePayer(feePayer: PublicKey): Builder { + messageBuilder.setFeePayer(feePayer) + return this } - out.write(serializedMessage) - return out.readByteArray(bufferSize.toLong()) - } - override fun toString(): String { - return """Transaction( - | signatures: [${signatures.joinToString()}], - | message: $message - |)""".trimMargin() + public fun build(): Transaction = Transaction(messageBuilder.build()) } - } diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionBuilder.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionBuilder.kt deleted file mode 100644 index c7ccf75..0000000 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionBuilder.kt +++ /dev/null @@ -1,23 +0,0 @@ -package net.avianlabs.solana.domain.core - -public class TransactionBuilder { - private val transaction: Transaction = Transaction() - public fun addInstruction(transactionInstruction: TransactionInstruction): TransactionBuilder { - transaction.addInstruction(transactionInstruction) - return this - } - - public fun setRecentBlockHash(recentBlockHash: String): TransactionBuilder { - transaction.setRecentBlockHash(recentBlockHash) - return this - } - - public fun setSigners(signers: List): TransactionBuilder { - transaction.sign(signers) - return this - } - - public fun build(): Transaction { - return transaction - } -} diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionInstruction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionInstruction.kt index 32984b0..37854df 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionInstruction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/TransactionInstruction.kt @@ -1,5 +1,6 @@ package net.avianlabs.solana.domain.core +import io.ktor.util.* import net.avianlabs.solana.tweetnacl.ed25519.PublicKey public data class TransactionInstruction( @@ -8,6 +9,9 @@ public data class TransactionInstruction( val data: ByteArray, ) { + override fun toString(): String = + "TransactionInstruction(programId=$programId, keys=$keys, data=${data.encodeBase64()})" + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/sendTransaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/sendTransaction.kt index f892a45..54fde4a 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/sendTransaction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/sendTransaction.kt @@ -8,7 +8,7 @@ import kotlinx.serialization.json.put import net.avianlabs.solana.SolanaClient import net.avianlabs.solana.client.Response import net.avianlabs.solana.domain.core.Commitment -import net.avianlabs.solana.domain.core.Transaction +import net.avianlabs.solana.domain.core.SignedTransaction /** * Send a signed transaction to the cluster @@ -21,7 +21,7 @@ import net.avianlabs.solana.domain.core.Transaction * @return The transaction signature */ public suspend fun SolanaClient.sendTransaction( - transaction: Transaction, + transaction: SignedTransaction, skipPreflight: Boolean = false, preflightCommitment: Commitment = Commitment.Finalized, maxRetries: Int? = null, diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/simulateTransaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/simulateTransaction.kt index 32b8cf9..a03a423 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/simulateTransaction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/methods/simulateTransaction.kt @@ -36,7 +36,7 @@ public suspend fun SolanaClient.simulateTransaction( ): Response> = invoke( method = "simulateTransaction", params = buildJsonArray { - add(transaction.serialize().encodeBase64()) + add(transaction.sign(emptyList()).serialize().encodeBase64()) addJsonObject { put("encoding", "base64") commitment?.let { put("commitment", it.value) } @@ -60,6 +60,8 @@ public suspend fun SolanaClient.simulateTransaction( public data class SimulateTransactionResponse( /** * Error if transaction failed, null if transaction succeeded. + * + * can be null, string or object */ val err: JsonElement?, /** diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt index 207e792..427463b 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt @@ -3,8 +3,10 @@ package net.avianlabs.solana.vendor import okio.Buffer import kotlin.experimental.and +internal typealias ShortVecLength = ByteArray + internal object ShortvecEncoding { - internal fun encodeLength(len: Int): ByteArray { + internal fun encodeLength(len: Int): ShortVecLength { val buffer = Buffer() var length = len while (true) { @@ -20,7 +22,7 @@ internal object ShortvecEncoding { return buffer.readByteArray() } - internal fun decodeLength(bytes: ByteArray): Int { + internal fun decodeLength(bytes: ShortVecLength): Int { var len = 0 var shift = 0 for (i in bytes.indices) { diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/AccountKeysListTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/AccountKeysListTest.kt index 0fe82f3..fd24ee0 100644 --- a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/AccountKeysListTest.kt +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/AccountKeysListTest.kt @@ -8,7 +8,7 @@ import kotlin.test.assertTrue class AccountKeysListTest { @Test fun add_different_writable_flags_preserves_isSigner() { - val accountKeysList = AccountKeysList() + val accountKeysList = mutableListOf() val accountMeta1 = AccountMeta( publicKey = PublicKey.fromBase58("4rZoSK72jVaAW1ayZLrefdMPAAStRVhCfH1PSundaoNt"), isSigner = false, @@ -25,14 +25,16 @@ class AccountKeysListTest { accountKeysList.add(accountMeta2) - assertTrue(accountKeysList.list.size == 1) - assertTrue(accountKeysList.list[0].isSigner) - assertTrue(accountKeysList.list[0].isWritable) + val normalized = accountKeysList.normalize() + + assertTrue(normalized.size == 1) + assertTrue(normalized[0].isSigner) + assertTrue(normalized[0].isWritable) } @Test fun account_order() { - val accountKeysList = AccountKeysList() + val accountKeysList = mutableListOf() val meta = listOf( AccountMeta( @@ -60,7 +62,7 @@ class AccountKeysListTest { PublicKey.fromBase58("9JGhZqi4MbnVz424uJ6vqk9a1u359xg3nJekdjzzL4d5"), PublicKey.fromBase58("G8iheDY9bGix5qCXEitCExLcgZzZrEemngk9cbTR3CQs"), ), - accountKeysList.list.map { it.publicKey } + accountKeysList.normalize().map { it.publicKey } ) } } diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/MessageTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/MessageTest.kt index 4da431c..0d07e8b 100644 --- a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/MessageTest.kt +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/MessageTest.kt @@ -10,7 +10,7 @@ class MessageTest { @Test fun account_keys_order() { - val transaction = Transaction() + val transaction = Transaction.Builder() .addInstruction( TokenProgram.transferChecked( source = PublicKey.fromBase58("9JGhZqi4MbnVz424uJ6vqk9a1u359xg3nJekdjzzL4d5"), @@ -26,6 +26,7 @@ class MessageTest { microLamports = 1u, ) ) + .build() transaction.message.accountKeys.map { println(it) } diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/TransactionTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/TransactionTest.kt new file mode 100644 index 0000000..4ca7cc8 --- /dev/null +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/core/TransactionTest.kt @@ -0,0 +1,40 @@ +package net.avianlabs.solana.domain.core + +import io.ktor.util.* +import net.avianlabs.solana.domain.program.ComputeBudgetProgram +import net.avianlabs.solana.domain.program.SystemProgram +import net.avianlabs.solana.tweetnacl.ed25519.Ed25519Keypair +import kotlin.test.Test +import kotlin.test.assertEquals + +class TransactionTest { + @Test + fun test_serialization() { + val keypair = Ed25519Keypair.fromBase58( + "9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy" + ) + + val transaction = Transaction.Builder() + .addInstruction( + SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1) + ) + .addInstruction( + ComputeBudgetProgram.setComputeUnitPrice(1u) + ) + .addInstruction( + SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1) + ) + .setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF") + .build() + .sign(keypair) + + val serialized = transaction.serialize().encodeBase64() + val expected = + "AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAAIDq" + + "OvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMGRm/l" + + "IRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQIAAAwCAAA" + + "AAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA==" + + assertEquals(expected, serialized) + } +} diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/crypto/CryptoEngineTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/crypto/CryptoEngineTest.kt index 515bcfb..1aca5c8 100644 --- a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/crypto/CryptoEngineTest.kt +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/crypto/CryptoEngineTest.kt @@ -1,6 +1,6 @@ package net.avianlabs.solana.domain.crypto -import net.avianlabs.solana.domain.core.TransactionBuilder +import net.avianlabs.solana.domain.core.Transaction import net.avianlabs.solana.domain.program.SystemProgram import net.avianlabs.solana.domain.randomKey import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 @@ -20,7 +20,7 @@ class CryptoEngineTest { @Test fun test_sign() { val keypair = randomKey() - val transaction = TransactionBuilder() + val transaction = Transaction.Builder() .addInstruction( SystemProgram.createAccount( fromPublicKey = keypair.publicKey, @@ -39,4 +39,4 @@ class CryptoEngineTest { assertTrue(signatureArray.size == 64, "wrong signature size ${signatureArray.size}") } -} \ No newline at end of file +} diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/SystemProgramTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/RPCIntegrationTest.kt similarity index 68% rename from solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/SystemProgramTest.kt rename to solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/RPCIntegrationTest.kt index 1c991bd..5f8f61f 100644 --- a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/SystemProgramTest.kt +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/domain/program/RPCIntegrationTest.kt @@ -1,42 +1,48 @@ package net.avianlabs.solana.domain.program import io.ktor.client.* -import io.ktor.client.plugins.logging.* +import io.ktor.util.* import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import net.avianlabs.solana.SolanaClient import net.avianlabs.solana.client.RpcKtorClient import net.avianlabs.solana.domain.core.Commitment import net.avianlabs.solana.domain.core.Transaction -import net.avianlabs.solana.domain.core.TransactionBuilder import net.avianlabs.solana.domain.core.decode +import net.avianlabs.solana.domain.core.serialize import net.avianlabs.solana.methods.* import net.avianlabs.solana.tweetnacl.TweetNaCl +import net.avianlabs.solana.tweetnacl.ed25519.PublicKey import kotlin.random.Random import kotlin.test.Ignore import kotlin.test.Test import kotlin.time.Duration.Companion.seconds -class SystemProgramTest { +/** + * Connects to a local Solana node. This test is ignored in CI. + * + * Run with solana-test-validator --ticks-per-slot 3 + */ +@Ignore +class RPCIntegrationTest { + + private val client = SolanaClient( + client = RpcKtorClient( + "http://localhost:8899", + httpClient = HttpClient {} + ), + ) @Test - @Ignore fun testCreateDurableNonceAccount() = runBlocking { - val client = - SolanaClient(client = RpcKtorClient("http://localhost:8899", httpClient = HttpClient { - install(Logging) { - level = LogLevel.ALL - logger = Logger.SIMPLE - } - })) - val keypair = TweetNaCl.Signature.generateKey(Random.nextBytes(32)) println("Keypair: ${keypair.publicKey}") val nonceAccount = TweetNaCl.Signature.generateKey(Random.nextBytes(32)) println("Nonce account: ${nonceAccount.publicKey}") client.requestAirdrop(keypair.publicKey, 2_000_000_000) - delay(15.seconds) + + delay(1.seconds) val balance = client.getBalance(keypair.publicKey) println("Balance: $balance") @@ -45,7 +51,7 @@ class SystemProgramTest { val blockhash = client.getLatestBlockhash().result!!.value - val initTransaction = Transaction() + val initTransaction = Transaction.Builder() .addInstruction( SystemProgram.createAccount( fromPublicKey = keypair.publicKey, @@ -61,6 +67,7 @@ class SystemProgramTest { ) ) .setRecentBlockHash(blockhash.blockhash) + .build() .sign(listOf(keypair, nonceAccount)) @@ -71,7 +78,7 @@ class SystemProgramTest { val initSignature = client.sendTransaction(initTransaction) println("Initialized nonce account: $initSignature") - delay(15.seconds) + delay(1.seconds) val lamportsPerSignature = client.getFeeForMessage(initTransaction.message.serialize()) println("Lamports per signature: $lamportsPerSignature") @@ -79,7 +86,7 @@ class SystemProgramTest { val nonce = client.getNonce(nonceAccount.publicKey, Commitment.Processed) println("Nonce account info: $nonce") - val testTransaction = TransactionBuilder() + val testTransaction = Transaction.Builder() .addInstruction( SystemProgram.nonceAdvance( nonceAccount = nonceAccount.publicKey, @@ -100,7 +107,7 @@ class SystemProgramTest { val testSignature = client.sendTransaction(testTransaction).result!! println("Advanced nonce account: $testSignature") - delay(15.seconds) + delay(1.seconds) val testTxInfo = client.getTransaction(testSignature, Commitment.Confirmed).result println("Transaction info: ${testTxInfo?.decode()}") @@ -108,4 +115,32 @@ class SystemProgramTest { val newNonce = client.getNonce(nonceAccount.publicKey, Commitment.Processed) println("New nonce account info: $newNonce") } + + @Test + fun testSimulateTransaction() = runBlocking { + val keypair = TweetNaCl.Signature.generateKey(Random.nextBytes(32)) + println("Keypair: ${keypair.publicKey}") + val initTransaction = Transaction.Builder() + .addInstruction( + SystemProgram.transfer( + keypair.publicKey, + PublicKey.fromBase58("11111111111111111111111111111111"), + 1, + ) + ) + + val toSimulate = initTransaction + .setRecentBlockHash("11111111111111111111111111111111") + .setFeePayer(keypair.publicKey) + .build() + + + val serialized = toSimulate.sign(emptyList()).serialize() + println("Serialized: ${serialized.encodeBase64()}") + + val simulated = + client.simulateTransaction(toSimulate, sigVerify = false, replaceRecentBlockhash = true) + + println("simulated: $simulated") + } } From 8a776c42db608ef4d9d7411c7079b7caa8e1f658 Mon Sep 17 00:00:00 2001 From: Guillermo Orellana Date: Thu, 7 Nov 2024 15:09:47 +0100 Subject: [PATCH 2/3] ios changes --- .../kotlin/net/avianlabs/solana/domain/core/Transaction.ios.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/solana-kotlin/src/iosMain/kotlin/net/avianlabs/solana/domain/core/Transaction.ios.kt b/solana-kotlin/src/iosMain/kotlin/net/avianlabs/solana/domain/core/Transaction.ios.kt index 12e7afd..32e8e0f 100644 --- a/solana-kotlin/src/iosMain/kotlin/net/avianlabs/solana/domain/core/Transaction.ios.kt +++ b/solana-kotlin/src/iosMain/kotlin/net/avianlabs/solana/domain/core/Transaction.ios.kt @@ -4,4 +4,4 @@ import kotlinx.cinterop.* import net.avianlabs.solana.tweetnacl.ed25519.toNSData import platform.Foundation.NSData -public fun Transaction.serializeData(): NSData = serialize().toNSData() +public fun SignedTransaction.serializeData(): NSData = serialize().toNSData() From 09eb2d6aba3d712c5acab6f6d4ad2e306f16e255 Mon Sep 17 00:00:00 2001 From: Guillermo Orellana Date: Fri, 8 Nov 2024 16:55:08 +0100 Subject: [PATCH 3/3] pr comments --- .../avianlabs/solana/domain/core/Message.kt | 6 +-- .../solana/domain/core/SerializeMessage.kt | 10 ++--- .../solana/domain/core/SignedTransaction.kt | 37 +++++++++++++++---- .../solana/domain/core/Transaction.kt | 13 +++++++ ...hortvecEncoding.kt => ShortVecEncoding.kt} | 2 +- .../solana/vendor/ShortvecEncodingTest.kt | 18 ++++----- 6 files changed, 60 insertions(+), 26 deletions(-) rename solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/{ShortvecEncoding.kt => ShortVecEncoding.kt} (95%) diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt index d1772d6..3cd3613 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Message.kt @@ -2,16 +2,14 @@ package net.avianlabs.solana.domain.core import net.avianlabs.solana.tweetnacl.ed25519.PublicKey -public class Message internal constructor( +@ConsistentCopyVisibility +public data class Message private constructor( public val feePayer: PublicKey?, public val recentBlockHash: String?, public val accountKeys: List, public val instructions: List, ) { - override fun toString(): String = - "Message(feePayer=$feePayer, recentBlockHash=$recentBlockHash, accountKeys=$accountKeys, instructions=$instructions)" - public fun newBuilder(): Builder = Builder( feePayer = feePayer, recentBlockHash = recentBlockHash, diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt index 85e0de2..eec8ed0 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SerializeMessage.kt @@ -3,8 +3,8 @@ package net.avianlabs.solana.domain.core import net.avianlabs.solana.tweetnacl.TweetNaCl import net.avianlabs.solana.tweetnacl.ed25519.PublicKey import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 +import net.avianlabs.solana.vendor.ShortVecEncoding import net.avianlabs.solana.vendor.ShortVecLength -import net.avianlabs.solana.vendor.ShortvecEncoding import okio.Buffer private const val RECENT_BLOCK_HASH_LENGTH = 32 @@ -72,7 +72,7 @@ public fun Message.serialize(): ByteArray { require(instructions.isNotEmpty()) { "No instructions provided" } val keysList = compileAccountKeys(feePayer) val accountKeysSize = keysList.size - val accountAddressesLength = ShortvecEncoding.encodeLength(accountKeysSize) + val accountAddressesLength = ShortVecEncoding.encodeLength(accountKeysSize) val compiledInstructions = instructions.map { instruction -> val keysSize = instruction.keys.size val keyIndices = ByteArray(keysSize) @@ -81,14 +81,14 @@ public fun Message.serialize(): ByteArray { } CompiledInstruction( programIdIndex = keysList.findAccountIndex(instruction.programId).toByte(), - keyIndicesLength = ShortvecEncoding.encodeLength(keysSize), + keyIndicesLength = ShortVecEncoding.encodeLength(keysSize), keyIndices = keyIndices, - dataLength = ShortvecEncoding.encodeLength(instruction.data.count()), + dataLength = ShortVecEncoding.encodeLength(instruction.data.count()), data = instruction.data, ) } val accountsKeyBufferSize = accountKeysSize * TweetNaCl.Signature.PUBLIC_KEY_BYTES - val instructionsLength = ShortvecEncoding.encodeLength(compiledInstructions.size) + val instructionsLength = ShortVecEncoding.encodeLength(compiledInstructions.size) val compiledInstructionsBytes = compiledInstructions.sumOf { it.bytes } val bufferSize = (Header.HEADER_LENGTH + RECENT_BLOCK_HASH_LENGTH + accountAddressesLength.size diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt index ddd9d7b..3b38936 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/SignedTransaction.kt @@ -4,12 +4,13 @@ import io.github.oshai.kotlinlogging.KotlinLogging import net.avianlabs.solana.tweetnacl.TweetNaCl import net.avianlabs.solana.tweetnacl.vendor.decodeBase58 import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String -import net.avianlabs.solana.vendor.ShortvecEncoding +import net.avianlabs.solana.vendor.ShortVecEncoding import okio.Buffer private val logger = KotlinLogging.logger {} -public class SignedTransaction( +@ConsistentCopyVisibility +public data class SignedTransaction internal constructor( public val originalMessage: Message, public val signedMessage: ByteArray, public val signatures: List, @@ -28,16 +29,38 @@ public class SignedTransaction( } } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + if (!super.equals(other)) return false + + other as SignedTransaction + + if (originalMessage != other.originalMessage) return false + if (!signedMessage.contentEquals(other.signedMessage)) return false + if (signatures != other.signatures) return false + + return true + } + + override fun hashCode(): Int { + var result = super.hashCode() + result = 31 * result + originalMessage.hashCode() + result = 31 * result + signedMessage.contentHashCode() + result = 31 * result + signatures.hashCode() + return result + } + override fun toString(): String = "SignedTransaction(message=${originalMessage}, signatures=$signatures)" public fun serialize(): ByteArray { val signaturesSize = signatures.size - val signaturesLength = ShortvecEncoding.encodeLength(signaturesSize) + val signaturesLength = ShortVecEncoding.encodeLength(signaturesSize) val bufferSize = signaturesLength.size + - signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES + - signedMessage.size + signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES + + signedMessage.size val out = Buffer() out.write(signaturesLength) for (signature in signatures) { @@ -52,10 +75,10 @@ public class SignedTransaction( val message = signedMessage val messageLength = message.size val signaturesSize = signatures.size - val signaturesLength = ShortvecEncoding.encodeLength(signaturesSize) + val signaturesLength = ShortVecEncoding.encodeLength(signaturesSize) val signaturesSizeBytes = signaturesLength.size val signatureSize = TweetNaCl.Signature.SIGNATURE_BYTES - val signatureSizeBytes = ShortvecEncoding.encodeLength(signatureSize) + val signatureSizeBytes = ShortVecEncoding.encodeLength(signatureSize) val signatureSizeBytesLength = signatureSizeBytes.size val expectedSize = messageLength + signaturesSizeBytes + signaturesSize * (signatureSize + signatureSizeBytesLength) diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt index 0688b43..c2b98ab 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/domain/core/Transaction.kt @@ -42,6 +42,19 @@ public open class Transaction internal constructor( ) } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as Transaction + + return message == other.message + } + + override fun hashCode(): Int { + return message.hashCode() + } + override fun toString(): String = "Transaction(message=$message)" public class Builder internal constructor( diff --git a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortVecEncoding.kt similarity index 95% rename from solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt rename to solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortVecEncoding.kt index 427463b..625eba5 100644 --- a/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortvecEncoding.kt +++ b/solana-kotlin/src/commonMain/kotlin/net/avianlabs/solana/vendor/ShortVecEncoding.kt @@ -5,7 +5,7 @@ import kotlin.experimental.and internal typealias ShortVecLength = ByteArray -internal object ShortvecEncoding { +internal object ShortVecEncoding { internal fun encodeLength(len: Int): ShortVecLength { val buffer = Buffer() var length = len diff --git a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/vendor/ShortvecEncodingTest.kt b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/vendor/ShortvecEncodingTest.kt index 49da154..b843ca2 100644 --- a/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/vendor/ShortvecEncodingTest.kt +++ b/solana-kotlin/src/commonTest/kotlin/net/avianlabs/solana/vendor/ShortvecEncodingTest.kt @@ -8,39 +8,39 @@ class ShortvecEncodingTest { fun encodeLength() { assertContentEquals( byteArrayOf(0) /* [0] */, - ShortvecEncoding.encodeLength(0) + ShortVecEncoding.encodeLength(0) ) assertContentEquals( byteArrayOf(1) /* [1] */, - ShortvecEncoding.encodeLength(1) + ShortVecEncoding.encodeLength(1) ) assertContentEquals( byteArrayOf(5) /* [5] */, - ShortvecEncoding.encodeLength(5) + ShortVecEncoding.encodeLength(5) ) assertContentEquals( byteArrayOf(127)/* [0x7f] */, - ShortvecEncoding.encodeLength(127) // 0x7f + ShortVecEncoding.encodeLength(127) // 0x7f ) assertContentEquals( byteArrayOf(-128, 1) /* [0x80, 0x01] */, - ShortvecEncoding.encodeLength(128) // 0x80 + ShortVecEncoding.encodeLength(128) // 0x80 ) assertContentEquals( byteArrayOf(-1, 1) /* [0xff, 0x01] */, - ShortvecEncoding.encodeLength(255) // 0xff + ShortVecEncoding.encodeLength(255) // 0xff ) assertContentEquals( byteArrayOf(-128, 2)/* [0x80, 0x02] */, - ShortvecEncoding.encodeLength(256) // 0x100 + ShortVecEncoding.encodeLength(256) // 0x100 ) assertContentEquals( byteArrayOf(-1, -1, 1)/* [0xff, 0xff, 0x01] */, - ShortvecEncoding.encodeLength(32767) // 0x7fff + ShortVecEncoding.encodeLength(32767) // 0x7fff ) assertContentEquals( byteArrayOf(-128, -128, -128, 1)/* [0x80, 0x80, 0x80, 0x01] */, - ShortvecEncoding.encodeLength(2097152) // 0x200000 + ShortVecEncoding.encodeLength(2097152) // 0x200000 ) } }