diff --git a/db_test.go b/db_test.go index 33e08eb..b545c06 100644 --- a/db_test.go +++ b/db_test.go @@ -526,6 +526,29 @@ func TestDB_Observable(t *testing.T) { require.False(t, ok, "expected channel to close, got event: %+v", ev) } +func TestDB_NumObjects(t *testing.T) { + t.Parallel() + + db, table, _ := newTestDB(t) + rtxn := db.ReadTxn() + assert.Equal(t, 0, table.NumObjects(rtxn)) + + txn := db.WriteTxn(table) + assert.Equal(t, 0, table.NumObjects(txn)) + table.Insert(txn, testObject{ID: uint64(1)}) + assert.Equal(t, 1, table.NumObjects(txn)) + table.Insert(txn, testObject{ID: uint64(1)}) + table.Insert(txn, testObject{ID: uint64(2)}) + assert.Equal(t, 2, table.NumObjects(txn)) + + assert.Equal(t, 0, table.NumObjects(rtxn)) + txn.Commit() + assert.Equal(t, 0, table.NumObjects(rtxn)) + + rtxn = db.ReadTxn() + assert.Equal(t, 2, table.NumObjects(rtxn)) +} + func TestDB_All(t *testing.T) { t.Parallel() @@ -888,6 +911,7 @@ func TestDB_Initialization(t *testing.T) { wtxn = db.WriteTxn(table) done1(wtxn) + require.False(t, table.Initialized(txn), "Initialized should be false") wtxn.Commit() // Old read transaction unaffected. @@ -900,10 +924,11 @@ func TestDB_Initialization(t *testing.T) { wtxn = db.WriteTxn(table) done2(wtxn) + assert.True(t, table.Initialized(wtxn), "Initialized should be true") wtxn.Commit() txn = db.ReadTxn() - require.True(t, table.Initialized(txn), "Initialized should be false") + require.True(t, table.Initialized(txn), "Initialized should be true") require.Empty(t, table.PendingInitializers(txn), "There should be no pending initializers") } diff --git a/table.go b/table.go index a12603b..7194946 100644 --- a/table.go +++ b/table.go @@ -173,7 +173,7 @@ func (t *genTable[Obj]) Initialized(txn ReadTxn) bool { return len(t.PendingInitializers(txn)) == 0 } func (t *genTable[Obj]) PendingInitializers(txn ReadTxn) []string { - return txn.getTxn().root[t.pos].pendingInitializers + return txn.getTxn().getTableEntry(t).pendingInitializers } func (t *genTable[Obj]) RegisterInitializer(txn WriteTxn, name string) func(WriteTxn) { @@ -201,12 +201,12 @@ func (t *genTable[Obj]) RegisterInitializer(txn WriteTxn, name string) func(Writ } func (t *genTable[Obj]) Revision(txn ReadTxn) Revision { - return txn.getTxn().getRevision(t) + return txn.getTxn().getTableEntry(t).revision } func (t *genTable[Obj]) NumObjects(txn ReadTxn) int { - table := &txn.getTxn().root[t.tablePos()] - return table.indexes[PrimaryIndexPos].tree.Len() + table := txn.getTxn().getTableEntry(t) + return table.numObjects() } func (t *genTable[Obj]) Get(txn ReadTxn, q Query[Obj]) (obj Obj, revision uint64, ok bool) { diff --git a/txn.go b/txn.go index 38ff9b3..c9934e0 100644 --- a/txn.go +++ b/txn.go @@ -55,14 +55,14 @@ func txnFinalizer(txn *txn) { } } -func (txn *txn) getRevision(meta TableMeta) Revision { +func (txn *txn) getTableEntry(meta TableMeta) *tableEntry { if txn.modifiedTables != nil { entry := txn.modifiedTables[meta.tablePos()] if entry != nil { - return entry.revision + return entry } } - return txn.root[meta.tablePos()].revision + return &txn.root[meta.tablePos()] } // indexReadTxn returns a transaction to read from the specific index. diff --git a/types.go b/types.go index f0e808d..f799de2 100644 --- a/types.go +++ b/types.go @@ -390,10 +390,16 @@ type tableEntry struct { func (t *tableEntry) numObjects() int { indexEntry := t.indexes[t.meta.indexPos(RevisionIndex)] + if indexEntry.txn != nil { + return indexEntry.txn.Len() + } return indexEntry.tree.Len() } func (t *tableEntry) numDeletedObjects() int { indexEntry := t.indexes[t.meta.indexPos(GraveyardIndex)] + if indexEntry.txn != nil { + return indexEntry.txn.Len() + } return indexEntry.tree.Len() }