Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute local alleles during encode #299

Merged
merged 12 commits into from
Jan 17, 2025
10 changes: 3 additions & 7 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def show_work_summary(work_summary, json):
@compressor
@progress
@worker_processes
@local_alleles
def explode(
vcfs,
icf_path,
Expand All @@ -231,7 +230,6 @@ def explode(
compressor,
progress,
worker_processes,
local_alleles,
):
"""
Convert VCF(s) to intermediate columnar format
Expand All @@ -245,7 +243,6 @@ def explode(
column_chunk_size=column_chunk_size,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)


Expand All @@ -260,7 +257,6 @@ def explode(
@verbose
@progress
@worker_processes
@local_alleles
def dexplode_init(
vcfs,
icf_path,
Expand All @@ -272,7 +268,6 @@ def dexplode_init(
verbose,
progress,
worker_processes,
local_alleles,
):
"""
Initial step for distributed conversion of VCF(s) to intermediate columnar format
Expand All @@ -289,7 +284,6 @@ def dexplode_init(
worker_processes=worker_processes,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)
show_work_summary(work_summary, json)

Expand Down Expand Up @@ -340,7 +334,8 @@ def inspect(path, verbose):
@icf_path
@variants_chunk_size
@samples_chunk_size
def mkschema(icf_path, variants_chunk_size, samples_chunk_size):
@local_alleles
def mkschema(icf_path, variants_chunk_size, samples_chunk_size, local_alleles):
"""
Generate a schema for zarr encoding
"""
Expand All @@ -350,6 +345,7 @@ def mkschema(icf_path, variants_chunk_size, samples_chunk_size):
stream,
variants_chunk_size=variants_chunk_size,
samples_chunk_size=samples_chunk_size,
local_alleles=local_alleles,
)


Expand Down
21 changes: 21 additions & 0 deletions bio2zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def chunk_aligned_slices(z, n, max_chunks=None):
return slices


def first_dim_slice_iter(z, start, stop):
"""
Efficiently iterate over the specified slice of the first dimension of the zarr
array z.
"""
chunk_size = z.chunks[0]
first_chunk = start // chunk_size
last_chunk = (stop // chunk_size) + (stop % chunk_size != 0)
for chunk in range(first_chunk, last_chunk):
Z = z.blocks[chunk]
chunk_start = chunk * chunk_size
chunk_stop = chunk_start + chunk_size
slice_start = None
if start > chunk_start:
slice_start = start - chunk_start
slice_stop = None
if stop < chunk_stop:
slice_stop = stop - chunk_start
yield from Z[slice_start:slice_stop]


def du(path):
"""
Return the total bytes stored at this path.
Expand Down
175 changes: 3 additions & 172 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def make_field_def(name, vcf_type, vcf_number):
return fields


def scan_vcf(path, target_num_partitions, *, local_alleles):
def scan_vcf(path, target_num_partitions):
with vcf_utils.IndexedVcf(path) as indexed_vcf:
vcf = indexed_vcf.vcf
filters = []
Expand All @@ -237,10 +237,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
pass_filter = filters.pop(pass_index)
filters.insert(0, pass_filter)

# Indicates whether vcf2zarr can introduce local alleles
can_localize = False
should_add_laa_field = True
should_add_lpl_field = True
fields = fixed_vcf_field_definitions()
for h in vcf.header_iter():
if h["HeaderType"] in ["INFO", "FORMAT"]:
Expand All @@ -249,36 +245,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
field.vcf_type = "Integer"
field.vcf_number = "."
fields.append(field)
if field.category == "FORMAT":
if field.name == "PL":
can_localize = True
if field.name == "LAA":
should_add_laa_field = False
if field.name == "LPL":
should_add_lpl_field = False

if local_alleles and can_localize:
if should_add_laa_field:
laa_field = VcfField(
category="FORMAT",
name="LAA",
vcf_type="Integer",
vcf_number=".",
description="1-based indices into ALT, indicating which alleles"
" are relevant (local) for the current sample",
summary=VcfFieldSummary(),
)
fields.append(laa_field)
if should_add_lpl_field:
lpl_field = VcfField(
category="FORMAT",
name="LPL",
vcf_type="Integer",
vcf_number="LG",
description="Local-allele representation of PL",
summary=VcfFieldSummary(),
)
fields.append(lpl_field)

try:
contig_lengths = vcf.seqlens
Expand Down Expand Up @@ -315,14 +281,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
return metadata, vcf.raw_header


def scan_vcfs(
paths,
show_progress,
target_num_partitions,
worker_processes=1,
*,
local_alleles,
):
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
logger.info(
f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
f" partitions."
Expand All @@ -346,7 +305,6 @@ def scan_vcfs(
scan_vcf,
path,
max(1, target_num_partitions // len(paths)),
local_alleles=local_alleles,
)
results = list(pwm.results_as_completed())

Expand Down Expand Up @@ -505,104 +463,6 @@ def sanitise_value_int_2d(buff, j, value):
buff[j, :, : value.shape[1]] = value


def compute_laa_field(variant) -> np.ndarray:
"""
Computes the value of the LAA field for each sample given a variant.

The LAA field is a list of one-based indices into the ALT alleles
that indicates which alternate alleles are observed in the sample.

This method infers which alleles are observed from the GT field.
"""
sample_count = variant.num_called + variant.num_unknown
alt_allele_count = len(variant.ALT)
allele_count = alt_allele_count + 1
allele_counts = np.zeros((sample_count, allele_count), dtype=int)

if "GT" in variant.FORMAT:
# The last element of each sample's genotype indicates the phasing
# and is not an allele.
genotypes = variant.genotype.array()[:, :-1]
genotypes.clip(0, None, out=genotypes)
genotype_allele_counts = np.apply_along_axis(
np.bincount, axis=1, arr=genotypes, minlength=allele_count
)
allele_counts += genotype_allele_counts

allele_counts[:, 0] = 0 # We don't count the reference allele
max_row_length = 1

def nonzero_pad(arr: np.ndarray, *, length: int):
nonlocal max_row_length
alleles = arr.nonzero()[0]
max_row_length = max(max_row_length, len(alleles))
pad_length = length - len(alleles)
return np.pad(
alleles,
(0, pad_length),
mode="constant",
constant_values=constants.INT_FILL,
)

alleles = np.apply_along_axis(
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
)
alleles = alleles[:, :max_row_length]

return alleles


def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray:
assert laa_val is not None

la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype)
la_val[:, 1:] = laa_val
ploidy = variant.ploidy

if "PL" not in variant.FORMAT:
sample_count = variant.num_called + variant.num_unknown
local_allele_count = la_val.shape[1]

if ploidy == 1:
local_genotype_count = local_allele_count
elif ploidy == 2:
local_genotype_count = local_allele_count * (local_allele_count + 1) // 2
else:
raise ValueError(f"Cannot handle ploidy = {ploidy}")

return np.full((sample_count, local_genotype_count), constants.INT_MISSING)

# Compute a and b
if ploidy == 1:
a = la_val
b = np.zeros_like(la_val)
elif ploidy == 2:
repeats = np.arange(1, la_val.shape[1] + 1)
b = np.repeat(la_val, repeats, axis=1)
arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1))
tril_indices = np.tril_indices_from(arange_tile)
a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1))
row_index = np.arange(la_val.shape[0]).reshape(-1, 1)
a = la_val[row_index, a_index]
else:
raise ValueError(f"Cannot handle ploidy = {ploidy}")

# Compute n, the local indices of the PL field
n = (b * (b + 1) / 2 + a).astype(int)

pl_val = variant.format("PL")
pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING
# When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
# In that case, we need to broadcast the PL value.
if pl_val.shape[1] < n.shape[1]:
pl_val = np.broadcast_to(pl_val, n.shape)
row_index = np.arange(pl_val.shape[0]).reshape(-1, 1)
lpl_val = pl_val[row_index, n]
lpl_val[b == constants.INT_FILL] = constants.INT_FILL

return lpl_val


missing_value_map = {
"Integer": constants.INT_MISSING,
"Float": constants.FLOAT32_MISSING,
Expand Down Expand Up @@ -1107,14 +967,11 @@ def init(
target_num_partitions=None,
show_progress=False,
compressor=None,
local_alleles=None,
):
if self.path.exists():
raise ValueError(f"ICF path already exists: {self.path}")
if compressor is None:
compressor = ICF_DEFAULT_COMPRESSOR
if local_alleles is None:
local_alleles = False
vcfs = [pathlib.Path(vcf) for vcf in vcfs]
target_num_partitions = max(target_num_partitions, len(vcfs))

Expand All @@ -1124,7 +981,6 @@ def init(
worker_processes=worker_processes,
show_progress=show_progress,
target_num_partitions=target_num_partitions,
local_alleles=local_alleles,
)
check_field_clobbering(icf_metadata)
self.metadata = icf_metadata
Expand Down Expand Up @@ -1207,17 +1063,6 @@ def process_partition(self, partition_index):
else:
format_fields.append(field)

format_field_names = [format_field.name for format_field in format_fields]
if "LAA" in format_field_names and "LPL" in format_field_names:
laa_index = format_field_names.index("LAA")
lpl_index = format_field_names.index("LPL")
# LAA needs to come before LPL
if lpl_index < laa_index:
format_fields[laa_index], format_fields[lpl_index] = (
format_fields[lpl_index],
format_fields[laa_index],
)

last_position = None
with IcfPartitionWriter(
self.metadata,
Expand Down Expand Up @@ -1245,18 +1090,8 @@ def process_partition(self, partition_index):
else:
val = variant.genotype.array()
tcw.append("FORMAT/GT", val)
laa_val = None
for field in format_fields:
if field.name == "LAA":
if "LAA" not in variant.FORMAT:
laa_val = compute_laa_field(variant)
else:
laa_val = variant.format("LAA")
val = laa_val
elif field.name == "LPL" and "LPL" not in variant.FORMAT:
val = compute_lpl_field(variant, laa_val)
else:
val = variant.format(field.name)
val = variant.format(field.name)
tcw.append(field.full_name, val)

# Note: an issue with updating the progress per variant here like
Expand Down Expand Up @@ -1352,7 +1187,6 @@ def explode(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
writer.init(
Expand All @@ -1363,7 +1197,6 @@ def explode(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)
writer.explode(worker_processes=worker_processes, show_progress=show_progress)
writer.finalise()
Expand All @@ -1379,7 +1212,6 @@ def explode_init(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
return writer.init(
Expand All @@ -1389,7 +1221,6 @@ def explode_init(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)


Expand Down
Loading
Loading