Skip to content

Commit

Permalink
allow to leave missing signatures as null
Browse files Browse the repository at this point in the history
this helps backend do partial signatures in the right places
  • Loading branch information
wiyarmir committed Dec 3, 2024
1 parent 6c22fb7 commit b6abff7
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package net.avianlabs.solana.domain.core

internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.publicKey }
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey

internal fun List<AccountMeta>.normalize(
feePayer: PublicKey? = null,
): List<AccountMeta> = groupBy { it.publicKey }
.mapValues { (_, metas) ->
metas.reduce { acc, meta ->
AccountMeta(
Expand All @@ -11,12 +15,17 @@ internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.pub
}
}
.values
.sortedWith(metaComparator)
.sortedWith(metaComparator(feePayer))
.toList()

private val metaComparator = Comparator<AccountMeta> { am1, am2 ->
private fun metaComparator(feePayer: PublicKey?) = Comparator<AccountMeta> { am1, am2 ->
// first sort by signer, then writable
if (am1.isSigner && !am2.isSigner) {
// and ensure feePayer is always first
if (am1.publicKey == feePayer) {
-1
} else if (am2.publicKey == feePayer) {
1
} else if (am1.isSigner && !am2.isSigner) {
-1
} else if (!am1.isSigner && am2.isSigner) {
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public data class Message private constructor(

public fun setFeePayer(feePayer: PublicKey): Builder {
this.feePayer = feePayer
accountKeys.add(AccountMeta(feePayer, isSigner = true, isWritable = true))
return this
}

Expand All @@ -45,6 +46,6 @@ public data class Message private constructor(
}

public fun build(): Message =
Message(feePayer, recentBlockHash, accountKeys.normalize(), instructions)
Message(feePayer, recentBlockHash, accountKeys.normalize(feePayer), instructions)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import net.avianlabs.solana.vendor.ShortVecEncoding
import okio.Buffer

private val logger = KotlinLogging.logger {}

@ConsistentCopyVisibility
public data class SignedTransaction internal constructor(
public val originalMessage: Message,
public val signedMessage: ByteArray,
public val signatures: List<String>,
public val signatures: Map<PublicKey, ByteArray>,
) : Transaction(originalMessage) {

public override fun sign(signers: List<Signer>): SignedTransaction = SignedTransaction(
originalMessage = originalMessage,
signedMessage = signedMessage,
signatures = signatures + signers.map { signer ->
TweetNaCl.Signature.sign(signedMessage, signer.secretKey).encodeToBase58String()
}
).also {
val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
signatures = signatures + signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(signedMessage, signer.secretKey)
}
}
)

override fun equals(other: Any?): Boolean {
if (this === other) return true
Expand All @@ -52,21 +45,23 @@ public data class SignedTransaction internal constructor(
}

override fun toString(): String =
"SignedTransaction(message=${originalMessage}, signatures=$signatures)"
"SignedTransaction(message=${originalMessage}, " +
"signatures=${signatures.values.map { it.encodeToBase58String() }})"

public fun serialize(): SerializedTransaction {
val signaturesSize = signatures.size
val signaturesLength = ShortVecEncoding.encodeLength(signaturesSize)
public fun serialize(includeNullSignatures: Boolean = false): SerializedTransaction {
val signerKeys = message.accountKeys.filter { it.isSigner }
.mapNotNull {
signatures[it.publicKey]
?: ByteArray(TweetNaCl.Signature.SIGNATURE_BYTES).takeIf { includeNullSignatures }
}
val signaturesLength = ShortVecEncoding.encodeLength(signerKeys.size)
val bufferSize =
signaturesLength.size +
signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES +
signerKeys.size * TweetNaCl.Signature.SIGNATURE_BYTES +
signedMessage.size
val out = Buffer()
out.write(signaturesLength)
for (signature in signatures) {
val rawSignature = signature.decodeBase58()
out.write(rawSignature)
}
signerKeys.forEach(out::write)
out.write(signedMessage)
return out.readByteArray(bufferSize.toLong())
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String

private val logger = KotlinLogging.logger {}

public open class Transaction internal constructor(
public val message: Message,
Expand All @@ -26,13 +22,9 @@ public open class Transaction internal constructor(
val serializedMessage = message
.serialize()

val signatures = signers.map { signer ->
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey).encodeToBase58String()
}

val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
val signatures = signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey)
}

return SignedTransaction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,25 @@ class AccountKeysListTest {
isSigner = false,
isWritable = false,
),
AccountMeta(
publicKey = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
isSigner = false,
isWritable = false,
),
)

accountKeysList.addAll(meta.shuffled())

assertEquals(
listOf(
PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
PublicKey.fromBase58("EtDXsqZ9Cgod7Z6j8cqu8fNMF7fu9txu2puHnxVY1wBk"),
PublicKey.fromBase58("9JGhZqi4MbnVz424uJ6vqk9a1u359xg3nJekdjzzL4d5"),
PublicKey.fromBase58("G8iheDY9bGix5qCXEitCExLcgZzZrEemngk9cbTR3CQs"),
),
accountKeysList.normalize().map { it.publicKey }
accountKeysList.normalize(
feePayer = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG")
).map { it.publicKey }
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,70 @@ class TransactionTest {

val serialized = transaction.serialize().encodeBase64()
val expected =
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAAIDq" +
"OvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMGRm/l" +
"IRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQIAAAwCAAA" +
"AAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAA" +
"IDqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" +
"MGRm/lIRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQ" +
"IAAAwCAAAAAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="

assertEquals(expected, serialized)
}

@Test
fun test_sign() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = false)

val expected =
"AXWGwDjk7s+ybacDFIIwXfVqO+Wuo17TlD9hGg76MrWihSWz6mUF3mMoengeRLsKS6LS9GfUArTK9tLsBeB2YggCAA" +
"EDbtBATs2PPthWHf3bqIU5/bs3SYSvA9m1WaJcIXO3XIGo6+Z8GIyqOmaH0mDsMQB7WSDhGdx7Sm/+XL2EgXsjJA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoBAg" +
"IBAQwCAAAAAQAAAAAAAAA="

assertEquals(expected, transaction.encodeBase64())
}

@Test
fun test_sign_null_signatures() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = true)

val expected =
"AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB1hs" +
"A45O7Psm2nAxSCMF31ajvlrqNe05Q/YRoO+jK1ooUls+plBd5jKHp4HkS7Ckui0vRn1AK0yvbS7AXgdmIIAgABA2" +
"7QQE7Njz7YVh3926iFOf27N0mErwPZtVmiXCFzt1yBqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGWOYQSLh7tk7/uVmsqz1KCPMrQYTNx7EWBN0gN8wOh6AQICAQ" +
"EMAgAAAAEAAAAAAAAA"

assertEquals(expected, transaction.encodeBase64())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package net.avianlabs.solana.domain.crypto
import net.avianlabs.solana.domain.core.Transaction
import net.avianlabs.solana.domain.program.SystemProgram
import net.avianlabs.solana.domain.randomKey
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import kotlin.random.Random
import kotlin.test.Test
Expand Down Expand Up @@ -33,10 +32,8 @@ class CryptoEngineTest {
.build()
.sign(keypair)

val signature = transaction.signatures.first()
val signature = transaction.signatures[keypair.publicKey]!!

val signatureArray = signature.decodeBase58()

assertTrue(signatureArray.size == 64, "wrong signature size ${signatureArray.size}")
assertTrue(signature.size == 64, "wrong signature size ${signature.size}")
}
}

0 comments on commit b6abff7

Please sign in to comment.