diff --git a/gnomad/variant_qc/evaluation.py b/gnomad/variant_qc/evaluation.py index 8bc3a9d50..03d50fd42 100644 --- a/gnomad/variant_qc/evaluation.py +++ b/gnomad/variant_qc/evaluation.py @@ -81,6 +81,23 @@ def compute_ranked_bin( _rand=hl.rand_unif(0, 1), ) + # Checkpoint bin Table prior to variant count aggregation. + bin_ht = bin_ht.checkpoint(hl.utils.new_temp_file("bin", "ht")) + + # Compute variant counts per group defined by bin_expr. This is used to determine + # bin assignment. + bin_group_variant_counts = bin_ht.aggregate( + hl.Struct( + **{ + bin_id: hl.agg.filter( + bin_ht[f"_filter_{bin_id}"], + hl.agg.count(), + ) + for bin_id in bin_expr + } + ) + ) + logger.info( "Sorting the HT by score_expr followed by a random float between 0 and 1. " "Then adding a row index per grouping defined by bin_expr..." @@ -97,22 +114,6 @@ def compute_ranked_bin( ) bin_ht = bin_ht.key_by("locus", "alleles") - # Annotate globals with variant counts per group defined by bin_expr. This - # is used to determine bin assignment - bin_ht = bin_ht.annotate_globals( - bin_group_variant_counts=bin_ht.aggregate( - hl.Struct( - **{ - bin_id: hl.agg.filter( - bin_ht[f"_filter_{bin_id}"], - hl.agg.count(), - ) - for bin_id in bin_expr - } - ) - ) - ) - logger.info("Binning ranked rows into %d bins...", n_bins) bin_ht = bin_ht.select( "snv", @@ -123,7 +124,7 @@ def compute_ranked_bin( n_bins * ( bin_ht[f"{bin_id}_rank"] - / hl.float64(bin_ht.bin_group_variant_counts[bin_id]) + / hl.float64(bin_group_variant_counts[bin_id]) ) ) + 1 @@ -143,20 +144,18 @@ def compute_ranked_bin( # in bin names in the table if compute_snv_indel_separately: bin_expr_no_snv = { - bin_id.rsplit("_", 1)[0] for bin_id in bin_ht.bin_group_variant_counts + bin_id.rsplit("_", 1)[0] for bin_id in bin_group_variant_counts } - bin_ht = bin_ht.annotate_globals( - bin_group_variant_counts=hl.struct( - **{ - bin_id: hl.struct( - **{ - snv: bin_ht.bin_group_variant_counts[f"{bin_id}_{snv}"] - for snv in ["snv", "indel"] - } - ) - for bin_id in bin_expr_no_snv - } - ) + bin_group_variant_counts = hl.struct( + **{ + bin_id: hl.struct( + **{ + snv: bin_group_variant_counts[f"{bin_id}_{snv}"] + for snv in ["snv", "indel"] + } + ) + for bin_id in bin_expr_no_snv + } ) bin_ht = bin_ht.transmute( @@ -170,6 +169,8 @@ def compute_ranked_bin( } ) + bin_ht = bin_ht.annotate_globals(bin_group_variant_counts=bin_group_variant_counts) + return bin_ht @@ -264,7 +265,7 @@ def compute_binned_truth_sample_concordance( score=indexed_binned_score_ht.score, global_bin=indexed_binned_score_ht.bin, ) - + ht = ht.checkpoint(hl.utils.new_temp_file("pre_bin", "ht")) # Annotate the truth sample bin bin_ht = compute_ranked_bin( ht, diff --git a/gnomad/variant_qc/pipeline.py b/gnomad/variant_qc/pipeline.py index f03b6bd56..cf9d24027 100644 --- a/gnomad/variant_qc/pipeline.py +++ b/gnomad/variant_qc/pipeline.py @@ -1,7 +1,7 @@ # noqa: D100 import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import hail as hl import pyspark.sql @@ -69,10 +69,11 @@ def create_binned_ht( :return: table with bin number for each variant """ - def _update_bin_expr( - bin_expr: Dict[str, hl.expr.BooleanExpression], + def _new_bin_expr( + bin_expr: Union[Dict[str, hl.expr.BooleanExpression], Dict[str, bool]], new_expr: hl.expr.BooleanExpression, new_id: str, + update: bool = False, ) -> Dict[str, hl.expr.BooleanExpression]: """ Update a dictionary of expressions to add another stratification. @@ -83,29 +84,33 @@ def _update_bin_expr( :return: Dictionary of `bin_expr` updated with `new_expr` added as an additional stratification to all expressions already in `bin_expr` """ - bin_expr.update( - { - f"{new_id}_{bin_id}": bin_expr & new_expr - for bin_id, bin_expr in bin_expr.items() - } - ) - return bin_expr + new_bin_expr = { + f"{new_id}_{bin_id}": bin_expr & new_expr + for bin_id, bin_expr in bin_expr.items() + } + if update: + bin_expr.update(new_bin_expr) + return bin_expr + else: + return new_bin_expr # Desired bins and sub-bins bin_expr = {"bin": True} if singleton: - bin_expr = _update_bin_expr(bin_expr, ht.ac_raw == 1, "singleton") + bin_expr = _new_bin_expr(bin_expr, ht.ac_raw == 1, "singleton", update=True) if biallelic: - bin_expr = _update_bin_expr(bin_expr, ~ht.was_split, "biallelic") + bin_expr = _new_bin_expr(bin_expr, ~ht.was_split, "biallelic", update=True) if adj: - bin_expr = _update_bin_expr(bin_expr, (ht.ac > 0), "adj") + bin_expr = _new_bin_expr(bin_expr, (ht.ac > 0), "adj", update=True) - if add_substrat: + if add_substrat is not None: + new_bin_expr = {} for add_id, add_expr in add_substrat.items(): - bin_expr = _update_bin_expr(bin_expr, add_expr, add_id) + new_bin_expr.update(_new_bin_expr(bin_expr, add_expr, add_id)) + bin_expr.update(new_bin_expr) bin_ht = compute_ranked_bin( ht, score_expr=ht.score, bin_expr=bin_expr, n_bins=n_bins @@ -223,22 +228,39 @@ def score_bin_agg( "Either 'fail_hard_filters' or 'info' must be present in the input Table!" ) + ins_expr = hl.is_insertion(ht.alleles[0], ht.alleles[1]) + del_expr = hl.is_deletion(ht.alleles[0], ht.alleles[1]) + indel_1bp_expr = indel_length == 1 + count_where_expr = { + "n_ins": ins_expr, + "n_del": del_expr, + "n_ti": hl.is_transition(ht.alleles[0], ht.alleles[1]), + "n_tv": hl.is_transversion(ht.alleles[0], ht.alleles[1]), + "n_1bp_indel": indel_1bp_expr, + "n_1bp_ins": ins_expr & indel_1bp_expr, + "n_2bp_ins": ins_expr & (indel_length == 2), + "n_3bp_ins": ins_expr & (indel_length == 3), + "n_1bp_del": del_expr & indel_1bp_expr, + "n_2bp_del": del_expr & (indel_length == 2), + "n_3bp_del": del_expr & (indel_length == 3), + "n_mod3bp_indel": (indel_length % 3) == 0, + "n_singleton": ht.singleton, + "fail_hard_filters": fail_hard_filters_expr, + "n_pos_train": ht.positive_train_site, + "n_neg_train": ht.negative_train_site, + "n_clinvar": hl.is_defined(clinvar), + "n_clinvar_path": hl.is_defined(clinvar_path), + "n_omni": truth_data.omni, + "n_mills": truth_data.mills, + "n_hapmap": truth_data.hapmap, + "n_kgp_phase1_hc": truth_data.kgp_phase1_hc, + } + return dict( min_score=hl.agg.min(ht.score), max_score=hl.agg.max(ht.score), n=hl.agg.count(), - n_ins=hl.agg.count_where(hl.is_insertion(ht.alleles[0], ht.alleles[1])), - n_del=hl.agg.count_where(hl.is_deletion(ht.alleles[0], ht.alleles[1])), - n_ti=hl.agg.count_where(hl.is_transition(ht.alleles[0], ht.alleles[1])), - n_tv=hl.agg.count_where(hl.is_transversion(ht.alleles[0], ht.alleles[1])), - n_1bp_indel=hl.agg.count_where(indel_length == 1), - n_mod3bp_indel=hl.agg.count_where((indel_length % 3) == 0), - n_singleton=hl.agg.count_where(ht.singleton), - fail_hard_filters=hl.agg.count_where(fail_hard_filters_expr), - n_pos_train=hl.agg.count_where(ht.positive_train_site), - n_neg_train=hl.agg.count_where(ht.negative_train_site), - n_clinvar=hl.agg.count_where(hl.is_defined(clinvar)), - n_clinvar_path=hl.agg.count_where(hl.is_defined(clinvar_path)), + **{k: hl.agg.count_where(v) for k, v in count_where_expr.items()}, n_de_novos_singleton_adj=hl.agg.filter( ht.ac == 1, hl.agg.sum(fam.n_de_novos_adj) ), @@ -247,6 +269,20 @@ def score_bin_agg( ), n_de_novos_adj=hl.agg.sum(fam.n_de_novos_adj), n_de_novo=hl.agg.sum(fam.n_de_novos_raw), + n_de_novos_AF_001_adj=hl.agg.filter( + hl.if_else( + fam.ac_parents_adj == 0, 0.0, fam.ac_parents_adj / fam.an_parents_adj + ) + < 0.001, + hl.agg.sum(fam.n_de_novos_adj), + ), + n_de_novos_AF_001=hl.agg.filter( + hl.if_else( + fam.ac_parents_raw == 0, 0.0, fam.ac_parents_raw / fam.an_parents_raw + ) + < 0.001, + hl.agg.sum(fam.n_de_novos_raw), + ), n_trans_singletons=hl.agg.filter( ht.ac_raw == 2, hl.agg.sum(fam.n_transmitted_raw) ), @@ -257,10 +293,6 @@ def score_bin_agg( n_train_trans_singletons=hl.agg.filter( (ht.ac_raw == 2) & ht.positive_train_site, hl.agg.sum(fam.n_transmitted_raw) ), - n_omni=hl.agg.count_where(truth_data.omni), - n_mills=hl.agg.count_where(truth_data.mills), - n_hapmap=hl.agg.count_where(truth_data.hapmap), - n_kgp_phase1_hc=hl.agg.count_where(truth_data.kgp_phase1_hc), )