Skip to content

Commit

Permalink
Merge pull request #182 from medema-group/hotfix/query
Browse files Browse the repository at this point in the history
Hotfix/query
  • Loading branch information
nlouwen authored Oct 1, 2024
2 parents 6a07618 + 3ab5de5 commit 054e396
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 119 deletions.
17 changes: 1 addition & 16 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,6 @@ def common_cluster_query(fn):
"based on the generic 'mix' weights."
),
),
click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others. "
"Will also use BiG-SCAPE v1 legacy_weights for distance calculations. "
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use "
"at your own risk, as BGC classes may have changed. All antiSMASH "
"classes that this legacy mode does not recognize will be grouped in "
"'others'."
),
),
click.option(
"--alignment_mode",
type=click.Choice(["global", "glocal", "local", "auto"]),
Expand Down Expand Up @@ -339,7 +324,7 @@ def common_cluster_query(fn):
"-db",
"--db_path",
type=click.Path(path_type=Path, dir_okay=False),
help="Path to sqlite db output file. (default: output_dir/data_sqlite.db).",
help="Path to sqlite db output file. (default: output_dir/output_dir.db).",
),
# TODO: implement cand_cluster here and LCS-ext
click.option(
Expand Down
2 changes: 1 addition & 1 deletion big_scape/cli/cli_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def validate_output_paths(ctx) -> None:
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())

if "db_path" in ctx.obj and ctx.obj["db_path"] is None:
db_path = ctx.obj["output_dir"] / Path("data_sqlite.db")
db_path = ctx.obj["output_dir"] / Path(f"{ctx.obj['output_dir'].name}.db")
ctx.obj["db_path"] = db_path

if "log_path" in ctx.obj and ctx.obj["log_path"] is None:
Expand Down
15 changes: 15 additions & 0 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@
)
# binning parameters
@click.option("--no_mix", is_flag=True, help=("Don't run the all-vs-all analysis."))
@click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others. "
"Will also use BiG-SCAPE v1 legacy_weights for distance calculations. "
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use "
"at your own risk, as BGC classes may have changed. All antiSMASH "
"classes that this legacy mode does not recognize will be grouped in "
"'others'."
),
)
# networking parameters
@click.option(
"--include_singletons",
Expand Down
12 changes: 6 additions & 6 deletions big_scape/cli/query_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@
),
)
@click.option(
"--skip_propagation",
"--propagate",
is_flag=True,
help=(
"Only generate edges between the query and reference BGCs. If not set, "
"BiG-SCAPE will also propagate edge generation to reference BGCs. "
"Warning: if the database already contains all edges, this will not work, "
"and the output will still showcase all edges between nodes "
"in the query connected component."
"By default, BiG-SCAPE will only generate edges between the query and reference"
" BGCs. With the propagate flag, BiG-SCAPE will go through multiple cycles of "
"edge generation until no new reference BGCs are connected to the query "
"connected component."
),
)
@click.pass_context
Expand All @@ -74,6 +73,7 @@ def query(ctx, *args, **kwarg):
ctx.obj.update(ctx.params)
ctx.obj["no_mix"] = None
ctx.obj["hybrids_off"] = False
ctx.obj["legacy_classify"] = False
ctx.obj["mode"] = "Query"

# workflow validations
Expand Down
50 changes: 48 additions & 2 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def num_pairs(self) -> int:
# find a collection of gbks with more than one subrecord
member_table = (
select(func.count(record_table.c.gbk_id).label("rec_count"))
.where(record_table.c.record_type == self.record_type.value)
.where(record_table.c.id.in_(self.record_ids))
.group_by(record_table.c.gbk_id)
.having(func.count() > 1)
.subquery()
Expand Down Expand Up @@ -406,8 +406,9 @@ def __init__(
label: str,
edge_param_id: int,
weights: str,
record_type: Optional[RECORD_TYPE],
):
super().__init__(label, edge_param_id, weights)
super().__init__(label, edge_param_id, weights, record_type)
self.reference_records: set[BGCRecord] = set()
self.done_records: set[BGCRecord] = set()
self.working_query_records: set[BGCRecord] = set()
Expand All @@ -426,6 +427,9 @@ def generate_pairs(
if record_a == record_b:
continue

if record_a.parent_gbk == record_b.parent_gbk:
continue

if legacy_sorting:
sorted_a, sorted_b = sorted((record_a, record_b), key=sort_name_key)
if sorted_a._db_id is None or sorted_b._db_id is None:
Expand Down Expand Up @@ -455,6 +459,48 @@ def num_pairs(self) -> int:

num_pairs = num_query_records * num_ref_records

# delete pairs originating from the same parent gbk
if self.record_type is not None and self.record_type != RECORD_TYPE.REGION:
query_ids = [record._db_id for record in self.working_query_records]
ref_ids = [record._db_id for record in self.working_ref_records]

if DB.metadata is None:
raise RuntimeError("DB.metadata is None")

rec_table = DB.metadata.tables["bgc_record"]

# contruct two tables that hold the gbk id and the number of subrecords
# present in the set of query and ref records respectively
query_gbk = (
select(
rec_table.c.gbk_id,
func.count(rec_table.c.gbk_id).label("query_count"),
)
.where(rec_table.c.id.in_(query_ids))
.group_by(rec_table.c.gbk_id)
.subquery()
)

ref_gbk = (
select(
rec_table.c.gbk_id,
func.count(rec_table.c.gbk_id).label("ref_count"),
)
.where(rec_table.c.id.in_(ref_ids))
.group_by(rec_table.c.gbk_id)
.subquery()
)

# now we can join the two tables and obtain the number of links between
# records from the same gbks by multiplying their counts
same_gbk_query = select(
func.sum(query_gbk.c.query_count * ref_gbk.c.ref_count)
).join(ref_gbk, query_gbk.c.gbk_id == ref_gbk.c.gbk_id)

same_gbks = DB.execute(same_gbk_query).scalar_one()
if same_gbks:
num_pairs -= same_gbks

return num_pairs

def add_records(self, record_list: list[BGCRecord]) -> None:
Expand Down
83 changes: 49 additions & 34 deletions big_scape/distances/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def calculate_distances_query(
max_cutoff = max(run["gcf_cutoffs"])
edge_param_id = bs_comparison.get_edge_param_id(run, weights)

query_bin = bs_comparison.QueryRecordPairGenerator("Query", edge_param_id, weights)
query_bin = bs_comparison.QueryRecordPairGenerator(
"Query", edge_param_id, weights, run["record_type"]
)
query_bin.add_records(query_records)

missing_query_bin = bs_comparison.QueryMissingRecordPairGenerator(query_bin)
Expand All @@ -54,9 +56,13 @@ def calculate_distances_query(

query_connected_component = next(
bs_network.get_connected_components(
max_cutoff, edge_param_id, query_bin, run["run_id"]
)
max_cutoff, edge_param_id, query_bin, run["run_id"], query_record
),
None,
)
if query_connected_component is None:
# no nodes are connected even with the highest cutoffs in the run
return query_bin

query_nodes = bs_network.get_nodes_from_cc(query_connected_component, query_records)

Expand Down Expand Up @@ -159,47 +165,56 @@ def calculate_distances(run: dict, bin: bs_comparison.RecordPairGenerator):
# fetches the current number of singleton ref <-> connected ref pairs from the database
num_pairs = bin.num_pairs()

# if there are no more singleton ref <-> connected ref pairs, then break and exit
if num_pairs == 0:
break

logging.info("Calculating distances for %d pairs", num_pairs)
if num_pairs > 0:
logging.info("Calculating distances for %d pairs", num_pairs)

save_batch = []
num_edges = 0
save_batch = []
num_edges = 0

with tqdm.tqdm(total=num_pairs, unit="edge", desc="Calculating distances") as t:
with tqdm.tqdm(
total=num_pairs, unit="edge", desc="Calculating distances"
) as t:

def callback(edges):
nonlocal num_edges
nonlocal save_batch
batch_size = run["cores"] * 100000
for edge in edges:
num_edges += 1
t.update(1)
save_batch.append(edge)
if len(save_batch) > batch_size:
bs_comparison.save_edges_to_db(save_batch, commit=True)
save_batch = []
def callback(edges):
nonlocal num_edges
nonlocal save_batch
batch_size = run["cores"] * 100000
for edge in edges:
num_edges += 1
t.update(1)
save_batch.append(edge)
if len(save_batch) > batch_size:
bs_comparison.save_edges_to_db(save_batch, commit=True)
save_batch = []

bs_comparison.generate_edges(
bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
)
bs_comparison.generate_edges(
bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
)

bs_comparison.save_edges_to_db(save_batch)
bs_comparison.save_edges_to_db(save_batch)

bs_data.DB.commit()
bs_data.DB.commit()

logging.info("Generated %d edges", num_edges)
logging.info("Generated %d edges", num_edges)

if run["skip_propagation"]:
if not run["propagate"]:
# in this case we only want one iteration, the Query -> Ref edges
break

if isinstance(bin, bs_comparison.MissingRecordPairGenerator):
# in this case we only need edges within one cc, no cycles needed
break

if isinstance(bin, bs_comparison.QueryMissingRecordPairGenerator):
# use the num_pairs from the parent bin because in a partial database,
# all distances for the first cycle(s) might already be present.
# we still only want to stop when no other connected nodes are discovered.
if bin.bin.num_pairs() == 0:
break

bin.cycle_records(max(run["gcf_cutoffs"]))
22 changes: 13 additions & 9 deletions big_scape/file_input/load_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,21 +409,16 @@ def get_all_bgc_records_query(
"""Get all BGC records from the working list of GBKs
Args:
gbks (list[GBK]): list of GBK objects
run (dict): run parameters
gbks (list[GBK]): list of GBK objects
Returns:
list[bs_gbk.BGCRecord]: list of BGC records
list[bs_gbk.BGCRecord], bs_gbk.BGCRecord: list of BGC records, query BGC record
"""
all_bgc_records: list[bs_gbk.BGCRecord] = []
for gbk in gbks:
if gbk.region is not None:
gbk_records = bs_gbk.bgc_record.get_sub_records(
gbk.region, run["record_type"]
)
if gbk.source_type == bs_enums.SOURCE_TYPE.QUERY:
query_record_type = run["record_type"]

query_record_type = run["record_type"]
query_record_number = run["query_record_number"]

Expand All @@ -435,15 +430,24 @@ def get_all_bgc_records_query(
query_record = query_sub_records[0]

else:
query_record = [
matching_query_records = [
record
for record in query_sub_records
if record.number == query_record_number
][0]
]
if len(matching_query_records) == 0:
raise RuntimeError(
f"Could not find {query_record_type.value} number {query_record_number} in query GBK. "
"Depending on config settings, overlapping records will be merged and take on the lower number."
)
query_record = matching_query_records[0]

all_bgc_records.append(query_record)

else:
gbk_records = bs_gbk.bgc_record.get_sub_records(
gbk.region, run["record_type"]
)
all_bgc_records.extend(gbk_records)

return all_bgc_records, query_record
46 changes: 25 additions & 21 deletions big_scape/network/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,35 +370,39 @@ def run_family_assignments_query(
# get_connected_components returns a list of connected components, but we only
# want the first one, so we use next()

try:
query_connected_component = next(
bs_network.get_connected_components(
cutoff, query_bin.edge_param_id, query_bin, run["run_id"]
)
)

cc_cutoff[cutoff] = query_connected_component

logging.debug(
"Found connected component with %d edges",
len(query_connected_component),
)

regions_families = generate_families(
query_connected_component, query_bin.label, cutoff, run["run_id"]
)

# save families to database
save_to_db(regions_families)
query_connected_component = next(
bs_network.get_connected_components(
cutoff,
query_bin.edge_param_id,
query_bin,
run["run_id"],
query_record,
),
None,
)

except StopIteration:
if query_connected_component is None:
logging.warning(
"No connected components found for %s bin at cutoff %s",
query_bin.label,
cutoff,
)
continue

cc_cutoff[cutoff] = query_connected_component

logging.debug(
"Found connected component with %d edges",
len(query_connected_component),
)

regions_families = generate_families(
query_connected_component, query_bin.label, cutoff, run["run_id"]
)

# save families to database
save_to_db(regions_families)

DB.commit()

# no connected components found
Expand Down
Loading

0 comments on commit 054e396

Please sign in to comment.