diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 0f0ea87a1..fa18b2f7f 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -1,11 +1,13 @@ # noqa: D100 import logging -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import hail as hl from gnomad.utils.annotations import ( + agg_by_strata, + annotate_adj, fs_from_sb, generate_freq_group_membership_array, get_adj_expr, @@ -913,6 +915,8 @@ def compute_coverage_stats( coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100], row_key_fields: List[str] = ["locus"], strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None, + group_membership_ht: Optional[hl.Table] = None, + include_an: bool = False, ) -> hl.Table: """ Compute coverage statistics for every base of the `reference_ht` provided. @@ -934,8 +938,13 @@ def compute_coverage_stats( :param row_key_fields: List of row key fields to use for joining `mtds` with `reference_ht` :param strata_expr: Optional list of dicts containing expressions to stratify the - coverage stats by. - :return: Table with per-base coverage stats + coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be + specified. + :param group_membership_ht: Optional Table containing group membership annotations + to stratify the coverage stats by. Only one of `group_membership_ht` or + `strata_expr` can be specified. + :param include_an: Whether to also compute AN. Default is False. + :return: Table with per-base coverage stats. """ is_vds = isinstance(mtds, hl.vds.VariantDataset) if is_vds: @@ -943,34 +952,109 @@ def compute_coverage_stats( else: mt = mtds - if strata_expr is None: - strata_expr = {} - no_strata = True - else: - no_strata = False - - # Annotate the MT cols with each of the expressions in strata_expr and redefine - # strata_expr based on the column HT with added annotations. - ht = mt.annotate_cols(**{k: v for d in strata_expr for k, v in d.items()}).cols() - strata_expr = [{k: ht[k] for k in d} for d in strata_expr] - - # Use the function for creating the frequency stratified by `freq_meta`, - # `freq_meta_sample_count`, and `group_membership` annotations to give - # stratification group membership info for computing coverage. By default, this - # function returns annotations where the second element is a placeholder for the - # "raw" frequency of all samples, where the first 2 elements are the same sample - # set, but freq_meta startswith [{"group": "adj", "group": "raw", ...]. Use - # `no_raw_group` to exclude the "raw" group so there is a single annotation - # representing the full samples set. `freq_meta` is updated below to remove "group" - # from all dicts. - group_membership_ht = generate_freq_group_membership_array( - ht, strata_expr, no_raw_group=True + # Determine the genotype field. + gt_field = set(mt.entry) & {"GT", "LGT"} + if len(gt_field) == 0: + raise ValueError("No genotype field found in entry fields.") + + gt_field = gt_field.pop() + + # Add function to compute coverage stats. + cov_bins = sorted(coverage_over_x_bins) + rev_cov_bins = list(reversed(cov_bins)) + max_cov_bin = cov_bins[-1] + cov_bins = hl.array(cov_bins) + entry_agg_funcs = { + "coverage_stats": ( + lambda t: hl.if_else(hl.is_missing(t.DP) | hl.is_nan(t.DP), 0, t.DP), + lambda dp: hl.struct( + # This expression creates a counter DP -> number of samples for DP + # between 0 and max_cov_bin. + coverage_counter=hl.agg.counter(hl.min(max_cov_bin, dp)), + mean=hl.agg.mean(dp), + median_approx=hl.agg.approx_median(dp), + total_DP=hl.agg.sum(dp), + ), + ) + } + + if include_an: + entry_agg_funcs["AN"] = get_allele_number_agg_func(gt_field) + + ht = compute_stats_per_ref_site( + mtds, + reference_ht, + entry_agg_funcs, + row_key_fields=row_key_fields, + interval_ht=interval_ht, + entry_keep_fields=[gt_field, "DP"], + strata_expr=strata_expr, + group_membership_ht=group_membership_ht, ) - n_samples = group_membership_ht.count() - sample_counts = group_membership_ht.index_globals().freq_meta_sample_count - logger.info("Computing coverage stats on %d samples.", n_samples) - # Filter datasets to interval list + # This expression aggregates the DP counter in reverse order of the cov_bins and + # computes the cumulative sum over them. It needs to be in reverse order because we + # want the sum over samples covered by > X. + def _cov_stats( + cov_stat: hl.expr.StructExpression, n: hl.expr.Int32Expression + ) -> hl.expr.StructExpression: + # The coverage was already floored to the max_coverage_bin, so no more + # aggregation is needed for the max bin. + count_expr = cov_stat.coverage_counter + max_bin_expr = hl.int32(count_expr.get(max_cov_bin, 0)) + + # For each of the other bins, coverage is summed between the boundaries. + bin_expr = hl.range(hl.len(cov_bins) - 1, 0, step=-1) + bin_expr = bin_expr.map( + lambda i: hl.sum( + hl.range(cov_bins[i - 1], cov_bins[i]).map( + lambda j: hl.int32(count_expr.get(j, 0)) + ) + ) + ) + bin_expr = hl.cumulative_sum(hl.array([max_bin_expr]).extend(bin_expr)) + + # Use reversed bins as count_array_expr has reverse order. + bin_expr = {f"over_{x}": bin_expr[i] / n for i, x in enumerate(rev_cov_bins)} + + return cov_stat.annotate(**bin_expr).drop("coverage_counter") + + ht_globals = ht.index_globals() + if isinstance(ht.coverage_stats, hl.expr.ArrayExpression): + cov_stats_expr = hl.map( + lambda c, n: _cov_stats(c, n), + ht.coverage_stats, + ht_globals.strata_sample_count, + ) + else: + cov_stats_expr = _cov_stats(ht.coverage_stats, ht_globals.sample_count) + + ht = ht.annotate(coverage_stats=cov_stats_expr) + + return ht + + +def densify_all_reference_sites( + mtds: Union[hl.MatrixTable, hl.vds.VariantDataset], + reference_ht: hl.Table, + interval_ht: Optional[hl.Table] = None, + row_key_fields: Union[Tuple[str], List[str], Set[str]] = ("locus",), + entry_keep_fields: Union[Tuple[str], List[str], Set[str]] = ("GT",), +) -> hl.MatrixTable: + """ + Densify a VariantDataset or Sparse MatrixTable at all sites in a reference Table. + + :param mtds: Input sparse Matrix Table or VariantDataset. + :param reference_ht: Table of reference sites. + :param interval_ht: Optional Table of intervals to filter to. + :param row_key_fields: Fields to use as row key. Defaults to locus. + :param entry_keep_fields: Fields to keep in entries before performing the + densification. Defaults to GT. + :return: Densified MatrixTable. + """ + is_vds = isinstance(mtds, hl.vds.VariantDataset) + + # Filter datasets to interval list. if interval_ht is not None: reference_ht = reference_ht.filter( hl.is_defined(interval_ht[reference_ht.locus]) @@ -986,169 +1070,249 @@ def compute_coverage_stats( " not supported." ) - # Create an outer join with the reference Table - def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: - """ - Outer join MatrixTable with reference Table. + entry_keep_fields = set(entry_keep_fields) + if is_vds: + mt = mtds.variant_data + else: + mt = mtds + entry_keep_fields.add("END") - Add 'in_ref' annotation indicating whether a given position is found in the reference Table. + # Get the total number of samples. + n_samples = mt.count_cols() - :param mt: Input MatrixTable. - :return: MatrixTable with 'in_ref' annotation added. - """ - keep_entries = ["DP"] - if "END" in mt.entry: - keep_entries.append("END") - if "LGT" in mt.entry: - keep_entries.append("LGT") - if "GT" in mt.entry: - keep_entries.append("GT") - mt_col_key_fields = list(mt.col_key) - mt_row_key_fields = list(mt.row_key) - t = mt.select_entries(*keep_entries).select_cols().select_rows() - t = t._localize_entries("__entries", "__cols") - t = ( - t.key_by(*row_key_fields) - .join( - reference_ht.key_by(*row_key_fields).select(_in_ref=True), - how="outer", - ) - .key_by(*mt_row_key_fields) - ) - t = t.annotate( - __entries=hl.or_else( - t.__entries, - hl.range(n_samples).map( - lambda x: hl.missing(t.__entries.dtype.element_type) - ), - ) + mt_col_key_fields = list(mt.col_key) + mt_row_key_fields = list(mt.row_key) + ht = mt.select_entries(*entry_keep_fields).select_cols().select_rows() + + # Localize entries and perform an outer join with the reference HT. + ht = ht._localize_entries("__entries", "__cols") + ht = ht.key_by(*row_key_fields) + ht = ht.join(reference_ht.key_by(*row_key_fields).select(_in_ref=True), how="outer") + ht = ht.key_by(*mt_row_key_fields) + + # Fill in missing entries with missing values for each entry field. + ht = ht.annotate( + __entries=hl.or_else( + ht.__entries, + hl.range(n_samples).map( + lambda x: hl.missing(ht.__entries.dtype.element_type) + ), ) + ) - return t._unlocalize_entries("__entries", "__cols", mt_col_key_fields) + # Unlocalize entries to turn the HT back to a MT. + mt = ht._unlocalize_entries("__entries", "__cols", mt_col_key_fields) + # Densify VDS/sparse MT at all sites. if is_vds: - mtds = hl.vds.VariantDataset( - mtds.reference_data.select_entries("END", "DP").select_cols().select_rows(), - join_with_ref(mtds.variant_data), + mt = hl.vds.to_dense_mt( + hl.vds.VariantDataset(mtds.reference_data.select_cols().select_rows(), mt) ) - - # Densify - mt = hl.vds.to_dense_mt(mtds) else: - mtds = join_with_ref(mtds) - # Densify - mt = hl.experimental.densify(mtds) + mt = hl.experimental.densify(mt) - # Filter rows where the reference is missing + # Filter rows where the reference is missing. mt = mt.filter_rows(mt._in_ref) - # Unfilter entries so that entries with no ref block overlap aren't null + # Unfilter entries so that entries with no ref block overlap aren't null. mt = mt.unfilter_entries() - # Annotate with group membership - mt = mt.annotate_cols( - group_membership=group_membership_ht[mt.col_key].group_membership - ) + return mt - # Compute coverage stats - coverage_over_x_bins = sorted(coverage_over_x_bins) - max_coverage_bin = coverage_over_x_bins[-1] - hl_coverage_over_x_bins = hl.array(coverage_over_x_bins) - # This expression creates a counter DP -> number of samples for DP between - # 0 and max_coverage_bin - coverage_counter_expr = hl.agg.counter( - hl.min(max_coverage_bin, hl.or_else(mt.DP, 0)) - ) - mean_expr = hl.agg.mean(hl.or_else(mt.DP, 0)) - - # Annotate all rows with coverage stats for each strata group. - ht = mt.select_rows( - coverage_stats=hl.agg.array_agg( - lambda x: hl.agg.filter( - x, - hl.struct( - coverage_counter=coverage_counter_expr, - mean=hl.if_else(hl.is_nan(mean_expr), 0, mean_expr), - median_approx=hl.or_else( - hl.agg.approx_median(hl.or_else(mt.DP, 0)), 0 - ), - total_DP=hl.agg.sum(mt.DP), - ), - ), - mt.group_membership, +def compute_stats_per_ref_site( + mtds: Union[hl.MatrixTable, hl.vds.VariantDataset], + reference_ht: hl.Table, + entry_agg_funcs: Dict[str, Tuple[Callable, Callable]], + row_key_fields: Union[Tuple[str], List[str]] = ("locus",), + interval_ht: Optional[hl.Table] = None, + entry_keep_fields: Union[Tuple[str], List[str], Set[str]] = None, + strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None, + group_membership_ht: Optional[hl.Table] = None, +) -> hl.Table: + """ + Compute stats per site in a reference Table. + + :param mtds: Input sparse Matrix Table or VariantDataset. + :param reference_ht: Table of reference sites. + :param entry_agg_funcs: Dict of entry aggregation functions to perform on the + VariantDataset/MatrixTable. The keys of the dict are the names of the + annotations and the values are tuples of functions. The first function is used + to transform the `mt` entries in some way, and the second function is used to + aggregate the output from the first function. + :param row_key_fields: Fields to use as row key. Defaults to locus. + :param interval_ht: Optional table of intervals to filter to. + :param entry_keep_fields: Fields to keep in entries before performing the + densification in `densify_all_reference_sites`. Should include any fields + needed for the functions in `entry_agg_funcs`. By default, only GT or LGT is + kept. + :param strata_expr: Optional list of dicts of expressions to stratify by. + :param group_membership_ht: Optional Table of group membership annotations. + :return: Table of stats per site. + """ + is_vds = isinstance(mtds, hl.vds.VariantDataset) + if is_vds: + mt = mtds.variant_data + else: + mt = mtds + + if entry_keep_fields is None: + entry_keep_fields = [] + + entry_keep_fields = set(entry_keep_fields) + + # Determine the genotype field. + gt_field = set(mt.entry) & {"GT", "LGT"} + if len(gt_field) == 0: + raise ValueError("No genotype field found in entry fields.") + + gt_field = gt_field.pop() + entry_keep_fields.add(gt_field) + + no_strata = False + add_adj = False + if group_membership_ht is not None: + if strata_expr is not None: + raise ValueError( + "Only one of 'group_membership_ht' or 'strata_expr' can be specified." + ) + + # Identify if adj annotation is needed. + group_globals = group_membership_ht.index_globals() + if "adj_group" in group_globals or "freq_meta" in group_globals: + if "adj" in mt.entry: + entry_keep_fields.add("adj") + else: + add_adj = True + ad_field = set(mt.entry) & {"AD", "LAD"} + if len(ad_field) == 0: + raise ValueError("No AD or LAD field found in entry fields!") + + entry_keep_fields |= {"DP", "GQ", ad_field.pop()} + else: + logger.warning( + "'group_membership_ht' is not specified, no stats are adj filtered." ) - ).rows() - ht = ht.checkpoint(hl.utils.new_temp_file("coverage_stats", "ht")) - - # This expression aggregates the DP counter in reverse order of the - # coverage_over_x_bins and computes the cumulative sum over them. - # It needs to be in reverse order because we want the sum over samples - # covered by > X. - count_array_expr = ht.coverage_stats.map( - lambda x: hl.cumulative_sum( - hl.array( - # The coverage was already floored to the max_coverage_bin, so no more - # aggregation is needed for the max bin. - [hl.int32(x.coverage_counter.get(max_coverage_bin, 0))] - # For each of the other bins, coverage needs to be summed between the - # boundaries. - ).extend( - hl.range(hl.len(hl_coverage_over_x_bins) - 1, 0, step=-1).map( - lambda i: hl.sum( - hl.range( - hl_coverage_over_x_bins[i - 1], hl_coverage_over_x_bins[i] - ).map(lambda j: hl.int32(x.coverage_counter.get(j, 0))) + + if strata_expr is None: + strata_expr = {} + no_strata = True + + # Annotate the MT cols with each of the expressions in strata_expr and redefine + # strata_expr based on the column HT with added annotations. + ht = mt.annotate_cols( + **{k: v for d in strata_expr for k, v in d.items()} + ).cols() + strata_expr = [{k: ht[k] for k in d} for d in strata_expr] + + # Use 'generate_freq_group_membership_array' to create a group_membership Table + # that gives stratification group membership info based on 'strata_expr'. The + # returned Table has the following annotations: 'freq_meta', + # 'freq_meta_sample_count', and 'group_membership'. By default, this + # function returns annotations where the second element is a placeholder for the + # "raw" frequency of all samples, where the first 2 elements are the same sample + # set, but 'freq_meta' startswith [{"group": "adj", "group": "raw", ...]. Use + # `no_raw_group` to exclude the "raw" group so there is a single annotation + # representing the full samples set and update 'freq_meta' "group" to all "raw". + group_membership_ht = generate_freq_group_membership_array( + ht, strata_expr, no_raw_group=True + ) + group_membership_ht = group_membership_ht.annotate( + freq_meta=group_membership_ht.freq_meta.map( + lambda x: hl.dict( + x.items().map( + lambda m: hl.if_else(m[0] == "group", ("group", "raw"), m) ) ) ) ) - ) - ht = ht.annotate( - coverage_stats=hl.map( - lambda c, g, n: c.annotate( - **{ - f"over_{x}": g[i] / n - for i, x in zip( - range(len(coverage_over_x_bins) - 1, -1, -1), - # Reverse the bin index as count_array_expr has reverse order. - coverage_over_x_bins, - ) - } - ).drop("coverage_counter"), - ht.coverage_stats, - count_array_expr, - sample_counts, + if is_vds: + rmt = mtds.reference_data + mtds = hl.vds.VariantDataset( + rmt.select_entries(*((set(entry_keep_fields) & set(rmt.entry)) | {"END"})), + mtds.variant_data, ) + + mt = densify_all_reference_sites( + mtds, + reference_ht, + interval_ht, + row_key_fields, + entry_keep_fields=entry_keep_fields, ) + + # Annotate with adj if needed. + if add_adj: + mt = annotate_adj(mt) + + ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht) + ht = ht.checkpoint(hl.utils.new_temp_file("agg_stats", "ht")) + current_keys = list(ht.key) - ht = ( - ht.key_by(*row_key_fields) - .select_globals() - .drop(*[k for k in current_keys if k not in row_key_fields]) - ) + ht = ht.key_by(*row_key_fields).select_globals() + ht = ht.drop(*[k for k in current_keys if k not in row_key_fields]) + + group_globals = group_membership_ht.index_globals() + global_expr = {} if no_strata: # If there was no stratification, move coverage_stats annotations to the top # level. - ht = ht.select(**{k: ht.coverage_stats[0][k] for k in ht.coverage_stats[0]}) + ht = ht.select(**{ann: ht[ann][0] for ann in entry_agg_funcs}) + global_expr["sample_count"] = group_globals.freq_meta_sample_count[0] else: # If there was stratification, add the metadata and sample count info for the # stratification to the globals. - ht = ht.annotate_globals( - coverage_stats_meta=( - group_membership_ht.index_globals().freq_meta.map( - lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group")) - ) - ), - coverage_stats_meta_sample_count=( - group_membership_ht.index_globals().freq_meta_sample_count - ), - ) + global_expr["strata_meta"] = group_globals.freq_meta + global_expr["strata_sample_count"] = group_globals.freq_meta_sample_count + + ht = ht.annotate_globals(**global_expr) return ht +def get_allele_number_agg_func(gt_field: str = "GT") -> Tuple[Callable, Callable]: + """ + Get a transformation and aggregation function for computing the allele number. + + Can be used as an entry aggregation function in `compute_stats_per_ref_site`. + + :param gt_field: Genotype field to use for computing the allele number. + :return: Tuple of functions to transform and aggregate the allele number. + """ + return lambda t: t[gt_field].ploidy, hl.agg.sum + + +def compute_allele_number_per_ref_site( + mtds: Union[hl.MatrixTable, hl.vds.VariantDataset], + reference_ht: hl.Table, + **kwargs, +) -> hl.Table: + """ + Compute the allele number per reference site. + + :param mtds: Input sparse Matrix Table or VariantDataset. + :param reference_ht: Table of reference sites. + :param kwargs: Keyword arguments to pass to `compute_stats_per_ref_site`. + :return: Table of allele number per reference site. + """ + # Determine the genotype field. + if isinstance(mtds, hl.vds.VariantDataset): + mt = mtds.variant_data + else: + mt = mtds + + gt_field = set(mt.entry) & {"GT", "LGT"} + if len(gt_field) == 0: + raise ValueError("No genotype field found in entry fields.") + + # Use ploidy to determine the number of alleles for each sample at each site. + gt_field = gt_field.pop() + entry_agg_funcs = {"AN": get_allele_number_agg_func(gt_field)} + + return compute_stats_per_ref_site(mtds, reference_ht, entry_agg_funcs, **kwargs) + + def filter_ref_blocks( t: Union[hl.MatrixTable, hl.Table] ) -> Union[hl.MatrixTable, hl.Table]: