Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ def annotate_freq(
pop_expr: Optional[hl.expr.StringExpression] = None,
subpop_expr: Optional[hl.expr.StringExpression] = None,
additional_strata_expr: Optional[Dict[str, hl.expr.StringExpression]] = None,
additional_strata_grouping_expr: Optional[
Dict[str, hl.expr.StringExpression]
] = None,
downsamplings: Optional[List[int]] = None,
) -> hl.MatrixTable:
"""
Expand Down Expand Up @@ -388,6 +391,7 @@ def annotate_freq(
:param pop_expr: When specified, frequencies are stratified by population. If `sex_expr` is also specified, then a pop/sex stratifiction is added.
:param subpop_expr: When specified, frequencies are stratified by sub-continental population. Note that `pop_expr` is required as well when using this option.
:param additional_strata_expr: When specified, frequencies are stratified by the given additional strata found in the dict. This can e.g. be used to stratify by platform.
:param additional_strata_grouping_expr: When specified, frequencies are further stratified by groups within the additional_strata_expr. This can e.g. be used to stratify by platform-population.
:param downsamplings: When specified, frequencies are computed by downsampling the data to the number of samples given in the list. Note that if `pop_expr` is specified, downsamplings by population is also computed.
:return: MatrixTable with `freq` annotation
"""
Expand All @@ -396,10 +400,20 @@ def annotate_freq(
"annotate_freq requires pop_expr when using subpop_expr"
)

if additional_strata_grouping_expr is not None and additional_strata_expr is None:
raise NotImplementedError(
"annotate_freq requires additional_strata_expr when using"
" additional_strata_grouping_expr"
)

if additional_strata_expr is None:
additional_strata_expr = {}

_freq_meta_expr = hl.struct(**additional_strata_expr)
if additional_strata_grouping_expr is None:
additional_strata_grouping_expr = {}
else:
_freq_meta_expr = _freq_meta_expr.annotate(**additional_strata_grouping_expr)
if sex_expr is not None:
_freq_meta_expr = _freq_meta_expr.annotate(sex=sex_expr)
if pop_expr is not None:
Expand All @@ -410,7 +424,7 @@ def annotate_freq(
# Annotate cols with provided cuts
mt = mt.annotate_cols(_freq_meta=_freq_meta_expr)

# Get counters for sex, pop and subpop if set
# Get counters for sex, pop and if set subpop and additional strata
cut_dict = {
cut: hl.agg.filter(
hl.is_defined(mt._freq_meta[cut]), hl.agg.counter(mt._freq_meta[cut])
Expand Down Expand Up @@ -509,6 +523,22 @@ def annotate_freq(
+ sample_group_filters
)

# Add additional groupings to strata, e.g. strata-pop, strata-sex, strata-pop-sex
if additional_strata_grouping_expr is not None:
sample_group_filters.extend(
[
(
{strata: str(s_value), add_strata: str(as_value)},
(mt._freq_meta[strata] == s_value)
& (mt._freq_meta[add_strata] == as_value),
)
for strata in additional_strata_expr
for s_value in cut_data.get(strata, {})
for add_strata in additional_strata_grouping_expr
for as_value in cut_data.get(add_strata, {})
]
)

freq_sample_count = mt.aggregate_cols(
[hl.agg.count_where(x[1]) for x in sample_group_filters]
)
Expand Down