diff --git a/gnomad/sample_qc/sex.py b/gnomad/sample_qc/sex.py index 09ad75108..82ba62d5d 100644 --- a/gnomad/sample_qc/sex.py +++ b/gnomad/sample_qc/sex.py @@ -42,14 +42,16 @@ def adjusted_sex_ploidy_expr( return ( hl.case(missing_false=True) + # Added to reduce the checks by entry. + .when(row_idx.in_autosome, gt_expr) + .when((row_idx.y_par | row_idx.y_nonpar) & col_idx.xx, hl.missing(hl.tcall)) .when(~row_idx.in_non_par, gt_expr) - .when(col_idx.xx & (row_idx.y_par | row_idx.y_nonpar), hl.null(hl.tcall)) .when( - col_idx.xy & (row_idx.x_nonpar | row_idx.y_nonpar) & gt_expr.is_het(), + (row_idx.x_nonpar | row_idx.y_nonpar) & col_idx.xy & gt_expr.is_het(), hl.null(hl.tcall), ) .when( - col_idx.xy & (row_idx.x_nonpar | row_idx.y_nonpar), + (row_idx.x_nonpar | row_idx.y_nonpar) & col_idx.xy, hl.call(gt_expr[0], phased=False), ) .default(gt_expr) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 2df5bd7d4..f64b45d4e 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -603,8 +603,8 @@ def create_frequency_bins_expr( def annotate_and_index_source_mt_for_sex_ploidy( - locus_expr: hl.expr.LocusExpression = None, - karyotype_expr: hl.expr.StringExpression = None, + locus_expr: hl.expr.LocusExpression, + karyotype_expr: hl.expr.StringExpression, xy_karyotype_str: str = "XY", xx_karyotype_str: str = "XX", ) -> Tuple[hl.expr.StructExpression, hl.expr.StructExpression]: @@ -640,6 +640,7 @@ def annotate_and_index_source_mt_for_sex_ploidy( ).cols() row_ht = source_mt.annotate_rows( in_non_par=~locus_expr.in_autosome_or_par(), + in_autosome=locus_expr.in_autosome(), x_nonpar=locus_expr.in_x_nonpar(), y_par=locus_expr.in_y_par(), y_nonpar=locus_expr.in_y_nonpar(), diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index a5a59d1a3..ddf6fe1e0 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -990,6 +990,10 @@ def densify_all_reference_sites( # Unfilter entries so that entries with no ref block overlap aren't null. mt = mt.unfilter_entries() + # Rekey by requested row key field and drop unused keys. + mt = mt.key_rows_by(*row_key_fields) + mt = mt.drop(*[k for k in mt_row_key_fields if k not in row_key_fields]) + return mt @@ -1192,12 +1196,7 @@ def compute_stats_per_ref_site( select_fields=row_keep_fields, entry_agg_group_membership=entry_agg_group_membership, ) - ht = ht.checkpoint(hl.utils.new_temp_file("agg_stats", "ht")) - - # Drop no longer needed fields - current_keys = list(ht.key) - ht = ht.key_by(*row_key_fields).select_globals() - ht = ht.drop(*[k for k in current_keys if k not in row_key_fields]) + ht = ht.select_globals().checkpoint(hl.utils.new_temp_file("agg_stats", "ht")) group_globals = group_membership_ht.index_globals() global_expr = {}