From 2aeec3b3a8d0c0456e0bea6c6fa9b66df2a04cb2 Mon Sep 17 00:00:00 2001
From: z4yx <z4yx@pixel>
Date: Thu, 1 Aug 2024 13:57:45 +0800
Subject: [PATCH] speed up DC lookup with cache

---
 applets/ctap/ctap-internal.h | 13 +++++
 applets/ctap/ctap.c          | 98 ++++++++++++++++++++++--------------
 2 files changed, 72 insertions(+), 39 deletions(-)

diff --git a/applets/ctap/ctap-internal.h b/applets/ctap/ctap-internal.h
index f6cc1362..d4bc97da 100644
--- a/applets/ctap/ctap-internal.h
+++ b/applets/ctap/ctap-internal.h
@@ -23,6 +23,7 @@
 #define DC_FILE         "ctap_dc"
 #define DC_GENERAL_ATTR 0x00
 #define DC_META_FILE    "ctap_dm"
+#define DC_INDEX_FILE   "ctap_di"
 #define LB_FILE         "ctap_lb"
 #define LB_FILE_TMP     "ctap_lbt"
 
@@ -229,6 +230,12 @@
 #define LARGE_BLOB_SIZE_LIMIT         4096
 #define MAX_FRAGMENT_LENGTH           (MAX_CTAP_BUFSIZE - 64)
 
+#define BUILD_RPID_HASH0(hash0, vld)  ((0x7F & (hash0)) | (vld))
+enum {
+  INDEX_MATCH_VALID = 1,
+  INDEX_MATCH_RPID = 0xFF,
+};
+
 typedef struct {
   uint8_t id[USER_ID_MAX_SIZE];
   uint8_t id_size;
@@ -252,6 +259,12 @@ typedef struct {
   uint8_t cred_blob[MAX_CRED_BLOB_LENGTH];
 } __packed CTAP_discoverable_credential;
 
+typedef struct {
+  // bit[0] is valid
+  // bit[7:1] equals rp_id_hash[0][7:1]
+  uint8_t rp_id_hash0[MAX_DC_NUM];
+} __packed discoverable_credential_idx;
+
 typedef struct {
   uint8_t numbers;
   uint8_t index; // enough when MAX_DC_NUM == 64
diff --git a/applets/ctap/ctap.c b/applets/ctap/ctap.c
index 6c3b6e61..ba1238ef 100644
--- a/applets/ctap/ctap.c
+++ b/applets/ctap/ctap.c
@@ -15,6 +15,7 @@
 #include <hmac.h>
 #include <memzero.h>
 #include <rand.h>
+#include <stddef.h>
 
 #define CHECK_PARSER_RET(ret)                                                                                          \
   do {                                                                                                                 \
@@ -71,9 +72,11 @@ uint8_t ctap_install(uint8_t reset) {
     return 0;
   }
   uint8_t kh_key[KH_KEY_SIZE] = {0};
+  discoverable_credential_idx dc_idx = {0};
   if (write_file(DC_FILE, NULL, 0, 0, 1) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
   if (write_attr(DC_FILE, DC_GENERAL_ATTR, kh_key, sizeof(CTAP_dc_general_attr)) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
   if (write_file(DC_META_FILE, NULL, 0, 0, 1) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
+  if (write_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx), 1) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
   if (write_file(CTAP_CERT_FILE, NULL, 0, 0, 0) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
   if (write_attr(CTAP_CERT_FILE, SIGN_CTR_ATTR, kh_key, 4) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
   if (write_attr(CTAP_CERT_FILE, PIN_ATTR, NULL, 0) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
@@ -222,6 +225,9 @@ int ctap_consistency_check(void) {
         break;
       }
     }
+    int ret = 0; // write a zero (valid=0)
+    ret = write_file(DC_INDEX_FILE, &ret, attr.index * sizeof(uint8_t), sizeof(uint8_t), 0);
+    if (ret < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
     if (attr.pending_delete)
       attr.numbers--;
 
@@ -232,6 +238,16 @@ int ctap_consistency_check(void) {
   return 0;
 }
 
+bool find_dc_index(discoverable_credential_idx *dc_idx, int *i, uint8_t mask, uint8_t *hash, uint8_t valid) {
+  uint8_t val = BUILD_RPID_HASH0(hash[0], valid);
+  for (; *i < MAX_DC_NUM; ++(*i)) {
+    if ((dc_idx->rp_id_hash0[*i] & mask) == (val & mask)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 uint8_t ctap_make_auth_data(uint8_t *rp_id_hash, uint8_t *buf, uint8_t flags, const uint8_t *extension,
                             size_t extension_size, size_t *len, int32_t alg_type, bool dc, uint8_t cred_protect) {
   // See https://www.w3.org/TR/webauthn/#sec-authenticator-data
@@ -667,6 +683,9 @@ static uint8_t ctap_make_credential(CborEncoder *encoder, uint8_t *params, size_
     if (write_file(DC_META_FILE, &meta, meta_pos * (int) sizeof(CTAP_rp_meta),
                    sizeof(CTAP_rp_meta), 0) < 0)
       return CTAP2_ERR_UNHANDLED_REQUEST;
+    uint8_t rp_id_hash0 = BUILD_RPID_HASH0(mc.rp_id_hash[0], 1);
+    ret = write_file(DC_INDEX_FILE, &rp_id_hash0, pos * sizeof(uint8_t), sizeof(uint8_t), 0);
+    if (ret < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
     attr.pending_add = 0;
     ++attr.numbers;
     if (write_attr(DC_FILE, DC_GENERAL_ATTR, &attr, sizeof(attr)) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
@@ -756,6 +775,7 @@ static uint8_t ctap_get_assertion(CborEncoder *encoder, uint8_t *params, size_t
   static uint32_t timer;
 
   CTAP_discoverable_credential dc = {0}; // We use dc to store the selected credential
+  discoverable_credential_idx dc_idx;
   uint8_t data_buf[sizeof(CTAP_auth_data) + CLIENT_DATA_HASH_SIZE];
   ecc_key_t key;  // TODO: cleanup
   CborParser parser;
@@ -915,6 +935,8 @@ static uint8_t ctap_get_assertion(CborEncoder *encoder, uint8_t *params, size_t
   //       MUST NOT be returned if user verification is not done by the authenticator.
   if (ga.allow_list_size > 0) { // Step 11
     size_t i;
+    if (read_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx) < 0))
+      return CTAP2_ERR_UNHANDLED_REQUEST;
     for (i = 0; i < ga.allow_list_size; ++i) {
       parse_credential_descriptor(&ga.allow_list, (uint8_t *) &dc.credential_id);
       // compare the rp_id first
@@ -931,15 +953,12 @@ static uint8_t ctap_get_assertion(CborEncoder *encoder, uint8_t *params, size_t
           if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
           int n_dc = (int) (size / sizeof(CTAP_discoverable_credential));
           bool found = false;
+          int j = 0;
           DBG_MSG("%d discoverable credentials\n", n_dc);
-          for (int j = 0; j < n_dc; ++j) {
+          while (find_dc_index(&dc_idx, &j, INDEX_MATCH_RPID, ga.rp_id_hash, true)) {
             if (read_file(DC_FILE, &dc, j * (int) sizeof(CTAP_discoverable_credential),
                           sizeof(CTAP_discoverable_credential)) < 0)
               return CTAP2_ERR_UNHANDLED_REQUEST;
-            if (dc.deleted) {
-              DBG_MSG("Skipped DC at %d\n", j);
-              continue;
-            }
             if (memcmp_s(ga.rp_id_hash, dc.credential_id.rp_id_hash, SHA256_DIGEST_LENGTH) == 0 &&
                 memcmp_s(data_buf, dc.credential_id.nonce, sizeof(dc.credential_id.nonce)) == 0) {
               found = true;
@@ -964,20 +983,18 @@ static uint8_t ctap_get_assertion(CborEncoder *encoder, uint8_t *params, size_t
     }
     number_of_credentials = 1;
   } else { // Step 12
-    int size;
     if (credential_counter == 0) {
-      size = get_file_size(DC_FILE);
-      if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-      int n_dc = (int) (size / sizeof(CTAP_discoverable_credential));
+      int i = 0;
       number_of_credentials = 0;
-      for (int i = n_dc - 1; i >= 0; --i) {  // 12-b-1
-        if (read_file(DC_FILE, &dc, i * (int) sizeof(CTAP_discoverable_credential),
-                      sizeof(CTAP_discoverable_credential)) < 0)
+      if (read_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx) < 0))
+        return CTAP2_ERR_UNHANDLED_REQUEST;
+      // 12-b-1
+      while (find_dc_index(&dc_idx, &i, INDEX_MATCH_RPID, ga.rp_id_hash, true)) {
+        // read dc.credential_id for comparsion
+        if (read_file(DC_FILE, &dc.credential_id,
+                      (int)(i*sizeof(CTAP_discoverable_credential)+offsetof(CTAP_discoverable_credential, credential_id)),
+                      sizeof(dc.credential_id)) < 0)
           return CTAP2_ERR_UNHANDLED_REQUEST;
-        if (dc.deleted) {
-          DBG_MSG("Skipped DC at %d\n", i);
-          continue;
-        }
         // Skip the credential which is protected
         if (!check_credential_protect_requirements(&dc.credential_id, false, uv)) continue;
         if (memcmp_s(ga.rp_id_hash, dc.credential_id.rp_id_hash, SHA256_DIGEST_LENGTH) == 0)
@@ -1655,6 +1672,7 @@ static uint8_t ctap_credential_management(CborEncoder *encoder, const uint8_t *p
   CTAP_rp_meta meta;
   CTAP_discoverable_credential dc;
   bool include_numbers;
+  discoverable_credential_idx dc_idx;
 
   if (cm.sub_command == CM_CMD_GET_CREDS_METADATA ||
       cm.sub_command == CM_CMD_ENUMERATE_RPS_BEGIN ||
@@ -1777,23 +1795,25 @@ static uint8_t ctap_credential_management(CborEncoder *encoder, const uint8_t *p
     case CM_CMD_ENUMERATE_CREDENTIALS_BEGIN:
       if (!cp_verify_rp_id(cm.rp_id_hash)) return CTAP2_ERR_PIN_AUTH_INVALID;
       if (numbers == 0) return CTAP2_ERR_NO_CREDENTIALS;
+      if (read_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx) < 0))
+        return CTAP2_ERR_UNHANDLED_REQUEST;
+
       include_numbers = true;
-      size = get_file_size(DC_META_FILE);
-      n_rp = size / (int) sizeof(CTAP_rp_meta);
+      slots = 0ull;
       KEEPALIVE();
-      for (idx = 0; idx < n_rp; ++idx) {
-        size = read_file(DC_META_FILE, &meta, idx * (int) sizeof(CTAP_rp_meta), sizeof(CTAP_rp_meta));
+      for (idx = 0; find_dc_index(&dc_idx, &idx, INDEX_MATCH_RPID, cm.rp_id_hash, true); ) {
+        size = read_file(DC_FILE, &dc.credential_id.rp_id_hash,
+                        (int)(idx*sizeof(CTAP_discoverable_credential)+offsetof(CTAP_discoverable_credential, credential_id.rp_id_hash)),
+                        sizeof(CTAP_discoverable_credential));
         if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-        if (meta.slots == 0) continue;
-        if (memcmp_s(meta.rp_id_hash, cm.rp_id_hash, SHA256_DIGEST_LENGTH) == 0) break;
+        if (memcmp_s(dc.credential_id.rp_id_hash, cm.rp_id_hash, SHA256_DIGEST_LENGTH) == 0) {
+          slots |= 1ull << idx;
+        }
       }
-      if (idx == n_rp) {
+      if (slots == 0ull) {
         DBG_MSG("Specified RP not found\n");
         return CTAP2_ERR_NO_CREDENTIALS;
       }
-      DBG_MSG("Use meta at slot %d: ", idx);
-      PRINT_HEX((const uint8_t *) &meta, sizeof(meta));
-      slots = meta.slots;
     generate_credential_response:
       DBG_MSG("Current slot bitmap: 0x%llx\n", slots);
       idx = get_next_slot(&slots, &numbers);
@@ -1897,21 +1917,20 @@ static uint8_t ctap_credential_management(CborEncoder *encoder, const uint8_t *p
     case CM_CMD_DELETE_CREDENTIAL:
       if (!cp_verify_rp_id(cm.credential_id.rp_id_hash)) return CTAP2_ERR_PIN_AUTH_INVALID;
       if (numbers == 0) return CTAP2_ERR_NO_CREDENTIALS;
-      size = get_file_size(DC_FILE);
-      if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-      numbers = size / sizeof(CTAP_discoverable_credential);
-      for (idx = 0; idx < numbers; ++idx) {
+      if (read_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx) < 0))
+        return CTAP2_ERR_UNHANDLED_REQUEST;
+
+      for (idx = 0; find_dc_index(&dc_idx, &idx, INDEX_MATCH_RPID, cm.credential_id.rp_id_hash, true); ) {
         size = read_file(DC_FILE, &dc, idx * (int) sizeof(CTAP_discoverable_credential),
                          sizeof(CTAP_discoverable_credential));
         if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-        if (dc.deleted) continue;
         if (memcmp_s(&dc.credential_id, &cm.credential_id, sizeof(credential_id)) == 0) {
           DBG_MSG("Found, credential_id: ");
           PRINT_HEX((const uint8_t *) &dc.credential_id, sizeof(credential_id));
           break;
         }
       }
-      if (idx == numbers) return CTAP2_ERR_NO_CREDENTIALS;
+      if (idx == MAX_DC_NUM) return CTAP2_ERR_NO_CREDENTIALS;
 
       CTAP_dc_general_attr attr;
       if (read_attr(DC_FILE, DC_GENERAL_ATTR, &attr, sizeof(attr)) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
@@ -1943,6 +1962,9 @@ static uint8_t ctap_credential_management(CborEncoder *encoder, const uint8_t *p
           break;
         }
       }
+      size = 0;
+      size = write_file(DC_INDEX_FILE, &size, idx * sizeof(uint8_t), sizeof(uint8_t), 0);
+      if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
       attr.numbers--;
       attr.pending_delete = 0;
       if (write_attr(DC_FILE, DC_GENERAL_ATTR, &attr, sizeof(attr)) < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
@@ -1951,23 +1973,21 @@ static uint8_t ctap_credential_management(CborEncoder *encoder, const uint8_t *p
     case CM_CMD_UPDATE_USER_INFORMATION:
       if (!cp_verify_rp_id(cm.credential_id.rp_id_hash)) return CTAP2_ERR_PIN_AUTH_INVALID;
       if (numbers == 0) return CTAP2_ERR_NO_CREDENTIALS;
-      // TODO: refactor this
-      size = get_file_size(DC_FILE);
-      if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-      numbers = size / sizeof(CTAP_discoverable_credential);
+      if (read_file(DC_INDEX_FILE, &dc_idx, 0, sizeof(dc_idx) < 0))
+        return CTAP2_ERR_UNHANDLED_REQUEST;
+
       KEEPALIVE();
-      for (idx = 0; idx < numbers; ++idx) {
+      for (idx = 0; find_dc_index(&dc_idx, &idx, INDEX_MATCH_RPID, cm.credential_id.rp_id_hash, true); ) {
         size = read_file(DC_FILE, &dc, idx * (int) sizeof(CTAP_discoverable_credential),
                          sizeof(CTAP_discoverable_credential));
         if (size < 0) return CTAP2_ERR_UNHANDLED_REQUEST;
-        if (dc.deleted) continue;
         if (memcmp_s(&dc.credential_id, &cm.credential_id, sizeof(credential_id)) == 0) {
           DBG_MSG("Found, credential_id: ");
           PRINT_HEX((const uint8_t *) &dc.credential_id, sizeof(credential_id));
           break;
         }
       }
-      if (idx == numbers) {
+      if (idx == MAX_DC_NUM) {
         DBG_MSG("No matching credential\n");
         return CTAP2_ERR_NO_CREDENTIALS;
       }