Skip to content
50 changes: 46 additions & 4 deletions gnomad/utils/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,48 @@ def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression:
return hl.agg.array_sum(hl.map(_get_criteria, sorted_indices))


def explode_downsamplings_oe(
ht: hl.Table,
downsampling_meta: Dict[str, List[str]],
metrics: List[str] = ["syn", "lof", "mis"],
) -> hl.Table:
"""
Explode observed and expected downsampling counts for each genetic ancestry group and metric.

The input `ht` must contain struct of downsampling information for genetic ancestry
groups under each metric name. For example: 'lof': struct {gen_anc_exp: struct
{global: array<float64>}.

:param ht: Input Table.
:param metrics: List of metrics to explode. Default is '['syn', 'lof', 'mis']'.
:param downsampling_meta: Dictionary containing downsampling metadata. Keys are the
genetic ancestry group names and values are the list of downsamplings for that
genetic ancestry group. Example: {'global': ['5000', '10000'], 'afr': ['5000',
'10000']}.
:return: Table with downsampling counts exploded so that observed and expected
metric counts for each pair of genetic ancestry groups and downsamplings is a
separate row.
"""
ht = ht.select(
_data=[
hl.struct(
gen_anc=pop,
downsampling=downsampling,
**{
f"{metric}.{oe}": ht[metric][f"gen_anc_{oe}"][pop][i]
for oe in ["obs", "exp"]
for metric in metrics
},
)
for pop, downsamplings in downsampling_meta.items()
for i, downsampling in enumerate(downsamplings)
]
)
ht = ht.explode("_data")
ht = ht.transmute(**ht._data)
return ht


def annotate_mutation_type(
t: Union[hl.MatrixTable, hl.Table],
context_length: Optional[int] = None,
Expand Down Expand Up @@ -1062,9 +1104,9 @@ def oe_aggregation_expr(
- oe - observed:expected ratio of variants filtered to `filter_expr`.

If `pops` is specified:
- pop_exp - Struct with the expected number of variants per population (for
- gen_anc_exp - Struct with the expected number of variants per population (for
all pop in `pops`) filtered to `filter_expr`.
- pop_obs - Struct with the observed number of variants per population (for
- gen_anc_obs - Struct with the observed number of variants per population (for
all pop in `pops`) filtered to `filter_expr`.

.. note::
Expand Down Expand Up @@ -1100,10 +1142,10 @@ def oe_aggregation_expr(
# Create aggregators that sum the number of observed variants
# and expected variants for each population if pops is specified.
if pops:
agg_expr["pop_exp"] = hl.struct(
agg_expr["gen_anc_exp"] = hl.struct(
**{pop: hl.agg.array_sum(ht[f"expected_variants_{pop}"]) for pop in pops}
)
agg_expr["pop_obs"] = hl.struct(
agg_expr["gen_anc_obs"] = hl.struct(
**{pop: hl.agg.array_sum(ht[f"downsampling_counts_{pop}"]) for pop in pops}
)

Expand Down