Skip to content

Commit

Permalink
fix but in NCES dataset class
Browse files Browse the repository at this point in the history
Jean-KOUAGOU committed Jan 22, 2025
1 parent 1398ff5 commit 5b1b47c
Showing 5 changed files with 128 additions and 40 deletions.
4 changes: 2 additions & 2 deletions examples/train_nces.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@ def start(args):
else:
synthesizer = ROCES(knowledge_base_path=args.kb, auto_train=False, k=5, max_length=48, proj_dim=128, embedding_dim=args.embedding_dim,
drop_prob=0.1, num_heads=4, num_seeds=1, m=[32, 64, 128], load_pretrained=args.load_pretrained, path_of_trained_models=args.path_of_trained_models, verbose=True)
synthesizer.train(training_data, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, tmax=args.tmax, max_num_lps=args.max_num_lps, refinement_expressivity=args.refinement_expressivity, refs_sample_size=args.sample_size, storage_path=args.storage_path)
synthesizer.train(training_data, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, clip_value=1.0, tmax=args.tmax, max_num_lps=args.max_num_lps, refinement_expressivity=args.refinement_expressivity, refs_sample_size=args.sample_size, storage_path=args.storage_path)

if __name__ == '__main__':
set_seed(42)
@@ -72,7 +72,7 @@ def start(args):
parser.add_argument('--epochs', type=int, default=500, help='Number of training epochs')
parser.add_argument('--dicee_model', type=str, default="DeCaL", help='The model to use for DICE embeddings (only for NCES)')
parser.add_argument('--dicee_emb_dim', type=int, default=128, help='Number of embedding dimensions for DICE embeddings (only for NCES)')
parser.add_argument('--dicee_epochs', type=int, default=100, help='Number of training epochs for the NCES (DICE) embeddings (only for NCES)')
parser.add_argument('--dicee_epochs', type=int, default=300, help='Number of training epochs for the NCES (DICE) embeddings (only for NCES)')
parser.add_argument('--dicee_lr', type=float, default=0.01, help='Learning rate for computing DICE embeddings (only for NCES)')
parser.add_argument('--batch_size', type=int, default=256, help='Minibatch size for training')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for training. The optimizer is Adam.')
41 changes: 23 additions & 18 deletions ontolearn/concept_learner.py
Original file line number Diff line number Diff line change
@@ -687,7 +687,7 @@ def pos_neg_to_tensor(self, pos: Union[Set[OWLNamedIndividual]], neg: Union[Set[
assert self.load_pretrained and self.pretrained_predictor_name, \
"No pretrained model found. Please first train length predictors, see the <<train>> method below"

dataset = CLIPDatasetInference([("", pos_str, neg_str)], self.instance_embeddings, False, False)
dataset = CLIPDatasetInference([("", pos_str, neg_str)], self.instance_embeddings, self.num_examples, False, False)
dataloader = DataLoader(dataset, batch_size=1, num_workers=self.num_workers,
collate_fn=self.collate_batch_inference, shuffle=False)
x_pos, x_neg = next(iter(dataloader))
@@ -780,7 +780,7 @@ def fit(self, *args, **kwargs):
def train(self, data: Iterable[List[Tuple]], epochs=300, batch_size=256, learning_rate=1e-3, decay_rate=0.0,
clip_value=5.0, save_model=True, storage_path=None, optimizer='Adam', record_runtime=True,
example_sizes=None, shuffle_examples=False):
train_dataset = CLIPDataset(data, self.instance_embeddings, shuffle_examples=shuffle_examples, example_sizes=example_sizes)
train_dataset = CLIPDataset(data, self.instance_embeddings, num_examples=self.num_examples, shuffle_examples=shuffle_examples, example_sizes=example_sizes)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=self.num_workers,
collate_fn=self.collate_batch, shuffle=True)
if storage_path is None:
@@ -814,15 +814,15 @@ def __init__(self, knowledge_base_path, nces2_or_roces=False,
self.rnn_n_layers = rnn_n_layers
self.sorted_examples = sorted_examples
self._set_prerequisites()
self.has_renamed_inds = False

def _rename_individuals(self, individual_name):
if isinstance(individual_name, str) and '/' in individual_name:
return individual_name.split('/')[-1]
return individual_name

def _set_prerequisites(self):

def _rename_individuals(individual_name):
if isinstance(individual_name, str) and '/' in individual_name:
return individual_name.split('/')[-1]
return individual_name

if self.path_of_embeddings is None or (os.path.isdir(self.path_of_embeddings) and not glob.glob(self.path_of_embeddings+'*_entity_embeddings.csv')) or not self.path_of_embeddings.endswith('.csv'):
if self.path_of_embeddings is None or (os.path.isdir(self.path_of_embeddings) and not glob.glob(self.path_of_embeddings+'*_entity_embeddings.csv')) or not os.path.exists(self.path_of_embeddings) or not self.path_of_embeddings.endswith('.csv'):
if not os.path.exists(self.knowledge_base_path):
raise ValueError(f"{knowledge_base_path} not found")
try:
@@ -844,7 +844,6 @@ def _rename_individuals(individual_name):
if self.auto_train:
print("\n"+"\x1b[0;30;43m"+f"Will also train {self.name} for 5 epochs"+"\x1b[0m"+"\n")
self.instance_embeddings = read_csv(self.path_of_embeddings)
self.instance_embeddings.index = self.instance_embeddings.index.map(_rename_individuals)
self.input_size = self.instance_embeddings.shape[1]
self.model = self.get_synthesizer(self.path_of_trained_models)
print(f"\nUsing embeddings at: {self.path_of_embeddings} with {self.input_size} dimensions.\n")
@@ -854,7 +853,6 @@ def _rename_individuals(individual_name):
self.refresh(self.path_of_trained_models)
else:
self.instance_embeddings = read_csv(self.path_of_embeddings)
self.instance_embeddings.index = self.instance_embeddings.index.map(_rename_individuals)
self.input_size = self.instance_embeddings.shape[1]
self.model = self.get_synthesizer(self.path_of_trained_models)

@@ -1007,7 +1005,7 @@ def fit_one(self, pos: Union[Set[OWLNamedIndividual], Set[str]], neg: Union[Set[

assert self.load_pretrained and self.learner_names, "No pretrained model found. Please first train NCES, see the <<train>> method below"

dataset = NCESDatasetInference([("", Pos_str, Neg_str) for (Pos_str, Neg_str) in zip(Pos, Neg)], self.instance_embeddings,
dataset = NCESDatasetInference([("", Pos_str, Neg_str) for (Pos_str, Neg_str) in zip(Pos, Neg)], self.instance_embeddings, self.num_examples,
self.vocab, self.inv_vocab, shuffle_examples=False, max_length=self.max_length, sorted_examples=self.sorted_examples)

dataloader = DataLoader(dataset, batch_size=self.batch_size,
@@ -1040,6 +1038,9 @@ def fit(self, learning_problem: PosNegLPStandard, **kwargs):
if isinstance(pos, set) or isinstance(pos, frozenset):
pos_list = list(pos)
neg_list = list(neg)
if not "/" in pos_list[0] and not self.has_renamed_inds:
self.instance_embeddings.index = self.instance_embeddings.index.map(self._rename_individuals)
self.has_renamed_inds = True
if self.sorted_examples:
pos_list = sorted(pos_list)
neg_list = sorted(neg_list)
@@ -1101,7 +1102,7 @@ def fit_from_iterable(self, dataset: Union[List[Tuple[str, Set[OWLNamedIndividua
assert self.load_pretrained and self.learner_names, \
"No pretrained model found. Please first train NCES, refer to the <<train>> method"
dataset = [self.convert_to_list_str_from_iterable(datapoint) for datapoint in dataset]
dataset = NCESDatasetInference(dataset, self.instance_embeddings, self.vocab, self.inv_vocab, shuffle_examples, max_length=self.max_length)
dataset = NCESDatasetInference(dataset, self.instance_embeddings, self.num_examples, self.vocab, self.inv_vocab, shuffle_examples, max_length=self.max_length)
dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_batch_inference, shuffle=False)
simpleSolution = SimpleSolution(list(self.vocab), self.atomic_concept_names)
predictions_as_owl_class_expressions = []
@@ -1153,7 +1154,10 @@ def train(self, data: Iterable[List[Tuple]]=None, epochs=50, batch_size=64, max_
if data is None:
data = self.generate_training_data(self.knowledge_base_path, max_num_lps=max_num_lps, refinement_expressivity=refinement_expressivity,
refs_sample_size=refs_sample_size, storage_path=storage_path)

example_ind = data[0][-1]["positive examples"][0]
if not "/" in example_ind and not self.has_renamed_inds:
self.instance_embeddings.index = self.instance_embeddings.index.map(self._rename_individuals)
self.has_renamed_inds = True
trainer = NCESTrainer(self, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, tmax=tmax, eta_min=eta_min,
clip_value=clip_value, num_workers=num_workers, storage_path=storage_path)
trainer.train(data=data, save_model=save_model, optimizer=optimizer, record_runtime=record_runtime)
@@ -1357,9 +1361,10 @@ def fit_one(self, pos: Union[Set[OWLNamedIndividual], Set[str]], neg: Union[Set[
dataloaders = []
for num_ind_points in self.model:
dataset = ROCESDatasetInference([("", pos_str, neg_str)],
triples_data=self.triples_data, k=self.k if hasattr(self, "k") else None,
triples_data=self.triples_data, num_examples=self.num_examples,
k=self.k if hasattr(self, "k") else None,
vocab=self.vocab, inv_vocab=self.inv_vocab,
max_length=self.max_length, num_examples=self.num_examples,
max_length=self.max_length,
sampling_strategy=self.sampling_strategy,
num_pred_per_lp=self.num_predictions)
dataset.load_embeddings(self.model[num_ind_points]["emb_model"])
@@ -1453,9 +1458,9 @@ def fit_from_iterable(self, data: Union[List[Tuple[str, Set[OWLNamedIndividual],
dataloaders = []
for num_ind_points in self.model:
dataset = ROCESDatasetInference(data,
self.triples_data, k=self.k if hasattr(self, "k") else None,
self.triples_data, num_examples=self.num_examples, k=self.k if hasattr(self, "k") else None,
vocab=self.vocab, inv_vocab=self.inv_vocab,
max_length=self.max_length, num_examples=self.num_examples,
max_length=self.max_length,
sampling_strategy=self.sampling_strategy,
num_pred_per_lp=self.num_predictions)
dataset.load_embeddings(self.model[num_ind_points]["emb_model"])
38 changes: 22 additions & 16 deletions ontolearn/data_struct.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
import numpy as np
import random
from rdflib import graph
from .nces_utils import try_get_embs


class PrepareBatchOfPrediction(torch.utils.data.Dataset): # pragma: no cover
@@ -193,11 +194,12 @@ def get_entities(data):

class CLIPDataset(torch.utils.data.Dataset): # pragma: no cover

def __init__(self, data: list, embeddings, shuffle_examples, example_sizes: list=None,
def __init__(self, data, embeddings, num_examples, shuffle_examples, example_sizes=None,
k=5, sorted_examples=True):
super().__init__()
self.data = data
self.embeddings = embeddings
self.num_examples = num_examples
self.shuffle_examples = shuffle_examples
self.example_sizes = example_sizes
self.k = k
@@ -210,6 +212,7 @@ def __getitem__(self, idx):
key, value = self.data[idx]
pos = value['positive examples']
neg = value['negative examples']
pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
length = value['length']
if self.example_sizes is not None:
k_pos, k_neg = random.choice(self.example_sizes)
@@ -241,11 +244,12 @@ def __getitem__(self, idx):

class CLIPDatasetInference(torch.utils.data.Dataset): # pragma: no cover

def __init__(self, data: list, embeddings, shuffle_examples,
def __init__(self, data: list, embeddings, num_examples, shuffle_examples,
sorted_examples=True):
super().__init__()
self.data = data
self.embeddings = embeddings
self.num_examples = num_examples
self.shuffle_examples = shuffle_examples
self.sorted_examples = sorted_examples

@@ -254,6 +258,7 @@ def __len__(self):

def __getitem__(self, idx):
_, pos, neg = self.data[idx]
pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
if self.sorted_examples:
pos, neg = sorted(pos), sorted(neg)
elif self.shuffle_examples:
@@ -313,10 +318,11 @@ def get_labels(self, target):

class NCESDataset(NCESBaseDataset, torch.utils.data.Dataset): # pragma: no cover

def __init__(self, data: list, embeddings, vocab, inv_vocab, shuffle_examples, max_length, example_sizes=None, sorted_examples=True):
def __init__(self, data, embeddings, num_examples, vocab, inv_vocab, shuffle_examples, max_length, example_sizes=None, sorted_examples=True):
super().__init__(vocab, inv_vocab, max_length)
self.data = data
self.embeddings = embeddings
self.num_examples = num_examples
self.shuffle_examples = shuffle_examples
self.example_sizes = example_sizes
self.sorted_examples = sorted_examples
@@ -328,6 +334,7 @@ def __getitem__(self, idx):
key, value = self.data[idx]
pos = value['positive examples']
neg = value['negative examples']
pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
if self.example_sizes is not None:
k_pos, k_neg = random.choice(self.example_sizes)
k_pos = min(k_pos, len(pos))
@@ -338,27 +345,26 @@ def __getitem__(self, idx):
selected_pos = pos
selected_neg = neg

selected_pos = list(filter(lambda x: x in self.embeddings, pos))
selected_neg = list(filter(lambda x: x in self.embeddings, neg))

labels, length = self.get_labels(key)

try:
datapoint_pos = torch.FloatTensor(self.embeddings.loc[selected_pos].values.squeeze())
datapoint_neg = torch.FloatTensor(self.embeddings.loc[selected_neg].values.squeeze())
except:
#print(f'\nSome individuals are not found in embedding matrix: {list(filter(lambda x: x not in self.embeddings, pos+neg))}')
return torch.zeros(len(pos), self.embeddings.shape[1]), torch.zeros(len(neg), self.embeddings.shape[1]), torch.cat([torch.tensor(labels), self.vocab['PAD'] * torch.ones(max(0, self.max_length-length))]).long()
except Exception as e:
print(e)
return None
#torch.zeros(len(pos), self.embeddings.shape[1]), torch.zeros(len(neg), self.embeddings.shape[1]), torch.cat([torch.tensor(labels), self.vocab['PAD'] * torch.ones(max(0, self.max_length-length))]).long()

return datapoint_pos, datapoint_neg, torch.cat([torch.tensor(labels), self.vocab['PAD'] * torch.ones(max(0, self.max_length-length))]).long()


class NCESDatasetInference(NCESBaseDataset, torch.utils.data.Dataset): # pragma: no cover

def __init__(self, data: list, embeddings, vocab, inv_vocab, shuffle_examples, max_length=48, sorted_examples=True):
def __init__(self, data, embeddings, num_examples, vocab, inv_vocab, shuffle_examples, max_length=48, sorted_examples=True):
super().__init__(vocab, inv_vocab, max_length)
self.data = data
self.embeddings = embeddings
self.num_examples = num_examples
self.shuffle_examples = shuffle_examples
self.sorted_examples = sorted_examples

@@ -367,14 +373,12 @@ def __len__(self):

def __getitem__(self, idx):
_, pos, neg = self.data[idx]
pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
if self.sorted_examples:
pos, neg = sorted(pos), sorted(neg)
elif self.shuffle_examples:
random.shuffle(pos)
random.shuffle(neg)

selected_pos = list(filter(lambda x: x in self.embeddings, pos))
selected_neg = list(filter(lambda x: x in self.embeddings, neg))

try:
datapoint_pos = torch.FloatTensor(self.embeddings.loc[selected_pos].values.squeeze())
@@ -388,10 +392,11 @@ def __getitem__(self, idx):

class ROCESDataset(NCESBaseDataset, torch.utils.data.Dataset):

def __init__(self, data, triples_data, k, vocab, inv_vocab, max_length, sampling_strategy="p"):
def __init__(self, data, triples_data, num_examples, k, vocab, inv_vocab, max_length, sampling_strategy="p"):
super(ROCESDataset, self).__init__(vocab, inv_vocab, max_length)
self.data = data
self.triples_data = triples_data
self.num_examples = num_examples
self.k = k
self.sampling_strategy = sampling_strategy

@@ -410,6 +415,7 @@ def __getitem__(self, idx):
key, value = self.data[idx]
pos = value['positive examples']
neg = value['negative examples']
pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
if self.sampling_strategy == 'p':
prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k)))
prob_pos_set = prob_pos_set/prob_pos_set.sum()
@@ -440,7 +446,7 @@ def __getitem__(self, idx):

class ROCESDatasetInference(NCESBaseDataset, torch.utils.data.Dataset):

def __init__(self, data, triples_data, k, vocab, inv_vocab, max_length, num_examples, sampling_strategy='p', num_pred_per_lp=1):
def __init__(self, data, triples_data, num_examples, k, vocab, inv_vocab, max_length, sampling_strategy='p', num_pred_per_lp=1):
super(ROCESDatasetInference, self).__init__(vocab, inv_vocab, max_length)
self.data = data
self.triples_data = triples_data
@@ -461,7 +467,7 @@ def __len__(self):

def __getitem__(self, idx):
_, pos, neg = self.data[idx]

pos, neg = try_get_embs(pos, neg, self.embeddings, self.num_examples)
if self.sampling_strategy == 'p':
prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k)))
prob_pos_set = prob_pos_set/prob_pos_set.sum()
Loading

0 comments on commit 5b1b47c

Please sign in to comment.