Skip to content

Commit

Permalink
Merge pull request #168 from avianlabs/guillermo/allow-to-leave-missi…
Browse files Browse the repository at this point in the history
…ng-signatures-as-null

allow to leave missing signatures as null
  • Loading branch information
wiyarmir authored Dec 3, 2024
2 parents 4aaee7f + b6abff7 commit 93e104e
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package net.avianlabs.solana.domain.core

internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.publicKey }
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey

internal fun List<AccountMeta>.normalize(
feePayer: PublicKey? = null,
): List<AccountMeta> = groupBy { it.publicKey }
.mapValues { (_, metas) ->
metas.reduce { acc, meta ->
AccountMeta(
Expand All @@ -11,12 +15,17 @@ internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.pub
}
}
.values
.sortedWith(metaComparator)
.sortedWith(metaComparator(feePayer))
.toList()

private val metaComparator = Comparator<AccountMeta> { am1, am2 ->
private fun metaComparator(feePayer: PublicKey?) = Comparator<AccountMeta> { am1, am2 ->
// first sort by signer, then writable
if (am1.isSigner && !am2.isSigner) {
// and ensure feePayer is always first
if (am1.publicKey == feePayer) {
-1
} else if (am2.publicKey == feePayer) {
1
} else if (am1.isSigner && !am2.isSigner) {
-1
} else if (!am1.isSigner && am2.isSigner) {
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public data class Message private constructor(

public fun setFeePayer(feePayer: PublicKey): Builder {
this.feePayer = feePayer
accountKeys.add(AccountMeta(feePayer, isSigner = true, isWritable = true))
return this
}

Expand All @@ -45,6 +46,6 @@ public data class Message private constructor(
}

public fun build(): Message =
Message(feePayer, recentBlockHash, accountKeys.normalize(), instructions)
Message(feePayer, recentBlockHash, accountKeys.normalize(feePayer), instructions)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import net.avianlabs.solana.vendor.ShortVecEncoding
import okio.Buffer

private val logger = KotlinLogging.logger {}

@ConsistentCopyVisibility
public data class SignedTransaction internal constructor(
public val originalMessage: Message,
public val signedMessage: ByteArray,
public val signatures: List<String>,
public val signatures: Map<PublicKey, ByteArray>,
) : Transaction(originalMessage) {

public override fun sign(signers: List<Signer>): SignedTransaction = SignedTransaction(
originalMessage = originalMessage,
signedMessage = signedMessage,
signatures = signatures + signers.map { signer ->
TweetNaCl.Signature.sign(signedMessage, signer.secretKey).encodeToBase58String()
}
).also {
val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
signatures = signatures + signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(signedMessage, signer.secretKey)
}
}
)

override fun equals(other: Any?): Boolean {
if (this === other) return true
Expand All @@ -52,21 +45,23 @@ public data class SignedTransaction internal constructor(
}

override fun toString(): String =
"SignedTransaction(message=${originalMessage}, signatures=$signatures)"
"SignedTransaction(message=${originalMessage}, " +
"signatures=${signatures.values.map { it.encodeToBase58String() }})"

public fun serialize(): SerializedTransaction {
val signaturesSize = signatures.size
val signaturesLength = ShortVecEncoding.encodeLength(signaturesSize)
public fun serialize(includeNullSignatures: Boolean = false): SerializedTransaction {
val signerKeys = message.accountKeys.filter { it.isSigner }
.mapNotNull {
signatures[it.publicKey]
?: ByteArray(TweetNaCl.Signature.SIGNATURE_BYTES).takeIf { includeNullSignatures }
}
val signaturesLength = ShortVecEncoding.encodeLength(signerKeys.size)
val bufferSize =
signaturesLength.size +
signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES +
signerKeys.size * TweetNaCl.Signature.SIGNATURE_BYTES +
signedMessage.size
val out = Buffer()
out.write(signaturesLength)
for (signature in signatures) {
val rawSignature = signature.decodeBase58()
out.write(rawSignature)
}
signerKeys.forEach(out::write)
out.write(signedMessage)
return out.readByteArray(bufferSize.toLong())
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String

private val logger = KotlinLogging.logger {}

public open class Transaction internal constructor(
public val message: Message,
Expand All @@ -26,13 +22,9 @@ public open class Transaction internal constructor(
val serializedMessage = message
.serialize()

val signatures = signers.map { signer ->
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey).encodeToBase58String()
}

val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
val signatures = signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey)
}

return SignedTransaction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,25 @@ class AccountKeysListTest {
isSigner = false,
isWritable = false,
),
AccountMeta(
publicKey = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
isSigner = false,
isWritable = false,
),
)

accountKeysList.addAll(meta.shuffled())

assertEquals(
listOf(
PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
PublicKey.fromBase58("EtDXsqZ9Cgod7Z6j8cqu8fNMF7fu9txu2puHnxVY1wBk"),
PublicKey.fromBase58("9JGhZqi4MbnVz424uJ6vqk9a1u359xg3nJekdjzzL4d5"),
PublicKey.fromBase58("G8iheDY9bGix5qCXEitCExLcgZzZrEemngk9cbTR3CQs"),
),
accountKeysList.normalize().map { it.publicKey }
accountKeysList.normalize(
feePayer = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG")
).map { it.publicKey }
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,70 @@ class TransactionTest {

val serialized = transaction.serialize().encodeBase64()
val expected =
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAAIDq" +
"OvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMGRm/l" +
"IRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQIAAAwCAAA" +
"AAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAA" +
"IDqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" +
"MGRm/lIRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQ" +
"IAAAwCAAAAAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="

assertEquals(expected, serialized)
}

@Test
fun test_sign() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = false)

val expected =
"AXWGwDjk7s+ybacDFIIwXfVqO+Wuo17TlD9hGg76MrWihSWz6mUF3mMoengeRLsKS6LS9GfUArTK9tLsBeB2YggCAA" +
"EDbtBATs2PPthWHf3bqIU5/bs3SYSvA9m1WaJcIXO3XIGo6+Z8GIyqOmaH0mDsMQB7WSDhGdx7Sm/+XL2EgXsjJA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoBAg" +
"IBAQwCAAAAAQAAAAAAAAA="

assertEquals(expected, transaction.encodeBase64())
}

@Test
fun test_sign_null_signatures() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = true)

val expected =
"AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB1hs" +
"A45O7Psm2nAxSCMF31ajvlrqNe05Q/YRoO+jK1ooUls+plBd5jKHp4HkS7Ckui0vRn1AK0yvbS7AXgdmIIAgABA2" +
"7QQE7Njz7YVh3926iFOf27N0mErwPZtVmiXCFzt1yBqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGWOYQSLh7tk7/uVmsqz1KCPMrQYTNx7EWBN0gN8wOh6AQICAQ" +
"EMAgAAAAEAAAAAAAAA"

assertEquals(expected, transaction.encodeBase64())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package net.avianlabs.solana.domain.crypto
import net.avianlabs.solana.domain.core.Transaction
import net.avianlabs.solana.domain.program.SystemProgram
import net.avianlabs.solana.domain.randomKey
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import kotlin.random.Random
import kotlin.test.Test
Expand Down Expand Up @@ -33,10 +32,8 @@ class CryptoEngineTest {
.build()
.sign(keypair)

val signature = transaction.signatures.first()
val signature = transaction.signatures[keypair.publicKey]!!

val signatureArray = signature.decodeBase58()

assertTrue(signatureArray.size == 64, "wrong signature size ${signatureArray.size}")
assertTrue(signature.size == 64, "wrong signature size ${signature.size}")
}
}

0 comments on commit 93e104e

Please sign in to comment.