Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow to leave missing signatures as null #168

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package net.avianlabs.solana.domain.core

internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.publicKey }
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey

internal fun List<AccountMeta>.normalize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to have two separate functions ?
normalize() and normalizeWithFeePayer ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, fee payer might be set or not

feePayer: PublicKey? = null,
): List<AccountMeta> = groupBy { it.publicKey }
.mapValues { (_, metas) ->
metas.reduce { acc, meta ->
AccountMeta(
Expand All @@ -11,12 +15,17 @@ internal fun List<AccountMeta>.normalize(): List<AccountMeta> = groupBy { it.pub
}
}
.values
.sortedWith(metaComparator)
.sortedWith(metaComparator(feePayer))
.toList()

private val metaComparator = Comparator<AccountMeta> { am1, am2 ->
private fun metaComparator(feePayer: PublicKey?) = Comparator<AccountMeta> { am1, am2 ->
// first sort by signer, then writable
if (am1.isSigner && !am2.isSigner) {
// and ensure feePayer is always first
if (am1.publicKey == feePayer) {
-1
} else if (am2.publicKey == feePayer) {
1
} else if (am1.isSigner && !am2.isSigner) {
-1
} else if (!am1.isSigner && am2.isSigner) {
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public data class Message private constructor(

public fun setFeePayer(feePayer: PublicKey): Builder {
this.feePayer = feePayer
accountKeys.add(AccountMeta(feePayer, isSigner = true, isWritable = true))
return this
}

Expand All @@ -45,6 +46,6 @@ public data class Message private constructor(
}

public fun build(): Message =
Message(feePayer, recentBlockHash, accountKeys.normalize(), instructions)
Message(feePayer, recentBlockHash, accountKeys.normalize(feePayer), instructions)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import net.avianlabs.solana.vendor.ShortVecEncoding
import okio.Buffer

private val logger = KotlinLogging.logger {}

@ConsistentCopyVisibility
public data class SignedTransaction internal constructor(
public val originalMessage: Message,
public val signedMessage: ByteArray,
public val signatures: List<String>,
public val signatures: Map<PublicKey, ByteArray>,
) : Transaction(originalMessage) {

public override fun sign(signers: List<Signer>): SignedTransaction = SignedTransaction(
originalMessage = originalMessage,
signedMessage = signedMessage,
signatures = signatures + signers.map { signer ->
TweetNaCl.Signature.sign(signedMessage, signer.secretKey).encodeToBase58String()
}
).also {
val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
signatures = signatures + signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(signedMessage, signer.secretKey)
}
}
)

override fun equals(other: Any?): Boolean {
if (this === other) return true
Expand All @@ -52,21 +45,23 @@ public data class SignedTransaction internal constructor(
}

override fun toString(): String =
"SignedTransaction(message=${originalMessage}, signatures=$signatures)"
"SignedTransaction(message=${originalMessage}, " +
"signatures=${signatures.values.map { it.encodeToBase58String() }})"

public fun serialize(): SerializedTransaction {
val signaturesSize = signatures.size
val signaturesLength = ShortVecEncoding.encodeLength(signaturesSize)
public fun serialize(includeNullSignatures: Boolean = false): SerializedTransaction {
val signerKeys = message.accountKeys.filter { it.isSigner }
.mapNotNull {
signatures[it.publicKey]
?: ByteArray(TweetNaCl.Signature.SIGNATURE_BYTES).takeIf { includeNullSignatures }
}
val signaturesLength = ShortVecEncoding.encodeLength(signerKeys.size)
val bufferSize =
signaturesLength.size +
signaturesSize * TweetNaCl.Signature.SIGNATURE_BYTES +
signerKeys.size * TweetNaCl.Signature.SIGNATURE_BYTES +
signedMessage.size
val out = Buffer()
out.write(signaturesLength)
for (signature in signatures) {
val rawSignature = signature.decodeBase58()
out.write(rawSignature)
}
signerKeys.forEach(out::write)
out.write(signedMessage)
return out.readByteArray(bufferSize.toLong())
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package net.avianlabs.solana.domain.core

import io.github.oshai.kotlinlogging.KotlinLogging
import net.avianlabs.solana.tweetnacl.TweetNaCl
import net.avianlabs.solana.tweetnacl.ed25519.PublicKey
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String

private val logger = KotlinLogging.logger {}

public open class Transaction internal constructor(
public val message: Message,
Expand All @@ -26,13 +22,9 @@ public open class Transaction internal constructor(
val serializedMessage = message
.serialize()

val signatures = signers.map { signer ->
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey).encodeToBase58String()
}

val signatureSet = signatures.toSet()
if (signatureSet.size != signatures.size) {
logger.warn { "Duplicate signatures detected" }
val signatures = signers.associate { signer ->
signer.publicKey to
TweetNaCl.Signature.sign(serializedMessage, signer.secretKey)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this now ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the signatures can preserve the order of account meta list

}

return SignedTransaction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,25 @@ class AccountKeysListTest {
isSigner = false,
isWritable = false,
),
AccountMeta(
publicKey = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
isSigner = false,
isWritable = false,
),
)

accountKeysList.addAll(meta.shuffled())

assertEquals(
listOf(
PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG"),
PublicKey.fromBase58("EtDXsqZ9Cgod7Z6j8cqu8fNMF7fu9txu2puHnxVY1wBk"),
PublicKey.fromBase58("9JGhZqi4MbnVz424uJ6vqk9a1u359xg3nJekdjzzL4d5"),
PublicKey.fromBase58("G8iheDY9bGix5qCXEitCExLcgZzZrEemngk9cbTR3CQs"),
),
accountKeysList.normalize().map { it.publicKey }
accountKeysList.normalize(
feePayer = PublicKey.fromBase58("8Ta2TgXmiG36c4219H5GC1yzpzuzSqA2rYBiRuGuCmzG")
).map { it.publicKey }
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,70 @@ class TransactionTest {

val serialized = transaction.serialize().encodeBase64()
val expected =
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAAIDq" +
"OvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMGRm/l" +
"IRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQIAAAwCAAA" +
"AAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="
"AbY2fW8NzhzkDyTK3Av5Kn3/aBwxWWlGYjdMWHU4sLtT55yooXG3gKAFKCeQtYb7S86WOkWU6MVEsqP26vBw/gYBAA" +
"IDqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" +
"MGRm/lIRcy/+ytunLDm+e8jOW7xfcSayxDmzpAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoDAQ" +
"IAAAwCAAAAAQAAAAAAAAACAAkDAQAAAAAAAAABAgAADAIAAAABAAAAAAAAAA=="

assertEquals(expected, serialized)
}

@Test
fun test_sign() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = false)

val expected =
"AXWGwDjk7s+ybacDFIIwXfVqO+Wuo17TlD9hGg76MrWihSWz6mUF3mMoengeRLsKS6LS9GfUArTK9tLsBeB2YggCAA" +
"EDbtBATs2PPthWHf3bqIU5/bs3SYSvA9m1WaJcIXO3XIGo6+Z8GIyqOmaH0mDsMQB7WSDhGdx7Sm/+XL2EgXsjJA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZY5hBIuHu2Tv+5WayrPUoI8ytBhM3HsRYE3SA3zA6HoBAg" +
"IBAQwCAAAAAQAAAAAAAAA="

assertEquals(expected, transaction.encodeBase64())
}

@Test
fun test_sign_null_signatures() {
val keypair = Ed25519Keypair.fromBase58(
"9bCpJHMBCpjCnJHCyEvwWqcnTPf4yxWWsNXU7AMYPyS4fsR1bSEPGfjHQ8TDfaQWofAHm8MVeSgQLeEia2uqvVy"
)

val keypair2 = Ed25519Keypair.fromBase58(
"53iBnfgSVoZPEo9EtKnZ8yDSTyxNTxmqECrQs9nLJotjcsJVQCjTn6J7V8cgKe2umYGx9SpGdDocamV4tgkXP6Fr"
)

val transaction = Transaction.Builder()
.addInstruction(
SystemProgram.transfer(keypair.publicKey, keypair.publicKey, 1)
)
.setRecentBlockHash("7qS6hDXGxd6ekYqnSqD7abG1jEfTcpfpjKApxWbb4gVF")
.setFeePayer(keypair2.publicKey)
.build()
.sign(keypair)
.serialize(includeNullSignatures = true)

val expected =
"AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB1hs" +
"A45O7Psm2nAxSCMF31ajvlrqNe05Q/YRoO+jK1ooUls+plBd5jKHp4HkS7Ckui0vRn1AK0yvbS7AXgdmIIAgABA2" +
"7QQE7Njz7YVh3926iFOf27N0mErwPZtVmiXCFzt1yBqOvmfBiMqjpmh9Jg7DEAe1kg4Rnce0pv/ly9hIF7IyQAAA" +
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGWOYQSLh7tk7/uVmsqz1KCPMrQYTNx7EWBN0gN8wOh6AQICAQ" +
"EMAgAAAAEAAAAAAAAA"

assertEquals(expected, transaction.encodeBase64())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package net.avianlabs.solana.domain.crypto
import net.avianlabs.solana.domain.core.Transaction
import net.avianlabs.solana.domain.program.SystemProgram
import net.avianlabs.solana.domain.randomKey
import net.avianlabs.solana.tweetnacl.vendor.decodeBase58
import net.avianlabs.solana.tweetnacl.vendor.encodeToBase58String
import kotlin.random.Random
import kotlin.test.Test
Expand Down Expand Up @@ -33,10 +32,8 @@ class CryptoEngineTest {
.build()
.sign(keypair)

val signature = transaction.signatures.first()
val signature = transaction.signatures[keypair.publicKey]!!

val signatureArray = signature.decodeBase58()

assertTrue(signatureArray.size == 64, "wrong signature size ${signatureArray.size}")
assertTrue(signature.size == 64, "wrong signature size ${signature.size}")
}
}
Loading