Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Notify] Add JWT validation #1216

Merged
merged 18 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package com.walletconnect.notify.common.model

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic

data class RegisteredAccount(
val accountId: AccountId,
val publicIdentityKey: PublicKey,
val isLimited: Boolean,
val appDomain: String?
val appDomain: String?,
val notifyServerWatchTopic: Topic?,
val notifyServerAuthenticationKey: PublicKey?,
)
Talhaali00 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@ package com.walletconnect.notify.data.jwt

import com.walletconnect.foundation.util.jwt.JwtClaims

internal interface NotifyJwtBase: JwtClaims {
internal interface NotifyJwtBase : JwtClaims {
val action: String
val issuedAt: Long
val expiration: Long

fun throwIdIssuedAtIsInvalid() {
if (issuedAt < System.currentTimeMillis()) throw IllegalArgumentException("Invalid issuedAt claim was $issuedAt instead of lower than ${System.currentTimeMillis()}")
}

fun throwExpirationAtIsInvalid() {
if (expiration > System.currentTimeMillis()) throw IllegalArgumentException("Invalid expiration claim was $expiration instead of greater than ${System.currentTimeMillis()}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal data class MessageRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "msg") val message: Message,
@Json(name = "act") override val action: String = "notify_message",
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {

@JsonClass(generateAdapter = true)
Expand All @@ -25,4 +25,19 @@ internal data class MessageRequestJwtClaim(
@Json(name = "url") val url: String,
@Json(name = "type") val type: String,
)
}


private fun throwIfActionIsInvalid() {
if (action != ACTION_CLAIM_VALUE) throw IllegalArgumentException("Invalid action claim was $action instead of $ACTION_CLAIM_VALUE")
}

fun throwIfBaseIsInvalid() {
throwIdIssuedAtIsInvalid()
throwExpirationAtIsInvalid()
throwIfActionIsInvalid()
}

}

private const val ACTION_CLAIM_VALUE = "notify_message"

Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,19 @@ internal data class SubscriptionsChangedRequestJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_subscriptions_changed",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
): NotifyJwtBase {

private fun throwIfActionIsInvalid() {
if (action != ACTION_CLAIM_VALUE) throw IllegalArgumentException("Invalid action claim was $action instead of $ACTION_CLAIM_VALUE")
}

fun throwIfBaseIsInvalid() {
throwIdIssuedAtIsInvalid()
throwExpirationAtIsInvalid()
throwIfActionIsInvalid()
}

}

private const val ACTION_CLAIM_VALUE = "notify_subscriptions_changed"
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,19 @@ internal data class WatchSubscriptionsResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_watch_subscriptions_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {

private fun throwIfActionIsInvalid() {
if (action != ACTION_CLAIM_VALUE) throw IllegalArgumentException("Invalid action claim was $action instead of $ACTION_CLAIM_VALUE")
}

fun throwIfBaseIsInvalid() {
throwIdIssuedAtIsInvalid()
throwExpirationAtIsInvalid()
throwIfActionIsInvalid()
}

}

private const val ACTION_CLAIM_VALUE = "notify_watch_subscriptions_response"
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@ package com.walletconnect.notify.data.storage

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic
import com.walletconnect.notify.common.model.RegisteredAccount
import com.walletconnect.notify.common.storage.data.dao.RegisteredAccountsQueries
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext

internal class RegisteredAccountsRepository(private val registeredAccounts: RegisteredAccountsQueries) {

suspend fun insertOrIgnoreAccount(
accountId: AccountId,
publicIdentityKey: PublicKey,
isLimited: Boolean,
appDomain: String?
) = withContext(Dispatchers.IO) {
suspend fun insertOrIgnoreAccount(accountId: AccountId, publicIdentityKey: PublicKey, isLimited: Boolean, appDomain: String?) = withContext(Dispatchers.IO) {
registeredAccounts.insertOrIgnoreAccount(accountId.value, publicIdentityKey.keyAsHex, isLimited, appDomain)
}

suspend fun updateNotifyServerData(accountId: AccountId, notifyServerWatchTopic: Topic, notifyServerAuthenticationKey: PublicKey) = withContext(Dispatchers.IO) {
registeredAccounts.updateNotifyServerData(accountId.value, notifyServerWatchTopic.value, notifyServerAuthenticationKey.keyAsHex)
}

suspend fun getAccountByAccountId(accountId: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByAccountId(accountId, ::toRegisterAccount).executeAsOne()
}

suspend fun getAccountByIdentityKey(identityPublicKey: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByIdentityKey(identityPublicKey, ::toRegisterAccount).executeAsOne()
}

suspend fun getAllAccounts(): List<RegisteredAccount> = withContext(Dispatchers.IO) {
registeredAccounts.getAllAccounts(::toRegisterAccount).executeAsList()
}
Expand All @@ -33,9 +37,7 @@ internal class RegisteredAccountsRepository(private val registeredAccounts: Regi
}

private fun toRegisterAccount(
accountId: String,
publicIdentityKey: String,
isLimited: Boolean,
appDomain: String?
): RegisteredAccount = RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain)
accountId: String, publicIdentityKey: String, isLimited: Boolean, appDomain: String?, notifyServerWatchTopic: String?, notifyServerAuthenticationKey: String?,
): RegisteredAccount =
RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain, notifyServerWatchTopic?.let { Topic(it) }, notifyServerAuthenticationKey?.let { PublicKey(it) })
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ internal fun engineModule() = module {
single {
StopWatchingSubscriptionsUseCase(
jsonRpcInteractor = get(),
keyManagementRepository = get(),
extractPublicKeysFromDidJsonUseCase = get(),
notifyServerUrl = get(),
getSelfKeyForWatchSubscriptionUseCase = get()
registeredAccountsRepository = get()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ internal fun requestModule() = module {
extractPublicKeysFromDidJsonUseCase = get(),
jsonRpcInteractor = get(),
logger = get(),
notifyServerUrl = get()
notifyServerUrl = get(),
registeredAccountsRepository = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
accountsRepository = get(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal fun responseModule() = module {
single {
OnWatchSubscriptionsResponseUseCase(
setActiveSubscriptionsUseCase = get(),
extractPublicKeysFromDidJsonUseCase = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
accountsRepository = get(),
notifyServerUrl = get()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ internal class UnregisterUseCase(
identitiesInteractor.unregisterIdentity(accountId, keyserverUrl).fold(
onFailure = { error -> onFailure(error) },
onSuccess = { identityPublicKey ->
runCatching { registeredAccountsRepository.deleteAccountByAccountId(account) }.fold(
runCatching {
stopWatchingSubscriptionsUseCase(accountId, onFailure = { error -> onFailure(error) })
registeredAccountsRepository.deleteAccountByAccountId(account)
}.fold(
onFailure = { error -> onFailure(error) },
onSuccess = {
stopWatchingSubscriptionsUseCase(accountId, onFailure)
subscriptionRepository.getAccountActiveSubscriptions(accountId).map { it.notifyTopic.value }.map { topic ->
jsonRpcInteractor.unsubscribe(Topic(topic)) { error -> onFailure(error) }
subscriptionRepository.deleteSubscriptionByNotifyTopic(topic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,24 @@

package com.walletconnect.notify.engine.domain

import com.walletconnect.android.internal.common.crypto.kmr.KeyManagementRepository
import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.android.internal.common.model.type.JsonRpcInteractorInterface
import com.walletconnect.notify.common.NotifyServerUrl
import com.walletconnect.notify.data.storage.RegisteredAccountsRepository
import kotlinx.coroutines.supervisorScope

internal class StopWatchingSubscriptionsUseCase(
private val jsonRpcInteractor: JsonRpcInteractorInterface,
private val keyManagementRepository: KeyManagementRepository,
private val extractPublicKeysFromDidJsonUseCase: ExtractPublicKeysFromDidJsonUseCase,
private val getSelfKeyForWatchSubscriptionUseCase: GetSelfKeyForWatchSubscriptionUseCase,
private val notifyServerUrl: NotifyServerUrl,
private val registeredAccountsRepository: RegisteredAccountsRepository,
) {

suspend operator fun invoke(accountId: AccountId, onFailure: (Throwable) -> Unit) = supervisorScope {
val (peerPublicKey, _) = extractPublicKeysFromDidJsonUseCase(notifyServerUrl.toUri()).getOrElse { useCaseError -> return@supervisorScope onFailure(useCaseError) }
val watchTopic = runCatching { registeredAccountsRepository.getAccountByAccountId(accountId.value).notifyServerWatchTopic }
.getOrElse { error -> return@supervisorScope onFailure(error) }

val requestTopic = keyManagementRepository.getTopicFromKey(peerPublicKey)

val selfPublicKey = getSelfKeyForWatchSubscriptionUseCase(requestTopic, accountId)
val responseTopic = keyManagementRepository.generateTopicFromKeyAgreement(selfPublicKey, peerPublicKey)

jsonRpcInteractor.unsubscribe(responseTopic) { error -> onFailure(error) }
if (watchTopic == null) {
return@supervisorScope onFailure(IllegalStateException("Watch topic is null"))
} else {
jsonRpcInteractor.unsubscribe(watchTopic) { error -> onFailure(error) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ internal class WatchSubscriptionsUseCase(
val account = registeredAccountsRepository.getAccountByAccountId(accountId.value)
val didJwt = fetchDidJwtInteractor.watchSubscriptionsRequest(accountId, authenticationPublicKey, if(account.isLimited) account.appDomain else null)
.getOrElse { error -> return@supervisorScope onFailure(error) }

registeredAccountsRepository.updateNotifyServerData(accountId, responseTopic, authenticationPublicKey)
val watchSubscriptionsParams = CoreNotifyParams.WatchSubscriptionsParams(didJwt.value)
val request = NotifyRpc.NotifyWatchSubscriptions(params = watchSubscriptionsParams)
val irnParams = IrnParams(Tags.NOTIFY_WATCH_SUBSCRIPTIONS, Ttl(THIRTY_SECONDS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.walletconnect.android.internal.common.jwt.did.extractVerifiedDidJwtCl
import com.walletconnect.android.internal.common.model.AppMetaData
import com.walletconnect.android.internal.common.model.AppMetaDataType
import com.walletconnect.android.internal.common.model.IrnParams
import com.walletconnect.android.internal.common.model.SDKError
import com.walletconnect.android.internal.common.model.Tags
import com.walletconnect.android.internal.common.model.WCRequest
import com.walletconnect.android.internal.common.model.params.ChatNotifyResponseAuthParams
Expand All @@ -16,6 +17,8 @@ import com.walletconnect.android.internal.common.model.type.JsonRpcInteractorInt
import com.walletconnect.android.internal.common.storage.MetadataStorageRepositoryInterface
import com.walletconnect.android.internal.utils.MONTH_IN_SECONDS
import com.walletconnect.foundation.common.model.Ttl
import com.walletconnect.foundation.util.jwt.decodeDidWeb
import com.walletconnect.foundation.util.jwt.decodeEd25519DidKey
import com.walletconnect.notify.common.model.NotifyMessage
import com.walletconnect.notify.common.model.NotifyRecord
import com.walletconnect.notify.data.jwt.message.MessageRequestJwtClaim
Expand All @@ -26,7 +29,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.supervisorScope
import timber.log.Timber
import java.net.URI

internal class OnNotifyMessageUseCase(
private val jsonRpcInteractor: JsonRpcInteractorInterface,
Expand All @@ -39,14 +42,15 @@ internal class OnNotifyMessageUseCase(
val events: SharedFlow<EngineEvent> = _events.asSharedFlow()

suspend operator fun invoke(request: WCRequest, params: CoreNotifyParams.MessageParams) = supervisorScope {
extractVerifiedDidJwtClaims<MessageRequestJwtClaim>(params.messageAuth).onSuccess { messageJwt ->
val activeSubscription = subscriptionRepository.getActiveSubscriptionByNotifyTopic(request.topic.value)
?: throw IllegalStateException("No active subscription for topic: ${request.topic.value}")

val metadata: AppMetaData = metadataStorageRepository.getByTopicAndType(activeSubscription.notifyTopic, AppMetaDataType.PEER)
?: throw Exception("No metadata found for topic ${activeSubscription.notifyTopic}")

/* TODO: Add validation after ETHNY
* jwtClaims.iat - compare with current time. Has to be lower
* jwtClaims.exp - compare with current time. Has to be higher
* jwtClaims.act == "notify_message"
* jwtClaims.iss - did:key of dapp authentication key. Add logic when cached value does not match jwtClaims.iss then fetch value again and if value still does not match then throw
* jwtClaims.app - did:web of app domain that this request is associated with. Must match domain of active subscription by topic */

extractVerifiedDidJwtClaims<MessageRequestJwtClaim>(params.messageAuth).onSuccess { messageJwt ->
messageJwt.throwIfIsInvalid(URI(metadata.url).host, activeSubscription.authenticationPublicKey.keyAsHex)

if (messagesRepository.doesMessagesExistsByRequestId(request.id)) {
messagesRepository.updateMessageWithPublishedAtByRequestId(request.publishedAt, request.id)
Expand Down Expand Up @@ -76,14 +80,7 @@ internal class OnNotifyMessageUseCase(
)
_events.emit(notifyRecord)
}
}.mapCatching { _ ->

val activeSubscription = subscriptionRepository.getActiveSubscriptionByNotifyTopic(request.topic.value)
?: throw IllegalStateException("No active subscription for topic: ${request.topic.value}")

val metadata: AppMetaData = metadataStorageRepository.getByTopicAndType(activeSubscription.notifyTopic, AppMetaDataType.PEER)
?: throw Exception("No metadata found for topic ${activeSubscription.notifyTopic}")

}.runCatching {
val messageResponseJwt = fetchDidJwtInteractor.messageResponse(
account = activeSubscription.account,
app = metadata.url,
Expand All @@ -94,12 +91,30 @@ internal class OnNotifyMessageUseCase(
val irnParams = IrnParams(Tags.NOTIFY_MESSAGE_RESPONSE, Ttl(MONTH_IN_SECONDS))

jsonRpcInteractor.respondWithParams(request.id, request.topic, messageResponseParams, irnParams) { throw it }
}.getOrElse { e ->
}.getOrElse { error ->
_events.emit(SDKError(error))
jsonRpcInteractor.respondWithError(
request,
Uncategorized.GenericError("Cannot handle the notify message: ${e.message}, topic: ${request.topic}"),
Uncategorized.GenericError("Cannot handle the notify message: ${error.message}, topic: ${request.topic}"),
IrnParams(Tags.NOTIFY_MESSAGE_RESPONSE, Ttl(MONTH_IN_SECONDS))
)
}
}

private fun MessageRequestJwtClaim.throwIfIsInvalid(expectedApp: String, expectedIssuer: String) {
throwIfBaseIsInvalid()
throwIfAppIsInvalid(expectedApp)
throwIfIssuerIsInvalid(expectedIssuer)
}

private fun MessageRequestJwtClaim.throwIfAppIsInvalid(expectedAppDomain: String) {
val decodedAppDomain = decodeDidWeb(app)
if (decodedAppDomain != expectedAppDomain) throw IllegalStateException("Invalid app claim was $decodedAppDomain instead of $expectedAppDomain")
}


private fun MessageRequestJwtClaim.throwIfIssuerIsInvalid(expectedIssuerAsHex: String) {
val decodedIssuerAsHex = decodeEd25519DidKey(issuer).keyAsHex
if (decodedIssuerAsHex != expectedIssuerAsHex) throw IllegalStateException("Invalid issuer claim was $decodedIssuerAsHex instead of $expectedIssuerAsHex")
}
}
Loading