Skip to content
248 changes: 129 additions & 119 deletions gnomad/utils/sparse_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hail as hl

from gnomad.utils.annotations import (
agg_by_strata,
fs_from_sb,
generate_freq_group_membership_array,
get_adj_expr,
Expand Down Expand Up @@ -913,6 +914,7 @@ def compute_coverage_stats(
coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100],
row_key_fields: List[str] = ["locus"],
strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Compute coverage statistics for every base of the `reference_ht` provided.
Expand All @@ -934,38 +936,70 @@ def compute_coverage_stats(
:param row_key_fields: List of row key fields to use for joining `mtds` with
`reference_ht`
:param strata_expr: Optional list of dicts containing expressions to stratify the
coverage stats by.
:return: Table with per-base coverage stats
coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be
specified.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. Only one of `group_membership_ht` or
`strata_expr` can be specified.
:return: Table with per-base coverage stats.
"""
is_vds = isinstance(mtds, hl.vds.VariantDataset)
if is_vds:
mt = mtds.variant_data
else:
mt = mtds

if strata_expr is None:
# Determine the genotype field.
gt_field = set(mt.entry) & {"GT", "LGT"}
if len(gt_field) == 0:
raise ValueError("No genotype field found in entry fields.")

gt_field = gt_field.pop()

if group_membership_ht is not None and strata_expr is not None:
raise ValueError(
"Only one of 'group_membership_ht' or 'strata_expr' can be specified."
)

# Initialize no_strata and default strata_expr if neither group_membership_ht nor
# strata_expr is provided.
no_strata = group_membership_ht is None and strata_expr is None
if no_strata:
strata_expr = {}
no_strata = True
else:
no_strata = False

# Annotate the MT cols with each of the expressions in strata_expr and redefine
# strata_expr based on the column HT with added annotations.
ht = mt.annotate_cols(**{k: v for d in strata_expr for k, v in d.items()}).cols()
strata_expr = [{k: ht[k] for k in d} for d in strata_expr]

# Use the function for creating the frequency stratified by `freq_meta`,
# `freq_meta_sample_count`, and `group_membership` annotations to give
# stratification group membership info for computing coverage. By default, this
# function returns annotations where the second element is a placeholder for the
# "raw" frequency of all samples, where the first 2 elements are the same sample
# set, but freq_meta startswith [{"group": "adj", "group": "raw", ...]. Use
# `no_raw_group` to exclude the "raw" group so there is a single annotation
# representing the full samples set. `freq_meta` is updated below to remove "group"
# from all dicts.
group_membership_ht = generate_freq_group_membership_array(
ht, strata_expr, no_raw_group=True
)

if group_membership_ht is None:
# Annotate the MT cols with each of the expressions in strata_expr and redefine
# strata_expr based on the column HT with added annotations.
ht = mt.annotate_cols(
**{k: v for d in strata_expr for k, v in d.items()}
).cols()
strata_expr = [{k: ht[k] for k in d} for d in strata_expr]

# Use 'generate_freq_group_membership_array' to create a group_membership Table
# that gives stratification group membership info based on 'strata_expr'. The
# returned Table has the following annotations: 'freq_meta',
# 'freq_meta_sample_count', and 'group_membership'. By default, this
# function returns annotations where the second element is a placeholder for the
# "raw" frequency of all samples, where the first 2 elements are the same sample
# set, but 'freq_meta' starts with [{"group": "adj", "group": "raw", ...]. Use
# `no_raw_group` to exclude the "raw" group so there is a single annotation
# representing the full samples set. Update all 'freq_meta' entries' "group"
# to "raw" because `generate_freq_group_membership_array` will return them all
# as "adj" since it was built for frequency computation, but for the coverage
# computation we don't want to do any filtering.
group_membership_ht = generate_freq_group_membership_array(
ht, strata_expr, no_raw_group=True
)
group_membership_ht = group_membership_ht.annotate_globals(
freq_meta=group_membership_ht.freq_meta.map(
lambda x: hl.dict(
x.items().map(
lambda m: hl.if_else(m[0] == "group", ("group", "raw"), m)
)
)
)
)

n_samples = group_membership_ht.count()
sample_counts = group_membership_ht.index_globals().freq_meta_sample_count

Expand All @@ -991,30 +1025,28 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
"""
Outer join MatrixTable with reference Table.

Add 'in_ref' annotation indicating whether a given position is found in the reference Table.
Add 'in_ref' annotation indicating whether a given position is found in the
reference Table.

:param mt: Input MatrixTable.
:return: MatrixTable with 'in_ref' annotation added.
"""
keep_entries = ["DP"]
if "END" in mt.entry:
keep_entries.append("END")
if "LGT" in mt.entry:
keep_entries.append("LGT")
if "GT" in mt.entry:
keep_entries.append("GT")
# Get the total number of samples.
n_samples = mt.count_cols()
entry_keep_fields = set(mt.entry) & {gt_field, "DP", "END"}
mt_col_key_fields = list(mt.col_key)
mt_row_key_fields = list(mt.row_key)
t = mt.select_entries(*keep_entries).select_cols().select_rows()
t = mt.select_entries(*entry_keep_fields).select_cols().select_rows()

# Localize entries and perform an outer join with the reference HT.
t = t._localize_entries("__entries", "__cols")
t = (
t.key_by(*row_key_fields)
.join(
reference_ht.key_by(*row_key_fields).select(_in_ref=True),
how="outer",
)
.key_by(*mt_row_key_fields)
t = t.key_by(*row_key_fields)
t = t.join(
reference_ht.key_by(*row_key_fields).select(_in_ref=True), how="outer"
)
t = t.key_by(*mt_row_key_fields)

# Fill in missing entries with missing values for each entry field.
t = t.annotate(
__entries=hl.or_else(
t.__entries,
Expand All @@ -1024,8 +1056,10 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
)
)

# Unlocalize entries to turn the HT back to a MT.
return t._unlocalize_entries("__entries", "__cols", mt_col_key_fields)

# Densify VDS/sparse MT at all sites.
if is_vds:
mtds = hl.vds.VariantDataset(
mtds.reference_data.select_entries("END", "DP").select_cols().select_rows(),
Expand All @@ -1039,112 +1073,88 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
# Densify
mt = hl.experimental.densify(mtds)

# Filter rows where the reference is missing
# Filter rows where the reference is missing.
mt = mt.filter_rows(mt._in_ref)

# Unfilter entries so that entries with no ref block overlap aren't null
# Unfilter entries so that entries with no ref block overlap aren't null.
mt = mt.unfilter_entries()

# Annotate with group membership
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)

# Compute coverage stats
coverage_over_x_bins = sorted(coverage_over_x_bins)
max_coverage_bin = coverage_over_x_bins[-1]
hl_coverage_over_x_bins = hl.array(coverage_over_x_bins)

# This expression creates a counter DP -> number of samples for DP between
# 0 and max_coverage_bin
coverage_counter_expr = hl.agg.counter(
hl.min(max_coverage_bin, hl.or_else(mt.DP, 0))
)
mean_expr = hl.agg.mean(hl.or_else(mt.DP, 0))

# Annotate all rows with coverage stats for each strata group.
ht = mt.select_rows(
coverage_stats=hl.agg.array_agg(
lambda x: hl.agg.filter(
x,
hl.struct(
coverage_counter=coverage_counter_expr,
mean=hl.if_else(hl.is_nan(mean_expr), 0, mean_expr),
median_approx=hl.or_else(
hl.agg.approx_median(hl.or_else(mt.DP, 0)), 0
),
total_DP=hl.agg.sum(mt.DP),
),
# Compute coverage stats.
cov_bins = sorted(coverage_over_x_bins)
rev_cov_bins = list(reversed(cov_bins))
max_cov_bin = cov_bins[-1]
cov_bins = hl.array(cov_bins)
entry_agg_funcs = {
"coverage_stats": (
lambda t: hl.if_else(hl.is_missing(t.DP) | hl.is_nan(t.DP), 0, t.DP),
lambda dp: hl.struct(
# This expression creates a counter DP -> number of samples for DP
# between 0 and max_cov_bin.
coverage_counter=hl.agg.counter(hl.min(max_cov_bin, dp)),
mean=hl.if_else(hl.is_nan(hl.agg.mean(dp)), 0, hl.agg.mean(dp)),
median_approx=hl.or_else(hl.agg.approx_median(dp), 0),
total_DP=hl.agg.sum(dp),
),
mt.group_membership,
)
).rows()
}
ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht)
ht = ht.checkpoint(hl.utils.new_temp_file("coverage_stats", "ht"))

# This expression aggregates the DP counter in reverse order of the
# coverage_over_x_bins and computes the cumulative sum over them.
# It needs to be in reverse order because we want the sum over samples
# covered by > X.
count_array_expr = ht.coverage_stats.map(
lambda x: hl.cumulative_sum(
hl.array(
# The coverage was already floored to the max_coverage_bin, so no more
# aggregation is needed for the max bin.
[hl.int32(x.coverage_counter.get(max_coverage_bin, 0))]
# For each of the other bins, coverage needs to be summed between the
# boundaries.
).extend(
hl.range(hl.len(hl_coverage_over_x_bins) - 1, 0, step=-1).map(
lambda i: hl.sum(
hl.range(
hl_coverage_over_x_bins[i - 1], hl_coverage_over_x_bins[i]
).map(lambda j: hl.int32(x.coverage_counter.get(j, 0)))
)
# This expression aggregates the DP counter in reverse order of the cov_bins and
# computes the cumulative sum over them. It needs to be in reverse order because we
# want the sum over samples covered by > X.
def _cov_stats(
cov_stat: hl.expr.StructExpression, n: hl.expr.Int32Expression
) -> hl.expr.StructExpression:
# The coverage was already floored to the max_coverage_bin, so no more
# aggregation is needed for the max bin.
count_expr = cov_stat.coverage_counter
max_bin_expr = hl.int32(count_expr.get(max_cov_bin, 0))

# For each of the other bins, coverage is summed between the boundaries.
bin_expr = hl.range(hl.len(cov_bins) - 1, 0, step=-1)
bin_expr = bin_expr.map(
lambda i: hl.sum(
hl.range(cov_bins[i - 1], cov_bins[i]).map(
lambda j: hl.int32(count_expr.get(j, 0))
)
)
)
)
bin_expr = hl.cumulative_sum(hl.array([max_bin_expr]).extend(bin_expr))

bin_expr = {f"over_{x}": bin_expr[i] / n for i, x in enumerate(rev_cov_bins)}

return cov_stat.annotate(**bin_expr).drop("coverage_counter")

ht = ht.annotate(
coverage_stats=hl.map(
lambda c, g, n: c.annotate(
**{
f"over_{x}": g[i] / n
for i, x in zip(
range(len(coverage_over_x_bins) - 1, -1, -1),
# Reverse the bin index as count_array_expr has reverse order.
coverage_over_x_bins,
)
}
).drop("coverage_counter"),
lambda c, n: _cov_stats(c, n),
ht.coverage_stats,
count_array_expr,
sample_counts,
)
)
current_keys = list(ht.key)
ht = (
ht.key_by(*row_key_fields)
.select_globals()
.drop(*[k for k in current_keys if k not in row_key_fields])
)
ht = ht.key_by(*row_key_fields).select_globals()
ht = ht.drop(*[k for k in current_keys if k not in row_key_fields])

group_globals = group_membership_ht.index_globals()
global_expr = {}
if no_strata:
# If there was no stratification, move coverage_stats annotations to the top
# level.
ht = ht.select(**{k: ht.coverage_stats[0][k] for k in ht.coverage_stats[0]})
global_expr["sample_count"] = group_globals.freq_meta_sample_count[0]
else:
# If there was stratification, add the metadata and sample count info for the
# stratification to the globals.
ht = ht.annotate_globals(
coverage_stats_meta=(
group_membership_ht.index_globals().freq_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
)
),
coverage_stats_meta_sample_count=(
group_membership_ht.index_globals().freq_meta_sample_count
),
global_expr["coverage_stats_meta"] = group_globals.freq_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
)
global_expr["coverage_stats_meta_sample_count"] = (
group_globals.freq_meta_sample_count
)

ht = ht.annotate_globals(**global_expr)

return ht

Expand Down