Skip to content

Commit

Permalink
Fixed include for pca; pass in kwargs for projection algorithm; versi…
Browse files Browse the repository at this point in the history
…on 2.4.1
  • Loading branch information
dsblank committed May 19, 2023
1 parent 1494b7b commit a7ce3ab
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 30 deletions.
2 changes: 1 addition & 1 deletion backend/kangas/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# All rights reserved #
######################################################

version_info = (2, 4, 0)
version_info = (2, 4, 1)
__version__ = ".".join(map(str, version_info))
55 changes: 39 additions & 16 deletions backend/kangas/datatypes/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
source=None,
unserialize=False,
dimensions=PROJECTION_DIMENSIONS,
**kwargs
):
"""
Create an embedding vector.
Expand All @@ -79,6 +80,7 @@ def __init__(
projection. Useful if you want to see one part of the datagrid in
the project of another.
dimensions: (int) maximum number of dimensions
kwargs: (dict) optional keyword arguments for projection algorithm
Example:
Expand All @@ -91,6 +93,11 @@ def __init__(
>>> dg.save("embeddings.datagrid")
```
"""
if not include and projection not in ["pca"]:
raise Exception(
"projection '%s' does not allow embeddings to be excluded; change projection or set include=True"
)

super().__init__(source)
if unserialize:
self._unserialize = unserialize
Expand All @@ -115,6 +122,7 @@ def __init__(
self.metadata["projection"] = projection
self.metadata["include"] = include
self.metadata["dimensions"] = dimensions
self.metadata["kwargs"] = kwargs

if file_name:
if is_valid_file_path(file_name):
Expand Down Expand Up @@ -174,7 +182,9 @@ def get_statistics(cls, datagrid, col_name, field_name):

projection = None
batch = []
asset_ids = []
batch_asset_ids = []
not_included = []
not_included_asset_ids = []

for row in datagrid.conn.execute(
"""SELECT {field_name} as assetId, asset_data, asset_metadata from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
Expand All @@ -186,20 +196,11 @@ def get_statistics(cls, datagrid, col_name, field_name):
continue

asset_metdata = json.loads(asset_metadata_json)

projection = asset_metdata["projection"]
include = asset_metdata["include"]
dimensions = asset_metdata["dimensions"]

# Skip if explicitly False
if not include:
continue

asset_data = json.loads(asset_data_json)
vector = prepare_embedding(asset_data["vector"], dimensions, seed)

# Save asset_id to update assets next
batch.append(vector)
asset_ids.append(asset_id)
kwargs = asset_metdata["kwargs"]

if projection == "pca":
projection_name = "pca"
Expand All @@ -208,19 +209,37 @@ def get_statistics(cls, datagrid, col_name, field_name):
elif projection == "umap":
projection_name = "umap"
else:
raise Exception("projection not found")
raise Exception("projection not found for %s" % asset_id)

asset_data = json.loads(asset_data_json)
vector = prepare_embedding(asset_data["vector"], dimensions, seed)

if include:
batch.append(vector)
batch_asset_ids.append(asset_id)
else:
not_included.append(vector)
not_included_asset_ids.append(asset_id)

if projection_name == "pca":
from sklearn.decomposition import PCA

projection = PCA(n_components=2)
if "n_components" not in kwargs:
kwargs["n_components"] = 2

projection = PCA(**kwargs)
transformed = projection.fit_transform(np.array(batch))
if not_included:
transformed_not_included = projection.transform(np.array(not_included))
else:
transformed_not_included = np.array([])

elif projection_name == "t-sne":
from sklearn.manifold import TSNE

projection = TSNE()
projection = TSNE(**kwargs)
transformed = projection.fit_transform(np.array(batch))
transformed_not_included = np.array([])

elif projection_name == "umap":
pass # TODO
Expand All @@ -244,7 +263,11 @@ def get_statistics(cls, datagrid, col_name, field_name):

# update assets with transformed
cursor = datagrid.conn.cursor()
for asset_id, tran in zip(asset_ids, transformed):
if not_included_asset_ids:
batch_asset_ids = batch_asset_ids + not_included_asset_ids
transformed = np.concatenate((transformed, transformed_not_included))

for asset_id, tran in zip(batch_asset_ids, transformed):
sql = """SELECT asset_data from assets WHERE asset_id = ?;"""
asset_data_json = datagrid.conn.execute(sql, (asset_id,)).fetchone()[0]
asset_data = json.loads(asset_data_json)
Expand Down
40 changes: 27 additions & 13 deletions notebooks/Visualizing_embeddings_in_Kangas.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"1001it [00:00, 3257.63it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 4498.83it/s]\n"
"1001it [00:00, 2097.76it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2324.51it/s]\n"
]
}
],
Expand Down Expand Up @@ -237,7 +237,10 @@
" row[8] = kg.Embedding(\n",
" ast.literal_eval(row[8]), \n",
" name=str(row[3]), \n",
" text=\"%s - %.10s\" % (row[3], row[4])\n",
" text=\"%s - %.10s\" % (row[3], row[4]),\n",
" projection=\"t-sne\",\n",
" learning_rate=10.0,\n",
" n_iter=500,\n",
" )\n",
" dg.append(row)"
]
Expand Down Expand Up @@ -309,7 +312,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 34767.11it/s]"
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 28791.81it/s]"
]
},
{
Expand Down Expand Up @@ -337,7 +340,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 636.93it/s]\n"
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 377.84it/s]\n"
]
},
{
Expand All @@ -351,7 +354,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00, 7.00it/s]\n"
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:08<00:00, 1.23it/s]\n"
]
}
],
Expand All @@ -365,7 +368,11 @@
"source": [
"### 3. Render 2D Projections\n",
"\n",
"To render the data directly in the notebook, simply show it. Note that each row contains an embedding projection. Group by \"Score\" to see rows of each group."
"To render the data directly in the notebook, simply show it. Note that each row contains an embedding projection. \n",
"\n",
"Scroll to far right to see embeddings projection per row.\n",
"\n",
"The color of the point in projection space represents the Score."
]
},
{
Expand All @@ -387,15 +394,15 @@
" <iframe\n",
" width=\"100%\"\n",
" height=\"750px\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1684437830.776708\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1684538685.5566168\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fb5663d4bb0>"
"<IPython.lib.display.IFrame at 0x7f1d99683340>"
]
},
"metadata": {},
Expand All @@ -406,6 +413,13 @@
"dg.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Group by \"Score\" to see rows of each group. Again, scroll right to see groups of embeddings."
]
},
{
"cell_type": "code",
"execution_count": 10,
Expand All @@ -418,23 +432,23 @@
" <iframe\n",
" width=\"100%\"\n",
" height=\"750px\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1684437830.776708&group=Score&sort=Score\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1684538685.5566168&group=Score&sort=Score&rows=5\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fb526293ac0>"
"<IPython.lib.display.IFrame at 0x7f1e655b6a70>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dg.show(group=\"Score\", sort=\"Score\")"
"dg.show(group=\"Score\", sort=\"Score\", rows=5)"
]
},
{
Expand Down Expand Up @@ -470,7 +484,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit a7ce3ab

Please sign in to comment.