diff --git a/gnomad/utils/constraint.py b/gnomad/utils/constraint.py index 178a3afdf..77039540e 100644 --- a/gnomad/utils/constraint.py +++ b/gnomad/utils/constraint.py @@ -55,6 +55,7 @@ def count_variants_by_group( freq_meta_expr: Optional[hl.expr.ArrayExpression] = None, count_singletons: bool = False, count_downsamplings: Tuple[str] = (), + downsamplings: Optional[List[int]] = None, additional_grouping: Tuple[str] = (), partition_hint: int = 100, omit_methylation: bool = False, @@ -129,6 +130,8 @@ def count_variants_by_group( Default is False. :param count_downsamplings: Tuple of populations to use for downsampling counts. Default is (). + :param downsamplings: Optional List of integers specifying what downsampling + indices to obtain. Default is None, which will return all downsampling counts. :param additional_grouping: Additional features to group by. e.g. 'exome_coverage'. Default is (). :param partition_hint: Target number of partitions for aggregation. Default is 100. @@ -204,7 +207,11 @@ def count_variants_by_group( pop, ) agg[f"downsampling_counts_{pop}"] = downsampling_counts_expr( - freq_expr, freq_meta_expr, pop, max_af=max_af + freq_expr, + freq_meta_expr, + pop, + max_af=max_af, + downsamplings=downsamplings, ) if count_singletons: logger.info( @@ -214,7 +221,12 @@ def count_variants_by_group( pop, ) agg[f"singleton_downsampling_counts_{pop}"] = downsampling_counts_expr( - freq_expr, freq_meta_expr, pop, singleton=True + freq_expr, + freq_meta_expr, + pop, + max_af=max_af, + downsamplings=downsamplings, + singleton=True, ) # Apply each variant count aggregation in `agg` to get counts for all # combinations of `grouping`. @@ -230,23 +242,55 @@ def get_downsampling_freq_indices( freq_meta_expr: hl.expr.ArrayExpression, pop: str = "global", variant_quality: str = "adj", + genetic_ancestry_label: Optional[str] = None, + subset: Optional[str] = None, + downsamplings: Optional[List[int]] = None, ) -> hl.expr.ArrayExpression: """ - Get indices of dictionaries in meta dictionaries that only have the "downsampling" key with specified "pop" and "variant_quality" values. + Get indices of dictionaries in meta dictionaries that only have the "downsampling" key with specified `genetic_ancestry_label` and "variant_quality" values. :param freq_meta_expr: ArrayExpression containing the set of groupings for each element of the `freq_expr` array (e.g., [{'group': 'adj'}, {'group': 'adj', 'pop': 'nfe'}, {'downsampling': '5000', 'group': 'adj', 'pop': 'global'}]). - :param pop: Population to use for filtering by the 'pop' key in `freq_meta_expr`. - Default is 'global'. + :param pop: Population to use for filtering by the `genetic_ancestry_label` key in + `freq_meta_expr`. Default is 'global'. :param variant_quality: Variant quality to use for filtering by the 'group' key in `freq_meta_expr`. Default is 'adj'. + :param genetic_ancestry_label: Label defining the genetic ancestry groups. If None, + "gen_anc" or "pop" is used (in that order of preference) if present. Default is + None. + :param subset: Subset to use for filtering by the 'subset' key in `freq_meta_expr`. + Default is None, which will return all downsampling indices without a 'subset' + key in `freq_meta_expr`. + :param downsamplings: Optional List of integers specifying what downsampling + indices to obtain. Default is None, which will return all downsampling indices. + :return: ArrayExpression of indices of dictionaries in `freq_meta_expr` that only + have the "downsampling" key with specified `genetic_ancestry_label` and + "variant_quality" values. """ - indices = hl.enumerate(freq_meta_expr).filter( - lambda f: (f[1].get("group") == variant_quality) - & (f[1].get("pop") == pop) - & f[1].contains("downsampling") - ) + if genetic_ancestry_label is None: + gen_anc = ["gen_anc", "pop"] + else: + gen_anc = [genetic_ancestry_label] + + def _get_filter_expr(m: hl.expr.StructExpression) -> hl.expr.BooleanExpression: + filter_expr = ( + (m.get("group") == variant_quality) + & (hl.any([m.get(l, "") == pop for l in gen_anc])) + & m.contains("downsampling") + ) + if downsamplings is not None: + filter_expr &= hl.literal(downsamplings).contains( + hl.int(m.get("downsampling", "0")) + ) + if subset is None: + filter_expr &= ~m.contains("subset") + else: + filter_expr &= m.get("subset", "") == subset + return filter_expr + + indices = hl.enumerate(freq_meta_expr).filter(lambda f: _get_filter_expr(f[1])) + # Get an array of indices and meta dictionaries sorted by "downsampling" key. return hl.sorted(indices, key=lambda f: hl.int(f[1]["downsampling"])) @@ -258,33 +302,50 @@ def downsampling_counts_expr( variant_quality: str = "adj", singleton: bool = False, max_af: Optional[float] = None, + genetic_ancestry_label: Optional[str] = None, + subset: Optional[str] = None, + downsamplings: Optional[List[int]] = None, ) -> hl.expr.ArrayExpression: """ Return an aggregation expression to compute an array of counts of all downsamplings found in `freq_expr` where specified criteria is met. The frequency metadata (`freq_meta_expr`) should be in a similar format to the `freq_meta` annotation added by `annotate_freq()`. Each downsampling should have - 'group', 'pop', and 'downsampling' keys. Included downsamplings are those where - 'group' == `variant_quality` and 'pop' == `pop`. + 'group', `genetic_ancestry_label`, and 'downsampling' keys. Included downsamplings + are those where 'group' == `variant_quality` and `genetic_ancestry_label` == `pop`. :param freq_expr: ArrayExpression of Structs with 'AC' and 'AF' annotations. :param freq_meta_expr: ArrayExpression containing the set of groupings for each element of the `freq_expr` array (e.g., [{'group': 'adj'}, {'group': 'adj', 'pop': 'nfe'}, {'downsampling': '5000', 'group': 'adj', 'pop': 'global'}]). - :param pop: Population to use for filtering by the 'pop' key in `freq_meta_expr`. - Default is 'global'. + :param pop: Population to use for filtering by the `genetic_ancestry_label` key in + `freq_meta_expr`. Default is 'global'. :param variant_quality: Variant quality to use for filtering by the 'group' key in `freq_meta_expr`. Default is 'adj'. :param singleton: Whether to filter to only singletons before counting (AC == 1). Default is False. :param max_af: Maximum variant allele frequency to keep. By default no allele frequency cutoff is applied. + :param genetic_ancestry_label: Label defining the genetic ancestry groups. If None, + "gen_anc" or "pop" is used (in that order of preference) if present. Default is + None. + :param subset: Subset to use for filtering by the 'subset' key in `freq_meta_expr`. + Default is None, which will return all downsampling counts without a 'subset' + key in `freq_meta_expr`. If specified, only downsamplings with the specified + subset will be included. + :param downsamplings: Optional List of integers specifying what downsampling + indices to obtain. Default is None, which will return all downsampling counts. :return: Aggregation Expression for an array of the variant counts in downsamplings for specified population. """ # Get an array of indices sorted by "downsampling" key. sorted_indices = get_downsampling_freq_indices( - freq_meta_expr, pop, variant_quality + freq_meta_expr, + pop, + variant_quality, + genetic_ancestry_label, + subset, + downsamplings, ).map(lambda x: x[0]) def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression: