Skip to content

Commit

Permalink
Merge pull request #327 from dice-group/DRILL
Browse files Browse the repository at this point in the history
WIP: DRILL training is available
  • Loading branch information
Demirrr authored Dec 6, 2023
2 parents 76b9a8c + 65aacbd commit 551407e
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 145 deletions.
20 changes: 6 additions & 14 deletions examples/concept_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@


def dl_concept_learning(args):
try:
os.chdir("examples")
except FileNotFoundError:
pass

with open(args.lps) as json_file:
settings = json.load(json_file)

kb = KnowledgeBase(path=args.kb)

ocel=OCEL(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
celoe=CELOE(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
evo=EvoLearner(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
drill=Drill(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
drill = Drill(knowledge_base=kb, path_pretrained_kge=args.path_pretrained_kge, quality_func=F1(),
max_runtime=args.max_runtime).train(num_episode=1, num_learning_problems=1)
ocel = OCEL(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
celoe = CELOE(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
evo = EvoLearner(knowledge_base=kb, quality_func=F1(), max_runtime=args.max_runtime)
columns = ["LP",
"OCEL", "F1-OCEL", "RT-OCEL",
"CELOE", "F1-CELOE", "RT-CELOE",
Expand All @@ -45,11 +40,9 @@ def dl_concept_learning(args):
lp = PosNegLPStandard(pos=typed_pos, neg=typed_neg)

start_time = time.time()
# Untrained & max runtime is not fully integrated.
pred_drill = drill.fit(lp).best_hypotheses(n=1)
rt_drill = time.time() - start_time


start_time = time.time()
pred_ocel = ocel.fit(lp).best_hypotheses(n=1)
rt_ocel = time.time() - start_time
Expand Down Expand Up @@ -80,6 +73,5 @@ def dl_concept_learning(args):
parser.add_argument("--max_runtime", type=int, default=10)
parser.add_argument("--lps", type=str, default="synthetic_problems.json")
parser.add_argument("--kb", type=str, default="../KGs/Family/family-benchmark_rich_background.owl")
parser.add_argument("--path_pretrained_kge", type=str, default="../KeciFamilyRun")

parser.add_argument("--path_pretrained_kge", type=str, default=None)
dl_concept_learning(parser.parse_args())
5 changes: 2 additions & 3 deletions ontolearn/data_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ def __init__(self, current_state: torch.FloatTensor, next_state_batch: torch.Flo
n: torch.FloatTensor):
assert len(p) > 0 and len(n) > 0
num_next_states = len(next_state_batch)


current_state = current_state.repeat(num_next_states, 1, 1)
p = p.repeat((num_next_states, 1, 1))
n = n.repeat((num_next_states, 1, 1))
Expand Down Expand Up @@ -63,8 +61,9 @@ def __init__(self, current_state_batch: torch.Tensor, next_state_batch: torch.Te
assert self.S.shape == self.S_Prime.shape == self.Positives.shape == self.Negatives.shape
assert self.S.dtype == self.S_Prime.dtype == self.Positives.dtype == self.Negatives.dtype == torch.float32
self.X = torch.cat([self.S, self.S_Prime, self.Positives, self.Negatives], 1)

num_points, depth, dim = self.X.shape
self.X = self.X.view(num_points, depth, 1, dim)
# self.X = self.X.view(num_points, depth, 1, dim)
# X[0] => corresponds to a data point, X[0] \in R^{4 \times 1 \times dim}
# where X[0][0] => current state representation R^{1 \times dim}
# where X[0][1] => next state representation R^{1 \times dim}
Expand Down
Loading

0 comments on commit 551407e

Please sign in to comment.