Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ae1492d
development notebook
DimaOrekhov Oct 28, 2021
41f1067
change package name
DimaOrekhov Nov 4, 2021
cdf1a49
WIP: pipeline
DimaOrekhov Nov 4, 2021
4ff97f2
WIP: pipeline
DimaOrekhov Nov 8, 2021
9fda0ca
example config file
DimaOrekhov Nov 8, 2021
3f273e3
Debug dataset wrapper
DimaOrekhov Nov 11, 2021
bd9fcb3
add wandb to dependencies
DimaOrekhov Nov 11, 2021
07fb84d
evaluation
DimaOrekhov Nov 11, 2021
154fd2a
main
DimaOrekhov Nov 11, 2021
307c78b
debug config
DimaOrekhov Nov 11, 2021
8c899e7
filtered metrics
DimaOrekhov Nov 14, 2021
5cd5544
to sq
DimaOrekhov Nov 14, 2021
c85c9cc
Add pytest-mock to dependencies
DimaOrekhov Nov 15, 2021
20157d6
to sq: evaluation.py
DimaOrekhov Nov 15, 2021
e2adec8
TestMetrics
DimaOrekhov Nov 15, 2021
84d9bbb
different ranks
DimaOrekhov Nov 15, 2021
af8f80f
to sq: evalutation
DimaOrekhov Nov 16, 2021
09595c8
amri
DimaOrekhov Nov 16, 2021
a184a12
update tests
DimaOrekhov Nov 16, 2021
13096de
sq
DimaOrekhov Nov 16, 2021
e46800d
to sq: fix amri computation
DimaOrekhov Nov 21, 2021
2109f3a
simple amri test
DimaOrekhov Nov 21, 2021
6657313
update catalyst
DimaOrekhov Nov 22, 2021
f43ce90
eval WIP
DimaOrekhov Nov 22, 2021
fe15889
log WIP
DimaOrekhov Nov 22, 2021
dc16e53
eval callback
DimaOrekhov Nov 22, 2021
c9397b4
to sq: main
DimaOrekhov Nov 22, 2021
a3893c7
to sq: main
DimaOrekhov Nov 29, 2021
32e6908
to sq: main
DimaOrekhov Nov 29, 2021
8bb795c
fix: loss computation bug
DimaOrekhov Dec 7, 2021
dffa241
Extend dataset and model to case of heterogenous graphs (edges)
DimaOrekhov Dec 7, 2021
e8d449d
heterogenous evaluation
DimaOrekhov Dec 8, 2021
438538c
Fix metric names
DimaOrekhov Dec 8, 2021
8c7183b
Fix OOM
DimaOrekhov Dec 8, 2021
bbc6561
to sq: device fix
DimaOrekhov Dec 8, 2021
960dc86
to sq: device fix
DimaOrekhov Dec 8, 2021
fbb0948
Revert "to sq: device fix"
DimaOrekhov Dec 8, 2021
6ddc8dd
to sq: device fix
DimaOrekhov Dec 8, 2021
e73248d
Add option to set lr
DimaOrekhov Dec 8, 2021
02d6cb9
change init
DimaOrekhov Dec 8, 2021
925852e
Add loading wn18rr
karl-crl Jan 10, 2022
49fec67
Fix tests
DimaOrekhov Jan 16, 2022
cb796db
Fix filtered metrics bug
DimaOrekhov Jan 16, 2022
f0a149b
Refactor: more readable code in dataset initialization
DimaOrekhov Jan 16, 2022
5a76a26
Add todo to discuss later
DimaOrekhov Jan 16, 2022
7064e96
Add wn18rr dataset
karl-crl Jan 25, 2022
fd8489e
Add omegaconf to dependencies
DimaOrekhov Jan 28, 2022
84ae2d3
Fix bugs with data
karl-crl Jan 31, 2022
14fb66b
Add minor fixes to dataset
karl-crl Feb 1, 2022
c83c473
Nested configs + switch to PN loss
DimaOrekhov Jan 28, 2022
3dfea1b
Eval on valid at the end of run
DimaOrekhov Jan 28, 2022
dddd6e5
Integrate sweeps WIP
DimaOrekhov Feb 1, 2022
f1b83bd
Prepare for merge
karl-crl Feb 2, 2022
1b52aa4
Prepare for merge
karl-crl Feb 2, 2022
810d94b
Delete useless file
karl-crl Feb 2, 2022
4c58800
Merge pull request #4 from jbr-ai-labs/pipeline-add-data
karl-crl Feb 2, 2022
8e88fea
Updated poetry
karl-crl Feb 2, 2022
af67e57
Add ComplEx
karl-crl Feb 3, 2022
f5c82e6
Merge pull request #6 from jbr-ai-labs/pipeline-add-complex
karl-crl Feb 3, 2022
9b04266
Add complex model to config
karl-crl Feb 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions notebooks/dev.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bb381fe8-0977-4f4e-82f1-800f81095cd0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using backend: pytorch\n"
]
}
],
"source": [
"import typing as ty\n",
"import torch\n",
"import dgl"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3ce89739-301e-40c0-8d3a-56cf8c7c4fd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# entities: 14541\n",
"# relations: 237\n",
"# training edges: 272115\n",
"# validation edges: 17535\n",
"# testing edges: 20466\n",
"Done loading data from cached files.\n"
]
}
],
"source": [
"fb15 = dgl.data.FB15k237Dataset()\n",
"graph = fb15[0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6ea9bf48-27fc-44c4-a432-14ea3992e349",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes=14541, num_edges=620232,\n",
" ndata_schemes={'ntype': Scheme(shape=(), dtype=torch.int64)}\n",
" edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'test_edge_mask': Scheme(shape=(), dtype=torch.bool), 'valid_edge_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_edge_mask': Scheme(shape=(), dtype=torch.bool)})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c2f95bc2-f5a9-4546-8224-2d4b9fc040ba",
"metadata": {},
"outputs": [],
"source": [
"def get_split(graph: dgl.DGLGraph, split_key: str) -> ty.Tuple[dgl.DGLGraph, torch.Tensor]:\n",
" split_mask = graph.edata[f\"{split_key}_edge_mask\"]\n",
" split_edges_index = torch.nonzero(split_mask, as_tuple=False).squeeze()\n",
"\n",
" split_graph = graph.edge_subgraph(split_edges_index, preserve_nodes=True)\n",
" split_graph.edata[\"etype\"] = graph.edata[\"etype\"][split_edges_index]\n",
"\n",
" return split_graph, split_edges_index\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0301b4cb-893b-4da1-9139-801be540eee4",
"metadata": {},
"outputs": [],
"source": [
"train_g, train_edges = get_split(graph, \"train\")\n",
"val_g, val_edges = get_split(graph, \"valid\")\n",
"test_g, test_edges = get_split(graph, \"test\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e71631b6-d27f-4267-a6f7-7136fae58e1c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(544230)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.edata[\"train_mask\"].sum()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ab75f30d-872f-4ff5-bbd8-c8d0fc743f20",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(272115)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.edata[\"train_edge_mask\"].sum()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "20868465-5182-47bb-b554-bcd5219df7ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14541"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.number_of_src_nodes()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "08e286d3-2e20-4894-be33-a51a9532ae75",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14541"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.number_of_dst_nodes()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "37e7aa6c-632d-4d8d-b396-53c9390999d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14541"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.number_of_nodes()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "81c8024f-25fa-41a5-85e6-0a0892a6e392",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class DirectedDglGraphDataset(Dataset):\n",
"\n",
" def __init__(self, graph: dgl.DGLGraph):\n",
" self.graph = graph\n",
" self.number_of_nodes = graph.number_of_nodes()\n",
" self.adjacency_mat = graph.adj()\n",
"\n",
" def __getitem__(self, i):\n",
" denominator = self.number_of_nodes - 1\n",
" src_node_index = i // denominator\n",
" dst_node_index = i % denominator\n",
"\n",
" if src_node_index <= dst_node_index:\n",
" dst_node_index += 1\n",
"\n",
" return {\n",
" \"src_node_index\": src_node_index,\n",
" \"dst_node_index\": dst_node_index,\n",
" \"relation\": self.adjacency_mat[src_node_index, dst_node_index]\n",
" }\n",
"\n",
" def __len__(self):\n",
" return self.number_of_nodes * (self.number_of_nodes - 1)\n",
"\n",
"\n",
"# Very model specific dataset since it fixates what is a head and what is a tail entites\n",
"class UndirectedDglGraphDataset(Dataset):\n",
"\n",
" def __init__(self, graph: dgl.DGLGraph):\n",
" self.graph = graph\n",
" self.number_of_nodes = graph.number_of_nodes()\n",
" self.adjacency_mat = graph.adj()\n",
" self.sample = [\n",
" (i, j)\n",
" for i, j in itertools.product(range(self.number_of_nodes), range(self.number_of_nodes))\n",
" if i < j # Take only upper triangle indices\n",
" ]\n",
"\n",
" def __getitem__(self, i):\n",
" src_node_index, dst_node_index = self.sample[i]\n",
"\n",
" return {\n",
" \"src_node_index\": src_node_index,\n",
" \"dst_node_index\": dst_node_index,\n",
" \"relation\": self.adjacency_mat[src_node_index, dst_node_index]\n",
" }\n",
"\n",
" def __len__(self):\n",
" return (self.number_of_nodes * (self.number_of_nodes - 1)) // 2\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7893bed-6853-451f-826e-baf373b98331",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading