From a2f6ff3bb46794684e43542749710b1e485316c9 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:19:52 -0700 Subject: [PATCH 1/5] Cleanup coverage function and use `agg_by_strata` --- gnomad/utils/sparse_mt.py | 232 +++++++++++++++++++------------------- 1 file changed, 118 insertions(+), 114 deletions(-) diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 0f0ea87a1..9930063a0 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -6,6 +6,7 @@ import hail as hl from gnomad.utils.annotations import ( + agg_by_strata, fs_from_sb, generate_freq_group_membership_array, get_adj_expr, @@ -913,6 +914,7 @@ def compute_coverage_stats( coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100], row_key_fields: List[str] = ["locus"], strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None, + group_membership_ht: Optional[hl.Table] = None, ) -> hl.Table: """ Compute coverage statistics for every base of the `reference_ht` provided. @@ -934,8 +936,12 @@ def compute_coverage_stats( :param row_key_fields: List of row key fields to use for joining `mtds` with `reference_ht` :param strata_expr: Optional list of dicts containing expressions to stratify the - coverage stats by. - :return: Table with per-base coverage stats + coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be + specified. + :param group_membership_ht: Optional Table containing group membership annotations + to stratify the coverage stats by. Only one of `group_membership_ht` or + `strata_expr` can be specified. + :return: Table with per-base coverage stats. """ is_vds = isinstance(mtds, hl.vds.VariantDataset) if is_vds: @@ -943,29 +949,44 @@ def compute_coverage_stats( else: mt = mtds - if strata_expr is None: - strata_expr = {} - no_strata = True + # Determine the genotype field. + gt_field = set(mt.entry) & {"GT", "LGT"} + if len(gt_field) == 0: + raise ValueError("No genotype field found in entry fields.") + + gt_field = gt_field.pop() + + no_strata = False + if group_membership_ht is not None: + if strata_expr is not None: + raise ValueError( + "Only one of 'group_membership_ht' or 'strata_expr' can be specified." + ) else: - no_strata = False - - # Annotate the MT cols with each of the expressions in strata_expr and redefine - # strata_expr based on the column HT with added annotations. - ht = mt.annotate_cols(**{k: v for d in strata_expr for k, v in d.items()}).cols() - strata_expr = [{k: ht[k] for k in d} for d in strata_expr] - - # Use the function for creating the frequency stratified by `freq_meta`, - # `freq_meta_sample_count`, and `group_membership` annotations to give - # stratification group membership info for computing coverage. By default, this - # function returns annotations where the second element is a placeholder for the - # "raw" frequency of all samples, where the first 2 elements are the same sample - # set, but freq_meta startswith [{"group": "adj", "group": "raw", ...]. Use - # `no_raw_group` to exclude the "raw" group so there is a single annotation - # representing the full samples set. `freq_meta` is updated below to remove "group" - # from all dicts. - group_membership_ht = generate_freq_group_membership_array( - ht, strata_expr, no_raw_group=True - ) + if strata_expr is None: + strata_expr = {} + no_strata = True + + # Annotate the MT cols with each of the expressions in strata_expr and redefine + # strata_expr based on the column HT with added annotations. + ht = mt.annotate_cols( + **{k: v for d in strata_expr for k, v in d.items()} + ).cols() + strata_expr = [{k: ht[k] for k in d} for d in strata_expr] + + # Use 'generate_freq_group_membership_array' to create a group_membership Table + # that gives stratification group membership info based on 'strata_expr'. The + # returned Table has the following annotations: 'freq_meta', + # 'freq_meta_sample_count', and 'group_membership'. By default, this + # function returns annotations where the second element is a placeholder for the + # "raw" frequency of all samples, where the first 2 elements are the same sample + # set, but 'freq_meta' startswith [{"group": "adj", "group": "raw", ...]. Use + # `no_raw_group` to exclude the "raw" group so there is a single annotation + # representing the full samples set and update 'freq_meta' "group" to all "raw". + group_membership_ht = generate_freq_group_membership_array( + ht, strata_expr, no_raw_group=True + ) + n_samples = group_membership_ht.count() sample_counts = group_membership_ht.index_globals().freq_meta_sample_count @@ -991,30 +1012,29 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: """ Outer join MatrixTable with reference Table. - Add 'in_ref' annotation indicating whether a given position is found in the reference Table. + Add 'in_ref' annotation indicating whether a given position is found in the + reference Table. :param mt: Input MatrixTable. :return: MatrixTable with 'in_ref' annotation added. """ - keep_entries = ["DP"] - if "END" in mt.entry: - keep_entries.append("END") - if "LGT" in mt.entry: - keep_entries.append("LGT") - if "GT" in mt.entry: - keep_entries.append("GT") + # Get the total number of samples. + n_samples = mt.count_cols() + + entry_keep_fields = set(mt.entry) & {gt_field, "DP", "END"} mt_col_key_fields = list(mt.col_key) mt_row_key_fields = list(mt.row_key) - t = mt.select_entries(*keep_entries).select_cols().select_rows() + t = mt.select_entries(*entry_keep_fields).select_cols().select_rows() + + # Localize entries and perform an outer join with the reference HT. t = t._localize_entries("__entries", "__cols") - t = ( - t.key_by(*row_key_fields) - .join( - reference_ht.key_by(*row_key_fields).select(_in_ref=True), - how="outer", - ) - .key_by(*mt_row_key_fields) + t = t.key_by(*row_key_fields) + t = t.join( + reference_ht.key_by(*row_key_fields).select(_in_ref=True), how="outer" ) + t = t.key_by(*mt_row_key_fields) + + # Fill in missing entries with missing values for each entry field. t = t.annotate( __entries=hl.or_else( t.__entries, @@ -1024,8 +1044,10 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: ) ) + # Unlocalize entries to turn the HT back to a MT. return t._unlocalize_entries("__entries", "__cols", mt_col_key_fields) + # Densify VDS/sparse MT at all sites. if is_vds: mtds = hl.vds.VariantDataset( mtds.reference_data.select_entries("END", "DP").select_cols().select_rows(), @@ -1039,10 +1061,10 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: # Densify mt = hl.experimental.densify(mtds) - # Filter rows where the reference is missing + # Filter rows where the reference is missing. mt = mt.filter_rows(mt._in_ref) - # Unfilter entries so that entries with no ref block overlap aren't null + # Unfilter entries so that entries with no ref block overlap aren't null. mt = mt.unfilter_entries() # Annotate with group membership @@ -1050,102 +1072,84 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: group_membership=group_membership_ht[mt.col_key].group_membership ) - # Compute coverage stats - coverage_over_x_bins = sorted(coverage_over_x_bins) - max_coverage_bin = coverage_over_x_bins[-1] - hl_coverage_over_x_bins = hl.array(coverage_over_x_bins) - - # This expression creates a counter DP -> number of samples for DP between - # 0 and max_coverage_bin - coverage_counter_expr = hl.agg.counter( - hl.min(max_coverage_bin, hl.or_else(mt.DP, 0)) - ) - mean_expr = hl.agg.mean(hl.or_else(mt.DP, 0)) - - # Annotate all rows with coverage stats for each strata group. - ht = mt.select_rows( - coverage_stats=hl.agg.array_agg( - lambda x: hl.agg.filter( - x, - hl.struct( - coverage_counter=coverage_counter_expr, - mean=hl.if_else(hl.is_nan(mean_expr), 0, mean_expr), - median_approx=hl.or_else( - hl.agg.approx_median(hl.or_else(mt.DP, 0)), 0 - ), - total_DP=hl.agg.sum(mt.DP), - ), + # Compute coverage stats. + cov_bins = sorted(coverage_over_x_bins) + rev_cov_bins = list(reversed(cov_bins)) + max_cov_bin = cov_bins[-1] + cov_bins = hl.array(cov_bins) + entry_agg_funcs = { + "coverage_stats": ( + lambda t: hl.if_else(hl.is_missing(t.DP) | hl.is_nan(t.DP), 0, t.DP), + lambda dp: hl.struct( + # This expression creates a counter DP -> number of samples for DP + # between 0 and max_cov_bin. + coverage_counter=hl.agg.counter(hl.min(max_cov_bin, dp)), + mean=hl.if_else(hl.is_nan(hl.agg.mean(dp)), 0, hl.agg.mean(dp)), + median_approx=hl.or_else(hl.agg.approx_median(dp), 0), + total_DP=hl.agg.sum(dp), ), - mt.group_membership, ) - ).rows() + } + ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht) ht = ht.checkpoint(hl.utils.new_temp_file("coverage_stats", "ht")) - # This expression aggregates the DP counter in reverse order of the - # coverage_over_x_bins and computes the cumulative sum over them. - # It needs to be in reverse order because we want the sum over samples - # covered by > X. - count_array_expr = ht.coverage_stats.map( - lambda x: hl.cumulative_sum( - hl.array( - # The coverage was already floored to the max_coverage_bin, so no more - # aggregation is needed for the max bin. - [hl.int32(x.coverage_counter.get(max_coverage_bin, 0))] - # For each of the other bins, coverage needs to be summed between the - # boundaries. - ).extend( - hl.range(hl.len(hl_coverage_over_x_bins) - 1, 0, step=-1).map( - lambda i: hl.sum( - hl.range( - hl_coverage_over_x_bins[i - 1], hl_coverage_over_x_bins[i] - ).map(lambda j: hl.int32(x.coverage_counter.get(j, 0))) - ) + # This expression aggregates the DP counter in reverse order of the cov_bins and + # computes the cumulative sum over them. It needs to be in reverse order because we + # want the sum over samples covered by > X. + def _cov_stats( + cov_stat: hl.expr.StructExpression, n: hl.expr.Int32Expression + ) -> hl.expr.StructExpression: + # The coverage was already floored to the max_coverage_bin, so no more + # aggregation is needed for the max bin. + count_expr = cov_stat.coverage_counter + max_bin_expr = hl.int32(count_expr.get(max_cov_bin, 0)) + + # For each of the other bins, coverage is summed between the boundaries. + bin_expr = hl.range(hl.len(cov_bins) - 1, 0, step=-1) + bin_expr = bin_expr.map( + lambda i: hl.sum( + hl.range(cov_bins[i - 1], cov_bins[i]).map( + lambda j: hl.int32(count_expr.get(j, 0)) ) ) ) - ) + bin_expr = hl.cumulative_sum(hl.array([max_bin_expr]).extend(bin_expr)) + + # Use reversed bins as count_array_expr has reverse order. + bin_expr = {f"over_{x}": bin_expr[i] / n for i, x in enumerate(rev_cov_bins)} + + return cov_stat.annotate(**bin_expr).drop("coverage_counter") ht = ht.annotate( coverage_stats=hl.map( - lambda c, g, n: c.annotate( - **{ - f"over_{x}": g[i] / n - for i, x in zip( - range(len(coverage_over_x_bins) - 1, -1, -1), - # Reverse the bin index as count_array_expr has reverse order. - coverage_over_x_bins, - ) - } - ).drop("coverage_counter"), + lambda c, n: _cov_stats(c, n), ht.coverage_stats, - count_array_expr, sample_counts, ) ) current_keys = list(ht.key) - ht = ( - ht.key_by(*row_key_fields) - .select_globals() - .drop(*[k for k in current_keys if k not in row_key_fields]) - ) + ht = ht.key_by(*row_key_fields).select_globals() + ht = ht.drop(*[k for k in current_keys if k not in row_key_fields]) + + group_globals = group_membership_ht.index_globals() + global_expr = {} if no_strata: # If there was no stratification, move coverage_stats annotations to the top # level. ht = ht.select(**{k: ht.coverage_stats[0][k] for k in ht.coverage_stats[0]}) + global_expr["sample_count"] = group_globals.freq_meta_sample_count[0] else: # If there was stratification, add the metadata and sample count info for the # stratification to the globals. - ht = ht.annotate_globals( - coverage_stats_meta=( - group_membership_ht.index_globals().freq_meta.map( - lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group")) - ) - ), - coverage_stats_meta_sample_count=( - group_membership_ht.index_globals().freq_meta_sample_count - ), + global_expr["coverage_stats_meta"] = group_globals.freq_meta.map( + lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group")) + ) + global_expr["coverage_stats_meta_sample_count"] = ( + group_globals.freq_meta_sample_count ) + ht = ht.annotate_globals(**global_expr) + return ht From 368f57d91ad3d0cd7096dab2aa287a25ec6ad68f Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 10 Jan 2024 13:54:57 -0700 Subject: [PATCH 2/5] Add fixes from testing --- gnomad/utils/sparse_mt.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 9930063a0..9ae5fbb89 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -986,6 +986,15 @@ def compute_coverage_stats( group_membership_ht = generate_freq_group_membership_array( ht, strata_expr, no_raw_group=True ) + group_membership_ht = group_membership_ht.annotate_globals( + freq_meta=group_membership_ht.freq_meta.map( + lambda x: hl.dict( + x.items().map( + lambda m: hl.if_else(m[0] == "group", ("group", "raw"), m) + ) + ) + ) + ) n_samples = group_membership_ht.count() sample_counts = group_membership_ht.index_globals().freq_meta_sample_count From c08cebce18c414c30f1e0e1de596c8acf9db70e8 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Tue, 16 Jan 2024 12:59:31 -0700 Subject: [PATCH 3/5] `group_membership` column annotation is handled in `agg_by_strata` --- gnomad/utils/sparse_mt.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 9ae5fbb89..558c91b85 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -1076,11 +1076,6 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: # Unfilter entries so that entries with no ref block overlap aren't null. mt = mt.unfilter_entries() - # Annotate with group membership - mt = mt.annotate_cols( - group_membership=group_membership_ht[mt.col_key].group_membership - ) - # Compute coverage stats. cov_bins = sorted(coverage_over_x_bins) rev_cov_bins = list(reversed(cov_bins)) From 8b8a37b95b5f450d677d1e7fb1ae70f9d53beae1 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:08:31 -0700 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Mike Wilson --- gnomad/utils/sparse_mt.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 558c91b85..101770d1a 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -956,17 +956,16 @@ def compute_coverage_stats( gt_field = gt_field.pop() - no_strata = False - if group_membership_ht is not None: - if strata_expr is not None: - raise ValueError( - "Only one of 'group_membership_ht' or 'strata_expr' can be specified." - ) - else: - if strata_expr is None: - strata_expr = {} - no_strata = True - + if group_membership_ht is not None and strata_expr is not None: + raise ValueError( + "Only one of 'group_membership_ht' or 'strata_expr' can be specified." + ) + + # Initialize no_strata and default strata_expr if neither group_membership_ht nor strata_expr is provided + no_strata = group_membership_ht is None and strata_expr is None + if no_strata: + strata_expr = {} + if group_membership_ht is None: # Annotate the MT cols with each of the expressions in strata_expr and redefine # strata_expr based on the column HT with added annotations. ht = mt.annotate_cols( @@ -980,9 +979,10 @@ def compute_coverage_stats( # 'freq_meta_sample_count', and 'group_membership'. By default, this # function returns annotations where the second element is a placeholder for the # "raw" frequency of all samples, where the first 2 elements are the same sample - # set, but 'freq_meta' startswith [{"group": "adj", "group": "raw", ...]. Use + # set, but 'freq_meta' starts with [{"group": "adj", "group": "raw", ...]. Use # `no_raw_group` to exclude the "raw" group so there is a single annotation - # representing the full samples set and update 'freq_meta' "group" to all "raw". + # representing the full samples set and update all 'freq_meta' entries' "group" + # to "raw". group_membership_ht = generate_freq_group_membership_array( ht, strata_expr, no_raw_group=True ) @@ -1029,7 +1029,6 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable: """ # Get the total number of samples. n_samples = mt.count_cols() - entry_keep_fields = set(mt.entry) & {gt_field, "DP", "END"} mt_col_key_fields = list(mt.col_key) mt_row_key_fields = list(mt.row_key) @@ -1119,7 +1118,6 @@ def _cov_stats( ) bin_expr = hl.cumulative_sum(hl.array([max_bin_expr]).extend(bin_expr)) - # Use reversed bins as count_array_expr has reverse order. bin_expr = {f"over_{x}": bin_expr[i] / n for i, x in enumerate(rev_cov_bins)} return cov_stat.annotate(**bin_expr).drop("coverage_counter") From 2f49efb7e1b455db2d1bca4e6ae55b882aa7b534 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:17:27 -0700 Subject: [PATCH 5/5] address review comments --- gnomad/utils/sparse_mt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 101770d1a..575d38ecb 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -960,11 +960,13 @@ def compute_coverage_stats( raise ValueError( "Only one of 'group_membership_ht' or 'strata_expr' can be specified." ) - - # Initialize no_strata and default strata_expr if neither group_membership_ht nor strata_expr is provided + + # Initialize no_strata and default strata_expr if neither group_membership_ht nor + # strata_expr is provided. no_strata = group_membership_ht is None and strata_expr is None if no_strata: strata_expr = {} + if group_membership_ht is None: # Annotate the MT cols with each of the expressions in strata_expr and redefine # strata_expr based on the column HT with added annotations. @@ -981,8 +983,10 @@ def compute_coverage_stats( # "raw" frequency of all samples, where the first 2 elements are the same sample # set, but 'freq_meta' starts with [{"group": "adj", "group": "raw", ...]. Use # `no_raw_group` to exclude the "raw" group so there is a single annotation - # representing the full samples set and update all 'freq_meta' entries' "group" - # to "raw". + # representing the full samples set. Update all 'freq_meta' entries' "group" + # to "raw" because `generate_freq_group_membership_array` will return them all + # as "adj" since it was built for frequency computation, but for the coverage + # computation we don't want to do any filtering. group_membership_ht = generate_freq_group_membership_array( ht, strata_expr, no_raw_group=True )