Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into checkpoint-disabling
Browse files Browse the repository at this point in the history
sshivam95 committed Nov 12, 2024
2 parents d4e0129 + 7001e31 commit 60ec608
Showing 5 changed files with 39 additions and 668 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/github-actions-python-package.yml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ jobs:
- name: Lint with ruff
run: |
ruff --select=E501 --line-length=200 dicee/
ruff check dicee/ --select=E501 --line-length=200
- name: Test with pytest
run: |
wget https://files.dice-research.org/datasets/dice-embeddings/KGs.zip --no-check-certificate && unzip KGs.zip
682 changes: 21 additions & 661 deletions LICENSE

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -37,6 +37,8 @@ Deploy a pre-trained embedding model without writing a single line of code.
git clone https://github.com/dice-group/dice-embeddings.git
conda create -n dice python=3.10.13 --no-default-packages && conda activate dice
pip3 install -e .
# or
pip3 install -e .["dev"]
```
or
```bash
17 changes: 13 additions & 4 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
@@ -10,10 +10,11 @@
import numpy as np
import sys


# import gradio as gr

import traceback


class KGE(BaseInteractiveKGE):
""" Knowledge Graph Embedding Class for interactive usage of pre-trained models"""

@@ -22,6 +23,13 @@ def __init__(self, path=None, url=None, construct_ensemble=False,
apply_semantic_constraint=False):
super().__init__(path=path, url=url, construct_ensemble=construct_ensemble, model_name=model_name)

def __str__(self):
return "KGE | " + str(self.model)

def to(self, device: str) -> None:
assert "cpu" in device or "cuda" in device, "Device must be either cpu or cuda"
self.model.to(device)

def get_transductive_entity_embeddings(self,
indices: Union[torch.LongTensor, List[str]],
as_pytorch=False,
@@ -118,9 +126,6 @@ def generate(self, h="", r=""):
counter += 1
print(self.enc.decode(tokens), end=f"\t {score}\n")

def __str__(self):
return "KGE | " + str(self.model)

# given a string, return is bpe encoded embeddings
def eval_lp_performance(self, dataset=List[Tuple[str, str, str]], filtered=True):
assert isinstance(dataset, list) and len(dataset) > 0
@@ -175,6 +180,7 @@ def predict_missing_head_entity(self, relation: Union[List[str], str], tail_enti
x = torch.stack((head_entity,
relation.repeat(self.num_entities, ),
tail_entity.repeat(self.num_entities, )), dim=1)
x = x.to(self.model.device)
return self.model(x)

def predict_missing_relations(self, head_entity: Union[List[str], str],
@@ -283,6 +289,7 @@ def predict_missing_tail_entity(self, head_entity: Union[List[str], str],
x = torch.stack((head_entity.repeat(self.num_entities, ),
relation.repeat(self.num_entities, ),
tail_entity), dim=1)
x = x.to(self.model.device)
return self.model(x)

def predict(self, *, h: Union[List[str], str] = None, r: Union[List[str], str] = None,
@@ -483,6 +490,7 @@ def triple_score(self, h: Union[List[str], str] = None, r: Union[List[str], str]
raise NotImplementedError()
else:
with torch.no_grad():
x = x.to(self.model.device)
if logits:
return self.model(x)
else:
@@ -1202,6 +1210,7 @@ def train_triples(self, h: List[str], r: List[str], t: List[str], labels: List[f
# (5) Eval
self.set_model_eval_mode()
with torch.no_grad():
x = x.to(self.model.device)
outputs = self.model(x)
loss = self.model.loss(outputs, labels)
print(f"Eval Mode:\tLoss:{loss.item()}")
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
"lightning>=2.1.3",
"pandas>=2.1.0",
"numpy==1.26.4",
"polars>=0.16.14",
"polars==0.16.14",
"scikit-learn>=1.2.2",
"pyarrow>=11.0.0",
"pykeen>=1.10.2",
@@ -55,7 +55,7 @@ def deps_list(*pkgs):
setup(
name="dicee",
description="Dice embedding is an hardware-agnostic framework for large-scale knowledge graph embedding applications",
version="0.1.4",
version="0.1.5",
packages=find_packages(),
extras_require=extras,
install_requires=list(install_requires),

0 comments on commit 60ec608

Please sign in to comment.