Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Add query_matrix_table an analogue to query_table #14806

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ package defs {
classOf[PartitionNativeReader],
classOf[PartitionNativeReaderIndexed],
classOf[PartitionNativeIntervalReader],
classOf[PartitionZippedNativeIntervalReader],
classOf[PartitionZippedNativeReader],
classOf[PartitionZippedIndexedNativeReader],
classOf[BgenPartitionReader],
Expand Down Expand Up @@ -1216,6 +1217,15 @@ package defs {
spec,
(jv \ "uidFieldName").extract[String],
)
case "PartitionZippedNativeIntervalReader" =>
val path = (jv \ "path").extract[String]
val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec]
PartitionZippedNativeIntervalReader(
ctx.stateManager,
path,
spec,
(jv \ "uidFieldName").extract[String],
)
case "GVCFPartitionReader" =>
val header = VCFHeaderInfo.fromJSON((jv \ "header"))
val callFields = (jv \ "callFields").extract[Set[String]]
Expand Down
34 changes: 34 additions & 0 deletions hail/hail/src/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,40 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe
}
}

case class PartitionZippedNativeIntervalReader(
sm: HailStateManager,
mtPath: String,
mtSpec: AbstractMatrixTableSpec,
uidFieldName: String,
) extends PartitionReader {
require(mtSpec.indexed)

// XXX: rows and entries paths are hardcoded, see MatrixTableSpec
private val rowsReader = PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy")
private val entriesReader = PartitionNativeIntervalReader(sm, mtPath + "/entries", mtSpec.entriesSpec, uidFieldName)
private val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader)

def contextType = rowsReader.contextType
def fullRowType = zippedReader.fullRowType
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrick-schultz I'm getting an NPE in initialization, I think it's because this assertion in PartitionReader is running before zippedReader is initialized. Thoughts on resolving that?

abstract class PartitionReader {
assert(fullRowType.hasField(uidFieldName))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's annoying. My first thought is to make the sub-readers lazy vals. Maybe try that for now, and we can think more if there's any more satisfying fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. It worked. I guess lazy val can force initialization, but val is only truly valid after full initialization of the class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, lazy val foo = init is basically private var _foo = null and def foo = { if (_foo == null) _foo = init; _foo }. So it will be initialized the first time it's accessed, rather than when the class initializer is run. Usually that means it's initialized later than the class, but in this case it happens earlier.

def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType)
def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

def emitStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
mb: EmitMethodBuilder[_],
_context: EmitCode,
requestedType: TStruct,
): IEmitCode = {
val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType)
val contexts = FastSeq(_context, _context)
val st = SStackStruct(zipContextType, contexts.map(_.emitType))
val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts))

zippedReader.emitStream(ctx, cb, mb, context, requestedType)
}
}

case class PartitionZippedIndexedNativeReader(
specLeft: AbstractTypedCodecSpec,
specRight: AbstractTypedCodecSpec,
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
qchisqtail,
qnorm,
qpois,
query_matrix_table,
query_table,
rand_beta,
rand_bool,
Expand Down Expand Up @@ -554,6 +555,7 @@
'_console_log',
'dnorm',
'dchisq',
'query_matrix_table',
'query_table',
'keyed_union',
'keyed_intersection',
Expand Down
66 changes: 66 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7069,6 +7069,72 @@ def coerce_endpoint(point):
aggregations=partition_interval._aggregations,
)

@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str)
def query_matrix_table(path, point_or_interval, entries_name='entries_array'):
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(f'field "{entries_name}" is present in matrix table row fields, pick a different `entries_name`')

entries_table = hl.read_table(path + '/entries')
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype
key_names = list(key_typ)
len = builtins.len

def coerce_endpoint(point):
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if isinstance(ts, tstruct):
i = 0
while i < len(ts):
if i >= len(key_typ):
raise ValueError(
f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)"
)
if key_typ[i] != ts[i]:
raise ValueError(
f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}"
)
i += 1

if i == 0:
raise ValueError("query_table: cannot query with empty key")

point_size = builtins.len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i]))
for i in builtins.range(builtins.len(key_typ))
}),
hl.int32(point_size),
])
else:
raise ValueError(
f"query_table: key mismatch: cannot query a table with key "
f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}"
)

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=coerce_endpoint(point_or_interval.start),
end=coerce_endpoint(point_or_interval.end),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = coerce_endpoint(point_or_interval)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type))
stream_expr = construct_expr(
read_part_ir,
type=hl.tstream(full_row_type),
indices=partition_interval._indices,
aggregations=partition_interval._aggregations,
)
return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array()

@typecheck(msg=expr_str, result=expr_any)
def _console_log(msg, result):
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
NDArraySVD,
NDArrayWrite,
PartitionNativeIntervalReader,
PartitionZippedNativeIntervalReader,
ProjectedTopLevelReference,
ReadPartition,
Recur,
Expand Down Expand Up @@ -527,6 +528,7 @@
'TableNativeFanoutWriter',
'ReadPartition',
'PartitionNativeIntervalReader',
'PartitionZippedNativeIntervalReader',
'GVCFPartitionReader',
'TableGen',
'Partitioner',
Expand Down
31 changes: 31 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,37 @@ def row_type(self):
return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class PartitionZippedNativeIntervalReader(PartitionReader):
def __init__(self, path, full_row_type, uid_field=None):
self.path = path
self.full_row_type = full_row_type
self.uid_field = uid_field

def with_uid_field(self, uid_field):
return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field)

def render(self):
return escape_str(
json.dumps({
"name": "PartitionZippedNativeIntervalReader",
"path": self.path,
"uidFieldName": self.uid_field if self.uid_field is not None else '__dummy',
})
)

def _eq(self, other):
return (
isinstance(other, PartitionZippedNativeIntervalReader)
and self.path == other.path
and self.uid_field == other.uid_field
)

def row_type(self):
if self.uid_field is None:
return self.full_row_type
return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class ReadPartition(IR):
@typecheck_method(context=IR, reader=PartitionReader)
def __init__(self, context, reader):
Expand Down