Skip to content

Commit

Permalink
Add Cosine Similarity Algorithm for Strings (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalkwst committed Aug 1, 2024
1 parent b0838cb commit 9eb2196
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 0 deletions.
84 changes: 84 additions & 0 deletions Algorithms.Tests/Strings/Similarity/CosineSimilarityTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using System;
using Algorithms.Strings.Similarity;
using NUnit.Framework;

namespace Algorithms.Tests.Strings.Similarity;

[TestFixture]
public class CosineSimilarityTests
{
[Test]
public void Calculate_IdenticalStrings_ReturnsOne()
{
var str1 = "test";
var str2 = "test";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "Identical strings should have a cosine similarity of 1.");
}

[Test]
public void Calculate_CompletelyDifferentStrings_ReturnsZero()
{
var str1 = "abc";
var str2 = "xyz";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Completely different strings should have a cosine similarity of 0.");
}

[Test]
public void Calculate_EmptyStrings_ReturnsZero()
{
var str1 = "";
var str2 = "";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Empty strings should have a cosine similarity of 0.");
}

[Test]
public void Calculate_OneEmptyString_ReturnsZero()
{
var str1 = "test";
var str2 = "";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Empty string should have a cosine similarity of 0.");
}

[Test]
public void Calculate_SameCharactersDifferentCases_ReturnsOne()
{
var str1 = "Test";
var str2 = "test";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "The method should be case-insensitive.");
}

[Test]
public void Calculate_SpecialCharacters_ReturnsCorrectValue()
{
var str1 = "hello!";
var str2 = "hello!";
var result = CosineSimilarity.Calculate(str1, str2);
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "Strings with special characters should have a cosine similarity of 1.");
}

[Test]
public void Calculate_DifferentLengthWithCommonCharacters_ReturnsCorrectValue()
{
var str1 = "hello";
var str2 = "hello world";
var result = CosineSimilarity.Calculate(str1, str2);
var expected = 10 / (Math.Sqrt(7) * Math.Sqrt(19)); // calculated manually
Assert.That(result, Is.EqualTo(expected).Within(1e-6), "Strings with different lengths but some common characters should have the correct cosine similarity.");
}

[Test]
public void Calculate_PartiallyMatchingStrings_ReturnsCorrectValue()
{
var str1 = "night";
var str2 = "nacht";
var result = CosineSimilarity.Calculate(str1, str2);
// Assuming the correct calculation gives an expected value
var expected = 3.0 / 5.0;
Assert.That(result, Is.EqualTo(expected).Within(1e-6), "Partially matching strings should have the correct cosine similarity.");
}
}
136 changes: 136 additions & 0 deletions Algorithms/Strings/Similarity/CosineSimilarity.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using System;
using System.Collections.Generic;

namespace Algorithms.Strings.Similarity;

public static class CosineSimilarity
{
/// <summary>
/// Calculates the Cosine Similarity between two strings.
/// Cosine Similarity is a measure of similarity between two non-zero vectors of an inner product space.
/// It measures the cosine of the angle between the two vectors.
/// </summary>
/// <param name="left">The first string.</param>
/// <param name="right">The second string.</param>
/// <returns>
/// A double value between 0 and 1 that represents the similarity
/// of the two strings.
/// </returns>
public static double Calculate(string left, string right)
{
// Step 1: Get the vectors for the two strings
// Each vector represents the frequency of each character in the string.
var vectors = GetVectors(left.ToLowerInvariant(), right.ToLowerInvariant());
var leftVector = vectors.leftVector;
var rightVector = vectors.rightVector;

// Step 2: Calculate the intersection of the two vectors
// The intersection is the set of characters that appear in both strings.
var intersection = GetIntersection(leftVector, rightVector);

// Step 3: Calculate the dot product of the two vectors
// The dot product is the sum of the products of the corresponding values of the characters in the intersection.
var dotProduct = DotProduct(leftVector, rightVector, intersection);

// Step 4: Calculate the square magnitude of each vector
// The magnitude is the square root of the sum of the squares of the values in the vector.
var mLeft = 0.0;
foreach (var value in leftVector.Values)
{
mLeft += value * value;
}

var mRight = 0.0;
foreach (var value in rightVector.Values)
{
mRight += value * value;
}

// Step 5: Check if either vector is zero
// If either vector is zero (i.e., all characters are unique), the Cosine Similarity is 0.
if (mLeft <= 0 || mRight <= 0)
{
return 0.0;
}

// Step 6: Calculate and return the Cosine Similarity
// The Cosine Similarity is the dot product divided by the product of the magnitudes.
return dotProduct / (Math.Sqrt(mLeft) * Math.Sqrt(mRight));
}

/// <summary>
/// Calculates the vectors for the given strings.
/// </summary>
/// <param name="left">The first string.</param>
/// <param name="right">The second string.</param>
/// <returns>A tuple containing the vectors for the two strings.</returns>
private static (Dictionary<char, int> leftVector, Dictionary<char, int> rightVector) GetVectors(string left, string right)
{
var leftVector = new Dictionary<char, int>();
var rightVector = new Dictionary<char, int>();

// Calculate the frequency of each character in the left string
foreach (var character in left)
{
leftVector.TryGetValue(character, out var frequency);
leftVector[character] = ++frequency;
}

// Calculate the frequency of each character in the right string
foreach (var character in right)
{
rightVector.TryGetValue(character, out var frequency);
rightVector[character] = ++frequency;
}

return (leftVector, rightVector);
}

/// <summary>
/// Calculates the dot product between two vectors represented as dictionaries of character frequencies.
/// The dot product is the sum of the products of the corresponding values of the characters in the intersection of the two vectors.
/// </summary>
/// <param name="leftVector">The vector of the left string.</param>
/// <param name="rightVector">The vector of the right string.</param>
/// <param name="intersection">The intersection of the two vectors, represented as a set of characters.</param>
/// <returns>The dot product of the two vectors.</returns>
private static double DotProduct(Dictionary<char, int> leftVector, Dictionary<char, int> rightVector, HashSet<char> intersection)
{
// Initialize the dot product to 0
double dotProduct = 0;

// Iterate over each character in the intersection of the two vectors
foreach (var character in intersection)
{
// Calculate the product of the corresponding values of the characters in the left and right vectors
dotProduct += leftVector[character] * rightVector[character];
}

// Return the dot product
return dotProduct;
}

/// <summary>
/// Calculates the intersection of two vectors, represented as dictionaries of character frequencies.
/// </summary>
/// <param name="leftVector">The vector of the left string.</param>
/// <param name="rightVector">The vector of the right string.</param>
/// <returns>A HashSet containing the characters that appear in both vectors.</returns>
private static HashSet<char> GetIntersection(Dictionary<char, int> leftVector, Dictionary<char, int> rightVector)
{
// Initialize a HashSet to store the intersection of the two vectors.
var intersection = new HashSet<char>();

// Iterate over each key-value pair in the left vector.
foreach (var kvp in leftVector)
{
// If the right vector contains the same key, add it to the intersection.
if (rightVector.ContainsKey(kvp.Key))
{
intersection.Add(kvp.Key);
}
}

return intersection;
}
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ find more than one implementation for the same objective but using different alg
* [A181391 Van Eck's](./Algorithms/Sequences/VanEcksSequence.cs)
* [String](./Algorithms/Strings)
* [Similarity](./Algorithms/Strings/Similarity/)
* [Cosine Similarity](./Algorithms/Strings/Similarity/CosineSimilarity.cs)
* [Hamming Distance](./Algorithms/Strings/Similarity/HammingDistance.cs)
* [Jaro Similarity](./Algorithms/Strings/Similarity/JaroSimilarity.cs)
* [Jaro-Winkler Distance](./Algorithms/Strings/Similarity/JaroWinklerDistance.cs)
Expand Down

0 comments on commit 9eb2196

Please sign in to comment.