Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Notify] Add JWT validation #1216

Merged
merged 18 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci_instrumented_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
group: apple-silicon
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Check what modules were changed
id: what_modules_changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package com.walletconnect.notify.common.model

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic

data class RegisteredAccount(
val accountId: AccountId,
val publicIdentityKey: PublicKey,
val isLimited: Boolean,
val appDomain: String?
val appDomain: String?,
val notifyServerWatchTopic: Topic?,
val notifyServerAuthenticationKey: PublicKey?,
)
Talhaali00 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,27 @@ package com.walletconnect.notify.data.jwt

import com.walletconnect.foundation.util.jwt.JwtClaims

internal interface NotifyJwtBase: JwtClaims {
internal interface NotifyJwtBase : JwtClaims {
val action: String
val issuedAt: Long
val expiration: Long
val requiredActionValue: String

private fun throwIdIssuedAtIsInvalid() {
if (issuedAt < System.currentTimeMillis()) throw IllegalArgumentException("Invalid issuedAt claim was $issuedAt instead of lower than ${System.currentTimeMillis()}")
}

private fun throwExpirationAtIsInvalid() {
if (expiration > System.currentTimeMillis()) throw IllegalArgumentException("Invalid expiration claim was $expiration instead of greater than ${System.currentTimeMillis()}")
}

private fun throwIfActionIsInvalid() {
if (action != requiredActionValue) throw IllegalArgumentException("Invalid action claim was $action instead of $requiredActionValue")
}

fun throwIfBaseIsInvalid() {
throwIdIssuedAtIsInvalid()
throwExpirationAtIsInvalid()
throwIfActionIsInvalid()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class DeleteRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_delete",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_delete"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class DeleteResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_delete_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_delete_response"
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ internal data class MessageRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "msg") val message: Message,
@Json(name = "act") override val action: String = "notify_message",
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE

@JsonClass(generateAdapter = true)
data class Message(
Expand All @@ -25,4 +26,8 @@ internal data class MessageRequestJwtClaim(
@Json(name = "url") val url: String,
@Json(name = "type") val type: String,
)
}

}

private const val ACTION_CLAIM_VALUE = "notify_message"

Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class MessageResponseJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "sub") val subject: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_message_response",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_message_response"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ internal data class SubscriptionRequestJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "scp") val scope: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_subscription",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscription"


Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class SubscriptionResponseJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "sub") val subject: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_subscription_response",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscription_response"
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class SubscriptionsChangedRequestJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_subscriptions_changed",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscriptions_changed"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class SubscriptionsChangedResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "exp") override val expiration: Long,
@Json(name = "act") override val action: String = "notify_subscriptions_changed_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscriptions_changed_response"
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ internal data class UpdateRequestJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val app: String,
@Json(name = "scp") val scope: String,
@Json(name = "act") override val action: String = "notify_update",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_update"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class UpdateResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_update_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_update_response"
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ internal data class WatchSubscriptionsRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val appDidWeb: String?,
@Json(name = "act") override val action: String = "notify_watch_subscriptions",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_watch_subscriptions"
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class WatchSubscriptionsResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_watch_subscriptions_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_watch_subscriptions_response"
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@ package com.walletconnect.notify.data.storage

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic
import com.walletconnect.notify.common.model.RegisteredAccount
import com.walletconnect.notify.common.storage.data.dao.RegisteredAccountsQueries
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext

internal class RegisteredAccountsRepository(private val registeredAccounts: RegisteredAccountsQueries) {

suspend fun insertOrIgnoreAccount(
accountId: AccountId,
publicIdentityKey: PublicKey,
isLimited: Boolean,
appDomain: String?
) = withContext(Dispatchers.IO) {
suspend fun insertOrIgnoreAccount(accountId: AccountId, publicIdentityKey: PublicKey, isLimited: Boolean, appDomain: String?) = withContext(Dispatchers.IO) {
registeredAccounts.insertOrIgnoreAccount(accountId.value, publicIdentityKey.keyAsHex, isLimited, appDomain)
}

suspend fun updateNotifyServerData(accountId: AccountId, notifyServerWatchTopic: Topic, notifyServerAuthenticationKey: PublicKey) = withContext(Dispatchers.IO) {
registeredAccounts.updateNotifyServerData(accountId.value, notifyServerWatchTopic.value, notifyServerAuthenticationKey.keyAsHex)
}

suspend fun getAccountByAccountId(accountId: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByAccountId(accountId, ::toRegisterAccount).executeAsOne()
}

suspend fun getAccountByIdentityKey(identityPublicKey: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByIdentityKey(identityPublicKey, ::toRegisterAccount).executeAsOne()
}

suspend fun getAllAccounts(): List<RegisteredAccount> = withContext(Dispatchers.IO) {
registeredAccounts.getAllAccounts(::toRegisterAccount).executeAsList()
}
Expand All @@ -33,9 +37,7 @@ internal class RegisteredAccountsRepository(private val registeredAccounts: Regi
}

private fun toRegisterAccount(
accountId: String,
publicIdentityKey: String,
isLimited: Boolean,
appDomain: String?
): RegisteredAccount = RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain)
accountId: String, publicIdentityKey: String, isLimited: Boolean, appDomain: String?, notifyServerWatchTopic: String?, notifyServerAuthenticationKey: String?,
): RegisteredAccount =
RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain, notifyServerWatchTopic?.let { Topic(it) }, notifyServerAuthenticationKey?.let { PublicKey(it) })
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ internal fun engineModule() = module {
single {
StopWatchingSubscriptionsUseCase(
jsonRpcInteractor = get(),
keyManagementRepository = get(),
extractPublicKeysFromDidJsonUseCase = get(),
notifyServerUrl = get(),
getSelfKeyForWatchSubscriptionUseCase = get()
registeredAccountsRepository = get()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ internal fun requestModule() = module {
extractPublicKeysFromDidJsonUseCase = get(),
jsonRpcInteractor = get(),
logger = get(),
notifyServerUrl = get()
notifyServerUrl = get(),
registeredAccountsRepository = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
accountsRepository = get(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal fun responseModule() = module {
single {
OnWatchSubscriptionsResponseUseCase(
setActiveSubscriptionsUseCase = get(),
extractPublicKeysFromDidJsonUseCase = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
accountsRepository = get(),
notifyServerUrl = get()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ internal class UnregisterUseCase(
identitiesInteractor.unregisterIdentity(accountId, keyserverUrl).fold(
onFailure = { error -> onFailure(error) },
onSuccess = { identityPublicKey ->
runCatching { registeredAccountsRepository.deleteAccountByAccountId(account) }.fold(
runCatching {
stopWatchingSubscriptionsUseCase(accountId, onFailure = { error -> onFailure(error) })
registeredAccountsRepository.deleteAccountByAccountId(account)
}.fold(
onFailure = { error -> onFailure(error) },
onSuccess = {
stopWatchingSubscriptionsUseCase(accountId, onFailure)
subscriptionRepository.getAccountActiveSubscriptions(accountId).map { it.notifyTopic.value }.map { topic ->
jsonRpcInteractor.unsubscribe(Topic(topic)) { error -> onFailure(error) }
subscriptionRepository.deleteSubscriptionByNotifyTopic(topic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,24 @@

package com.walletconnect.notify.engine.domain

import com.walletconnect.android.internal.common.crypto.kmr.KeyManagementRepository
import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.android.internal.common.model.type.JsonRpcInteractorInterface
import com.walletconnect.notify.common.NotifyServerUrl
import com.walletconnect.notify.data.storage.RegisteredAccountsRepository
import kotlinx.coroutines.supervisorScope

internal class StopWatchingSubscriptionsUseCase(
private val jsonRpcInteractor: JsonRpcInteractorInterface,
private val keyManagementRepository: KeyManagementRepository,
private val extractPublicKeysFromDidJsonUseCase: ExtractPublicKeysFromDidJsonUseCase,
private val getSelfKeyForWatchSubscriptionUseCase: GetSelfKeyForWatchSubscriptionUseCase,
private val notifyServerUrl: NotifyServerUrl,
private val registeredAccountsRepository: RegisteredAccountsRepository,
) {

suspend operator fun invoke(accountId: AccountId, onFailure: (Throwable) -> Unit) = supervisorScope {
val (peerPublicKey, _) = extractPublicKeysFromDidJsonUseCase(notifyServerUrl.toUri()).getOrElse { useCaseError -> return@supervisorScope onFailure(useCaseError) }
val watchTopic = runCatching { registeredAccountsRepository.getAccountByAccountId(accountId.value).notifyServerWatchTopic }
.getOrElse { error -> return@supervisorScope onFailure(error) }

val requestTopic = keyManagementRepository.getTopicFromKey(peerPublicKey)

val selfPublicKey = getSelfKeyForWatchSubscriptionUseCase(requestTopic, accountId)
val responseTopic = keyManagementRepository.generateTopicFromKeyAgreement(selfPublicKey, peerPublicKey)

jsonRpcInteractor.unsubscribe(responseTopic) { error -> onFailure(error) }
if (watchTopic == null) {
return@supervisorScope onFailure(IllegalStateException("Watch topic is null"))
} else {
jsonRpcInteractor.unsubscribe(watchTopic) { error -> onFailure(error) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ internal class WatchSubscriptionsUseCase(
val account = registeredAccountsRepository.getAccountByAccountId(accountId.value)
val didJwt = fetchDidJwtInteractor.watchSubscriptionsRequest(accountId, authenticationPublicKey, if(account.isLimited) account.appDomain else null)
.getOrElse { error -> return@supervisorScope onFailure(error) }

registeredAccountsRepository.updateNotifyServerData(accountId, responseTopic, authenticationPublicKey)
val watchSubscriptionsParams = CoreNotifyParams.WatchSubscriptionsParams(didJwt.value)
val request = NotifyRpc.NotifyWatchSubscriptions(params = watchSubscriptionsParams)
val irnParams = IrnParams(Tags.NOTIFY_WATCH_SUBSCRIPTIONS, Ttl(THIRTY_SECONDS))
Expand Down
Loading