diff --git a/examples/concept_learning_evaluation.py b/examples/concept_learning_evaluation.py index cf5efadb..db368924 100644 --- a/examples/concept_learning_evaluation.py +++ b/examples/concept_learning_evaluation.py @@ -14,20 +14,15 @@ def dl_concept_learning(args): - try: - os.chdir("examples") - except FileNotFoundError: - pass - with open(args.lps) as json_file: settings = json.load(json_file) kb = KnowledgeBase(path=args.kb) - - ocel=OCEL(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) - celoe=CELOE(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) - evo=EvoLearner(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) - drill=Drill(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) + drill = Drill(knowledge_base=kb, path_pretrained_kge=args.path_pretrained_kge, quality_func=F1(), + max_runtime=args.max_runtime).train(num_episode=1, num_learning_problems=1) + ocel = OCEL(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) + celoe = CELOE(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) + evo = EvoLearner(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime) columns = ["LP", "OCEL", "F1-OCEL", "RT-OCEL", "CELOE", "F1-CELOE", "RT-CELOE", @@ -45,11 +40,9 @@ def dl_concept_learning(args): lp = PosNegLPStandard(pos=typed_pos, neg=typed_neg) start_time = time.time() - # Untrained & max runtime is not fully integrated. pred_drill = drill.fit(lp).best_hypotheses(n=1) rt_drill = time.time() - start_time - start_time = time.time() pred_ocel = ocel.fit(lp).best_hypotheses(n=1) rt_ocel = time.time() - start_time @@ -80,6 +73,5 @@ def dl_concept_learning(args): parser.add_argument("--max_runtime", type=int, default=10) parser.add_argument("--lps", type=str, default="synthetic_problems.json") parser.add_argument("--kb", type=str, default="../KGs/Family/family-benchmark_rich_background.owl") - parser.add_argument("--path_pretrained_kge", type=str, default="../KeciFamilyRun") - + parser.add_argument("--path_pretrained_kge", type=str, default=None) dl_concept_learning(parser.parse_args()) diff --git a/ontolearn/data_struct.py b/ontolearn/data_struct.py index 31f82746..f956794d 100644 --- a/ontolearn/data_struct.py +++ b/ontolearn/data_struct.py @@ -12,8 +12,6 @@ def __init__(self, current_state: torch.FloatTensor, next_state_batch: torch.Flo n: torch.FloatTensor): assert len(p) > 0 and len(n) > 0 num_next_states = len(next_state_batch) - - current_state = current_state.repeat(num_next_states, 1, 1) p = p.repeat((num_next_states, 1, 1)) n = n.repeat((num_next_states, 1, 1)) @@ -63,8 +61,9 @@ def __init__(self, current_state_batch: torch.Tensor, next_state_batch: torch.Te assert self.S.shape == self.S_Prime.shape == self.Positives.shape == self.Negatives.shape assert self.S.dtype == self.S_Prime.dtype == self.Positives.dtype == self.Negatives.dtype == torch.float32 self.X = torch.cat([self.S, self.S_Prime, self.Positives, self.Negatives], 1) + num_points, depth, dim = self.X.shape - self.X = self.X.view(num_points, depth, 1, dim) + # self.X = self.X.view(num_points, depth, 1, dim) # X[0] => corresponds to a data point, X[0] \in R^{4 \times 1 \times dim} # where X[0][0] => current state representation R^{1 \times dim} # where X[0][1] => next state representation R^{1 \times dim} diff --git a/ontolearn/learners/drill.py b/ontolearn/learners/drill.py index 9087d516..21f10ff3 100644 --- a/ontolearn/learners/drill.py +++ b/ontolearn/learners/drill.py @@ -9,11 +9,17 @@ from ontolearn.data_struct import Experience from ontolearn.search import DRILLSearchTreePriorityQueue from ontolearn.utils import create_experiment_folder -from collections import Counter +from collections import Counter, deque from itertools import chain import time import dicee import os +from owlapy.render import DLSyntaxObjectRenderer +from ontolearn.metrics import F1 +import random +from ontolearn.heuristics import Reward +import torch +from ontolearn.data_struct import PrepareBatchOfTraining, PrepareBatchOfPrediction class Drill(RefinementBasedConceptLearner): @@ -34,21 +40,24 @@ def __init__(self, knowledge_base, use_card_restrictions=True, card_limit=10, quality_func: AbstractScorer = None, - reward_func=None, - batch_size=None, num_workers=None, pretrained_model_name=None, - iter_bound=None, max_num_of_concepts_tested=None, verbose=None, terminate_on_goal=None, - max_len_replay_memory=None, epsilon_decay=None, epsilon_min=None, num_epochs_per_replay=None, - num_episodes_per_replay=None, learning_rate=None, max_runtime=None, num_of_sequential_actions=None, - num_episode=None): - - print("***DRILL has not yet been fully integrated***") + reward_func: object = None, + batch_size=None, num_workers: int = 1, pretrained_model_name=None, + iter_bound=None, max_num_of_concepts_tested=None, verbose: int = 0, terminate_on_goal=None, + max_len_replay_memory=256, + epsilon_decay: float = 0.01, epsilon_min: float = 0.0, + num_epochs_per_replay: int = 100, + num_episodes_per_replay: int = 2, learning_rate: float = 0.001, + max_runtime=None, + num_of_sequential_actions=3, + num_episode=10): + self.name = "DRILL" if path_pretrained_kge is not None and os.path.isdir(path_pretrained_kge): self.pre_trained_kge = dicee.KGE(path=path_pretrained_kge) self.embedding_dim = self.pre_trained_kge.configs["embedding_dim"] else: self.pre_trained_kge = None - self.embedding_dim = 32 + self.embedding_dim = 12 if path_pretrained_drill is not None and os.path.isdir(path_pretrained_drill): raise NotImplementedError() @@ -61,33 +70,40 @@ def __init__(self, knowledge_base, use_card_restrictions=use_card_restrictions, card_limit=card_limit, use_inverse=use_inverse) - self.reward_func = reward_func + else: + refinement_operator = refinement_operator + if reward_func is None: + self.reward_func = Reward() + else: + self.reward_func = reward_func + self.representation_mode = "averaging" - self.heuristic_func = None + self.sample_size = 1 + self.heuristic_func = DrillHeuristic(mode=self.representation_mode, + model_args={'input_shape': (4 * self.sample_size, self.embedding_dim), + 'first_out_channels': 32, + 'second_out_channels': 16, 'third_out_channels': 8, + 'kernel_size': 3}) + self.num_workers = num_workers - self.epsilon = 1 - self.learning_rate = .001 - self.num_episode = 1 - self.num_of_sequential_actions = 3 - self.num_epochs_per_replay = 1 - self.max_len_replay_memory = 256 - self.epsilon_decay = 0.01 - self.epsilon_min = 0 - self.batch_size = 1024 - self.verbose = 0 - self.num_episodes_per_replay = 2 + self.learning_rate = learning_rate + self.num_episode = num_episode + self.num_of_sequential_actions = num_of_sequential_actions + self.num_epochs_per_replay = num_epochs_per_replay + self.max_len_replay_memory = max_len_replay_memory + self.epsilon_decay = epsilon_decay + self.epsilon_min = epsilon_min + self.batch_size = batch_size + self.verbose = verbose + self.num_episodes_per_replay = num_episodes_per_replay self.seen_examples = dict() self.emb_pos, self.emb_neg = None, None self.start_time = None self.goal_found = False self.experiences = Experience(maxlen=self.max_len_replay_memory) - self.sample_size = 1 - self.heuristic_func = DrillHeuristic(mode=self.representation_mode, - model_args={'input_shape': (4 * self.sample_size, self.embedding_dim), - 'first_out_channels': 32, - 'second_out_channels': 16, 'third_out_channels': 8, - 'kernel_size': 3}) + self.epsilon = 1 + if self.learning_rate: self.optimizer = torch.optim.Adam(self.heuristic_func.net.parameters(), lr=self.learning_rate) @@ -105,11 +121,12 @@ def __init__(self, knowledge_base, iter_bound=iter_bound, max_num_of_concepts_tested=max_num_of_concepts_tested, max_runtime=max_runtime) - print('Number of parameters: ', sum([p.numel() for p in self.heuristic_func.net.parameters()])) - self.search_tree = DRILLSearchTreePriorityQueue() - self._learning_problem = None self.storage_path, _ = create_experiment_folder() + self._learning_problem = None + self.renderer = DLSyntaxObjectRenderer() + + self.operator: RefinementBasedConceptLearner def best_hypotheses(self, n=1): assert self.search_tree is not None @@ -220,8 +237,7 @@ def fit(self, lp: PosNegLPStandard, max_runtime=None): try: assert len(next_possible_states) > 0 except AssertionError: - if self.verbose > 1: - logger.info(f'DEAD END at {most_promising}') + print(f'DEAD END at {most_promising}') continue if len(next_possible_states) == 0: # We do not need to compute Q value based on embeddings of "zeros". @@ -239,10 +255,10 @@ def fit(self, lp: PosNegLPStandard, max_runtime=None): return self.terminate() def show_search_tree(self, heading_step: str, top_n: int = 10) -> None: - ValueError('show_search_tree') + assert ValueError('show_search_tree') def terminate_training(self): - ValueError('terminate_training') + return self def fit_from_iterable(self, dataset: List[Tuple[object, Set[OWLNamedIndividual], Set[OWLNamedIndividual]]], @@ -257,9 +273,8 @@ def fit_from_iterable(self, results = [] for (target_ce, p, n) in dataset: - if self.verbose > 0: - logger.info(f'TARGET OWL CLASS EXPRESSION:\n{target_ce}') - logger.info(f'|Sampled Positive|:{len(p)}\t|Sampled Negative|:{len(n)}') + print(f'TARGET OWL CLASS EXPRESSION:\n{target_ce}') + print(f'|Sampled Positive|:{len(p)}\t|Sampled Negative|:{len(n)}') start_time = time.time() self.fit(pos=p, neg=n, max_runtime=max_runtime) rn = time.time() - start_time @@ -287,23 +302,25 @@ def init_training(self, pos_uri: Set[OWLNamedIndividual], neg_uri: Set[OWLNamedI """ (2) Update REWARD FUNC FOR each learning problem """ self.reward_func.lp = self._learning_problem """ (3) Obtain embeddings of positive and negative examples """ - self.emb_pos = torch.tensor( - self.instance_embeddings.loc[[owl_indv.get_iri().as_str() for owl_indv in pos_uri]].values, - dtype=torch.float32) - self.emb_neg = torch.tensor( - self.instance_embeddings.loc[[owl_indv.get_iri().as_str() for owl_indv in neg_uri]].values, - dtype=torch.float32) - """ (3) Take the mean of positive and negative examples and reshape it into (1,1,embedding_dim) for mini - batching """ - self.emb_pos = torch.mean(self.emb_pos, dim=0) - self.emb_pos = self.emb_pos.view(1, 1, self.emb_pos.shape[0]) - self.emb_neg = torch.mean(self.emb_neg, dim=0) - self.emb_neg = self.emb_neg.view(1, 1, self.emb_neg.shape[0]) - # Sanity checking - if torch.isnan(self.emb_pos).any() or torch.isinf(self.emb_pos).any(): - raise ValueError('invalid value detected in E+,\n{0}'.format(self.emb_pos)) - if torch.isnan(self.emb_neg).any() or torch.isinf(self.emb_neg).any(): - raise ValueError('invalid value detected in E-,\n{0}'.format(self.emb_neg)) + if self.pre_trained_kge is not None: + self.emb_pos = self.pre_trained_kge.get_entity_embeddings( + [owl_individual.get_iri().as_str() for owl_individual in pos_uri]) + self.emb_neg = self.pre_trained_kge.get_entity_embeddings( + [owl_individual.get_iri().as_str() for owl_individual in neg_uri]) + """ (3) Take the mean of positive and negative examples and reshape it into (1,1,embedding_dim) for mini + batching """ + self.emb_pos = torch.mean(self.emb_pos, dim=0) + self.emb_pos = self.emb_pos.view(1, 1, self.emb_pos.shape[0]) + self.emb_neg = torch.mean(self.emb_neg, dim=0) + self.emb_neg = self.emb_neg.view(1, 1, self.emb_neg.shape[0]) + # Sanity checking + if torch.isnan(self.emb_pos).any() or torch.isinf(self.emb_pos).any(): + raise ValueError('invalid value detected in E+,\n{0}'.format(self.emb_pos)) + if torch.isnan(self.emb_neg).any() or torch.isinf(self.emb_neg).any(): + raise ValueError('invalid value detected in E-,\n{0}'.format(self.emb_neg)) + else: + self.emb_pos = None + self.emb_neg = None # Default exploration exploitation tradeoff. """ (3) Default exploration exploitation tradeoff and number of expression tested """ @@ -336,6 +353,7 @@ def apply_refinement(self, rl_state: RL_State) -> Generator: 3. Return Generator. """ assert isinstance(rl_state, RL_State) + self.operator: LengthBasedRefinement # 1. for i in self.operator.refine(rl_state.concept): # O(N) yield self.create_rl_state(i, parent_node=rl_state) @@ -365,7 +383,7 @@ def learn_from_illustration(self, sequence_of_goal_path: List[RL_State]): self.form_experiences(sequence_of_states, rewards) self.learn_from_replay_memory() - def rl_learning_loop(self, pos_uri: Set[OWLNamedIndividual], neg_uri: Set[OWLNamedIndividual], + def rl_learning_loop(self, num_episode: int, pos_uri: Set[OWLNamedIndividual], neg_uri: Set[OWLNamedIndividual], goal_path: List[RL_State] = None) -> List[float]: """ Standard RL training loop. @@ -376,41 +394,41 @@ def rl_learning_loop(self, pos_uri: Set[OWLNamedIndividual], neg_uri: Set[OWLNam 2. Training Loop. """ """ (1) Initialize RL environment for training """ + assert isinstance(pos_uri, Set) and isinstance(neg_uri, Set) self.init_training(pos_uri=pos_uri, neg_uri=neg_uri) root_rl_state = self.create_rl_state(self.start_class, is_root=True) self.compute_quality_of_class_expression(root_rl_state) sum_of_rewards_per_actions = [] - log_every_n_episodes = int(self.num_episode * .1) + 1 """ (2) Learn from an illustration if possible """ if goal_path: self.learn_from_illustration(goal_path) """ (3) Reinforcement Learning offline training loop """ - for th in range(self.num_episode): + for th in range(num_episode): """ (3.1) Sequence of decisions """ sequence_of_states, rewards = self.sequence_of_actions(root_rl_state) - if self.verbose >= 10: - logger.info('#' * 10, end='') - logger.info(f'{th}\t.th Sequence of Actions', end='') - logger.info('#' * 10) - for step, (current_state, next_state) in enumerate(sequence_of_states): - logger.info(f'{step}. Transition \n{current_state}\n----->\n{next_state}') - logger.info(f'Reward:{rewards[step]}') - - if th % log_every_n_episodes == 0: - if self.verbose >= 1: - logger.info('{0}.th iter. SumOfRewards: {1:.2f}\t' - 'Epsilon:{2:.2f}\t' - '|ReplayMem.|:{3}'.format(th, sum(rewards), - self.epsilon, - len(self.experiences))) + """ + + print('#' * 10, end='') + print(f'\t{th}.th Sequence of Actions\t', end='') + print('#' * 10) + for step, (current_state, next_state) in enumerate(sequence_of_states): + print(f'{step}. Transition \n{current_state}\n----->\n{next_state}') + print(f'Reward:{rewards[step]}') + + print('{0}.th iter. SumOfRewards: {1:.2f}\t' + 'Epsilon:{2:.2f}\t' + '|ReplayMem.|:{3}'.format(th, sum(rewards), + self.epsilon, + len(self.experiences))) + """ """(3.2) Form experiences""" self.form_experiences(sequence_of_states, rewards) sum_of_rewards_per_actions.append(sum(rewards)) """(3.2) Learn from experiences""" - if th % self.num_episodes_per_replay == 0: - self.learn_from_replay_memory() + # if th % self.num_episodes_per_replay == 0: + self.learn_from_replay_memory() """(3.4) Exploration Exploitation""" if self.epsilon < 0: break @@ -464,8 +482,7 @@ def form_experiences(self, state_pairs: List, rewards: List) -> None: y - Argmax Q value. """ - if self.verbose > 1: - print('Form Experiences for the training') + print('Form Experiences for the training') for th, consecutive_states in enumerate(state_pairs): e, e_next = consecutive_states @@ -476,13 +493,15 @@ def learn_from_replay_memory(self) -> None: """ Learning by replaying memory. """ - if self.verbose > 1: - print('Learn from Experience') - - current_state_batch, next_state_batch, q_values = self.experiences.retrieve() + print('learn_from_replay_memory', end="\t|\t") + current_state_batch: List[torch.FloatTensor] + next_state_batch: List[torch.FloatTensor] + current_state_batch, next_state_batch, y = self.experiences.retrieve() + # N, 1, dim current_state_batch = torch.cat(current_state_batch, dim=0) + # N, 1, dim next_state_batch = torch.cat(next_state_batch, dim=0) - q_values = torch.Tensor(q_values) + y = torch.Tensor(y) try: assert current_state_batch.shape[1] == next_state_batch.shape[1] == self.emb_pos.shape[1] == \ @@ -499,6 +518,14 @@ def learn_from_replay_memory(self) -> None: assert current_state_batch.shape[2] == next_state_batch.shape[2] == self.emb_pos.shape[2] == self.emb_neg.shape[ 2] + + num_next_states = len(current_state_batch) + + # batch, 4, dim + X = torch.cat([current_state_batch, next_state_batch, self.emb_pos.repeat((num_next_states, 1, 1)), + self.emb_neg.repeat((num_next_states, 1, 1))], 1) + """ + # We can skip this part perhaps dataset = PrepareBatchOfTraining(current_state_batch=current_state_batch, next_state_batch=next_state_batch, p=self.emb_pos, n=self.emb_neg, q=q_values) @@ -506,27 +533,27 @@ def learn_from_replay_memory(self) -> None: data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) - if self.verbose > 1: - print(f'Number of experiences:{num_experience}') - print('DQL agent is learning via experience replay') + """ + + print(f'Experiences:{X.shape}', end="\t|\t") self.heuristic_func.net.train() + total_loss = 0 for m in range(self.num_epochs_per_replay): - total_loss = 0 - for X, y in data_loader: - self.optimizer.zero_grad() # zero the gradient buffers - # forward - predicted_q = self.heuristic_func.net.forward(X) - # loss - loss = self.heuristic_func.net.loss(predicted_q, y) - total_loss += loss.item() - # compute the derivative of the loss w.r.t. the parameters using backpropagation - loss.backward() - # clip gradients if gradients are killed. =>torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) - self.optimizer.step() - if self.verbose > 1: - print(f'{m}.th Epoch average loss during training:{total_loss / num_experience}') - - self.heuristic_func.net.train().eval() + self.optimizer.zero_grad() # zero the gradient buffers + # forward + # n by 4, dim + predicted_q = self.heuristic_func.net.forward(X) + # loss + loss = self.heuristic_func.net.loss(predicted_q, y) + total_loss += loss.item() + # compute the derivative of the loss w.r.t. the parameters using backpropagation + loss.backward() + # clip gradients if gradients are killed. =>torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) + self.optimizer.step() + + print(f'Average loss during training: {total_loss / self.num_epochs_per_replay:0.5f}') + + self.heuristic_func.net.eval() def update_search(self, concepts, predicted_Q_values=None): """ @@ -642,7 +669,7 @@ def exploration_exploitation_tradeoff(self, current_state: AbstractNode, (1) Exploration. (2) Exploitation. """ - if np.random.random() < self.epsilon: + if random.random() < self.epsilon: next_state = random.choice(next_states) self.assign_embeddings(next_state) else: @@ -710,7 +737,52 @@ def retrieve_concept_chain(rl_state: RL_State) -> List[RL_State]: hierarchy.appendleft(rl_state) return list(hierarchy) - def train(self, dataset: Iterable[Tuple[str, Set, Set]], relearn_ratio: int = 2): + def generate_learning_problems(self, dataset: Optional[Iterable[Tuple[str, Set, Set]]] = None, + num_learning_problems: int = 5) -> Iterable[ + Tuple[str, Set, Set]]: + """ Generate learning problems if none is provided. + + Time complexity: O(n^2) n = named concepts + """ + + if dataset is None: + learning_problems = [] + counter = 0 + size_of_examples = 3 + print("Generating learning problems...") + for i in self.kb.get_concepts(): + individuals_i = set(self.kb.individuals(i)) + + if len(individuals_i) > size_of_examples: + str_dl_concept_i = self.renderer.render(i) + for j in self.kb.get_concepts(): + if i == j: + continue + individuals_j = set(self.kb.individuals(j)) + if len(individuals_j) < size_of_examples: + continue + + lp = (str_dl_concept_i, + set(random.sample(individuals_i, size_of_examples)), + set(random.sample(individuals_j, size_of_examples))) + yield lp + counter += 1 + + if counter == num_learning_problems: + break + + if counter == num_learning_problems: + break + else: + """Empy concept""" + + # assert isinstance(learning_problems, Iterable) + # return learning_problems + else: + return dataset + + def train(self, dataset: Optional[Iterable[Tuple[str, Set, Set]]] = None, num_episode: int = 10, + relearn_ratio: int = 2, num_learning_problems=3): """ Train RL agent on learning problems with relearn_ratio. @@ -730,8 +802,10 @@ def train(self, dataset: Iterable[Tuple[str, Set, Set]], relearn_ratio: int = 2) Returns: self. """ - if self.verbose > 0: - logger.info(f'Training starts.\nNumber of learning problem:{len(dataset)},\t Relearn ratio:{relearn_ratio}') + if self.pre_trained_kge is None: + return self.terminate_training() + + dataset = self.generate_learning_problems(dataset, num_learning_problems) counter = 1 renderer = DLSyntaxObjectRenderer() @@ -739,23 +813,18 @@ def train(self, dataset: Iterable[Tuple[str, Set, Set]], relearn_ratio: int = 2) for _ in range(relearn_ratio): for (target_owl_ce, positives, negatives) in dataset: - if self.verbose > 0: - logger.info( - 'Goal Concept:{0}\tE^+:[{1}] \t E^-:[{2}]'.format(target_owl_ce, - len(positives), len(negatives))) - logger.info(f'RL training on {counter}.th learning problem starts') + print('Goal Concept:{0}\tE^+:[{1}] \t E^-:[{2}]'.format(target_owl_ce, + len(positives), len(negatives))) + print(f'RL training on {counter}.th learning problem {target_owl_ce} starts') - goal_path = list(reversed(self.retrieve_concept_chain(target_owl_ce))) - # goal_path: [⊤, Daughter, Daughter ⊓ Mother] - sum_of_rewards_per_actions = self.rl_learning_loop(pos_uri=positives, neg_uri=negatives, - goal_path=goal_path) + sum_of_rewards_per_actions = self.rl_learning_loop(num_episode=num_episode, pos_uri=positives, + neg_uri=negatives) - if self.verbose > 2: - logger.info(f'Sum of Rewards in first 3 trajectory:{sum_of_rewards_per_actions[:3]}') - logger.info(f'Sum of Rewards in last 3 trajectory:{sum_of_rewards_per_actions[:3]}') + print(f'Sum of Rewards in first 3 trajectory:{sum_of_rewards_per_actions[:3]}') + print(f'Sum of Rewards in last 3 trajectory:{sum_of_rewards_per_actions[:3]}') self.seen_examples.setdefault(counter, dict()).update( - {'Concept': renderer.render(target_owl_ce.concept), + {'Concept': target_owl_ce, 'Positives': [i.get_iri().as_str() for i in positives], 'Negatives': [i.get_iri().as_str() for i in negatives]}) @@ -840,10 +909,10 @@ def forward(self, X: torch.FloatTensor): X n by 4 by d float tensor """ # N x 32 x D - X = F.relu(self.conv1(X)) + X = torch.nn.functional.relu(self.conv1(X)) X = X.flatten(start_dim=1) # N x (32D/2) - X = F.relu(self.fc1(X)) + X = torch.nn.functional.relu(self.fc1(X)) # N x 1 scores = self.fc2(X).flatten() return scores diff --git a/ontolearn/refinement_operators.py b/ontolearn/refinement_operators.py index dd23e33a..0417738b 100644 --- a/ontolearn/refinement_operators.py +++ b/ontolearn/refinement_operators.py @@ -44,8 +44,7 @@ def __init__(self, knowledge_base: KnowledgeBase, use_inverse=False, assert num_of_named_classes == len(list(i for i in self.kb.ontology().classes_in_signature())) self.max_len_refinement_top = 5 - self.top_refinements = {ref for ref in self.refine_top()} - print("Top refinements:", len(self.top_refinements)) + self.top_refinements = None # {ref for ref in self.refine_top()} def from_iterables(self, cls, a_operands, b_operands): assert (isinstance(a_operands, Generator) is False) and (isinstance(b_operands, Generator) is False) @@ -237,6 +236,9 @@ def refine_object_intersection_of(self, class_expression: OWLClassExpression) -> def refine(self, class_expression) -> Iterable[OWLClassExpression]: assert isinstance(class_expression, OWLClassExpression) + if self.top_refinements is None: + self.top_refinements = {ref for ref in self.refine_top()} + if class_expression.is_owl_thing(): yield from self.top_refinements elif class_expression.is_owl_nothing(): @@ -263,11 +265,6 @@ def refine(self, class_expression) -> Iterable[OWLClassExpression]: else: raise ValueError(f"{type(class_expression)} objects are not yet supported") - """ - - - """ - class ModifiedCELOERefinement(BaseRefinement[OENode]): """