From 09f087c3c3d24f688f286271d58dcecdbae0f787 Mon Sep 17 00:00:00 2001 From: Mathieu Fourment Date: Thu, 28 Oct 2021 19:00:29 +1100 Subject: [PATCH] Fix classifiers in setup.cfg --- benchmarks/benchmark.py | 12 ++++++++---- setup.cfg | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index e8620257..984be715 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -10,6 +10,7 @@ from torchtree.evolution.alignment import Alignment, Sequence from torchtree.evolution.coalescent import ConstantCoalescent from torchtree.evolution.datatype import NucleotideDataType +from torchtree.evolution.io import read_tree, read_tree_and_alignment from torchtree.evolution.site_pattern import compress_alignment from torchtree.evolution.substitution_model import JC69 from torchtree.evolution.taxa import Taxa, Taxon @@ -21,7 +22,6 @@ ReparameterizedTimeTreeModel, heights_from_branch_lengths, ) -from torchtree.io import read_tree, read_tree_and_alignment def benchmark(f): @@ -473,10 +473,13 @@ def fn_grad(ratios_root_height): def ratio_transform(args): replicates = args.replicates tree = read_tree(args.tree, True, True) + taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) - ratios_root_height = Parameter("internal_heights", torch.tensor([0.5] * 67 + [10])) + ratios_root_height = Parameter( + "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10]) + ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) @@ -632,11 +635,12 @@ def fn3_grad_jit(ratios_root_height): def constant_coalescent(args): tree = read_tree(args.tree, True, True) + taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( - "internal_heights", torch.tensor([0.5] * 67 + [20.0]) + "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height @@ -743,7 +747,7 @@ def fn3_grad(tree_model, ratios_root_height, pop_size): return log_p x, counts = torch.unique(tree_model.sampling_times, return_counts=True) - counts = torch.cat((counts, torch.tensor([-1] * 68))) + counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1)))) with torch.no_grad(): total_time, log_p = fn3(args.replicates, tree_model, pop_size) diff --git a/setup.cfg b/setup.cfg index a7a38427..dc14317a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,13 +10,13 @@ long_description = file: README.md license = GPL3 license_file = LICENSE classifiers = - Intended Audience :: Science/Research", + Intended Audience :: Science/Research License :: OSI Approved :: GNU General Public License v3 (GPLv3) Operating System :: OS Independent - Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 Topic :: Scientific/Engineering :: Bio-Informatics [options]