diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index cc0c62306..416017fc0 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -334,6 +334,9 @@ def annotate_freq( pop_expr: Optional[hl.expr.StringExpression] = None, subpop_expr: Optional[hl.expr.StringExpression] = None, additional_strata_expr: Optional[Dict[str, hl.expr.StringExpression]] = None, + additional_strata_grouping_expr: Optional[ + Dict[str, hl.expr.StringExpression] + ] = None, downsamplings: Optional[List[int]] = None, ) -> hl.MatrixTable: """ @@ -388,6 +391,7 @@ def annotate_freq( :param pop_expr: When specified, frequencies are stratified by population. If `sex_expr` is also specified, then a pop/sex stratifiction is added. :param subpop_expr: When specified, frequencies are stratified by sub-continental population. Note that `pop_expr` is required as well when using this option. :param additional_strata_expr: When specified, frequencies are stratified by the given additional strata found in the dict. This can e.g. be used to stratify by platform. + :param additional_strata_grouping_expr: When specified, frequencies are further stratified by groups within the additional_strata_expr. This can e.g. be used to stratify by platform-population. :param downsamplings: When specified, frequencies are computed by downsampling the data to the number of samples given in the list. Note that if `pop_expr` is specified, downsamplings by population is also computed. :return: MatrixTable with `freq` annotation """ @@ -396,10 +400,20 @@ def annotate_freq( "annotate_freq requires pop_expr when using subpop_expr" ) + if additional_strata_grouping_expr is not None and additional_strata_expr is None: + raise NotImplementedError( + "annotate_freq requires additional_strata_expr when using" + " additional_strata_grouping_expr" + ) + if additional_strata_expr is None: additional_strata_expr = {} _freq_meta_expr = hl.struct(**additional_strata_expr) + if additional_strata_grouping_expr is None: + additional_strata_grouping_expr = {} + else: + _freq_meta_expr = _freq_meta_expr.annotate(**additional_strata_grouping_expr) if sex_expr is not None: _freq_meta_expr = _freq_meta_expr.annotate(sex=sex_expr) if pop_expr is not None: @@ -410,7 +424,7 @@ def annotate_freq( # Annotate cols with provided cuts mt = mt.annotate_cols(_freq_meta=_freq_meta_expr) - # Get counters for sex, pop and subpop if set + # Get counters for sex, pop and if set subpop and additional strata cut_dict = { cut: hl.agg.filter( hl.is_defined(mt._freq_meta[cut]), hl.agg.counter(mt._freq_meta[cut]) @@ -509,6 +523,22 @@ def annotate_freq( + sample_group_filters ) + # Add additional groupings to strata, e.g. strata-pop, strata-sex, strata-pop-sex + if additional_strata_grouping_expr is not None: + sample_group_filters.extend( + [ + ( + {strata: str(s_value), add_strata: str(as_value)}, + (mt._freq_meta[strata] == s_value) + & (mt._freq_meta[add_strata] == as_value), + ) + for strata in additional_strata_expr + for s_value in cut_data.get(strata, {}) + for add_strata in additional_strata_grouping_expr + for as_value in cut_data.get(add_strata, {}) + ] + ) + freq_sample_count = mt.aggregate_cols( [hl.agg.count_where(x[1]) for x in sample_group_filters] )