Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Add query_matrix_table an analogue to query_table #14806

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ package defs {
classOf[PartitionNativeReader],
classOf[PartitionNativeReaderIndexed],
classOf[PartitionNativeIntervalReader],
classOf[PartitionZippedNativeIntervalReader],
classOf[PartitionZippedNativeReader],
classOf[PartitionZippedIndexedNativeReader],
classOf[BgenPartitionReader],
Expand Down Expand Up @@ -1216,6 +1217,15 @@ package defs {
spec,
(jv \ "uidFieldName").extract[String],
)
case "PartitionZippedNativeIntervalReader" =>
val path = (jv \ "path").extract[String]
val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec]
PartitionZippedNativeIntervalReader(
ctx.stateManager,
path,
spec,
(jv \ "uidFieldName").extract[String],
)
case "GVCFPartitionReader" =>
val header = VCFHeaderInfo.fromJSON((jv \ "header"))
val callFields = (jv \ "callFields").extract[Set[String]]
Expand Down
56 changes: 56 additions & 0 deletions hail/hail/src/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@ case class PartitionNativeIntervalReader(
lazy val partitioner = rowsSpec.partitioner(sm)

lazy val contextType: Type = RVDPartitioner.intervalIRRepresentation(partitioner.kType)
require(partitioner.kType.size > 0)

def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

Expand Down Expand Up @@ -1509,6 +1510,61 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe
}
}

private[this] class PartitionEntriesNativeIntervalReader(
sm: HailStateManager,
entriesPath: String,
entriesSpec: AbstractTableSpec,
uidFieldName: String,
rowsTableSpec: AbstractTableSpec,
) extends PartitionNativeIntervalReader(sm, entriesPath, entriesSpec, uidFieldName) {
override lazy val partitioner = rowsTableSpec.rowsSpec.partitioner(sm)
}

case class PartitionZippedNativeIntervalReader(
sm: HailStateManager,
mtPath: String,
mtSpec: AbstractMatrixTableSpec,
uidFieldName: String,
) extends PartitionReader {
require(mtSpec.indexed)

// XXX: rows and entries paths are hardcoded, see MatrixTableSpec
private lazy val rowsReader =
PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy")

private lazy val entriesReader =
new PartitionEntriesNativeIntervalReader(
sm,
mtPath + "/entries",
mtSpec.entriesSpec,
uidFieldName,
rowsReader.tableSpec,
)

private lazy val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader)

def contextType = rowsReader.contextType
def fullRowType = zippedReader.fullRowType
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrick-schultz I'm getting an NPE in initialization, I think it's because this assertion in PartitionReader is running before zippedReader is initialized. Thoughts on resolving that?

abstract class PartitionReader {
assert(fullRowType.hasField(uidFieldName))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's annoying. My first thought is to make the sub-readers lazy vals. Maybe try that for now, and we can think more if there's any more satisfying fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. It worked. I guess lazy val can force initialization, but val is only truly valid after full initialization of the class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, lazy val foo = init is basically private var _foo = null and def foo = { if (_foo == null) _foo = init; _foo }. So it will be initialized the first time it's accessed, rather than when the class initializer is run. Usually that means it's initialized later than the class, but in this case it happens earlier.

Comment on lines +1513 to +1547
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive me Chris, I'm not really following what's mandating all these private lazy vals.
Since you're not using PartitionZippedNativeIntervalReader for pattern matching, can you make this a smart constructor that returns an anonymous instance of PartitionReader? Perhaps that might solve the initialisation order issue?

def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType)
def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

def emitStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
mb: EmitMethodBuilder[_],
codeContext: EmitCode,
requestedType: TStruct,
): IEmitCode = {
val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType)
val valueContext = cb.memoize(codeContext)
val contexts: IndexedSeq[EmitCode] = FastSeq(valueContext, valueContext)
val st = SStackStruct(zipContextType, contexts.map(_.emitType))
val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts))

zippedReader.emitStream(ctx, cb, mb, context, requestedType)
}
}

case class PartitionZippedIndexedNativeReader(
specLeft: AbstractTypedCodecSpec,
specRight: AbstractTypedCodecSpec,
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
qchisqtail,
qnorm,
qpois,
query_matrix_table_rows,
query_table,
rand_beta,
rand_bool,
Expand Down Expand Up @@ -554,6 +555,7 @@
'_console_log',
'dnorm',
'dchisq',
'query_matrix_table_rows',
'query_table',
'keyed_union',
'keyed_intersection',
Expand Down
147 changes: 106 additions & 41 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import itertools
import operator
import os.path
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -6984,6 +6985,43 @@ def shuffle(a, seed: Optional[builtins.int] = None) -> ArrayExpression:
return sorted(a, key=lambda _: hl.rand_unif(0.0, 1.0))


def __validate_and_coerce_endpoint(point, key_typ):
"""query validation for the points or endpoints of the query in query_table"""
len = builtins.len
key_names = list(key_typ)
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if not isinstance(ts, tstruct):
raise ValueError(
f'key mismatch: cannot use query point type {point.dtype} to query a table with key of '
f'({", ".join(builtins.str(x) for x in key_typ.values())}) '
)

if not ts:
raise ValueError("query point value cannot be an empty struct")

for i, qt, kt in builtins.zip(builtins.range(len(ts)), ts.values(), key_typ.values()):
if kt != qt:
raise ValueError(
f'mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {qt}, key type is {qt}'
)

# this check is here because it is more useful to the user to check each
# type than it is to fail fast with a larger query struct than the table
# has key fields
if len(ts) > len(key_typ):
raise ValueError(f'query point type has {len(ts)} field(s), but key only has {len(key_typ)} field(s)')

point_size = len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) for i in builtins.range(len(key_typ))
}),
hl.int32(point_size),
])


@typecheck(path=builtins.str, point_or_interval=expr_any)
def query_table(path, point_or_interval):
"""Query records from a table corresponding to a given point or range of keys.
Expand Down Expand Up @@ -7013,54 +7051,18 @@ def query_table(path, point_or_interval):
row_typ = table.row.dtype

key_typ = table.key.dtype
key_names = list(key_typ)
len = builtins.len
if len(key_typ) == 0:
raise ValueError("query_table: cannot query unkeyed table")

def coerce_endpoint(point):
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if isinstance(ts, tstruct):
i = 0
while i < len(ts):
if i >= len(key_typ):
raise ValueError(
f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)"
)
if key_typ[i] != ts[i]:
raise ValueError(
f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}"
)
i += 1

if i == 0:
raise ValueError("query_table: cannot query with empty key")

point_size = builtins.len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i]))
for i in builtins.range(builtins.len(key_typ))
}),
hl.int32(point_size),
])
else:
raise ValueError(
f"query_table: key mismatch: cannot query a table with key "
f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}"
)
if builtins.len(key_typ) == 0:
raise ValueError('cannot query unkeyed table')

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=coerce_endpoint(point_or_interval.start),
end=coerce_endpoint(point_or_interval.end),
start=__validate_and_coerce_endpoint(point_or_interval.start, key_typ),
end=__validate_and_coerce_endpoint(point_or_interval.end, key_typ),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = coerce_endpoint(point_or_interval)
point = __validate_and_coerce_endpoint(point_or_interval, key_typ)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
return construct_expr(
ir.ToArray(ir.ReadPartition(partition_interval._ir, reader=ir.PartitionNativeIntervalReader(path, row_typ))),
Expand All @@ -7070,6 +7072,69 @@ def coerce_endpoint(point):
)


@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str)
def query_matrix_table_rows(path, point_or_interval, entries_name='entries_array'):
"""Query row records from a matrix table corresponding to a given point or
range of row keys. The entry fields are localized as an array of structs as
in :meth:`.MatrixTable.localize_entries`.

Notes
-----
This function does not dispatch to a distributed runtime; it can be used inside
already-distributed queries such as in :meth:`.Table.annotate`.

Warning
-------
This function contains no safeguards against reading large amounts of data
using a single thread.

Parameters
----------
path : :class:`str`
Table path.
point_or_interval
Point or interval to query.
entries_name : :class:`str`
Identifier to use for the localized entries array. Must not conflict
with any row field identifiers.

Returns
-------
:class:`.ArrayExpression`
"""
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(
f'field "{entries_name}" is present in matrix table row fields, use a different `entries_name`'
)
entries_table = hl.read_table(os.path.join(path, 'entries'))
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=__validate_and_coerce_endpoint(point_or_interval.start, key_typ),
end=__validate_and_coerce_endpoint(point_or_interval.end, key_typ),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = __validate_and_coerce_endpoint(point_or_interval, key_typ)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(
partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type)
)
stream_expr = construct_expr(
read_part_ir,
type=hl.tstream(full_row_type),
indices=partition_interval._indices,
aggregations=partition_interval._aggregations,
)
return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array()


@typecheck(msg=expr_str, result=expr_any)
def _console_log(msg, result):
indices, aggregations = unify_all(msg, result)
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
NDArraySVD,
NDArrayWrite,
PartitionNativeIntervalReader,
PartitionZippedNativeIntervalReader,
ProjectedTopLevelReference,
ReadPartition,
Recur,
Expand Down Expand Up @@ -527,6 +528,7 @@
'TableNativeFanoutWriter',
'ReadPartition',
'PartitionNativeIntervalReader',
'PartitionZippedNativeIntervalReader',
'GVCFPartitionReader',
'TableGen',
'Partitioner',
Expand Down
31 changes: 31 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,37 @@ def row_type(self):
return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class PartitionZippedNativeIntervalReader(PartitionReader):
def __init__(self, path, full_row_type, uid_field=None):
self.path = path
self.full_row_type = full_row_type
self.uid_field = uid_field

def with_uid_field(self, uid_field):
return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field)

def render(self):
return escape_str(
json.dumps({
"name": "PartitionZippedNativeIntervalReader",
"path": self.path,
"uidFieldName": self.uid_field if self.uid_field is not None else '__dummy',
})
)

def _eq(self, other):
return (
isinstance(other, PartitionZippedNativeIntervalReader)
and self.path == other.path
and self.uid_field == other.uid_field
)

def row_type(self):
if self.uid_field is None:
return self.full_row_type
return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class ReadPartition(IR):
@typecheck_method(context=IR, reader=PartitionReader)
def __init__(self, context, reader):
Expand Down
Loading