Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Notify] Add JWT validation #1216

Merged
merged 18 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/ci_instrumented_tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ inputs:
required: false
testTimeoutSeconds:
description: 'Seconds for test timeout'
default: '40'
default: '120'

runs:
using: "composite"
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/ci_instrumented_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
testTimeoutSeconds:
description: 'Seconds for test timeout'
required: true
default: 40 # should be same as env.TEST_TIMEOUT_SECONDS
default: 120 # should be same as env.TEST_TIMEOUT_SECONDS
pull_request:
branches:
- develop
Expand All @@ -17,7 +17,7 @@ on:
- 'protocol/notify/**'

env:
TEST_TIMEOUT_SECONDS: 40 # Predefined timeout for integration tests
TEST_TIMEOUT_SECONDS: 120 # Predefined timeout for integration tests

concurrency:
# Support push/pr as event types with different behaviors each:
Expand All @@ -35,6 +35,8 @@ jobs:
group: apple-silicon
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Check what modules were changed
id: what_modules_changed
Expand Down
2 changes: 1 addition & 1 deletion protocol/notify/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ android {
buildConfigField("String", "PROJECT_ID", "\"${System.getenv("WC_CLOUD_PROJECT_ID") ?: ""}\"")
buildConfigField("String", "PROD_GM_PROJECT_ID", "\"${System.getenv("PROD_GM_PROJECT_ID") ?: ""}\"")
buildConfigField("String", "PROD_GM_SECRET", "\"${System.getenv("PROD_GM_SECRET") ?: ""}\"")
buildConfigField("Integer", "TEST_TIMEOUT_SECONDS", "${System.getenv("TEST_TIMEOUT_SECONDS") ?: 10}")
buildConfigField("Integer", "TEST_TIMEOUT_SECONDS", "${System.getenv("TEST_TIMEOUT_SECONDS") ?: 60}")

testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
testInstrumentationRunnerArguments += mutableMapOf("clearPackageData" to "true")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package com.walletconnect.notify.common.model

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic

data class RegisteredAccount(
val accountId: AccountId,
val publicIdentityKey: PublicKey,
val isLimited: Boolean,
val appDomain: String?
val appDomain: String?,
val notifyServerWatchTopic: Topic?,
val notifyServerAuthenticationKey: PublicKey?,
)
Talhaali00 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,33 @@
package com.walletconnect.notify.data.jwt

import com.walletconnect.foundation.util.jwt.JwtClaims
import java.util.concurrent.TimeUnit

internal interface NotifyJwtBase: JwtClaims {
internal interface NotifyJwtBase : JwtClaims {
val action: String
val issuedAt: Long
val expiration: Long
val requiredActionValue: String

private fun throwIdIssuedAtIsInvalid() {
val currentTimeSeconds = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS) + TimeUnit.SECONDS.convert(5, TimeUnit.MINUTES)
if (issuedAt > currentTimeSeconds)
throw IllegalArgumentException("Invalid issuedAt claim was $issuedAt instead of lower than $currentTimeSeconds")
}

private fun throwExpirationAtIsInvalid() {
val currentTimeSeconds = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS) - TimeUnit.SECONDS.convert(5, TimeUnit.MINUTES)
if (expiration < currentTimeSeconds)
throw IllegalArgumentException("Invalid expiration claim was $expiration instead of greater than $currentTimeSeconds")
}

private fun throwIfActionIsInvalid() {
if (action != requiredActionValue) throw IllegalArgumentException("Invalid action claim was $action instead of $requiredActionValue")
}

fun throwIfBaseIsInvalid() {
throwIdIssuedAtIsInvalid()
throwExpirationAtIsInvalid()
throwIfActionIsInvalid()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class DeleteRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_delete",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_delete"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class DeleteResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_delete_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_delete_response"
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ internal data class MessageRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "msg") val message: Message,
@Json(name = "act") override val action: String = "notify_message",
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE

@JsonClass(generateAdapter = true)
data class Message(
Expand All @@ -25,4 +26,8 @@ internal data class MessageRequestJwtClaim(
@Json(name = "url") val url: String,
@Json(name = "type") val type: String,
)
}

}

private const val ACTION_CLAIM_VALUE = "notify_message"

Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class MessageResponseJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "sub") val subject: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_message_response",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_message_response"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ internal data class SubscriptionRequestJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "scp") val scope: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_subscription",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscription"


Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class SubscriptionResponseJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "sub") val subject: String,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_subscription_response",
) : NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscription_response"
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class SubscriptionsChangedRequestJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_subscriptions_changed",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscriptions_changed"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class SubscriptionsChangedResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "exp") override val expiration: Long,
@Json(name = "act") override val action: String = "notify_subscriptions_changed_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_subscriptions_changed_response"
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ internal data class UpdateRequestJwtClaim(
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val app: String,
@Json(name = "scp") val scope: String,
@Json(name = "act") override val action: String = "notify_update",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_update"
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ internal data class UpdateResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "app") val app: String,
@Json(name = "act") override val action: String = "notify_update_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_update_response"
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ internal data class WatchSubscriptionsRequestJwtClaim(
@Json(name = "exp") override val expiration: Long,
@Json(name = "ksu") val keyserverUrl: String,
@Json(name = "app") val appDidWeb: String?,
@Json(name = "act") override val action: String = "notify_watch_subscriptions",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_watch_subscriptions"
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ internal data class WatchSubscriptionsResponseJwtClaim(
@Json(name = "iat") override val issuedAt: Long,
@Json(name = "exp") override val expiration: Long,
@Json(name = "sbs") val subscriptions: List<ServerSubscription>,
@Json(name = "act") override val action: String = "notify_watch_subscriptions_response",
): NotifyJwtBase
@Json(name = "act") override val action: String = ACTION_CLAIM_VALUE,
) : NotifyJwtBase {
override val requiredActionValue: String = ACTION_CLAIM_VALUE
}

private const val ACTION_CLAIM_VALUE = "notify_watch_subscriptions_response"
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,32 @@ package com.walletconnect.notify.data.storage

import com.walletconnect.android.internal.common.model.AccountId
import com.walletconnect.foundation.common.model.PublicKey
import com.walletconnect.foundation.common.model.Topic
import com.walletconnect.notify.common.model.RegisteredAccount
import com.walletconnect.notify.common.storage.data.dao.RegisteredAccountsQueries
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext

internal class RegisteredAccountsRepository(private val registeredAccounts: RegisteredAccountsQueries) {

suspend fun insertOrIgnoreAccount(
accountId: AccountId,
publicIdentityKey: PublicKey,
isLimited: Boolean,
appDomain: String?
) = withContext(Dispatchers.IO) {
suspend fun insertOrIgnoreAccount(accountId: AccountId, publicIdentityKey: PublicKey, isLimited: Boolean, appDomain: String?) = withContext(Dispatchers.IO) {
registeredAccounts.insertOrIgnoreAccount(accountId.value, publicIdentityKey.keyAsHex, isLimited, appDomain)
}

suspend fun updateNotifyServerData(accountId: AccountId, notifyServerWatchTopic: Topic, notifyServerAuthenticationKey: PublicKey) = withContext(Dispatchers.IO) {
registeredAccounts.updateNotifyServerData(
accountId = accountId.value, notifyServerWatchTopic = notifyServerWatchTopic.value, notifyServerAuthenticationKey = notifyServerAuthenticationKey.keyAsHex
)
}

suspend fun getAccountByAccountId(accountId: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByAccountId(accountId, ::toRegisterAccount).executeAsOne()
}

suspend fun getAccountByIdentityKey(identityPublicKey: String): RegisteredAccount = withContext(Dispatchers.IO) {
registeredAccounts.getAccountByIdentityKey(identityPublicKey, ::toRegisterAccount).executeAsOne()
}

suspend fun getAllAccounts(): List<RegisteredAccount> = withContext(Dispatchers.IO) {
registeredAccounts.getAllAccounts(::toRegisterAccount).executeAsList()
}
Expand All @@ -33,9 +39,7 @@ internal class RegisteredAccountsRepository(private val registeredAccounts: Regi
}

private fun toRegisterAccount(
accountId: String,
publicIdentityKey: String,
isLimited: Boolean,
appDomain: String?
): RegisteredAccount = RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain)
accountId: String, publicIdentityKey: String, isLimited: Boolean, appDomain: String?, notifyServerWatchTopic: String?, notifyServerAuthenticationKey: String?,
): RegisteredAccount =
RegisteredAccount(AccountId(accountId), PublicKey(publicIdentityKey), isLimited, appDomain, notifyServerWatchTopic?.let { Topic(it) }, notifyServerAuthenticationKey?.let { PublicKey(it) })
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ internal fun engineModule() = module {
single {
StopWatchingSubscriptionsUseCase(
jsonRpcInteractor = get(),
keyManagementRepository = get(),
extractPublicKeysFromDidJsonUseCase = get(),
notifyServerUrl = get(),
getSelfKeyForWatchSubscriptionUseCase = get()
registeredAccountsRepository = get()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ internal fun requestModule() = module {
extractPublicKeysFromDidJsonUseCase = get(),
jsonRpcInteractor = get(),
logger = get(),
notifyServerUrl = get()
notifyServerUrl = get(),
registeredAccountsRepository = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal fun responseModule() = module {
single {
OnWatchSubscriptionsResponseUseCase(
setActiveSubscriptionsUseCase = get(),
extractPublicKeysFromDidJsonUseCase = get(),
watchSubscriptionsForEveryRegisteredAccountUseCase = get(),
accountsRepository = get(),
notifyServerUrl = get()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ internal class UnregisterUseCase(
identitiesInteractor.unregisterIdentity(accountId, keyserverUrl).fold(
onFailure = { error -> onFailure(error) },
onSuccess = { identityPublicKey ->
runCatching { registeredAccountsRepository.deleteAccountByAccountId(account) }.fold(
runCatching {
stopWatchingSubscriptionsUseCase(accountId, onFailure = { error -> onFailure(error) })
registeredAccountsRepository.deleteAccountByAccountId(account)
}.fold(
onFailure = { error -> onFailure(error) },
onSuccess = {
stopWatchingSubscriptionsUseCase(accountId, onFailure)
subscriptionRepository.getAccountActiveSubscriptions(accountId).map { it.notifyTopic.value }.map { topic ->
jsonRpcInteractor.unsubscribe(Topic(topic)) { error -> onFailure(error) }
subscriptionRepository.deleteSubscriptionByNotifyTopic(topic)
Expand Down
Loading