Skip to content

Commit

Permalink
"Refactor SimilarityIndex initialization and setupDimension method fo…
Browse files Browse the repository at this point in the history
…r improved efficiency and clarity."
  • Loading branch information
buhe committed Feb 13, 2024
1 parent eef5ab0 commit 7bf8423
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
8 changes: 3 additions & 5 deletions Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,16 @@ public class SimilarityIndex {

// MARK: - Initializers

public init(name: String? = nil, model: (any EmbeddingsProtocol)? = nil, metric: (any DistanceMetricProtocol)? = nil, vectorStore: (any VectorStoreProtocol)? = nil) async {
public init(name: String? = nil, model: (any EmbeddingsProtocol)? = nil, metric: (any DistanceMetricProtocol)? = nil, vectorStore: (any VectorStoreProtocol)? = nil) {
// Setup index with defaults
self.indexName = name ?? "SimilaritySearchKitIndex"
self.indexModel = model ?? NativeEmbeddings()
self.indexMetric = metric ?? CosineSimilarity()
self.vectorStore = vectorStore ?? JsonStore()

// Run the model once to discover dimention size
await setupDimension()
}

private func setupDimension() async {
// Run the model once to discover dimention size
public func setupDimension() async {
if let testVector = await indexModel.encode(sentence: "Test sentence") {
dimension = testVector.count
} else {
Expand Down
7 changes: 4 additions & 3 deletions Tests/SimilaritySearchKitTests/BenchmarkTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class BenchmarkTests: XCTestCase {
let expectation = XCTestExpectation(description: "Encoding passage texts")

Task {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings())
let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings())
await similarityIndex.setupDimension()
await similarityIndex.addItems(
ids: [UUID().uuidString],
texts: [searchPassage.text],
Expand Down Expand Up @@ -125,8 +126,8 @@ class BenchmarkTests: XCTestCase {

Task {
print("\nGenerating similarity index for \(testAmount) passages")
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings())

let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings())
await similarityIndex.setupDimension()
var startTime = CFAbsoluteTimeGetCurrent()
await similarityIndex.addItems(
ids: passageIds,
Expand Down
24 changes: 12 additions & 12 deletions Tests/SimilaritySearchKitTests/SimilaritySearchKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class SimilaritySearchKitTests: XCTestCase {
}

func testSavingJsonIndex() async {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())

let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())
await similarityIndex.setupDimension()
await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"], embedding: [0.1, 0.2, 0.3])

let successPath = try! similarityIndex.saveIndex(name: "TestIndexForSaving")
Expand All @@ -31,24 +31,24 @@ class SimilaritySearchKitTests: XCTestCase {
}

func testLoadingJsonIndex() async {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())

let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())
await similarityIndex.setupDimension()
await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"])

let successPath = try! similarityIndex.saveIndex(name: "TestIndexForLoading")

XCTAssertNotNil(successPath)

let similarityIndex2 = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())

let similarityIndex2 = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: JsonStore())
await similarityIndex2.setupDimension()
let loadedItems = try! similarityIndex2.loadIndex(name: "TestIndexForLoading")

XCTAssertNotNil(loadedItems)
}

func testSavingBinaryIndex() async {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())

let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())
await similarityIndex.setupDimension()
await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"], embedding: [0.1, 0.2, 0.3])

let successPath = try! similarityIndex.saveIndex(name: "TestIndexForSaving")
Expand All @@ -57,16 +57,16 @@ class SimilaritySearchKitTests: XCTestCase {
}

func testLoadingBinaryIndex() async {
let similarityIndex = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())

let similarityIndex = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())
await similarityIndex.setupDimension()
await similarityIndex.addItem(id: "1", text: "Example text", metadata: ["source": "test source"])

let successPath = try! similarityIndex.saveIndex(name: "TestIndexForLoading")

XCTAssertNotNil(successPath)

let similarityIndex2 = await SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())

let similarityIndex2 = SimilarityIndex(model: DistilbertEmbeddings(), vectorStore: BinaryStore())
await similarityIndex.setupDimension()
let loadedItems = try! similarityIndex2.loadIndex(name: "TestIndexForLoading")

XCTAssertNotNil(loadedItems)
Expand Down

0 comments on commit 7bf8423

Please sign in to comment.