Skip to content

Commit

Permalink
Updates tools needed to store full sentences. Updates point cloud
Browse files Browse the repository at this point in the history
dataset
  • Loading branch information
JSybrandt committed Feb 12, 2020
1 parent 4fe6533 commit f3c06a4
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 30 deletions.
47 changes: 32 additions & 15 deletions pymoliere/ml/point_cloud_evaluator/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,64 @@ class PointCloudTensors:
lemmas:torch.FloatTensor

def pointclouds_to_tensors(
samples:List[PointCloudObservation]
samples:List[PointCloudObservation],
)->PointCloudTensors:
return PointCloudTensors(
# Seq Leng X Batch size X emb dim
lemmas=torch.nn.utils.rnn.pad_sequence([
torch.FloatTensor(s.lemma_embeddings) for s in samples
]),
lemmas=torch.nn.utils.rnn.pad_sequence(
[torch.FloatTensor(s.lemma_embeddings) for s in samples]
),
)

def sample_lemma(examples:List[PointCloudObservation]):
return random.choice(random.choice(examples).lemma_embeddings)

def generate_negative_scramble_batch(
positive_examples:List[PointCloudObservation]
def generate_neg_batch(
positive_examples:List[PointCloudObservation],
scramble_prob:float=0.0,
drop_prob:float=0.0,
)->PointCloudTensors:
def get_lemma_emb(pos_ref):
lemma_embeddings=[
(
sample_lemma(positive_examples)
if random.random() < scramble_prob else
emb
)
for emb in pos_ref.lemma_embeddings
if random.random() >= drop_prob
]
if len(lemma_embeddings) == 0:
lemma_embeddings = [sample_lemma(positive_examples)]
return lemma_embeddings

return pointclouds_to_tensors([
PointCloudObservation(
# generate a random set of equal length
lemma_embeddings=[
sample_lemma(positive_examples)
for _ in pos_ref.lemma_embeddings
]
lemma_embeddings=get_lemma_emb(pos_ref)
)
# Duplicate the sizes from each positive example
for pos_ref in positive_examples
])

def collate_point_clouds(
positive_examples:List[PointCloudObservation],
neg_scrambles_per:int,
full_scrambles_per:int,
fractional_scrambles_per:int,
deletes_per:int,
)->List[PointCloudTensors]:
"""
The first one is the positive sample, the rest are negatives
"""
assert neg_scrambles_per > 0
positive_examples = [
p for p in positive_examples if len(p.lemma_embeddings) > 0
]
res = [pointclouds_to_tensors(positive_examples)]
for _ in range(neg_scrambles_per):
res.append(generate_negative_scramble_batch(positive_examples))
for _ in range(full_scrambles_per):
res.append(generate_neg_batch(positive_examples, scramble_prob=1.0))
for _ in range(fractional_scrambles_per):
res.append(generate_neg_batch(positive_examples, scramble_prob=0.1))
for _ in range(deletes_per):
res.append(generate_neg_batch(positive_examples, drop_prob=0.5))
return res


Expand Down
20 changes: 17 additions & 3 deletions pymoliere/ml/point_cloud_evaluator/point_cloud_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,17 @@ def _config_dl(self, dataset):
sampler=torch.utils.data.distributed.DistributedSampler(dataset)
collate = lambda batch: collate_point_clouds(
positive_examples=batch,
neg_scrambles_per=self.hparams.neg_scramble_rate,
full_scrambles_per=self.hparams.neg_easy_rate,
fractional_scrambles_per=self.hparams.neg_hard_rate,
deletes_per=self.hparams.neg_delete_rate,
)
return torch.utils.data.DataLoader(
dataset=self.training_data,
shuffle=shuffle,
sampler=sampler,
batch_size=self.hparams.positives_per_batch,
collate_fn=collate,
num_workers=4,
#num_workers=4,
)

@pl.data_loader
Expand Down Expand Up @@ -257,11 +259,23 @@ def configure_argument_parser(parser:ArgumentParser)->ArgumentParser:
default=512,
)
parser.add_argument(
"--neg-scramble-rate",
"--neg-easy-rate",
type=int,
default=10,
help="A negative scramble draws the cloud sets randomly from batch"
)
parser.add_argument(
"--neg-hard-rate",
type=int,
default=10,
help="A negative scramble draws the cloud sets randomly from batch"
)
parser.add_argument(
"--neg-delete-rate",
type=int,
default=10,
help="A negative delete removes lemmas from the batch"
)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--distributed", action="store_true")
return parser
Expand Down
6 changes: 4 additions & 2 deletions scripts/models/sentence_point_cloud/train_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ if [[ ! -z "$1" ]]; then
--entity-dir $ENTITIES \
--embedding-dir $EMBEDDINGS \
--positives-per-batch 256 \
--neg-scramble-rate 7 \
--neg-easy-rate 3 \
--neg-hard-rate 4 \
--neg-delete-rate 4 \
--transformer-layers 4 \
--transformer-heads 8 \
--train-fraction 0.005
--train-fraction 0.005
else

NODEFILE=$PBS_NODEFILE
Expand Down
2 changes: 1 addition & 1 deletion tools/convert_sentences_for_sqlite/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ LIBS+=-lsqlite3
LIB_PATHS=
LIB_PATHS+=-L/software/sqlite/3.21.0/lib

EX=bow_json_to_sqlite
EX=sent_json_to_sqlite

$(EX): main.cc
$(CC) $(FLAGS) $(LIB_PATHS) -o $@ $(INCLUDES) $< $(LIBS)
Expand Down
Binary file not shown.
27 changes: 18 additions & 9 deletions tools/convert_sentences_for_sqlite/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,27 @@ namespace fs = std::filesystem;
namespace sql = sqlite_orm;


struct BowEntry {
struct SentenceEntry {
std::string id;
std::string json_encoded_bow;
BowEntry(std::string i, std::string j):id(i), json_encoded_bow(j){};
std::string text;
SentenceEntry(
const std::string& i,
const std::string& j,
const std::string k
):
id(i),
json_encoded_bow(j),
text(k){};
};

std::list<BowEntry> parse_json_file(const fs::path& json_path){
std::list<BowEntry> res;
std::list<SentenceEntry> parse_json_file(const fs::path& json_path){
std::list<SentenceEntry> res;
std::fstream json_file(json_path, std::ios::in);
std::string line;
while(getline(json_file, line)){
json id_bow = json::parse(line);
res.emplace_back(id_bow["id"], id_bow["bow"].dump());
res.emplace_back(id_bow["id"], id_bow["bow"].dump(), id_bow["sent_text"]);
}
json_file.close();
return res;
Expand Down Expand Up @@ -65,10 +73,10 @@ int main(int argc, char** argv){

std::cout << "Loading all bow" << std::endl;
int num_finished = 0;
std::list<BowEntry> bag_entries;
std::list<SentenceEntry> bag_entries;
#pragma omp parallel for schedule(dynamic)
for(size_t i = 0; i < all_json_files.size(); ++i){
std::list<BowEntry> local_bags = parse_json_file(all_json_files[i]);
std::list<SentenceEntry> local_bags = parse_json_file(all_json_files[i]);
#pragma omp critical
{
bag_entries.splice(bag_entries.end(), local_bags);
Expand All @@ -82,8 +90,9 @@ int main(int argc, char** argv){
sqlite_path,
sql::make_table(
"sentences",
sql::make_column("id", &BowEntry::id),
sql::make_column("bow", &BowEntry::json_encoded_bow)
sql::make_column("id", &SentenceEntry::id),
sql::make_column("bow", &SentenceEntry::json_encoded_bow),
sql::make_column("text", &SentenceEntry::text)
)
);
storage.sync_schema();
Expand Down
Binary file not shown.

0 comments on commit f3c06a4

Please sign in to comment.