diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4ea4bc9f..a27eaf29 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,13 +10,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.11.4 - name: Use pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ~/.cache/pip key: pip-${{ hashFiles('**/requirements*.txt') }} @@ -24,9 +24,10 @@ jobs: pip- - name: Install Python dependencies run: | + python -m pip install --upgrade pip pip install wheel - pip install -r requirements-dev.txt pip install hail + pip install -r requirements-dev.txt - name: Setup R uses: r-lib/actions/setup-r@v2 with: diff --git a/.pylintrc b/.pylintrc index 5952b3f7..bb074012 100644 --- a/.pylintrc +++ b/.pylintrc @@ -5,9 +5,10 @@ notes=FIXME, XXX, [MESSAGES CONTROL] -init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" disable= + # Disable because ~ is used in a lot of hail code where it is not an error + invalid-unary-operand-type, unused-wildcard-import, R, C, diff --git a/gnomad_constraint/experimental/__init__.py b/gnomad_constraint/experimental/__init__.py new file mode 100644 index 00000000..6e031999 --- /dev/null +++ b/gnomad_constraint/experimental/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/gnomad_constraint/experimental/proemis3d/README.md b/gnomad_constraint/experimental/proemis3d/README.md new file mode 100644 index 00000000..9bdd3b25 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/README.md @@ -0,0 +1,555 @@ +# Proemis3D Pipeline + +The Proemis3D (Protein Missense Constraint in 3D) pipeline is an experimental module for analyzing protein missense constraint using 3D structural information from AlphaFold2. This pipeline integrates genomic variant data with protein structural data to identify regions of high missense constraint in 3D space. + +## Overview + +The Proemis3D pipeline combines: +- **Genomic variant data** from gnomAD and other sources +- **Protein structural data** from AlphaFold2 +- **Functional annotations** from various databases +- **Statistical methods** for constraint analysis + +To identify regions of proteins that are highly intolerant to missense variation in 3D space. + +## Key Features + +- **3D Constraint Analysis**: Identifies regions of high missense constraint based on 3D protein structure +- **AlphaFold2 Integration**: Uses AlphaFold2 structural predictions and confidence scores +- **Multiple Annotation Sources**: Integrates data from COSMIS, InterPro, ClinVar, and other databases +- **Statistical Rigor**: Implements both greedy and forward algorithms for constraint region identification +- **Scalable Processing**: Built on Hail for large-scale genomic data processing + +## Directory Structure + +``` +promis3d/ +├── README.md # This file +├── __init__.py # Package initialization +├── constants.py # Pipeline constants and gene lists +├── data_import.py # Data import and processing functions +├── proemis_3d.py # Main pipeline script +├── resources.py # Resource definitions and paths +└── utils.py # Utility functions and algorithms +``` + +## Core Components + +### 1. Constants (`constants.py`) +- **MIN_EXP_MIS**: Minimum expected missense variants for constraint calculation (16) +- **Gene Lists**: HI genes, severe HI genes, and gene categories +- **Configuration**: Pipeline parameters and thresholds + +### 2. Data Import (`data_import.py`) +- **COSMIS Scores**: Protein constraint scores from multiple structure sources +- **Varity Data**: Variant effect predictions +- **MTR3D Data**: Missense tolerance ratio in 3D +- **InterPro Annotations**: Protein domain and functional annotations +- **Kaplanis Variants**: Developmental delay de novo missense variants +- **Fu Variants**: Autism spectrum disorder de novo variants +- **ClinVar Data**: Clinical variant interpretations and significance +- **Constraint Metrics**: Gene-level constraint statistics (synonymous, missense, LoF) +- **MTR Data**: Missense tolerance ratio annotations +- **RMC Data**: Regional missense constraint metrics +- **Context Data**: Variant context and coverage information +- **Genetics Gym Scores**: AI-based missense prediction scores +- **REVEL Scores**: Rare Exome Variant Ensemble Learner predictions + +#### Key Functions: +- `import_cosmis_score_data()`: Import COSMIS constraint scores from AlphaFold2, PDB, and Swiss Model structures +- `import_varity_data()`: Import VariTY variant effect prediction scores +- `import_mtr3d_data()`: Import MTR3D (Missense Tolerance Ratio in 3D) scores +- `import_mtr_data()`: Import MTR (Missense Tolerance Ratio) annotations +- `import_kaplanis_variants()`: Process developmental delay de novo variants with GRCh37/GRCh38 liftover +- `get_kaplanis_sig_gene_annotations()`: Get significant gene set annotations from Kaplanis study +- `import_fu_variants()`: Import autism spectrum disorder de novo variants +- `import_interpro_annotations()`: Import InterPro protein domain and functional annotations +- `process_clinvar_ht()`: Process ClinVar data and filter to missense variants +- `process_gnomad_site_ht()`: Process gnomAD site-level variant data +- `process_pext_base_ht()`: Process base-level expression data +- `process_pext_annotation_ht()`: Process annotation-level expression data +- `process_gnomad_de_novo_ht()`: Process gnomAD de novo variant data +- `process_rmc_ht()`: Process Regional Missense Constraint data with p-values and confidence intervals +- `process_constraint_metrics_ht()`: Process gene-level constraint metrics (syn, mis, lof) +- `process_context_ht()`: Process variant context data with coverage and frequency information +- `process_genetics_gym_missense_scores_ht()`: Process AI-based missense prediction scores with percentiles +- `import_revel_ht()`: Import REVEL (Rare Exome Variant Ensemble Learner) scores + +### 4. Utils (`utils.py`) +- **FASTA Processing**: Convert GENCODE FASTA files to Hail tables +- **AlphaFold2 Processing**: Extract sequences, distance matrices, and confidence scores +- **Constraint Algorithms**: Greedy and forward algorithms for region identification +- **Annotation Functions**: Combine variant and residue-level annotations + +### 5. Main Pipeline (`proemis_3d.py`) +- **Pipeline Orchestration**: Coordinates all pipeline steps +- **Resource Management**: Handles data dependencies and checkpointing +- **Command Line Interface**: Provides CLI for pipeline execution + +## Key Algorithms + +### Greedy Algorithm +Identifies the most intolerant region by iteratively selecting residues with the lowest upper bound of the observed/expected (OE) confidence interval. + +### Forward Algorithm +Uses a forward selection approach with Akaike Information Criterion (AIC) to identify optimal constraint regions. + +### 3D Constraint Calculation +- Calculates distance matrices from AlphaFold2 structures +- Identifies spatially proximal residues +- Computes constraint metrics for 3D regions + +## Usage + +### Basic Pipeline Execution + +```python +import hail as hl +from gnomad_constraint.experimental.proemis3d import get_proemis3d_resources + +# Initialize Hail +hl.init() + +# Get pipeline resources +resources = get_proemis3d_resources( + version="4.1", + overwrite=False, + test=False +) + +# Run the pipeline +resources.run() +``` + +### Command Line Interface + +```bash +python proemis_3d.py --version 4.1 --test +``` + +### Key Parameters + +- `version`: gnomAD version (2.1.1 or 4.1) +- `test`: Run in test mode with smaller datasets +- `overwrite`: Overwrite existing outputs + +## Command Line Parameters + +The Proemis3D pipeline supports extensive command-line parameters for fine-grained control over execution. Here's a complete reference: + +### Basic Parameters + +#### `--version` +- **Type**: String +- **Default**: `4.1` (current version) +- **Description**: Which version of the resource Tables will be used +- **Options**: `2.1.1`, `4.1` + +#### `--test` +- **Type**: Boolean flag +- **Description**: Whether to run a test instead of the full pipeline +- **Effect**: Uses smaller test datasets and filters to specific test transcript (ENST00000372435) + +#### `--overwrite` +- **Type**: Boolean flag +- **Description**: Whether to overwrite existing output files +- **Effect**: Forces regeneration of existing intermediate and output files + +### Data Processing Steps + +#### `--convert-gencode-fastn-to-ht` +- **Type**: Boolean flag +- **Description**: Import and pre-process GENCODE transcripts FASTA file as a Hail Table +- **Output**: GENCODE transcripts Hail Table with sequence data + +#### `--convert-gencode-fasta-to-ht` +- **Type**: Boolean flag +- **Description**: Import and pre-process GENCODE translations FASTA file as a Hail Table +- **Output**: GENCODE translations Hail Table with protein sequence data + +#### `--read-af2-sequences` +- **Type**: Boolean flag +- **Description**: Process AlphaFold2 structures from GCS bucket into a Hail Table +- **Output**: AlphaFold2 sequences Hail Table +- **Mode**: `sequence` + +#### `--compute-af2-distance-matrices` +- **Type**: Boolean flag +- **Description**: Compute distance matrices for AlphaFold2 structures +- **Output**: AlphaFold2 distance matrices Hail Table +- **Mode**: `distance_matrix` + +#### `--extract-af2-plddt` +- **Type**: Boolean flag +- **Description**: Extract pLDDT (per-residue confidence) scores from AlphaFold2 structures +- **Output**: AlphaFold2 pLDDT scores Hail Table +- **Mode**: `plddt` + +#### `--extract-af2-pae` +- **Type**: Boolean flag +- **Description**: Extract pAE (predicted aligned error) scores from AlphaFold2 structures +- **Output**: AlphaFold2 PAE matrices Hail Table +- **Mode**: `pae` + +#### `--gencode-alignment` +- **Type**: Boolean flag +- **Description**: Join GENCODE translations and AlphaFold2 structures based on sequence +- **Output**: Matched GENCODE-AlphaFold2 Hail Table + +#### `--get-gencode-positions` +- **Type**: Boolean flag +- **Description**: Create GENCODE positions Hail Table with genomic coordinates +- **Output**: GENCODE positions Hail Table + +### Constraint Analysis + +#### `--run-greedy` +- **Type**: Boolean flag +- **Description**: Execute the greedy algorithm for constraint region identification +- **Output**: Greedy algorithm results Hail Table + +#### `--run-forward` +- **Type**: Boolean flag +- **Description**: Execute the forward algorithm for constraint region identification +- **Output**: Forward algorithm results Hail Table + +#### `--min-exp-mis` +- **Type**: Integer +- **Default**: `16` +- **Description**: Minimum expected number of missense variants to consider for constraint algorithms +- **Effect**: Filters regions with insufficient expected missense variants + +### Output Generation + +#### `--write-per-variant` +- **Type**: Boolean flag +- **Description**: Generate per-variant annotated Hail Table with comprehensive annotations +- **Output**: Fully annotated per-variant Hail Table +- **Dependencies**: Requires forward algorithm results + +#### `--write-per-missense-variant` +- **Type**: Boolean flag +- **Description**: Generate per-variant annotated Hail Table filtered to missense variants only +- **Output**: Missense-only per-variant Hail Table +- **Dependencies**: Requires per-variant Hail Table + +#### `--write-per-residue` +- **Type**: Boolean flag +- **Description**: Generate per-residue Hail Table from per-variant data +- **Output**: Per-residue annotated Hail Table +- **Dependencies**: Requires per-variant Hail Table + +#### `--write-per-region` +- **Type**: Boolean flag +- **Description**: Generate per-region Hail Table from per-residue data +- **Output**: Per-region annotated Hail Table +- **Dependencies**: Requires per-residue Hail Table + +#### `--create-missense-viewer-input-ht` +- **Type**: Boolean flag +- **Description**: Create missense viewer input Hail Table for visualization +- **Output**: Formatted data for web-based visualization +- **Dependencies**: Requires forward algorithm results and GENCODE positions + +### Performance Tuning + +#### `--all-snv-n-partitions` +- **Type**: Integer +- **Default**: `5000` +- **Description**: Number of partitions to use for the all possible SNVs Hail Table +- **Effect**: Controls memory usage and processing speed for large datasets + +### Usage Examples + +#### Complete Pipeline +```bash +python proemis_3d.py --version 4.1 --overwrite +``` + +#### Test Run +```bash +python proemis_3d.py --version 4.1 --test --overwrite +``` + +#### Individual Steps +```bash +# Process GENCODE data +python proemis_3d.py --convert-gencode-fastn-to-ht --convert-gencode-fasta-to-ht + +# Process AlphaFold2 data +python proemis_3d.py --read-af2-sequences --compute-af2-distance-matrices --extract-af2-plddt --extract-af2-pae + +# Run constraint analysis +python proemis_3d.py --run-greedy --run-forward --min-exp-mis 20 + +# Generate outputs +python proemis_3d.py --write-per-variant --write-per-residue --write-per-region +``` + +#### Custom Configuration +```bash +python proemis_3d.py \ + --version 4.1 \ + --test \ + --overwrite \ + --min-exp-mis 25 \ + --all-snv-n-partitions 10000 \ + --run-forward \ + --write-per-variant \ + --write-per-residue +``` + +### Pipeline Dependencies + +The pipeline steps have specific dependencies that must be satisfied: + +1. **GENCODE Processing**: `--convert-gencode-fastn-to-ht` and `--convert-gencode-fasta-to-ht` can run independently +2. **AlphaFold2 Processing**: All AF2 steps can run independently +3. **Alignment**: `--gencode-alignment` requires both GENCODE and AF2 sequence data +4. **Positions**: `--get-gencode-positions` requires GENCODE data and alignment +5. **Constraint Analysis**: `--run-greedy` and `--run-forward` require positions and RMC data +6. **Output Generation**: Each output step depends on its prerequisite data + +### Resource Management + +- **Checkpointing**: Intermediate results are automatically checkpointed +- **Temporary Files**: Uses `gs://gnomad-tmp-4day` for temporary storage +- **Logging**: Pipeline logs are written to `/proemis_3d.log` +- **Memory Management**: Large operations use repartitioning for memory efficiency + +## Data Requirements + +### Input Data +- **GENCODE**: Transcript and translation FASTA files +- **AlphaFold2**: Protein structures and confidence scores +- **gnomAD**: Variant data and constraint metrics +- **Annotations**: COSMIS, InterPro, ClinVar, and other functional annotations + +### Output Data +- **Per-SNV Tables**: Annotated variant-level data +- **Per-Residue Tables**: Residue-level constraint metrics +- **Per-Region Tables**: 3D constraint regions +- **Viewer Input**: Formatted data for visualization + +## Key Functions + +### Data Processing +- `convert_fasta_to_table()`: Convert FASTA files to Hail tables +- `process_af2_structures()`: Process AlphaFold2 structural data +- `get_gencode_positions()`: Extract genomic positions for transcripts + +### Constraint Analysis +- `run_greedy()`: Execute greedy constraint algorithm +- `run_forward()`: Execute forward constraint algorithm +- `determine_regions_with_min_oe_upper()`: Identify constraint regions + +### Annotation +- `annotate_snvs_with_variant_level_data()`: Add variant-level annotations +- `annotate_proemis3d_with_af2_metrics()`: Add AlphaFold2 metrics +- `create_per_snv_combined_ht()`: Create comprehensive annotation tables + +## Dependencies + +- **Hail**: Genomic data processing +- **Biopython**: Protein structure analysis +- **NumPy/Pandas**: Numerical computing +- **PySpark**: Distributed computing + +## Output Structure + +The pipeline generates several key outputs: + +1. **Variant-Level Data**: All possible SNVs with comprehensive annotations +2. **Residue-Level Data**: Per-residue constraint metrics and AlphaFold2 scores +3. **Region-Level Data**: 3D constraint regions with statistical significance +4. **Viewer Data**: Formatted data for web-based visualization + +## Statistical Methods + +- **Observed/Expected Ratios**: Compare observed to expected missense variants +- **Confidence Intervals**: Chi-squared based confidence intervals +- **AIC Selection**: Model selection using Akaike Information Criterion +- **3D Distance Metrics**: Spatial proximity calculations + +## Performance Considerations + +- **Checkpointing**: Intermediate results are checkpointed for fault tolerance +- **Partitioning**: Data is partitioned for efficient processing +- **Caching**: Frequently accessed data is cached in memory +- **Parallel Processing**: Uses Spark for distributed computation + +## Testing + +The pipeline includes test mode with smaller datasets: +- **Test Transcript**: ENST00000372435 +- **Test UniProt**: P60891 +- **Reduced Data**: Smaller subsets for development and testing + +## Future Development + +- **Additional Algorithms**: New constraint identification methods +- **Enhanced Annotations**: Integration with additional data sources +- **Visualization Tools**: Interactive 3D constraint visualization +- **Performance Optimization**: Improved scalability and efficiency + +## Citation + +If you use this pipeline in your research, please cite the relevant gnomAD and AlphaFold2 papers, as well as any specific Proemis3D methodology papers. + +## Contact + +For questions or issues with the Proemis3D pipeline, please contact the gnomAD team or create an issue in the repository. + +## Detailed Output Data Structure + +### Fully Annotated Hail Table Schema + +The main output of the Proemis3D pipeline is a comprehensive Hail Table with the following structure: + +#### Global Fields +- **None** (no global fields) + +#### Row Fields + +##### Basic Identifiers +- `locus`: Genomic position (locus) +- `alleles`: Variant alleles (array) +- `transcript_id`: GENCODE transcript ID (str) +- `uniprot_id`: UniProt protein ID (str) +- `gene_id`: GENCODE gene ID (str) +- `gene_symbol`: Gene symbol (str) + +##### Transcript and Gene Metadata +- `canonical`: Whether transcript is canonical (bool) +- `mane_select`: Whether transcript is MANE select (bool) +- `transcript_biotype`: Transcript biotype (str) +- `most_severe_consequence`: Most severe consequence (str) +- `cds_len_mismatch`: CDS length mismatch flag (bool) +- `cds_len_not_div_by_3`: CDS length not divisible by 3 flag (bool) + +##### Gene Classification Flags +- `is_phaplo_gene`: Haploinsufficiency gene flag (bool) +- `is_ptriplo_gene`: Triplosensitivity gene flag (bool) +- `is_hi_gene`: Haploinsufficiency gene flag (bool) +- `hi_gene_category`: HI gene category (str) +- `one_uniprot_per_transcript`: One UniProt per transcript flag (bool) +- `one_transcript_per_gene`: One transcript per gene flag (bool) + +##### Variant Level Annotations (`variant_level_annotations`) + +**Context and Basic Info:** +- `context`: Sequence context (str) +- `ref`: Reference allele (str) +- `alt`: Alternative allele (str) +- `was_flipped`: Whether variant was flipped (bool) +- `transition`: Whether variant is a transition (bool) +- `cpg`: Whether variant is in CpG context (bool) +- `mutation_type`: Mutation type (str) +- `methylation_level`: Methylation level (int32) + +**gnomAD Exomes Data:** +- `gnomad_exomes_filters`: gnomAD exomes filters (set) +- `gnomad_exomes_coverage`: Coverage statistics (struct with mean, median_approx, AN, percent_AN) +- `gnomad_exomes_freq`: Frequency data by population (struct with total, afr, amr, eas, nfe, sas) +- `gnomad_exomes_flags`: gnomAD exomes flags (set) + +**Functional Predictions:** +- `sift_score`: SIFT score (float64) +- `polyphen_score`: PolyPhen score (float64) +- `vep_domains`: VEP domains (array) +- `revel`: REVEL score (float64) +- `cadd`: CADD scores (struct with phred, raw_score) +- `phylop`: PhyloP score (float64) + +**Genetics Gym Missense Scores:** +- `genetics_gym_missense_scores`: Comprehensive missense prediction scores (struct with esm_score, proteinmpnn_llr, am_pathogenicity, rasp_score, MisFit_S, MisFit_D, popeve, eve, esm1_v, mpc, and negative controls) + +**Disease Associations:** +- `autism`: Autism-related annotations (struct with role) +- `dd_denovo`: Developmental delay de novo variants (struct with grch37_locus, grch37_alleles, case_control, gene flags) +- `dd_denovo_no_transcript_match`: DD de novo without transcript match +- `gnomad_de_novo`: gnomAD de novo variants (struct with de_novo_AC, p_de_novo_stats) +- `clinvar`: ClinVar annotations (comprehensive struct with rsid, frequencies, clinical significance, etc.) + +**Expression Data:** +- `base_level_pext`: Base-level expression data across 49 tissues (struct with exp_prop_mean and tissue-specific data) +- `annotation_level_pext`: Annotation-level expression data (similar structure to base_level_pext) + +**Constraint Metrics:** +- `mtr`: Missense tolerance ratio (struct with mtr, synExp, misExp, expMTR, synObs, misObs, obsMTR, adj_rate, pvalue, qvalue, proteinLength, percentiles) +- `rmc`: Regional missense constraint (struct with section_obs, section_exp, section_oe, section_chisq, interval, section_oe_ci, section_p_value) + +##### Residue Level Annotations (`residue_level_annotations`) + +**Basic Residue Info:** +- `residue_index`: Residue position in protein (int32) +- `residue_ref`: Reference amino acid (str) +- `residue_alt`: Alternative amino acid (str) + +**Functional Annotations:** +- `interpro`: InterPro domain annotations (struct with interpro_id, interpro_short_description, interpro_description) +- `varity`: VariTY scores (struct with varity_r, varity_er, varity_r_loo, varity_er_loo) +- `mtr3d`: MTR3D scores (struct with mean_pLDDT, mtr3daf2_5a, mtr3daf2_8a, mtr3daf2_11a, mtr3daf2_14a) + +**COSMIS Scores (Multiple Sources):** +- `cosmis_alphafold`: COSMIS scores from AlphaFold2 structures +- `cosmis_pdb`: COSMIS scores from PDB structures +- `cosmis_swiss_model`: COSMIS scores from Swiss Model structures +- `cosmis`: Combined COSMIS scores (struct with alphafold, pdb, swiss_model) + +**Proemis3D Annotations:** +- `promis3d`: Proemis3D constraint annotations (struct with residue_level_annotations and region_level_annotations) + +**Proemis3D Residue Level:** +- `residue_to_region_aa_dist_stats`: Amino acid distance statistics within region +- `alphafold2_info`: AlphaFold2 metrics (residue_plddt, residue_to_region_pae_stats, residue_to_region_dist_stats) + +**Proemis3D Region Level:** +- `region_index`: Region index (int32) +- `region_residues`: Array of residue indices in region (array) +- `region_length`: Length of region (int32) +- `obs`: Observed missense variants (int64) +- `exp`: Expected missense variants (float64) +- `oe`: Observed/expected ratio (float64) +- `oe_upper`: Upper bound of OE confidence interval (float64) +- `oe_ci`: OE confidence interval (struct with lower, upper) +- `chisq`: Chi-squared statistic (float64) +- `p_value`: P-value (float64) +- `is_null`: Whether region is null (bool) +- `region_aa_dist_stats`: Region amino acid distance statistics +- `alphafold2_info`: Region-level AlphaFold2 metrics (region_plddt, region_pae_stats, region_dist_stats) + +##### Gene Level Annotations (`gene_level_annotations`) + +**Basic Gene Info:** +- `strand`: Gene strand (str) +- `cds_length`: CDS length (int32) +- `cds_len_mismatch`: CDS length mismatch flag (bool) +- `cds_len_not_div_by_3`: CDS length not divisible by 3 flag (bool) +- `aminoacid_length`: Amino acid length (int32) +- `gene`: Gene symbol (str) +- `canonical`: Canonical transcript flag (bool) +- `mane_select`: MANE select flag (bool) + +**Constraint Metrics (syn, mis, lof):** +Each variant type (synonymous, missense, loss-of-function) includes: +- `mu_snp`: Mutation rate (float64) +- `mu`: Mutation rate (float64) +- `possible_variants`: Number of possible variants (int64) +- `coverage_correction`: Coverage correction factor (float64) +- `flags`: Quality flags (set) +- `z_score`: Z-score (float64) +- `upper_rank`: Upper rank (int64) +- `upper_bin_sextile`: Upper bin sextile (int32) +- `upper_bin_decile`: Upper bin decile (int32) +- `observed_variants`: Number of observed variants (int64) +- `predicted_proportion_observed`: Predicted proportion observed (float64) +- `expected_variants`: Number of expected variants (float64) +- `oe`: Observed/expected ratio (float64) +- `oe_ci`: OE confidence interval (struct with lower, upper) +- `z_raw`: Raw Z-score (float64) + +#### Key Structure +The table is keyed by: `['locus', 'alleles', 'transcript_id', 'uniprot_id', 'gene_id']` + +This comprehensive schema provides detailed annotations at multiple levels (variant, residue, gene) with extensive functional predictions, constraint metrics, and structural information from AlphaFold2. diff --git a/gnomad_constraint/experimental/proemis3d/__init__.py b/gnomad_constraint/experimental/proemis3d/__init__.py new file mode 100644 index 00000000..6e031999 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/__init__.py @@ -0,0 +1 @@ +# noqa: D104 diff --git a/gnomad_constraint/experimental/proemis3d/constants.py b/gnomad_constraint/experimental/proemis3d/constants.py new file mode 100644 index 00000000..15611643 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/constants.py @@ -0,0 +1,264 @@ +"""Constants for the proemis3D pipeline.""" + +MIN_EXP_MIS = 16 +""" +Minimum number of expected missense variants in a proemis3D region to be considered for +constraint calculation. +""" +SEVERE_HI_GENES = [ + "ADNP", + "AHDC1", + "ANKRD11", + "ARID1A", + "ARID1B", + "ARID2", + "AUTS2", + "CHAMP1", + "CHD2", + "CHD7", + "CHD8", + "CREBBP", + "CTCF", + "CTNNB1", + "DMRT1", + "DYRK1A", + "EFTUD2", + "EHMT1", + "EP300", + "FOXG1", + "FOXP1", + "GATA2", + "GATA6", + "GLI2", + "GRIN2B", + "HIVEP2", + "HNRNPK", + "KANSL1", + "KMT2D", + "MBD5", + "MED13L", + "MEF2C", + "NFIA", + "NIPBL", + "OTX2", + "PAFAH1B1", + "PURA", + "RAI1", + "RERE", + "SATB2", + "SCN1A", + "SCN2A", + "SETBP1", + "SETD5", + "SHANK3", + "SHH", + "SIX3", + "SLC2A1", + "SOX5", + "SOX9", + "STXBP1", + "SYNGAP1", + "TCF4", + "TGIF1", + "ZEB2", + "ZIC2", + "CASK", + "CDKL5", + "HCCS", + "KDM6A", + "MECP2", + "PCDH19", + "WDR45", +] +""" +Severe HI genes. +""" + +MODERATE_HI_GENES = [ + "BMP4", + "BMPR1A", + "BMPR2", + "CHRNA7", + "COL11A2", + "COL1A1", + "COL2A1", + "COL5A1", + "ELN", + "EYA1", + "FAS", + "FBN1", + "FGF10", + "FOXC1", + "FOXC2", + "FZD4", + "GATA3", + "GCH1", + "HOXD13", + "IRF6", + "JAG1", + "KCNH2", + "KCNQ1", + "KCNQ2", + "KIF11", + "LHX4", + "MITF", + "MNX1", + "MYCN", + "NF1", + "NKX2-5", + "NOG", + "NRXN1", + "NSD1", + "PAX2", + "PAX3", + "PITX2", + "PTEN", + "RB1", + "RET", + "RPS19", + "RPS24", + "RPS26", + "RUNX2", + "SALL1", + "SALL4", + "SF3B4", + "SMAD3", + "SMAD4", + "SMARCB1", + "SOX10", + "SOX2", + "SPRED1", + "STK11", + "TBX1", + "TBX3", + "TBX5", + "TCF12", + "TCOF1", + "TRPS1", + "TSC1", + "TSC2", + "WT1", + "ACSL4", + "BCOR", + "DCX", + "EBP", + "EFNB1", + "FLNA", + "FMR1", + "GRIA3", + "LAMP2", + "NSDHL", + "OFD1", + "OTC", + "PDHA1", + "PHF6", + "PLP1", + "PORCN", + "SLC6A8", + "SLC9A6", +] +""" +Moderate HI genes. +""" + +NEW_HI_GENES = [ + "ADNP", + "AHDC1", + "ANKRD11", + "ARID1A", + "ARID1B", + "ARID2", + "AUTS2", + "CHAMP1", + "CHD2", + "CHD7", + "CHD8", + "CREBBP", + "CTNNB1", + "DYRK1A", + "EFTUD2", + "EHMT1", + "EP300", + "FOXG1", + "FOXP1", + "GATA2", + "GATA6", + "GRIN2B", + "HIVEP2", + "KANSL1", + "KMT2D", + "MBD5", + "MED13L", + "MEF2C", + "NFIA", + "NIPBL", + "PAFAH1B1", + "PURA", + "SATB2", + "SCN1A", + "SCN2A", + "SETD5", + "SHANK3", + "SLC2A1", + "SOX5", + "SOX9", + "STXBP1", + "SYNGAP1", + "TCF4", + "ZEB2", + "ZIC2", + "CTCF", + "HNRNPK", + "RAI1", + "RERE", + "SETBP1", + "ASH1L", + "ASXL1", + "ASXL3", + "BCL11A", + "CIC", + "GATAD2B", + "KAT6A", + "KAT6B", + "KMT2A", + "KMT2C", + "MYT1L", + "NFIX", + "OTX2", + "PBX1", + "PHIP", + "POGZ", + "SETD2", + "SON", + "SOX11", + "TBL1XR1", + "TBR1", + "TCF20", + "TRIP12", + "WAC", + "ZBTB18", + "ZMYND11", + "ZNF462", +] +""" +New HI genes. +""" + +HI_GENES = list(set([*SEVERE_HI_GENES, *MODERATE_HI_GENES, *NEW_HI_GENES])) +""" +HI genes. +""" + +HI_GENES_NO_MODERATE = list(set([*SEVERE_HI_GENES, *NEW_HI_GENES])) +""" +HI genes without moderate HI genes. +""" + +HI_GENE_CATEGORIES = { + "Severe Haploinsufficient": SEVERE_HI_GENES, + "Moderate Haploinsufficient": MODERATE_HI_GENES, + "New Haploinsufficient": NEW_HI_GENES, +} +""" +HI gene categories. +""" diff --git a/gnomad_constraint/experimental/proemis3d/data_import.py b/gnomad_constraint/experimental/proemis3d/data_import.py new file mode 100644 index 00000000..243dc8a7 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/data_import.py @@ -0,0 +1,826 @@ +"""Script to import data for Proemis3D. + +This script imports data for Proemis3D, including COSMIS scores, Varity data, MTR3D data, +InterPro annotations, Kaplanis variants, Fu variants, ClinVar missense variants, +constraint metrics, MTR data, RMC data, context data, and Genetics Gym missense scores. +""" + +import argparse + +import hail as hl +from gnomad.resources.grch38.reference_data import clinvar +from gnomad.utils.constraint import oe_confidence_interval +from gnomad.utils.liftover import default_lift_data +from hail.utils.misc import divide_null + +from gnomad_constraint.experimental.proemis3d.resources import ( + CURRENT_VERSION, + get_clinvar_missense_ht, + get_constraint_metrics_ht, + get_context_preprocessed_ht, + get_cosmis_score_ht, + get_cosmis_score_tsv, + get_fu_variants_ht, + get_fu_variants_tsv, + get_genetics_gym_missense_scores_ht, + get_gnomad_de_novo_ht, + get_insilico_annotations_ht, + get_interpro_annotations, + get_interpro_annotations_ht, + get_kaplanis_sig_variants_tsv, + get_kaplanis_variants_ht, + get_kaplanis_variants_tsv, + get_mtr3d_ht, + get_mtr3d_tsv, + get_mtr_ht, + get_mtr_tsv, + get_processed_genetics_gym_missense_scores_ht, + get_revel_csv, + get_rmc_ht, + get_temp_context_preprocessed_ht, + get_temp_processed_constraint_ht, + get_temp_processed_rmc_ht, + get_varity_ht, + get_varity_tsv, +) + + +def import_cosmis_score_data( + model: str, +) -> hl.Table: + """ + Import and process COSMIS TSV file for a specified structure model. + + :param model: Structure model source to process ('alphafold', 'swiss_model', or 'pdb'). + :param distance: Distance metric to use ('8a' or '10a'). Default is '8a'. + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Hail Table with COSMIS scores. + """ + ht = hl.import_table(get_cosmis_score_tsv(model), force=True, impute=True) + ht = ht.transmute( + transcript_id=ht.enst_id, + residue_index=ht.uniprot_pos - 1, + cossyn=hl.float(ht.cossyn), + ) + ht = ht.key_by("transcript_id", "uniprot_id", "residue_index") + + return ht + + +def import_varity_data() -> hl.Table: + """ + Import Varity data. + + :return: Hail Table with Varity data. + """ + ht = hl.import_table( + get_varity_tsv(), + impute=True, + delimiter="\t", + min_partitions=1000, + ) + ht = ht.select( + uniprot_id=ht.p_vid, + residue_index=ht.aa_pos - 1, + residue_ref=ht.aa_ref, + residue_alt=ht.aa_alt, + varity_r=ht.VARITY_R, + varity_er=ht.VARITY_ER, + varity_r_loo=ht.VARITY_R_LOO, + varity_er_loo=ht.VARITY_ER_LOO, + ) + ht = ht.key_by("uniprot_id", "residue_index", "residue_ref", "residue_alt") + + ht.show() + + return ht + + +def import_mtr3d_data() -> hl.Table: + """ + Import MTR3D data. + + :return: Hail Table with MTR3D data. + """ + ht = hl.import_table( + get_mtr3d_tsv(), + impute=True, + delimiter=",", + min_partitions=1000, + missing="", + ) + ht = ht.key_by( + transcript_id=ht.transcript, + uniprot_id=ht.uniprot, + residue_index=ht.aa_num - 1, + ) + ht = ht.select( + "mean_pLDDT", "mtr3daf2_5a", "mtr3daf2_8a", "mtr3daf2_11a", "mtr3daf2_14a" + ) + + return ht + + +def import_mtr_data() -> hl.Table: + """ + Import MTR data. + + :return: Hail Table with MTR data. + """ + ht = hl.import_table( + get_mtr_tsv(), + impute=True, + delimiter="\t", + min_partitions=1000, + missing="", + ) + ht = ht.annotate( + locus=hl.parse_locus( + hl.format( + "chr%s:%s", hl.if_else(ht.CHR == 23, "X", hl.str(ht.CHR)), ht.POS + ), + reference_genome="GRCh38", + ), + alleles=[ht.REF, ht.ALT], + transcript_id=ht.TranscriptId, + mtr=ht.MTR, + ) + ht = ht.key_by("locus", "alleles", "transcript_id") + ht = ht.select( + "mtr", + "synExp", + "misExp", + "expMTR", + "synObs", + "misObs", + "obsMTR", + "adj_rate", + "pvalue", + "qvalue", + "proteinLength", + "MTRpercentile_exome", + "MTRpercentile_transcript", + ) + + return ht + + +def import_kaplanis_variants( + liftover_to_grch38: bool = False, + key_by_gene_and_transcript: bool = False, +) -> hl.Table: + """ + Process Kaplanis de novo missense variants file and lift over loci to GRCh38. + + :param liftover_to_grch38: Whether to lift over loci to GRCh38. Default is False. + :param key_by_gene_and_transcript: Whether to key the table by gene and transcript. + Default is False. + :return: Hail Table with Kaplanis de novo missense variants. + """ + ht = hl.import_table(get_kaplanis_variants_tsv(), min_partitions=10, impute=True) + + variant_expr = ht.Variant.split(":") + ht = ht.select( + locus=hl.locus( + variant_expr[0], hl.int(variant_expr[1]), reference_genome="GRCh37" + ), + transcript_id=ht.Transcript_ID, + gene_id=ht.Gene_ID, + gene_name=ht.Gene, + alleles=[variant_expr[2], variant_expr[3]], + ) + + # Add case_control field to keep track of the number of individuals carrying the + # variant in the Kaplanis study. + ht = ht.annotate(case_control="DD") + ht = ht.group_by("locus", "alleles", "gene_id", "transcript_id").aggregate( + case_control=hl.agg.collect(ht.case_control), + ) + + # Set default select and key by fields. + select_fields = ["case_control"] + key_by_fields = ["locus", "alleles"] + if liftover_to_grch38: + # Perform liftover. + ht = default_lift_data(ht) + ht = ht.key_by() + ht = ht.transmute( + locus=ht.new_locus, + alleles=ht.new_alleles, + grch37_locus=ht.original_locus, + grch37_alleles=ht.original_alleles, + ) + select_fields = ["grch37_locus", "grch37_alleles"] + select_fields + + if key_by_gene_and_transcript: + key_by_fields = key_by_fields + ["gene_id", "transcript_id"] + + ht = ht.key_by(*key_by_fields) + ht = ht.select(*select_fields) + + return ht + + +def get_kaplanis_sig_gene_annotations( + gene_name_expr: hl.expr.StringExpression, +) -> hl.Table: + """ + Get Kaplanis significant gene set annotations. + + :param gene_name_expr: Gene name expression. + :return: Struct with boolean expressions for whether the gene is in the significant, + diagnostic consensus, or sig or diagnostic consensus gene set. + """ + # Load significant gene set + sig_gene_ht = hl.import_table(get_kaplanis_sig_variants_tsv()) + sig_gene_expr = sig_gene_ht.significant == "TRUE" + diagnostic_category_consensus_expr = sig_gene_ht.diagnostic_category == "consensus" + agg_expr = hl.agg.collect_as_set(sig_gene_ht.symbol) + sig_genes = hl.literal( + sig_gene_ht.aggregate( + { + "sig_gene": hl.agg.filter(sig_gene_expr, agg_expr), + "consensus_gene": hl.agg.filter( + diagnostic_category_consensus_expr, agg_expr + ), + "sig_or_consensus_gene": hl.agg.filter( + sig_gene_expr | diagnostic_category_consensus_expr, agg_expr + ), + } + ) + ) + return hl.struct( + in_sig_gene=sig_genes["sig_gene"].contains(gene_name_expr), + in_diagnostic_consensus_gene=( + sig_genes["consensus_gene"].contains(gene_name_expr) + ), + in_sig_or_diagnostic_consensus_gene=( + sig_genes["sig_or_consensus_gene"].contains(gene_name_expr) + ), + ) + + +def import_fu_variants() -> None: + """ + Import de novo variants from Fu et al. (2022) paper. + + Function imports variants from TSV into HT. + + :return: Hail Table with Fu de novo variants. + """ + fu_ht = hl.import_table( + get_fu_variants_tsv(), + impute=True, + # Skip blank lines at the bottom of this TSV + missing="", + skip_blank_lines=True, + ) + + # Remove lines from bottom of TSV that are parsed incorrectly upon import + # These lines contain metadata about the TSV, e.g.: + # "Supplementary Table 20. The de novo SNV/indel variants used in TADA + # association analyses from assembled ASD cohorts" + fu_ht = fu_ht.filter(~hl.is_missing(fu_ht.Role)) + fu_ht = fu_ht.annotate( + locus=hl.parse_locus( + hl.format( + "chr%s:%s", + fu_ht.Variant.split(":")[0], + fu_ht.Variant.split(":")[1], + ), + reference_genome="GRCh38", + ), + alleles=[fu_ht.Variant.split(":")[2], fu_ht.Variant.split(":")[3]], + ) + + # Rename 'Proband' > 'ASD' and 'Sibling' > 'control' + fu_ht = fu_ht.transmute(role=hl.if_else(fu_ht.Role == "Proband", "ASD", "control")) + fu_ht = fu_ht.group_by("locus", "alleles").aggregate( + role=hl.agg.collect(fu_ht.role), + ) + + return fu_ht + + +def import_interpro_annotations() -> hl.Table: + """ + Import InterPro annotations. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Hail Table with InterPro annotations. + """ + ht = hl.import_table(get_interpro_annotations(), impute=True, min_partitions=500) + ht = ht.select( + uniprot_id=ht["UniProtKB/Swiss-Prot ID"], + transcript_id=ht["Transcript stable ID"], + residue_index=hl.range( + hl.int(hl.or_missing(ht["Interpro start"] != "", ht["Interpro start"])) - 1, + hl.int(hl.or_missing(ht["Interpro end"] != "", ht["Interpro end"])), + ), + interpro_id=ht["Interpro ID"], + interpro_short_description=ht["Interpro Short Description"], + interpro_description=ht["Interpro Description"], + ) + ht = ht.explode("residue_index") + ht = ht.key_by("transcript_id", "uniprot_id", "residue_index") + + return ht + + +def process_clinvar_ht(clinvar_version: str = "20250504") -> hl.Table: + """ + Process ClinVar HT by extracting info fields, filtering to missense variants, and rekey. + + :param clinvar_version: Version of ClinVar to use. Default is `20250504`. + :return: Hail Table with ClinVar missense variants. + """ + ht = clinvar.versions[clinvar_version].ht() + ht = ht.select("rsid", **ht.info) + ht = ht.filter(hl.any(ht.MC.map(lambda x: x.split("\\|")[1] == "missense_variant"))) + ht = ht.key_by("locus", "alleles", gene=ht.GENEINFO.split(":")[0]) + + return ht + + +def process_gnomad_site_ht(ht) -> hl.Table: + """ + Process gnomAD site Hail Table. + + :param ht: Hail Table to process. + :return: Hail Table with gnomAD site variants. + """ + return ht.select(gnomad_exomes_flags=ht.exome.flags) + + +def process_pext_base_ht(ht) -> hl.Table: + """ + Process PEXT base level Hail Table. + + :param ht: Hail Table to process. + :return: Hail Table with PEXT base level variants. + """ + return ht.key_by("locus", "gene_id").drop("gene_symbol") + + +def process_pext_annotation_ht(ht) -> hl.Table: + """ + Process PEXT annotation level Hail Table. + + :param ht: Hail Table to process. + :return: Hail Table with PEXT annotation level variants. + """ + return ht.key_by("locus", "alleles", "gene_id", "most_severe_consequence").drop( + "gene_symbol" + ) + + +def process_gnomad_de_novo_ht(ht) -> hl.Table: + """ + Process gnomAD de novo Hail Table. + + :return: Hail Table with gnomAD de novo variants. + """ + ht = ht.key_by("locus", "alleles") + ht = ht.select("de_novo_AC", "p_de_novo_stats") + + return ht + + +def process_rmc_ht(version: str = CURRENT_VERSION) -> hl.Table: + """ + Load RMC Hail Table and annotate with p-values and per-locus expansion. + + Annotates: + - `section_oe_upper`: Upper bound of the observed/expected confidence interval. + - `section_p_value`: Chi-square p-value. + - `locus`: List of loci covered by the interval. + + Explodes the `locus` array into multiple rows and keys the table by `locus` and + `transcript`. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Hail Table with RMC data. + """ + rmc_ht = get_rmc_ht(version).ht() + + rmc_ht = rmc_ht.annotate( + section_oe=divide_null(rmc_ht.section_obs, rmc_ht.section_exp), + section_oe_ci=oe_confidence_interval(rmc_ht.section_obs, rmc_ht.section_exp), + section_p_value=hl.pchisqtail(rmc_ht.section_chisq, 1), + locus=hl.range( + rmc_ht.interval.start.position, rmc_ht.interval.end.position + 1 + ).map( + lambda x: hl.locus( + rmc_ht.interval.start.contig, x, reference_genome="GRCh38" + ) + ), + ) + + rmc_ht = rmc_ht.explode("locus").key_by("locus", "transcript") + + return rmc_ht + + +def process_constraint_metrics_ht(version: str = CURRENT_VERSION) -> hl.Table: + """ + Load constraint metrics Hail Table and filter to transcripts starting with "ENST". + + Only selects the syn, mis, and lof constraint groups and only the first oe_info + struct, which contains info for the full dataset, not per-genetic ancestry group. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Hail Table with constraint metrics. + """ + ht = get_constraint_metrics_ht(version).ht() + ht = ht.filter(ht.transcript.startswith("ENST")) + + ht = ht.select( + syn=ht.constraint_groups[0] + .annotate(**ht.constraint_groups[0].oe_info[0]) + .drop("oe_info"), + mis=ht.constraint_groups[1] + .annotate(**ht.constraint_groups[1].oe_info[0]) + .drop("oe_info"), + lof=ht.constraint_groups[5] + .annotate(**ht.constraint_groups[5].oe_info[0]) + .drop("oe_info"), + ) + + ht = ht.key_by("transcript") + + return ht + + +def process_context_ht(version: str = CURRENT_VERSION) -> hl.Table: + """ + Read, re-key, and restructure the context preprocessed Hail Table. + + This function: + 1. Reads the context table from `get_context_preprocessed_ht`. + 2. Extracts population labels from the `exomes_freq_meta` field. + 3. Re-keys the table by (`locus`, `alleles`). + 4. Selects and renames key fields: + - Coverage statistics (`mean`, `median_approx`, `AN`, `percent_AN`) + - Calibrated mutation frequencies per population + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: The processed and checkpointed Hail Table. + """ + ht = get_context_preprocessed_ht(version).ht() + pops = hl.eval(ht.exomes_freq_meta.map(lambda x: x.get("gen_anc", "total"))) + + ht = ht.select( + "context", + "ref", + "alt", + "was_flipped", + "transition", + "cpg", + "mutation_type", + "methylation_level", + transcript_consequences=ht.vep.transcript_consequences.map( + lambda x: x.annotate( + canonical=hl.or_else(x.canonical == 1, False), + mane_select=hl.is_defined(x.mane_select), + vep_domains=x.domains, + residue_ref=x.amino_acids.split("/").first(), + residue_alt=x.amino_acids.split("/").last(), + ).select( + "transcript_id", + "gene_id", + "gene_symbol", + "canonical", + "mane_select", + "biotype", + "most_severe_consequence", + "sift_score", + "polyphen_score", + "vep_domains", + "residue_ref", + "residue_alt", + ) + ), + gnomad_exomes_filters=ht.filters.exomes, + gnomad_exomes_coverage=hl.struct( + mean=ht.coverage.exomes.mean, + median_approx=ht.coverage.exomes.median_approx, + AN=ht.AN.exomes, + percent_AN=ht.exomes_coverage, + ), + gnomad_exomes_freq=hl.struct( + **{pop: ht.calibrate_mu.exomes_freq[i] for i, pop in enumerate(pops)} + ), + ).explode("transcript_consequences") + ht = ht.transmute(**ht.transcript_consequences) + ht = ht.key_by("locus", "alleles", "transcript_id") + + return ht + + +def process_genetics_gym_missense_scores_ht() -> hl.Table: + """ + Process Genetics Gym missense scores Hail Table. + + :return: Hail Table with Genetics Gym missense scores. + """ + scores = [ + "esm_score", + "proteinmpnn_llr", + "am_pathogenicity", + "rasp_score", + "MisFit_S", + "MisFit_D", + "popeve", + "eve", + "esm1_v", + "mpc", + "esm_score_neg", + "proteinmpnn_llr_neg", + "popeve_neg", + "esm1_v_neg", + ] + ht = ( + get_genetics_gym_missense_scores_ht() + .ht() + .select( + "ensembl_tid", + "uniprot_id", + *scores, + ) + ) + ht = ht.filter(hl.any([hl.is_defined(ht[s]) for s in scores])) + ht = ht.annotate( + uniprot_id=hl.or_else(ht.uniprot_id, "None"), rand_n=hl.rand_unif(0, 1) + ).cache() + + for s in scores: + ht = ht.order_by(ht[s], ht.rand_n) + ht = ht.annotate( + **{ + f"{s}_idx": hl.or_missing( + hl.is_defined(ht[s]), hl.scan.count_where(hl.is_defined(ht[s])) + ) + } + ).cache() + + ht = ht.annotate(transcript_id=ht.ensembl_tid) + ht = ht.key_by("locus", "alleles", "transcript_id", "uniprot_id").cache() + max_idx = ht.aggregate(hl.struct(**{s: hl.agg.max(ht[f"{s}_idx"]) for s in scores})) + ht = ht.select( + **{ + s: hl.struct( + score=ht[s], + idx=ht[f"{s}_idx"], + percentile=hl.int((ht[f"{s}_idx"] / max_idx[s]) * 100), + ) + for s in scores + } + ) + + return ht + + +def import_revel_ht() -> hl.Table: + """ + Import REVEL Hail Table. + + :return: Hail Table with REVEL annotations. + """ + ht = hl.import_table( + get_revel_csv(), + delimiter=",", + min_partitions=1000, + types={"grch38_pos": hl.tstr, "REVEL": hl.tfloat64}, + ) + + ht = ht.drop("hg19_pos", "aaref", "aaalt") + + # Drop variants that have no position in GRCh38 when lifted over from GRCh37. + ht = ht.filter(ht.grch38_pos.contains("."), keep=False) + ht = ht.transmute(chr="chr" + ht.chr) + ht = ht.select( + locus=hl.locus(ht.chr, hl.int(ht.grch38_pos), reference_genome="GRCh38"), + alleles=hl.array([ht.ref, ht.alt]), + revel=ht.REVEL, + transcript_id=ht.Ensembl_transcriptid.strip().split(";"), + ) + ht = ht.explode("transcript_id") + ht = ht.key_by("locus", "alleles", "transcript_id") + + return ht + + +def main(args): + """Execute the Proemis 3D pipeline.""" + hl.init( + log="/proemis_3d_data_import.log", + tmp_dir="gs://gnomad-tmp-4day", + ) + overwrite = args.overwrite + + if args.import_cosmis_score_data or args.import_all: + for model in ["alphafold", "swiss_model", "pdb"]: + import_cosmis_score_data(model).write( + get_cosmis_score_ht(model).path, + overwrite=overwrite, + ) + + if args.import_varity_data or args.import_all: + import_varity_data().write( + get_varity_ht().path, + overwrite=overwrite, + ) + + if args.import_mtr3d_data or args.import_all: + import_mtr3d_data().write( + get_mtr3d_ht().path, + overwrite=overwrite, + ) + + if args.import_interpro_annotations or args.import_all: + import_interpro_annotations().write( + get_interpro_annotations_ht().path, + overwrite=overwrite, + ) + + if args.import_kaplanis_variants or args.import_all: + import_kaplanis_variants().write( + get_kaplanis_variants_ht().path, + overwrite=overwrite, + ) + import_kaplanis_variants( + liftover_to_grch38=True, + key_by_gene_and_transcript=True, + ).write( + get_kaplanis_variants_ht( + liftover_to_grch38=True, key_by_transcript=True + ).path, + overwrite=overwrite, + ) + import_kaplanis_variants( + liftover_to_grch38=True, + key_by_gene_and_transcript=False, + ).write( + get_kaplanis_variants_ht( + liftover_to_grch38=True, key_by_transcript=False + ).path, + overwrite=overwrite, + ) + import_kaplanis_variants( + liftover_to_grch38=False, + key_by_gene_and_transcript=True, + ).write( + get_kaplanis_variants_ht( + liftover_to_grch38=False, key_by_transcript=True + ).path, + overwrite=overwrite, + ) + + if args.import_fu_variants or args.import_all: + import_fu_variants().write( + get_fu_variants_ht().path, + overwrite=overwrite, + ) + + if args.import_revel_ht or args.import_all: + import_revel_ht().write( + get_insilico_annotations_ht("revel").path, + overwrite=overwrite, + ) + + if args.process_clinvar_ht or args.import_all: + process_clinvar_ht(args.clinvar_version).write( + get_clinvar_missense_ht().path, + overwrite=overwrite, + ) + + if args.process_constraint_metrics_ht or args.import_all: + process_constraint_metrics_ht(args.constraint_metrics_ht_version).write( + get_temp_processed_constraint_ht(args.constraint_metrics_ht_version).path, + overwrite=overwrite, + ) + + if args.import_mtr_data or args.import_all: + import_mtr_data().write( + get_mtr_ht().path, + overwrite=overwrite, + ) + + if args.process_rmc_ht or args.import_all: + process_rmc_ht(args.rmc_ht_version).write( + get_temp_processed_rmc_ht(args.rmc_ht_version).path, + overwrite=overwrite, + ) + + if args.process_context_ht or args.import_all: + process_context_ht(args.context_ht_version).write( + get_temp_context_preprocessed_ht(args.context_ht_version).path, + overwrite=overwrite, + ) + + if args.process_genetics_gym_missense_scores_ht or args.import_all: + process_genetics_gym_missense_scores_ht().write( + get_processed_genetics_gym_missense_scores_ht().path, + overwrite=overwrite, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--overwrite", help="Whether to overwrite output files.", action="store_true" + ) + parser.add_argument( + "--import-all", + help="Import all data.", + action="store_true", + ) + parser.add_argument( + "--import-cosmis-score-data", + help="Whether to import COSMIS score data.", + action="store_true", + ) + parser.add_argument( + "--import-varity-data", + help="Whether to import Varity data.", + action="store_true", + ) + parser.add_argument( + "--import-mtr3d-data", + help="Whether to import MTR3D data.", + action="store_true", + ) + parser.add_argument( + "--import-interpro-annotations", + help="Whether to import InterPro annotations.", + action="store_true", + ) + parser.add_argument( + "--import-kaplanis-variants", + help="Whether to import Kaplanis variants.", + action="store_true", + ) + parser.add_argument( + "--import-fu-variants", + help="Whether to import Fu variants.", + action="store_true", + ) + parser.add_argument( + "--process-clinvar-ht", + help="Whether to process ClinVar HT.", + action="store_true", + ) + parser.add_argument( + "--import-revel-ht", + help="Whether to import REVEL HT.", + action="store_true", + ) + parser.add_argument( + "--clinvar-version", + help="Version of ClinVar HT to process.", + default="20250504", + ) + parser.add_argument( + "--process-constraint-metrics-ht", + help="Whether to process constraint metrics HT.", + action="store_true", + ) + parser.add_argument( + "--constraint-metrics-ht-version", + help="Version of constraint metrics HT to process.", + default=CURRENT_VERSION, + ) + parser.add_argument( + "--import-mtr-data", + help="Whether to import MTR data.", + action="store_true", + ) + parser.add_argument( + "--process-rmc-ht", + help="Whether to process the RMC HT.", + action="store_true", + ) + parser.add_argument( + "--rmc-ht-version", + help="Version of RMC HT to process.", + default=CURRENT_VERSION, + ) + parser.add_argument( + "--process-context-ht", + help="Whether to process the context HT.", + action="store_true", + ) + parser.add_argument( + "--context-ht-version", + help="Version of context HT to process.", + default=CURRENT_VERSION, + ) + parser.add_argument( + "--process-genetics-gym-missense-scores-ht", + help="Whether to process the Genetics Gym missense scores HT.", + action="store_true", + ) + + args = parser.parse_args() + main(args) diff --git a/gnomad_constraint/experimental/proemis3d/proemis_3d.py b/gnomad_constraint/experimental/proemis3d/proemis_3d.py new file mode 100644 index 00000000..bdd640b7 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/proemis_3d.py @@ -0,0 +1,609 @@ +"""Script to perform the proemis3d pipeline.""" + +import argparse +import logging + +import hail as hl +from gnomad_qc.resource_utils import ( + PipelineResourceCollection, + PipelineStepResourceCollection, +) + +import gnomad_constraint.experimental.proemis3d.resources as proemis3d_res +from gnomad_constraint.experimental.proemis3d.constants import MIN_EXP_MIS +from gnomad_constraint.experimental.proemis3d.utils import ( + COLNAMES_TRANSLATIONS, + convert_fasta_to_table, + convert_gencode_transcripts_fasta_to_table, + create_missense_viewer_input_ht, + create_per_proemis3d_region_ht_from_residue_ht, + create_per_residue_ht_from_snv_ht, + create_per_snv_combined_ht, + determine_regions_with_min_oe_upper, + generate_all_possible_snvs_from_gencode_positions, + generate_codon_oe_table, + get_gencode_positions, + join_by_sequence, + process_af2_structures, + remove_multi_frag_uniprots, + run_forward, + run_greedy, +) + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger("proemis3d_pipeline") +logger.setLevel(logging.INFO) + +TEST_TRANSCRIPT_ID = "ENST00000372435" # "ENST00000215754" +TEST_UNIPROT_ID = "P60891" # "P14174" +"""Transcript and UniProt IDs for testing.""" + + +def get_proemis3d_resources( + version: str, + overwrite: bool, + test: bool, +) -> PipelineResourceCollection: + """ + Get PipelineResourceCollection for all resources needed in the proemis3d pipeline. + + :param version: Version of proemis3d resources to use. + :param overwrite: Whether to overwrite existing resources. + :param test: Whether to use test resources. + :return: PipelineResourceCollection containing resources for all steps of the + proemis3D pipeline. + """ + # Get glob for AlphaFold2 structures. + af2_dir_path = proemis3d_res.get_alpha_fold2_dir(version=version) + if test: + af2_struct_dir_path = f"{af2_dir_path}/AF-{TEST_UNIPROT_ID}-*.cif.gz" + else: + af2_struct_dir_path = f"{af2_dir_path}/*.cif.gz" + + # Get glob for AlphaFold2 confidence. + if test: + af2_conf_dir_path = ( + f"{af2_dir_path}/AF-{TEST_UNIPROT_ID}-*confidence_v4.json.gz" + ) + else: + af2_conf_dir_path = f"{af2_dir_path}/*confidence_v4.json.gz" + + # Get glob for AlphaFold2 pAE. + if test: + af2_pae_dir_path = ( + f"{af2_dir_path}/AF-{TEST_UNIPROT_ID}-*predicted_aligned_error_v4.json.gz" + ) + else: + af2_pae_dir_path = f"{af2_dir_path}/*predicted_aligned_error_v4.json.gz" + + # Initialize proemis3D pipeline resource collection. + proemis3d_pipeline = PipelineResourceCollection( + pipeline_name="proemis3d", + overwrite=overwrite, + pipeline_resources={ + "AlphaFold2 directory": { + "af2_struct_dir_path": af2_struct_dir_path, + "af2_conf_dir_path": af2_conf_dir_path, + "af2_pae_dir_path": af2_pae_dir_path, + } + }, + ) + + # Create resource collection for each step of the proemis3D pipeline. + gencode_transcipt = PipelineStepResourceCollection( + "--convert-gencode-fastn-to-ht", + input_resources={ + "GENCODE transcripts FASTA": { + "gencode_transcipt_fasta_path": proemis3d_res.get_gencode_fasta( + version=version, name="pc_transcripts" + ) + }, + }, + output_resources={ + "gencode_transcipt_ht": proemis3d_res.get_gencode_seq_ht( + version=version, name="pc_transcripts", test=test + ), + }, + ) + gencode_translation = PipelineStepResourceCollection( + "--convert-gencode-fasta-to-ht", + input_resources={ + "GENCODE translations FASTA": { + "gencode_translation_fasta_path": proemis3d_res.get_gencode_fasta( + version=version, name="pc_translations" + ) + }, + }, + output_resources={ + "gencode_translation_ht": proemis3d_res.get_gencode_seq_ht( + version=version, name="pc_translations", test=test + ), + }, + ) + read_af2_sequences = PipelineStepResourceCollection( + "--read-af2-sequences", + output_resources={"af2_ht": proemis3d_res.get_af2_ht(version, test)}, + ) + compute_af2_distance_matrices = PipelineStepResourceCollection( + "--compute-af2-distance-matrices", + output_resources={"af2_dist_ht": proemis3d_res.get_af2_dist_ht(version, test)}, + ) + extract_af2_plddt = PipelineStepResourceCollection( + "--extract-af2-plddt", + output_resources={ + "af2_plddt_ht": proemis3d_res.get_af2_plddt_ht(version, test) + }, + ) + extract_af2_pae = PipelineStepResourceCollection( + "--extract-af2-pae", + output_resources={"af2_pae_ht": proemis3d_res.get_af2_pae_ht(version, test)}, + ) + gencode_alignment = PipelineStepResourceCollection( + "--gencode-alignment", + pipeline_input_steps=[gencode_translation, read_af2_sequences], + output_resources={ + "matched_ht": proemis3d_res.get_gencode_translations_matched_ht( + version, test + ), + }, + ) + get_gencode_positions = PipelineStepResourceCollection( + "--get-gencode-positions", + pipeline_input_steps=[gencode_transcipt, gencode_alignment], + add_input_resources={ + "GENCODE GTF": {"gencode_gtf_ht": proemis3d_res.get_gencode_ht(version)} + }, + output_resources={ + "gencode_pos_ht": proemis3d_res.get_gencode_pos_ht(version, test), + }, + ) + run_greedy = PipelineStepResourceCollection( + "--run-greedy", + pipeline_input_steps=[compute_af2_distance_matrices, get_gencode_positions], + add_input_resources={ + "RMC OE Table": {"obs_exp_ht": proemis3d_res.get_obs_exp_ht(version)} + }, + output_resources={"greedy_ht": proemis3d_res.get_greedy_ht(version, test)}, + ) + run_forward = PipelineStepResourceCollection( + "--run-forward", + pipeline_input_steps=[compute_af2_distance_matrices, get_gencode_positions], + add_input_resources={ + "RMC OE Table": {"obs_exp_ht": proemis3d_res.get_obs_exp_ht(version)} + }, + output_resources={"forward_ht": proemis3d_res.get_forward_ht(version, test)}, + ) + # Add resources for per-variant, per-residue, and per-region HTs. + write_per_variant = PipelineStepResourceCollection( + "--write-per-variant", + pipeline_input_steps=[ + gencode_transcipt, + gencode_translation, + gencode_alignment, + compute_af2_distance_matrices, + extract_af2_plddt, + extract_af2_pae, + run_forward, + ], + add_input_resources={ + "GENCODE GTF": {"gencode_gtf_ht": proemis3d_res.get_gencode_ht(version)} + }, + output_resources={ + "per_variant_ht": proemis3d_res.get_forward_annotation_ht( + "per_variant", version, test + ), + }, + ) + write_per_missense_variant = PipelineStepResourceCollection( + "--write-per-missense-variant", + pipeline_input_steps=[write_per_variant], + output_resources={ + "per_missense_variant_ht": proemis3d_res.get_forward_annotation_ht( + "per_missense_variant", version, test + ), + }, + ) + write_per_residue = PipelineStepResourceCollection( + "--write-per-residue", + pipeline_input_steps=[write_per_variant], + output_resources={ + "per_residue_ht": proemis3d_res.get_forward_annotation_ht( + "per_residue", version, test + ), + }, + ) + write_per_region = PipelineStepResourceCollection( + "--write-per-region", + pipeline_input_steps=[write_per_residue], + output_resources={ + "per_region_ht": proemis3d_res.get_forward_annotation_ht( + "per_region", version, test + ), + }, + ) + create_missense_viewer_input_ht = PipelineStepResourceCollection( + "--create-missense-viewer-input-ht", + pipeline_input_steps=[get_gencode_positions, run_forward], + output_resources={ + "missense_viewer_input_ht": proemis3d_res.get_missense_viewer_input_ht( + version + ), + }, + ) + + # Add all steps to the proemis3D pipeline resource collection. + proemis3d_pipeline.add_steps( + { + "gencode_transcipt": gencode_transcipt, + "gencode_translation": gencode_translation, + "read_af2_sequences": read_af2_sequences, + "compute_af2_distance_matrices": compute_af2_distance_matrices, + "extract_af2_plddt": extract_af2_plddt, + "extract_af2_pae": extract_af2_pae, + "gencode_alignment": gencode_alignment, + "get_gencode_positions": get_gencode_positions, + "run_greedy": run_greedy, + "run_forward": run_forward, + "write_per_variant": write_per_variant, + "write_per_missense_variant": write_per_missense_variant, + "write_per_residue": write_per_residue, + "write_per_region": write_per_region, + } + ) + + return proemis3d_pipeline + + +def main(args): + """Execute the Proemis 3D pipeline.""" + hl.init( + log="/proemis_3d.log", + tmp_dir="gs://gnomad-tmp-4day", + ) + version = args.version + test = args.test + overwrite = args.overwrite + + if version not in proemis3d_res.VERSIONS: + raise ValueError("The requested version of resource Tables is not available.") + + # Construct resources with paths for intermediate Tables generated in the pipeline. + resources = get_proemis3d_resources(version, overwrite, test) + + if args.convert_gencode_fastn_to_ht: + logger.info( + "Importing and pre-process GENCODE transcripts FASTA file as a Hail Table." + ) + res = resources.gencode_transcipt + res.check_resource_existence() + ht = convert_gencode_transcripts_fasta_to_table( + res.gencode_transcipt_fasta_path + ) + if test: + ht = ht.filter(ht.enst == TEST_TRANSCRIPT_ID) + ht = ht.checkpoint(res.gencode_transcipt_ht.path, overwrite=overwrite) + ht.show() + + if args.convert_gencode_fasta_to_ht: + logger.info( + "Importing and pre-process GENCODE translations FASTA file as a Hail Table." + ) + res = resources.gencode_translation + res.check_resource_existence() + ht = convert_fasta_to_table( + res.gencode_translation_fasta_path, COLNAMES_TRANSLATIONS[version] + ) + if test: + ht = ht.filter(ht.enst == TEST_TRANSCRIPT_ID) + ht = ht.checkpoint(res.gencode_translation_ht.path, overwrite=overwrite) + ht.show() + + if args.read_af2_sequences: + logger.info( + "Processing AlphaFold2 structures from a GCS bucket into a Hail Table." + ) + res = resources.read_af2_sequences + res.check_resource_existence() + ht = process_af2_structures(resources.af2_struct_dir_path, mode="sequence") + ht = remove_multi_frag_uniprots(ht) + ht = ht.checkpoint(res.af2_ht.path, overwrite=overwrite) + ht.show() + + if args.compute_af2_distance_matrices: + logger.info("Computing distance matrices for AlphaFold2 structures.") + res = resources.compute_af2_distance_matrices + res.check_resource_existence() + ht = process_af2_structures(resources.af2_struct_dir_path, mode="distance") + ht = remove_multi_frag_uniprots(ht) + ht = ht.checkpoint(res.af2_dist_ht.path, overwrite=overwrite) + ht.show() + + if args.extract_af2_plddt: + logger.info("Extracting pLDDT scores from AlphaFold2 structures.") + res = resources.extract_af2_plddt + res.check_resource_existence() + ht = process_af2_structures(resources.af2_conf_dir_path, mode="plddt") + ht = remove_multi_frag_uniprots(ht) + ht = ht.checkpoint(res.af2_plddt_ht.path, overwrite=overwrite) + ht.show() + + if args.extract_af2_pae: + logger.info("Extracting pAE scores from AlphaFold2 structures.") + res = resources.extract_af2_pae + res.check_resource_existence() + ht = process_af2_structures(resources.af2_pae_dir_path, mode="pae") + ht = remove_multi_frag_uniprots(ht) + ht = ht.checkpoint(res.af2_pae_ht.path, overwrite=overwrite) + ht.show() + + if args.gencode_alignment: + logger.info( + "Joining the GENCODE translations and AlphaFold2 structures based on " + "sequence." + ) + res = resources.gencode_alignment + res.check_resource_existence() + ht = join_by_sequence(res.af2_ht.ht(), res.gencode_translation_ht.ht()) + ht = ht.checkpoint(res.matched_ht.path, overwrite=overwrite) + ht.show() + + if args.get_gencode_positions: + logger.info("Creating GENCODE positions Hail Table.") + res = resources.get_gencode_positions + res.check_resource_existence() + ht = res.gencode_gtf_ht.ht() + + if test: + ht.filter(ht.transcript_id == TEST_TRANSCRIPT_ID) + ht = get_gencode_positions( + res.gencode_transcipt_ht.ht(), res.matched_ht.ht(), ht + ) + ht = ht.checkpoint(res.gencode_pos_ht.path, overwrite=overwrite) + ht.show() + + if args.run_greedy or args.run_forward: + logger.info("Preparing to run greedy and/or forward algorithms.") + if args.run_greedy: + res = resources.run_greedy + res.check_resource_existence() + if args.run_forward: + res = resources.run_forward + res.check_resource_existence() + + # Use new shuffle method for apply models to prevent shuffle errors. + hl._set_flags(use_new_shuffle="1") + + ht = res.obs_exp_ht.ht() + if test: + ht = ht.filter(ht.transcript == TEST_TRANSCRIPT_ID) + + ht = ht.filter(ht.annotation == "missense_variant") + + if version == "2.1.1": + ht = ht.group_by("locus", "transcript").aggregate( + obs=hl.agg.sum(ht.observed), exp=hl.agg.sum(ht.expected) + ) + elif version == "4.1": + ht = ht.group_by("locus", "transcript").aggregate( + obs=hl.agg.sum(ht.calibrate_mu.observed_variants[0]), + exp=hl.agg.sum(ht.expected_variants[0]), + ) + else: + raise ValueError( + "The requested version of the resource Tables is not available." + ) + ht = generate_codon_oe_table(ht, res.gencode_pos_ht.ht()) + ht = ht.repartition(5000).checkpoint(hl.utils.new_temp_file("codon_oe", "ht")) + af2_ht = ( + res.af2_dist_ht.ht() + .repartition(5000) + .checkpoint(hl.utils.new_temp_file("af2_dist", "ht")) + ) + ht = determine_regions_with_min_oe_upper( + af2_ht, ht, min_exp_mis=args.min_exp_mis + ) + ht = ht.repartition(5000).checkpoint( + hl.utils.new_temp_file("sort_regions_by_oe", "ht") + ) + + if args.run_greedy: + logger.info("Running greedy algorithm.") + res = resources.run_greedy + greedy_ht = run_greedy(ht) + greedy_ht = greedy_ht.checkpoint(res.greedy_ht.path, overwrite=overwrite) + greedy_ht.show() + + if args.run_forward: + logger.info("Running forward algorithm.") + res = resources.run_forward + forward_ht = run_forward(ht, min_exp_mis=args.min_exp_mis) + forward_ht = forward_ht.checkpoint(res.forward_ht.path, overwrite=overwrite) + forward_ht.show() + + if args.write_per_variant: + logger.info("Creating per-variant annotated Hail Table.") + res = resources.write_per_variant + res.check_resource_existence() + + all_snv_temp_path = hl.utils.new_temp_file("all_snv", "ht") + ht = generate_all_possible_snvs_from_gencode_positions( + res.gencode_transcipt_ht.ht(), + res.gencode_translation_ht.ht().repartition(1000), + res.gencode_gtf_ht.ht(), + res.matched_ht.ht(), + ).checkpoint(all_snv_temp_path, overwrite=overwrite) + partition_intervals = ht._calculate_new_partitions(args.all_snv_n_partitions) + ht = hl.read_table( + all_snv_temp_path, _intervals=partition_intervals + ).checkpoint( + proemis3d_res.get_temp_all_possible_snvs_ht().path, overwrite=overwrite + ) + + ht = create_per_snv_combined_ht( + ht, + res.forward_ht.ht(), + res.af2_plddt_ht.ht(), + res.af2_pae_ht.ht(), + res.af2_dist_ht.ht(), + ) + ht = ht.checkpoint(res.per_variant_ht.path, overwrite=overwrite) + ht.describe() + + if args.write_per_missense_variant: + logger.info("Filtering per-variant annotated Hail Table to missense variants.") + res = resources.write_per_missense_variant + res.check_resource_existence() + ht = res.per_variant_ht.ht() + ht = ht.filter( + hl.any( + ht.variant_level_annotations.transcript_consequences.map( + lambda x: x == "missense_variant" + ) + ) + ) + ht = ht.checkpoint(res.per_missense_variant_ht.path, overwrite=overwrite) + ht.show() + + if args.write_per_residue: + logger.info("Creating per-residue annotated Hail Table.") + res = resources.write_per_residue + res.check_resource_existence() + ht = create_per_residue_ht_from_snv_ht(res.per_variant_ht.ht()) + ht = ht.checkpoint(res.per_residue_ht.path, overwrite=overwrite) + ht.show() + + if args.write_per_region: + logger.info("Creating per-region annotated Hail Table.") + res = resources.write_per_region + res.check_resource_existence() + ht = create_per_proemis3d_region_ht_from_residue_ht(res.per_residue_ht.ht()) + ht = ht.checkpoint(res.per_region_ht.path, overwrite=overwrite) + ht.show() + + if args.create_missense_viewer_input_ht: + logger.info("Creating missense viewer input Hail Table.") + res = resources.create_missense_viewer_input_ht + res.check_resource_existence() + ht = create_missense_viewer_input_ht( + res.gencode_pos_ht.ht(), res.forward_ht.ht() + ) + ht = ht.checkpoint(res.missense_viewer_input_ht.path, overwrite=overwrite) + ht.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--overwrite", help="Whether to overwrite output files.", action="store_true" + ) + parser.add_argument( + "--version", + help=( + "Which version of the resource Tables will be used. Default is" + f" {proemis3d_res.CURRENT_VERSION}." + ), + type=str, + default=proemis3d_res.CURRENT_VERSION, + ) + parser.add_argument( + "--test", + help="Whether to run a test instead of the full pipeline", + action="store_true", + ) + parser.add_argument( + "--convert-gencode-fastn-to-ht", + help="", + action="store_true", + ) + parser.add_argument( + "--convert-gencode-fasta-to-ht", + help="", + action="store_true", + ) + parser.add_argument( + "--read-af2-sequences", + help="", + action="store_true", + ) + parser.add_argument( + "--compute-af2-distance-matrices", + help="", + action="store_true", + ) + parser.add_argument( + "--extract-af2-plddt", + help="", + action="store_true", + ) + parser.add_argument( + "--extract-af2-pae", + help="", + action="store_true", + ) + parser.add_argument( + "--gencode-alignment", + help="", + action="store_true", + ) + parser.add_argument( + "--get-gencode-positions", + help="", + action="store_true", + ) + parser.add_argument( + "--run-greedy", + help="", + action="store_true", + ) + parser.add_argument( + "--run-forward", + help="", + action="store_true", + ) + parser.add_argument( + "--min-exp-mis", + help=( + "Minimum expected number of missense variants to consider for the greedy " + f"algorithm. Default is {MIN_EXP_MIS}." + ), + type=int, + default=MIN_EXP_MIS, + ) + parser.add_argument( + "--write-per-variant", + action="store_true", + help="Generate per-variant annotated HT", + ) + parser.add_argument( + "--all-snv-n-partitions", + help="Number of partitions to use for the all possible SNVs Hail Table.", + type=int, + default=5000, + ) + parser.add_argument( + "--write-per-missense-variant", + action="store_true", + help="Generate per-variant annotated HT", + ) + parser.add_argument( + "--write-per-residue", + action="store_true", + help="Generate per-residue HT from per-variant", + ) + parser.add_argument( + "--write-per-region", + action="store_true", + help="Generate per-region HT from per-residue", + ) + parser.add_argument( + "--create-missense-viewer-input-ht", + action="store_true", + help="Create missense viewer input HT", + ) + + args = parser.parse_args() + main(args) diff --git a/gnomad_constraint/experimental/proemis3d/resources.py b/gnomad_constraint/experimental/proemis3d/resources.py new file mode 100644 index 00000000..77ea7329 --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/resources.py @@ -0,0 +1,651 @@ +"""Resource definitions for the proemis3d pipeline.""" + +import logging +from typing import Optional + +import hail as hl +from gnomad.resources.grch37.reference_data import gencode as grch37_gencode +from gnomad.resources.grch38.reference_data import gencode as grch38_gencode +from gnomad.resources.resource_utils import ( + BaseResource, + ExpressionResource, + TableResource, + VersionedTableResource, +) + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger("proemis3d_pipeline") +logger.setLevel(logging.INFO) + +VERSIONS = ["2.1.1", "4.1"] +"""Possible gnomAD versions for the proemis3d pipeline.""" + +CURRENT_VERSION = "4.1" +"""Current gnomAD version for the proemis3d pipeline.""" + +GENCODE_VERSION_MAP = { + "2.1.1": "19", + "4.1": "39", +} +"""GENCODE version map for each gnomAD version.""" + + +def get_proemis3d_root(version: str = CURRENT_VERSION, test: bool = False) -> str: + """ + Get root path to proemis3d resources. + + :param version: Version of proemis3d resources to use. + :param test: Whether to use a tmp path for testing. + :return: Root path to proemis3d resources. + """ + return ( + f"gs://gnomad-tmp/gnomad_v{version}_testing/constraint/proemis3d" + if test + else f"gs://gnomad/v{version}/constraint/proemis3d" + ) + + +def get_gencode_fasta( + version: str = CURRENT_VERSION, + name: str = "pc_transcripts", +) -> str: + """ + Get GENCODE FASTA file path. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param name: Name of the type of GENCODE FASTA file to get. One of 'pc_transcripts', + 'pc_translations'. Default is 'pc_transcripts'. + :return: GENCODE FASTA file path. + """ + # TODO: Change this path when moved to a more permanent location. + return ( + f"gs://gnomad-julia/proemis3d/resources/" + f"gencode.v{GENCODE_VERSION_MAP[version]}.{name}.fa.gz" + ) + + +def get_alpha_fold2_dir(version: str = CURRENT_VERSION) -> str: + """ + Get AlphaFold2 directory path. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :return: AlphaFold2 directory path. + """ + # TODO: Change this path when moved to a more permanent location and add any + # needed versioning. + return "gs://gnomad-julia/alphafold2" + + +def get_gencode_seq_ht( + version: str = CURRENT_VERSION, + name: str = "pc_transcripts", + test: bool = False, +) -> TableResource: + """ + Get GENCODE sequences Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param name: Name of the type of GENCODE sequences to get. One of 'pc_transcripts', + 'pc_translations'. Default is 'pc_transcripts'. + :param test: Whether to use a tmp path for testing. Default is False. + :return: GENCODE sequences Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root(version, test)}/preprocessed_data/" + f"gencode_sequences.{name}.ht" + ) + + +def get_af2_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get alphafold2 Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: GENCODE sequences Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root('2.1.1', test)}/preprocessed_data/af2.ht" + ) + + +def get_af2_dist_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get alphafold2 distance matrix Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: AlphaFold2 distance matrix Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root('2.1.1', test)}/preprocessed_data/af2_dist.ht" + ) + + +def get_af2_plddt_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get alphafold2 pLDDT Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: AlphaFold2 pLDDT Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root('2.1.1', test)}/preprocessed_data/af2_plddt.ht" + ) + + +def get_af2_pae_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get alphafold2 pAE Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: AlphaFold2 pAE Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root('2.1.1', test)}/preprocessed_data/af2_pae.ht" + ) + + +def get_gencode_translations_matched_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get GENCODE translations matched Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: GENCODE translations matched Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root(version, test)}/preprocessed_data/" + f"gencode_translations_matched.ht" + ) + + +def get_gencode_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get GENCODE Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :return: GENCODE Hail Table resource. + """ + if version == "2.1.1": + gencode = grch37_gencode + elif version == "4.1": + gencode = grch38_gencode + else: + raise ValueError(f"Invalid version: {version}") + + return gencode.versions[f"v{GENCODE_VERSION_MAP[version]}"] + + +def get_gencode_pos_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get GENCODE positions Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: GENCODE positions Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root(version, test)}/preprocessed_data/gencode_positions.ht" + ) + + +def get_obs_exp_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get observed/expected Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :return: Observed/expected Hail Table resource. + """ + # TODO: Change this path when moved to a more permanent location. + if version == "2.1.1": + return TableResource( + "gs://gnomad/v2.1.1/constraint/temp/gnomad.v2.1.1.per_base_expected.ht" + ) + elif version == "4.1": + return TableResource( + "gs://gnomad/v4.1/constraint_coverage_corrected/apply_models/transcript_consequences/gnomad.v4.1.per_variant_expected.coverage_corrected.ht" + ) + else: + raise ValueError(f"Invalid version: {version}") + + +def get_greedy_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get Proemis3D greedy algorithm Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: Proemis3D greedy algorithm Hail Table resource. + """ + return TableResource(f"{get_proemis3d_root(version, test)}/proemis3D_greedy.ht") + + +def get_forward_ht( + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get Proemis3D forward algorithm Hail Table resource. + + :param version: Version of gnomAD to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: Proemis3D forward algorithm Hail Table resource. + """ + return TableResource(f"{get_proemis3d_root(version, test)}/proemis3D_forward.ht") + + +def get_forward_annotation_ht( + name: str, + version: str = CURRENT_VERSION, + test: bool = False, +) -> TableResource: + """ + Get PROEMIS3D forward algorithm annotation Hail Table resource. + + :param name: Annotation type ('per_variant', 'per_missense_variant', 'per_residue', + or 'per_proemis3d_region'). + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :param test: Whether to use a tmp path for testing. Default is False. + :return: PROEMIS3D annotated forward algorithm Hail Table resource. + """ + return TableResource( + f"{get_proemis3d_root(version, test)}/proemis3D_forward.{name}.annotated.ht" + ) + + +def get_temp_all_possible_snvs_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get temp all possible SNVs Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Temp all possible SNVs Hail Table resource. + """ + return TableResource( + f"gs://gnomad-tmp-4day/v{version}/constraint/proemis3d/all_possible_snvs.ht" + ) + + +def get_missense_viewer_input_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get missense viewer input Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Missense viewer input Hail Table resource. + """ + return TableResource( + f"gs://gnomad/v{version}/constraint/proemis3d/missense_viewer_input.ht" + ) + + +######################################################################################## +# The following functions are for external resources. +######################################################################################## +def get_kaplanis_variants_tsv() -> str: + """ + Get Kaplanis annotated variants file path. + + :return: File path to Kaplanis annotated variants. + """ + return "gs://gnomad/v4.1/constraint/resources/variant_lists/tsv/kaplanis_variants_annotated_2024-05-15.txt" + + +def get_kaplanis_sig_variants_tsv() -> str: + """ + Get Kaplanis significant variants file path. + + :return: File path to Kaplanis significant variants. + """ + return "gs://gnomad/v4.1/constraint/resources/variant_lists/tsv/kaplanis_variants_sig.txt" + + +def get_kaplanis_variants_ht( + liftover_to_grch38: bool = False, + key_by_transcript: bool = False, +) -> TableResource: + """ + Get processed Hail Table path for Kaplanis annotated de novo missense variants. + + This is the output of the `process_kaplanis_variants_ht` function, containing + lifted GRCh38 loci and relevant annotations. + + :param liftover_to_grch38: Whether to liftover the variants to GRCh38. Default is + False. + :param key_by_transcript: Whether to key the table by transcript. Default is False. + :return: Hail Table resource path for processed Kaplanis missense variants. + """ + postfix = "liftover_to_grch38" if liftover_to_grch38 else "" + postfix = f".{postfix}.keyed_by_transcript" if key_by_transcript else postfix + return TableResource( + f"gs://gnomad/v4.1/constraint/resources/variant_lists/ht/kaplanis_variants{postfix}.ht" + ) + + +def get_fu_variants_tsv() -> str: + """ + Get Fu variants TSV file path. + + :return: Fu variants TSV file path. + """ + return "gs://gnomad/v4.1/constraint/resources/variant_lists/tsv/fu_2022_supp20.txt" + + +def get_fu_variants_ht() -> TableResource: + """ + Get processed Hail Table path for Fu annotated de novo missense variants. + + This is the output of the `process_fu_variants_ht` function, containing + lifted GRCh38 loci and relevant annotations. + + :param key_by_transcript: Whether to key the table by transcript. Default is False. + :return: Hail Table resource path for processed Fu missense variants. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/variant_lists/ht/fu_variants.ht" + ) + + +def get_interpro_annotations() -> str: + """ + Get Ensembl BioMart export InterPro annotations file path. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: File path to InterPro annotations. + """ + return "gs://gnomad/v4.1/constraint/resources/annotations/tsv/ensembl_biomart_export_interpro.txt" + + +def get_interpro_annotations_ht() -> TableResource: + """ + Get InterPro annotations Hail Table resource. + + :return: InterPro annotations Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/annotations/ht/ensembl_biomart_export_interpro.ht" + ) + + +def get_cosmis_score_tsv(model: str) -> str: + """ + Get COSMIS scores TSV path for a specified structure model. + + :param model: Structure model source ('alphafold', 'swiss_model', or 'pdb'). + :return: COSMIS scores TSV file path. + """ + return f"gs://gnomad/v4.1/constraint/resources/3d_missense_methods/tsv/cosmis_scores_{model}.tsv.gz" + + +def get_cosmis_score_ht(model: str) -> TableResource: + """ + Get COSMIS Hail Table resource for a specified structure model. + + :param model: Structure model source ('alphafold', 'swiss_model', or 'pdb'). + :return: COSMIS scores Hail Table resource. + """ + return TableResource( + f"gs://gnomad/v4.1/constraint/resources/3d_missense_methods/ht/cosmis_scores_{model}.ht" + ) + + +def get_varity_tsv() -> str: + """ + Get Varity TSV file path. + + :return: Varity TSV file path. + """ + return "gs://gnomad/v4.1/constraint/resources/3d_missense_methods/tsv/varity_all_predictions.txt" + + +def get_varity_ht() -> TableResource: + """ + Get Varity Hail Table resource. + + :return: Varity Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/3d_missense_methods/ht/varity_all_predictions.ht" + ) + + +def get_mtr3d_tsv() -> str: + """ + Get MTR3D TSV file path. + + :return: MTR3D TSV file path. + """ + return "gs://gnomad/v4.1/constraint/resources/3d_missense_methods/tsv/mtr_data.csv" + + +def get_mtr3d_ht() -> TableResource: + """ + Get MTR3D Hail Table resource. + + :return: MTR3D Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/3d_missense_methods/ht/mtr3d_data.ht" + ) + + +def get_mtr_tsv() -> str: + """ + Get MTR TSV file path. + + :return: MTR TSV file path. + """ + return "gs://gnomad/v4.1/constraint/resources/annotations/tsv/full_MTR_scores.tsv" + + +def get_mtr_ht() -> TableResource: + """ + Get MTR Hail Table resource. + + :return: MTR Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/annotations/ht/mtr_data.ht" + ) + + +def get_rmc_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get Hail Table resource with all regional missense constraint (RMC) scores. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: RMC Hail Table resource. + """ + return TableResource(f"gs://gnomad/v{version}/constraint/resources/all_rmc.ht") + + +def get_rmc_browser_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get Hail Table resource with all regional missense constraint (RMC) scores. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: RMC Hail Table resource. + """ + return TableResource( + f"gs://regional_missense_constraint/constraint/{version}/2/rmc_browser.ht" + ) + + +def get_temp_processed_rmc_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get temp processed RMC Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Temp processed RMC Hail Table resource. + """ + return TableResource( + f"gs://gnomad-tmp-4day/v{version}/constraint/resources/all_rmc.ht" + ) + + +def get_context_preprocessed_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get preprocessed context Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Context preprocessed Hail Table resource. + """ + return TableResource( + f"gs://gnomad/v{version}/constraint_coverage_corrected/preprocessed_data/gnomad.v{version}.context.preprocessed.ht" + ) + + +def get_temp_context_preprocessed_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get temp preprocessed context Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Temp preprocessed context Hail Table resource. + """ + return TableResource( + f"gs://gnomad-tmp-4day/v{version}/constraint_coverage_corrected/preprocessed_data/gnomad.v{version}.context.preprocessed.ht" + ) + + +def get_constraint_metrics_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get coverage-corrected constraint metrics Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Constraint metrics Hail Table resource. + """ + return TableResource( + f"gs://gnomad/v{version}/constraint_coverage_corrected/metrics/transcript_consequences/gnomad.v{version}.constraint_metrics.coverage_corrected.ht" + ) + + +def get_temp_processed_constraint_ht(version: str = CURRENT_VERSION) -> TableResource: + """ + Get temp processed constraint Hail Table resource. + + :param version: gnomAD version to use. Default is `CURRENT_VERSION`. + :return: Temp processed constraint Hail Table resource. + """ + return TableResource( + f"gs://gnomad-tmp-4day/v{version}/constraint_coverage_corrected/gnomad.v{version}.constraint.ht" + ) + + +def get_clinvar_missense_ht() -> TableResource: + """ + Get ClinVar missense Hail Table resource. + + :return: ClinVar missense Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/annotations/ht/clinvar_missense.ht" + ) + + +def get_revel_csv() -> str: + """ + Get REVEL CSV file path. + + :return: REVEL CSV file path. + """ + return "gs://gnomad-insilico/revel/revel-v1.3_all_chromosomes_with_transcript_ids.csv.bgz" + + +def get_insilico_annotations_ht(method: str) -> TableResource: + """ + Get insilico annotations Hail Table resource. + + :param method: Insilico method to use. Must be one of 'cadd', 'phylop', or 'revel'. + :return: Insilico annotations Hail Table resource. + """ + if method == "revel": + return TableResource( + "gs://gnomad/v4.1/constraint/resources/annotations/ht/revel.ht" + ) + else: + if method not in {"cadd", "phylop"}: + raise ValueError( + f"Invalid method: {method}. Must be one of 'cadd' or 'phylop'." + ) + + return TableResource( + f"gs://gnomad/v4.0/annotations/in_silico_predictors/gnomad.v4.0.{method}.grch38.ht" + ) + + +def get_genetics_gym_missense_scores_ht() -> TableResource: + """ + Get Genetics Gym missense scores Hail Table resource. + + :return: Genetics Gym missense scores Hail Table resource. + """ + return TableResource("gs://genetics-gym/vsm-tables/all-models-no-PAI3D.ht") + + +def get_processed_genetics_gym_missense_scores_ht() -> TableResource: + """ + Get temp Genetics Gym missense scores Hail Table resource. + + :return: Temp Genetics Gym missense scores Hail Table resource. + """ + return TableResource( + "gs://gnomad/v4.1/constraint/resources/annotations/ht/all_missense_scores_percentile.with_uniprot.ht" + ) + + +def get_phaplo() -> ExpressionResource: + """ + Get phaplo Hail expression resource. + + :return: Phaplo Hail expression resource. + """ + return ExpressionResource( + "gs://gnomad/v4.1/constraint/resources/gene_lists/he/phaplo_genes.he" + ) + + +def get_ptriplo() -> ExpressionResource: + """ + Get ptriplo Hail expression resource. + + :return: Ptriplo Hail expression resource. + """ + return ExpressionResource( + "gs://gnomad/v4.1/constraint/resources/gene_lists/he/ptriplo_genes.he" + ) + + +def get_gnomad_de_novo_ht() -> TableResource: + """ + Get gnomAD de novo Hail Table resource. + + :return: GnomAD de novo Hail Table resource. + """ + return TableResource( + "gs://gcp-public-data--gnomad/release/4.1/ht/exomes/gnomad.exomes.v4.1.de_novo.high_quality_coding.ht" + ) diff --git a/gnomad_constraint/experimental/proemis3d/utils.py b/gnomad_constraint/experimental/proemis3d/utils.py new file mode 100644 index 00000000..9512b99f --- /dev/null +++ b/gnomad_constraint/experimental/proemis3d/utils.py @@ -0,0 +1,2141 @@ +"""Script with utility functions for the Proemis3D pipeline.""" + +import io +import json +import logging +import os +from typing import Dict, Iterator, List, Optional, Union + +import hail as hl +import numpy as np +import pandas as pd +from Bio.PDB import PPBuilder +from Bio.PDB.MMCIFParser import MMCIFParser +from Bio.PDB.Polypeptide import is_aa +from gnomad.resources.grch38.gnomad import browser_gene, browser_variant, pext +from gnomad.utils.constraint import oe_confidence_interval +from gnomad.utils.filtering import filter_gencode_ht +from gnomad.utils.reference_genome import get_reference_genome +from hail.utils.misc import divide_null +from pyspark.sql.functions import col, explode, pandas_udf, rtrim, split +from pyspark.sql.types import StringType, StructField, StructType + +from gnomad_constraint.experimental.proemis3d.constants import ( + HI_GENE_CATEGORIES, + HI_GENES, + MIN_EXP_MIS, +) +from gnomad_constraint.experimental.proemis3d.data_import import ( + get_kaplanis_sig_gene_annotations, + process_gnomad_de_novo_ht, + process_gnomad_site_ht, + process_pext_annotation_ht, + process_pext_base_ht, +) +from gnomad_constraint.experimental.proemis3d.resources import ( + get_clinvar_missense_ht, + get_cosmis_score_ht, + get_fu_variants_ht, + get_gnomad_de_novo_ht, + get_insilico_annotations_ht, + get_interpro_annotations_ht, + get_kaplanis_variants_ht, + get_mtr3d_ht, + get_mtr_ht, + get_phaplo, + get_processed_genetics_gym_missense_scores_ht, + get_ptriplo, + get_rmc_browser_ht, + get_temp_context_preprocessed_ht, + get_temp_processed_constraint_ht, + get_temp_processed_rmc_ht, + get_varity_ht, +) + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger("proemis3d_utils") +logger.setLevel(logging.INFO) + +######################################################################################## +# Functions to perform tasks from convert_gencode_fastn_to_dt.R and +# convert_gencode_fasta_to_dt.R +######################################################################################## +COLNAMES_TRANSCRIPTS = [ + "enst", + "ensg", + "havana_g", + "havana_t", + "transcript", + "gene", + "ntlength", + "index1", + "index2", + "index3", +] +""" +Column names for the GENCODE transcripts Hail Table. +""" + +COLNAMES_TRANSLATIONS = { + "2.1.1": [ + "enst", + "ensg", + "havana_g", + "havana_t", + "transcript", + "gene", + "aalength", + ], + "4.1": [ + "ensp", + "enst", + "ensg", + "havana_g", + "havana_t", + "transcript", + "gene", + "aalength", + ], +} +""" +Column names for the GENCODE translations Hail Table. +""" + +VARIANT_LEVEL_ANNOTATION_CONFIG = { + "context": { + "ht": get_temp_context_preprocessed_ht(), + "keys": ["locus", "alleles", "transcript_id"], + }, + "gnomad_site": { + "ht": browser_variant(), + "keys": ["locus", "alleles"], + "custom_select": process_gnomad_site_ht, + }, + "revel": { + "ht": get_insilico_annotations_ht("revel"), + "keys": ["locus", "alleles", "transcript_id"], + }, + "cadd": { + "ht": get_insilico_annotations_ht("cadd"), + "keys": ["locus", "alleles"], + }, + "phylop": {"ht": get_insilico_annotations_ht("phylop"), "keys": ["locus"]}, + "genetics_gym": { + "ht": get_processed_genetics_gym_missense_scores_ht(), + "keys": ["locus", "alleles", "transcript_id", "uniprot_id"], + "annotation_name": "genetics_gym_missense_scores", + }, + "autism": { + "ht": get_fu_variants_ht(), + "keys": ["locus", "alleles"], + "annotation_name": "autism", + }, + "dd_denovo": { + "ht": get_kaplanis_variants_ht(liftover_to_grch38=True, key_by_transcript=True), + "keys": ["locus", "alleles", "gene_id", "transcript_id"], + "annotation_name": "dd_denovo", + }, + "dd_denovo_no_transcript": { + "ht": get_kaplanis_variants_ht(liftover_to_grch38=True), + "keys": ["locus", "alleles"], + "annotation_name": "dd_denovo_no_transcript_match", + }, + "gnomad_de_novo": { + "ht": get_gnomad_de_novo_ht(), + "keys": ["locus", "alleles"], + "annotation_name": "gnomad_de_novo", + "custom_select": process_gnomad_de_novo_ht, + }, + "clinvar": { + "ht": get_clinvar_missense_ht(), + "keys": ["locus", "alleles", "gene_symbol"], + "annotation_name": "clinvar", + }, + "pext_base": { + "ht": pext("base_level"), + "keys": ["locus", "gene_id"], + "annotation_name": "base_level_pext", + "custom_select": process_pext_base_ht, + }, + "mtr": { + "ht": get_mtr_ht(), + "keys": ["locus", "alleles", "transcript_id"], + "annotation_name": "mtr", + }, + "rmc": { + "ht": get_temp_processed_rmc_ht(), + "keys": ["locus", "transcript_id"], + "annotation_name": "rmc", + }, +} +""" +Configuration for variant level annotations. +""" + +RESIDUE_LEVEL_ANNOTATION_CONFIG = { + "interpro": { + "ht": get_interpro_annotations_ht(), + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "interpro", + }, + "varity": { + "ht": get_varity_ht(), + "keys": ["uniprot_id", "residue_index", "residue_ref", "residue_alt"], + "annotation_name": "varity", + }, + "mtr3d": { + "ht": get_mtr3d_ht(), + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "mtr3d", + }, + "cosmis_alphafold": { + "ht": get_cosmis_score_ht("alphafold"), + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "cosmis_alphafold", + }, + "cosmis_pdb": { + "ht": get_cosmis_score_ht("pdb"), + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "cosmis_pdb", + }, + "cosmis_swiss_model": { + "ht": get_cosmis_score_ht("swiss_model"), + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "cosmis_swiss_model", + }, +} +""" +Configuration for residue level annotations. +""" + +BASE_LEVEL_ANNOTATION_FIELDS = [ + "gene_symbol", + "canonical", + "mane_select", + "transcript_biotype", + "most_severe_consequence", +] +""" +Fields to keep at the base level. +""" + + +def convert_fasta_to_table(fasta_file: str, colnames: List[str]) -> hl.Table: + """ + Convert a FASTA file to a Hail Table. + + :param fasta_file: Path to the FASTA file. + :param colnames: Column names for the Hail Table. + :return: Hail Table with the FASTA file contents. + """ + spark = hl.utils.java.Env.spark_session() + df = spark.read.format("text").load(fasta_file, wholetext=True) + df = df.select(explode(split(df["value"], ">")).alias("sequence")) + + # Convert the Spark DataFrame to a Hail Table. + ht = hl.Table.from_spark(df) + ht = ht.filter(ht.sequence != "") + + split_expr = ht.sequence.split("\n") + split_info_expr = split_expr[0].split("\\|") + ht = ht.select( + **{ + n: hl.or_missing( + (split_info_expr.length() > i) + & ~hl.array(["", "-"]).contains(split_info_expr[i]), + split_info_expr[i], + ) + for i, n in enumerate(colnames) + }, + sequence=split_expr[1].upper(), + ) + + # Remove version numbers from ENST and ENSG. + ht = ht.annotate(enst=ht.enst.split("\\.")[0], ensg=ht.ensg.split("\\.")[0]) + + return ht + + +def convert_gencode_transcripts_fasta_to_table(fasta_file: str) -> hl.Table: + """ + Convert GENCODE transcripts FASTA file to a Hail Table. + + :param fasta_file: Path to the GENCODE transcripts FASTA file. + :return: Hail Table with the GENCODE transcripts FASTA file contents parsed. + """ + ht = convert_fasta_to_table(fasta_file, COLNAMES_TRANSCRIPTS) + + # Organize the UTR5/CDS/UTR3 indices. + # If the UTR5/CDS/UTR3 index is not present, it is set to missing. + ht = ht.annotate( + index1=hl.or_missing(ht.index1.startswith("UTR5"), ht.index1), + index2=( + hl.case() + .when(ht.index1.startswith("CDS"), ht.index1) + .when(ht.index2.startswith("CDS"), ht.index2) + .or_missing() + ), + index3=( + hl.case() + .when(ht.index1.startswith("UTR3"), ht.index1) + .when(ht.index2.startswith("UTR3"), ht.index2) + .when(ht.index3.startswith("UTR3"), ht.index3) + .or_missing() + ), + ) + + # Rename the indices to 'utr5', 'cds', and 'utr3'. + names = ["utr5", "cds", "utr3"] + ht = ht.rename({f"index{i + 1}": n for i, n in enumerate(names)}) + + # Split the 'utr5', 'cds', and 'utr3' annotations into start and end positions. + ht = ht.annotate( + **{ + n: hl.bind(lambda x: x.map(hl.int), ht[n].split(":")[1].split("-")) + for n in names + } + ) + + # Trim the sequence to the CDS range. + ht = ht.annotate(cds_sequence=ht.sequence[ht.cds[0] - 1 : ht.cds[1]]) + + return ht + + +######################################################################################## +# Note that the functionality in split_context_obs_exp.R is not implemented here +# because it just splits the tsv file by transcript, we will be directly using the +# Hail Table instead. +######################################################################################## + + +######################################################################################## +# Functions to perform tasks from read_af2_sequences.R +######################################################################################## +def get_plddt_from_confidence_json(plddt_content: str) -> List[float]: + """ + Get the pLDDT from a confidence JSON file. + + :param plddt_content: Content of the pLDDT JSON file. + :return: List of pLDDT scores. + """ + data = json.loads(plddt_content) + return data["confidenceScore"] + + +def get_pae_from_json(pae_content: str) -> List[List[int]]: + """ + Get the PAE from a PAE JSON file. + + :param pae_content: Content of the PAE JSON file. + :return: List of PAE scores. + """ + data = json.loads(pae_content) + return data[0]["predicted_aligned_error"] + + +def get_structure_peptide(structure) -> str: + """ + Get the sequence from a structure. + + :param structure: Structure object. + :return: Sequence as a string. + """ + ppb = PPBuilder() + + # Return the sequence as a string. + return "".join([str(pp.get_sequence()) for pp in ppb.build_peptides(structure)]) + + +def get_structure_dist_matrix(structure: MMCIFParser) -> np.ndarray: + """ + Calculate the "calpha" distance matrix from a structure. + + :param structure: Structure object. + :return: Distance matrix as a NumPy array. + """ + calpha_atoms = [] + for model in structure: + for chain in model: + for residue in chain: + if is_aa(residue, standard=True) and "CA" in residue: + calpha_atoms.append(residue["CA"].get_coord()) + + def _calc_dist_matrix(calpha_atoms: List[np.ndarray]) -> np.ndarray: + """ + Calculate the pairwise distance matrix between Calpha atoms. + + :param calpha_atoms: List of Calpha atoms. + :return: Distance matrix as a NumPy array. + """ + num_atoms = len(calpha_atoms) + dist_matrix = np.zeros((num_atoms, num_atoms)) + for i, atom1 in enumerate(calpha_atoms): + for j, atom2 in enumerate(calpha_atoms): + dist_matrix[i, j] = np.linalg.norm(atom1 - atom2) + + return dist_matrix + + # Calculate the distance matrix + return _calc_dist_matrix(calpha_atoms) + + +def process_af2_mmcif( + uniprot_id: str, + mmcif_content: str, + distance_matrix: bool = False, +) -> Union[str, np.ndarray, List[float]]: + """ + Process AlphaFold2 MMCIF content. + + :param uniprot_id: UniProt ID. + :param mmcif_content: MMCIF content as a string. + :param distance_matrix: Whether to return the distance matrix. Default is False. + :return: Sequence or distance matrix. + """ + parser = MMCIFParser(QUIET=True) + structure = parser.get_structure(uniprot_id, io.StringIO(mmcif_content)) + + if distance_matrix: + return get_structure_dist_matrix(structure) + else: + return get_structure_peptide(structure) + + +def process_af2_file_by_mode( + uniprot_id: str, + file_content: str, + mode: str, +) -> Union[str, np.ndarray, List[float]]: + """ + Dispatcher to handle different AF2 modes based on filename suffix and mode. + + :param uniprot_id: UniProt ID. + :param file_content: File content. + :param mode: Mode for processing files. Options are 'sequence', 'distance_matrix', + 'plddt', or 'pae'. Default is 'sequence'. + :return: Sequence or distance matrix. + """ + if mode in {"sequence", "distance_matrix"}: + return process_af2_mmcif( + uniprot_id, file_content, distance_matrix=(mode == "distance_matrix") + ) + if mode == "plddt": + return get_plddt_from_confidence_json(file_content) + + if mode == "pae": + return get_pae_from_json(file_content) + + raise ValueError(f"Unsupported mode: {mode}") + + +def process_af2_structures( + gcs_bucket_glob: str, + mode: str = "sequence", +) -> hl.Table: + """ + Process AlphaFold2 structures from a GCS bucket. + + .. note:: + + All files in the bucket must be in CIF format with a '.cif.gz' extension. + + :param gcs_bucket_glob: GCS bucket glob pattern. + :param mode: Mode for processing files. Options are 'sequence', 'distance_matrix', + 'plddt', or 'pae'. Default is 'sequence'. + :return: Hail Table with UniProt IDs and sequences or distance matrices. + """ + # Get Spark session for file distribution and processing. + spark = hl.utils.java.Env.spark_session() + spark.conf.set( + "spark.sql.execution.arrow.maxRecordsPerBatch", + 1000, + ) + + # Define schema for loading the files. + schema = StructType( + [ + StructField("file_content", StringType(), True), + StructField("af2_file", StringType(), True), + ] + ) + + # Use Spark to read files in parallel. + # This reads the entire content of each file as a (filename, content) pair. + af2_files_df = ( + spark.read.format("text") + .load(gcs_bucket_glob, schema=schema, wholetext=True) + .withColumn("af2_file", col("_metadata.file_path")) + ) + if mode == "distance_matrix": + col_name = "dist_mat" + rtype = "array>" + elif mode == "pae": + col_name = "pae" + rtype = "array>" + elif mode == "plddt": + col_name = "plddt" + rtype = "array" + else: + col_name = "sequence" + rtype = "string" + + @pandas_udf(f"uniprot_id string, {col_name} {rtype}") + def process_file( + file_path_series: pd.Series, file_content_series: pd.Series + ) -> pd.DataFrame: + """ + Process a list of files in parallel using a Pandas UDF. + + :param file_path_series: File paths. + :param file_content_series: File contents. + :return: Pandas DataFrame with UniProt IDs and sequences. + """ + result = [] + for file_path, file_content in zip(file_path_series, file_content_series): + # Extract UniProt ID from the file path. + uniprot_id = os.path.basename(file_path).split("-")[1] + + # Process the file content. + af2_data = process_af2_file_by_mode(uniprot_id, file_content, mode=mode) + result.append((uniprot_id, af2_data)) + + return pd.DataFrame(result, columns=["uniprot_id", col_name]) + + # Apply the Pandas UDF to process the files. + result_df = af2_files_df.withColumn( + "uniprot_id_sequence", process_file(col("af2_file"), col("file_content")) + ) + + # Split the resulting column into separate columns. + result_df = result_df.select( + "af2_file", + col("uniprot_id_sequence.uniprot_id"), + col(f"uniprot_id_sequence.{col_name}"), + ) + if mode in {"distance_matrix", "pae"}: + from pyspark.sql.functions import posexplode + + result_df = result_df.select( + "af2_file", + "uniprot_id", + posexplode(col(col_name)).alias("aa_index", col_name), + ) + tmp_path = hl.utils.new_temp_file("process_af2_structures", "parquet") + logger.info(f"Writing processed AlphaFold2 structures to {tmp_path}") + result_df.write.mode("overwrite").save(tmp_path) + logger.info(f"Finished writing.") + result_df = spark.read.parquet(tmp_path) + + # Convert the Spark DataFrame to a Hail Table. + key = ["af2_file", "uniprot_id"] + if mode in {"distance_matrix", "pae"}: + key.append("aa_index") + + ht = hl.Table.from_spark(result_df, key=key) + + return ht + + +def remove_multi_frag_uniprots(ht: hl.Table) -> hl.Table: + """ + Remove UniProt IDs with multiple fragments (F2). + + .. note:: + + With the current release, there is only one structure per UniProt. + + :param ht: Hail Table with structures keyed by 'af2_file' and 'uniprot_id'. + :return: Hail Table with UniProt IDs with multiple fragments removed and + keyed by 'uniprot_id'. + """ + uniprots_with_multifrags = ( + ht.filter(ht.af2_file.contains("-F2-")).select().distinct() + ) + + return ht.anti_join(uniprots_with_multifrags).key_by("uniprot_id").drop("af2_file") + + +######################################################################################## +# Functions to perform tasks from gencode_alignment.R +######################################################################################## +def join_by_sequence(ht1: hl.Table, ht2: hl.Table) -> hl.Table: + """ + Join two Hail Tables based on the 'sequence' field. + + .. note:: + + The join is an inner join. + + :param ht1: First Hail Table. + :param ht2: Second Hail Table. + :return: Hail Table with the two input tables joined based on 'sequence'. + """ + # Overlap the tables based on sequence. + ht1 = ht1.key_by("sequence") + ht2 = ht2.key_by("sequence") + + return ht1.join(ht2, how="inner") + + +######################################################################################## +# Functions to perform tasks from make_gencode_positions_files.R +######################################################################################## +def get_gencode_positions( + transcripts_ht: hl.Table, + translations_ht: hl.Table, + gencode_gtf_ht: hl.Table, + no_filter: bool = False, +) -> hl.Table: + """ + Get GENCODE positions for the given transcripts and translations. + + :param transcripts_ht: Hail Table with GENCODE transcripts. + :param translations_ht: Hail Table with GENCODE translations. + :param gencode_gtf_ht: Hail Table with GENCODE GTF data. + :param no_filter: Whether to filter the data if the CDS length is not divisible + by 3 or if the sequence length in `transcripts_ht` is not equal to the sequence + length in `gencode_gtf_ht`. If `no_filter` is True, then two additional fields + are added to the table: `cds_len_mismatch` and `cds_len_not_div_by_3` to + facilitate filtering. Default is False. + :return: Hail Table with GENCODE positions. + """ + build = get_reference_genome(gencode_gtf_ht.interval.start).name + + # Filter GTF data to keep only CDS features. + gencode_gtf_ht = gencode_gtf_ht.filter(gencode_gtf_ht.feature == "CDS") + + # Get list of intervals for each transcript, and keep strand information. + gencode_gtf_ht = gencode_gtf_ht.annotate(chrom=gencode_gtf_ht.interval.start.contig) + gencode_gtf_ht = gencode_gtf_ht.group_by( + "transcript_id", "strand", "chrom" + ).aggregate(intervals=hl.agg.collect(gencode_gtf_ht.interval)) + + # Get CDS positions and lengths for each transcript in the GTF data. + positions = gencode_gtf_ht.intervals.flatmap( + lambda x: hl.range(x.start.position, x.end.position + 1) + ) + gencode_gtf_ht = gencode_gtf_ht.transmute( + gtf_cds_pos=hl.sorted(positions), + gtf_cds_len=hl.len(positions), + ).key_by("transcript_id") + + # Get CDS sequences and lengths for each transcript in the transcript data. + transcripts_ht = transcripts_ht.select( + "enst", "cds_sequence", cds_len=hl.len(transcripts_ht.cds_sequence) + ).key_by("enst") + + # Join the CDS data from the GTF and transcript data. + cds_len_ht = transcripts_ht.join(gencode_gtf_ht, how="inner") + + # Filter CDS data to keep only transcripts with matching CDS lengths after removing + # stop codons. + cds_sequence_expr = hl.if_else( + cds_len_ht.gtf_cds_len == (cds_len_ht.cds_len - 3), + cds_len_ht.cds_sequence[:-3], + cds_len_ht.cds_sequence, + ) + cds_len_ht = cds_len_ht.annotate( + cds_sequence=cds_sequence_expr, + cds_len=hl.len(cds_sequence_expr), + ) + cds_len_mismatch_expr = cds_len_ht.gtf_cds_len != cds_len_ht.cds_len + cds_len_not_div_by_3_expr = cds_len_ht.cds_len % 3 != 0 + if no_filter: + cds_len_ht = cds_len_ht.annotate( + cds_len_mismatch=cds_len_mismatch_expr, + cds_len_not_div_by_3=cds_len_not_div_by_3_expr, + ) + else: + cds_len_ht = cds_len_ht.filter( + ~cds_len_mismatch_expr & ~cds_len_not_div_by_3_expr + ) + + # Get the reference sequence and amino acid positions for each position in the CDS. + # If the strand is negative, reverse the reference sequence and amino acid + # positions. + aapos_expr = hl.flatten(hl.repeat(hl.range(hl.int(cds_len_ht.cds_len / 3)), 3)) + cds_len_ht = cds_len_ht.annotate( + ref=hl.if_else( + cds_len_ht.strand == "+", + cds_len_ht.cds_sequence, + hl.expr.functions.reverse_complement(cds_len_ht.cds_sequence, rna=False), + ), + aapos=hl.if_else( + cds_len_ht.strand == "+", + hl.sorted(aapos_expr), + hl.sorted(aapos_expr, reverse=True), + ), + ) + + # Annotate the translations data with the genomic positions, reference sequence, and + # amino acid positions. + # Explode the positions to get one row per position. + ht = translations_ht.annotate(**cds_len_ht[translations_ht.enst]) + ht = ht.transmute( + positions=hl.map( + lambda gp, r, ap: hl.struct( + locus=hl.locus(ht.chrom, gp, reference_genome=build), ref=r, aapos=ap + ), + ht.gtf_cds_pos, + hl.range(hl.len(ht.ref)).map(lambda i: ht.ref[i]), + ht.aapos, + ) + ).explode("positions") + + return ht.transmute(**ht.positions) + + +######################################################################################## +# Functions to perform tasks from run_greedy.R and run_forward.R +######################################################################################## +def generate_codon_oe_table(obs_exp_ht: hl.Table, pos_ht: hl.Table) -> hl.Table: + """ + Generate a Table with observed and expected values for codons. + + :param obs_exp_ht: Hail Table with observed and expected values. + :param pos_ht: Hail Table with positions. + :return: Hail Table with observed and expected values for codons. + """ + oe_keyed = obs_exp_ht[pos_ht.locus, pos_ht.enst] + pos_ht = pos_ht.annotate(obs=oe_keyed.obs, exp=oe_keyed.exp) + + # Get aggregate sum of observed and expected values for each codon. + oe_codon_ht = pos_ht.group_by("enst", "uniprot_id", "aapos").aggregate( + obs=hl.agg.sum(pos_ht.obs), + exp=hl.agg.sum(pos_ht.exp), + ) + + # Get a list of observed and expected codon values for each transcript and UniProt + # ID sorted by amino acid position. + oe_codon_ht = oe_codon_ht.group_by("enst", "uniprot_id").aggregate( + oe=hl.agg.collect( + ( + oe_codon_ht.aapos, + hl.struct(obs=oe_codon_ht.obs, exp=oe_codon_ht.exp), + ) + ) + ) + oe_codon_ht = oe_codon_ht.annotate( + oe=hl.sorted(oe_codon_ht.oe, key=lambda x: x[0]).map(lambda x: x[1]) + ).key_by("uniprot_id") + + return oe_codon_ht.collect_by_key("oe_by_transcript") + + +def add_idx_to_array( + expr: hl.expr.ArrayExpression, idx_name: str, element_name: Optional[str] = None +) -> hl.expr.ArrayExpression: + """ + Add an index to each element in an array expression. + + If the elements are structs, the index is added as a field with the name `idx_name`. + If the elements are not structs, a new struct is created with the index as a field + with the name `idx_name` and the element as a field with the name `element_name`. + If `element_name` is not provided, then only hl.enumerate is used to add the index. + + :param expr: Array expression to add index to. + :param idx_name: Name of the index field. + :param element_name: Name of the element field. Default is None. + :return: Array expression with index added to each element. + """ + element_type = expr.dtype.element_type + expr = hl.enumerate(expr) + + if isinstance(element_type, hl.tstruct): + return expr.map(lambda x: x[1].annotate(**{idx_name: x[0]})) + elif element_name is not None: + return expr.map(lambda x: hl.struct(**{element_name: x[1], idx_name: x[0]})) + else: + return expr + + +def get_cumulative_oe(oe_expr): + """ + Get the cumulative OE. + + :param oe_expr: Array expression with observed and expected values. + :return: Array expression with cumulative OE. + """ + oe_expr = hl.array_scan( + lambda i, j: j.annotate(obs=i.obs + j.obs, exp=i.exp + j.exp), + oe_expr[0], + oe_expr[1:], + ) + + return oe_expr + + +def calculate_oe_upper(oe_expr, alpha=0.05): + """ + Calculate the upper bound of the OE confidence interval. + + :param oe_expr: Array expression with observed and expected values. + :param alpha: Significance level for the OE confidence interval. Default is 0.05. + :return: Array expression with upper bound of the OE confidence interval. + """ + # Calculate upper bound of oe confidence interval. + oe_upper_expr = oe_expr.map( + lambda x: x.annotate( + oe=divide_null(x.obs, x.exp), + oe_upper=( + hl.qchisqtail(1 - alpha / 2, 2 * (x.obs + 1), lower_tail=True) + / (2 * x.exp) + ), + ) + ) + + return oe_upper_expr + + +def get_min_oe_upper(oe_expr, min_exp_mis=None): + """ + Get the 3D residue with the lowest upper bound of the OE confidence interval. + + :param oe_expr: Array expression with observed and expected values. + :param min_exp_mis: Minimum number of expected missense variants in a region to be + considered for constraint calculation. Default is None. + :return: Struct expression with the 3D residue with the lowest upper bound of the OE + confidence interval. + """ + oe_expr = add_idx_to_array(oe_expr, "dist_index") + if min_exp_mis is None: + filtered_oe_expr = oe_expr + else: + filtered_oe_expr = oe_expr.filter(lambda x: x.exp >= min_exp_mis) + filtered_oe_expr = hl.or_missing( + filtered_oe_expr.length() > 0, filtered_oe_expr + ) + + min_oe_upper_expr = hl.sorted(filtered_oe_expr, key=lambda x: x.oe_upper)[0] + dist_index_expr = min_oe_upper_expr.dist_index + oe_expr = hl.or_missing( + hl.is_defined(filtered_oe_expr), oe_expr[: dist_index_expr + 1] + ) + min_oe_upper_expr = min_oe_upper_expr.drop("dist_index") + min_oe_upper_expr = min_oe_upper_expr.annotate( + region=oe_expr.map(lambda x: x.residue_index) + ) + + return min_oe_upper_expr + + +def get_3d_residue( + dist_mat_expr: hl.expr.ArrayExpression, + oe_expr: hl.expr.ArrayExpression, + alpha: float = 0.05, + min_exp_mis: int = MIN_EXP_MIS, +) -> hl.expr.StructExpression: + """ + Get the 3D residue with the lowest upper bound of the OE confidence interval. + + :param dist_mat_expr: Array expression with distance matrix. + :param oe_expr: Array expression with observed and expected values. + :param alpha: Significance level for the OE confidence interval. Default is 0.05. + :param min_exp_mis: Minimum number of expected missense variants in a region to be + considered for constraint calculation. Default is MIN_EXP_MIS. + :return: Struct expression with the 3D residue with the lowest upper bound of the OE + confidence interval. + """ + # Annotate neighbor order per residue. + dist_mat_expr = add_idx_to_array( + dist_mat_expr, "residue_index", element_name="dist" + ) + dist_mat_expr = hl.sorted(dist_mat_expr, key=lambda x: x.dist) + dist_mat_expr = dist_mat_expr.map(lambda x: x.drop("dist")) + + # Annotate neighbor observed and expected, cumulative observed and expected, and + # upper bound of OE confidence interval. + oe_expr = dist_mat_expr.map(lambda x: x.annotate(**oe_expr[x.residue_index])) + oe_expr = get_cumulative_oe(oe_expr) + oe_expr = calculate_oe_upper(oe_expr, alpha=alpha) + + # Get the 3D region with the lowest upper bound of the OE confidence interval for + # each residue. + min_moeuf_expr = get_min_oe_upper(oe_expr, min_exp_mis=min_exp_mis) + + return min_moeuf_expr + + +def determine_regions_with_min_oe_upper( + af2_ht: hl.Table, oe_codon_ht: hl.Table, min_exp_mis: int = MIN_EXP_MIS +) -> hl.Table: + """ + Determine the most intolerant region for each UniProt ID and residue index. + + :param af2_ht: Hail Table with AlphaFold2 data. + :param oe_codon_ht: Hail Table with observed and expected values for codons. + :param min_exp_mis: Minimum number of expected missense variants in a region to be + considered for constraint calculation. Default is MIN_EXP_MIS. + :return: Hail Table with the most intolerant region for each UniProt ID and residue + index + """ + af2_ht = af2_ht.annotate(oe=oe_codon_ht[af2_ht.uniprot_id].oe_by_transcript) + af2_ht = af2_ht.explode(af2_ht.oe) + af2_ht = af2_ht.annotate(**af2_ht.oe) + af2_ht = af2_ht.transmute( + transcript_id=af2_ht.enst, + oe=af2_ht.oe, + min_oe_upper=get_3d_residue( + af2_ht.dist_mat, af2_ht.oe, min_exp_mis=min_exp_mis + ), + ) + + af2_ht = af2_ht.group_by("uniprot_id", "transcript_id").aggregate( + oe=hl.agg.take(af2_ht.oe, 1)[0], + min_oe_upper=hl.agg.collect( + af2_ht.min_oe_upper.annotate(residue_index=af2_ht.aa_index) + ), + ) + + af2_ht = af2_ht.annotate( + oe=hl.enumerate(af2_ht.oe).map(lambda x: x[1].annotate(residue_index=x[0])), + min_oe_upper=hl.sorted(af2_ht.min_oe_upper, key=lambda x: x.oe), + ) + + return af2_ht + + +######################################################################################## +# Functions specific to the greedy algorithm. +######################################################################################## +def run_greedy(ht: hl.Table) -> hl.Table: + """ + Run the greedy algorithm to find the most intolerant region. + + :param ht + :return: Hail Table with the most intolerant region for each UniProt ID and residue + index + """ + min_oe_upper_expr = add_idx_to_array(ht.min_oe_upper, "region_index") + initial_score_expr = min_oe_upper_expr.map( + lambda x: hl.missing(min_oe_upper_expr.dtype.element_type) + ) + score_expr = hl.fold( + lambda i, j: hl.enumerate(i).map( + lambda x: hl.coalesce(x[1], hl.or_missing(j.region.contains(x[0]), j)) + ), + initial_score_expr, + min_oe_upper_expr, + ) + score_expr = add_idx_to_array(score_expr, "residue_index") + ann_keep = ["residue_index", "region_index", "obs", "exp", "oe", "oe_upper"] + ht = ht.select(score=score_expr.map(lambda x: x.select(*ann_keep))) + ht = ht.checkpoint(hl.utils.new_temp_file("sort_regions_by_oe", "ht")) + ht = ht.explode("score") + + ht = ht.select(**ht.score).key_by("uniprot_id", "transcript_id", "residue_index") + + return ht + + +######################################################################################## +# Functions specific to the forward algorithm. +######################################################################################## +def annotate_region_with_oe(region_expr, oe_expr): + """ + Annotate a region with the OE. + + :param region_expr: Region expression. + :param oe_expr: OE expression. + :return: Region expression annotated with the OE. + """ + return region_expr.map(lambda x: oe_expr[x]) + + +def get_agg_oe_for_region(region_expr): + """ + Get the aggregate OE for a region. + + :param region_expr: Region expression. + :return: Aggregate OE expression. + """ + oe_agg_expr = hl.or_missing( + hl.is_defined(region_expr), + region_expr.aggregate( + lambda x: hl.struct( + obs=hl.agg.sum(x.obs), + exp=hl.agg.sum(x.exp), + ) + ), + ) + gamma_expr = divide_null(oe_agg_expr.obs, oe_agg_expr.exp) + + return oe_agg_expr.annotate(oe=gamma_expr) + + +def calculate_neg_log_likelihood(region_expr, gamma_expr): + """ + Calculate the negative log-likelihood of a region. + + :param region_expr: Region expression. + :param gamma_expr: Gamma expression. + :return: Negative log-likelihood expression. + """ + return hl.sum( + region_expr.map(lambda x: -hl.dpois(x.obs, gamma_expr * x.exp, log_p=True)) + ) + + +def getAIC(region_expr, nll): + """ + Get the AIC. + + :param region_expr: Region expression. + :param nll: Negative log-likelihood. + :return: AIC. + """ + if isinstance(region_expr, hl.expr.ArrayExpression): + region_count = region_expr.length() + else: + region_count = hl.int(region_expr.region.length() > 0) + + return 2 * region_count + 2 * nll + + +def remove_residues_from_region(region_expr, remove_region_expr): + """ + Remove residues from a region. + + :param region_expr: Region expression. + :param remove_region_expr: Region expression to remove. + :return: Region expression with residues removed. + """ + remove_region_residues = hl.set(remove_region_expr.region) + updated_region_expr = hl.set(region_expr.region).difference(remove_region_residues) + return hl.or_missing( + hl.is_defined(region_expr.region) & hl.is_defined(remove_region_expr.region), + region_expr.annotate(region=hl.array(updated_region_expr)), + ) + + +def get_min_region(regions_expr, min_field="oe_upper", min_exp_mis=None): + """ + Get the minimum region. + + :param regions_expr: Regions expression. + :param min_field: Field to use for sorting. Default is "oe_upper". + :param min_exp_mis: Minimum number of expected missense variants in a region to be + considered for constraint calculation. Default is None. + :return: Minimum region expression. + """ + regions_expr = hl.agg.collect(regions_expr) + if min_exp_mis is None: + filtered_regions_expr = regions_expr + else: + filtered_regions_expr = regions_expr.filter(lambda x: x.exp >= min_exp_mis) + filtered_regions_expr = hl.or_missing( + filtered_regions_expr.length() > 0, filtered_regions_expr + ) + + min_region_expr = hl.sorted(filtered_regions_expr, key=lambda x: x[min_field])[0] + + return min_region_expr + + +def prep_region_struct(region_expr, oe_expr): + """ + Prepare a region struct. + + :param region_expr: Region expression. + :param oe_expr: OE expression. + :return: Region struct expression. + """ + oe_expr = annotate_region_with_oe(region_expr, oe_expr) + oe_agg_expr = get_agg_oe_for_region(oe_expr) + nll_expr = calculate_neg_log_likelihood(oe_expr, oe_agg_expr.oe) + return hl.struct( + region=region_expr, + **oe_agg_expr, + region_nll=nll_expr, + nll=nll_expr, + ) + + +def calculate_oe_neq_1_chisq( + obs_expr: hl.expr.Int64Expression, + exp_expr: hl.expr.Float64Expression, +) -> hl.expr.Float64Expression: + """ + Check for significance that observed/expected values for regions are different from 1. + + Formula is: (obs - exp)^2 / exp. + + :param obs_expr: Observed variant counts. + :param exp_expr: Expected variant counts. + :return: Chi-squared value. + """ + return ((obs_expr - exp_expr) ** 2) / exp_expr + + +def run_forward(ht, min_exp_mis=MIN_EXP_MIS): + """ + Run the forward algorithm to find the most intolerant region. + + :param ht: Hail Table with the most intolerant region for each UniProt ID and residue + index + :return: Hail Table annotated with the observed and expected values for each residue. + """ + num_residues = ht.oe.length() + null_region = hl.range(num_residues) + null_model = prep_region_struct(null_region, ht.oe) + ht = ht.select( + "oe", + num_residues=num_residues, + null_model=null_model, + regions=hl.enumerate( + ht.min_oe_upper.map(lambda x: x.select("region")).filter( + lambda x: x.region.length() < num_residues + ) + ), + selected=hl.empty_array(null_model.dtype), + selected_nll=0, + best_aic=getAIC(null_model, null_model.nll), + found_best=False, + ) + ht = ht.explode("regions") + ht = ht.transmute(idx=ht.regions[0], region=ht.regions[1].region) + ht = ht.key_by("uniprot_id", "transcript_id", "idx") + ht = ht.repartition(5000, shuffle=True) + ht = ht.checkpoint(hl.utils.new_temp_file(f"forward_explode", "ht")) + ht.describe() + ht.show(5) + round_num = 1 + while ht.aggregate(hl.agg.any(hl.is_defined(ht.region))): + # For each region in regions, update the list of selected by + # removing the residues in the region from the "catch all remaining" + # region, and adding the new region to the selected list. + region_expr = prep_region_struct(ht.region, ht.oe) + # TODO: Consider adding a checkpoint here or after the next step. + ht = ht.annotate(_region=region_expr).checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.prep", "ht") + ) + region_expr = ht._region + updated_null_expr = remove_residues_from_region(ht.null_model, region_expr) + ht = ht.annotate(_updated_null=updated_null_expr).checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.remove", "ht") + ) + updated_null_expr = ht._updated_null + updated_null_expr = prep_region_struct(updated_null_expr.region, ht.oe) + ht = ht.annotate(_updated_null=updated_null_expr).checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.prep2", "ht") + ) + updated_null_expr = ht._updated_null + region_expr = ht._region + region_expr = region_expr.annotate( + null_model=updated_null_expr, + region_nll=region_expr.nll, + nll=updated_null_expr.nll + ht.selected_nll + region_expr.nll, + ) + ht2 = ht.select(exp=region_expr.exp, nll=region_expr.nll) + ht2 = ht2.filter(hl.is_defined(ht2.nll) & (ht2.exp >= min_exp_mis)).checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.scan1", "ht") + ) + ht2 = ( + ht2.group_by("uniprot_id", "transcript_id") + .aggregate( + **hl.agg.fold( + hl.missing(hl.tstruct(min_idx=hl.tint, min_nll=hl.tfloat)), + lambda accum: ( + hl.case() + .when( + hl.is_missing(accum) | (accum.min_nll > ht2.nll), + hl.struct(min_idx=ht2.idx, min_nll=ht2.nll), + ) + .when(accum.min_nll <= ht2.nll, accum) + .or_missing() + ), + lambda accum1, accum2: ( + hl.case() + .when(hl.is_missing(accum1), accum2) + .when(hl.is_missing(accum2), accum1) + .when(accum1.min_nll <= accum2.min_nll, accum1) + .default(accum2) + ), + ) + ) + .checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.scan2", "ht") + ) + ) + _ht = ht.select(region=region_expr) + ht2 = ht2.annotate( + best_region=_ht[ht2.uniprot_id, ht2.transcript_id, ht2.min_idx].region + ) + ht2 = ht2.checkpoint( + hl.utils.new_temp_file(f"forward_round_{round_num}.scan3", "ht") + ) + + # Get AIC of best candidate model. + best_region = ht2[ht.uniprot_id, ht.transcript_id].best_region + region_expr = hl.struct(region=ht.region) + + # Update region list. + region_expr = remove_residues_from_region(region_expr, best_region).region + region_expr = hl.or_missing(region_expr.length() > 0, region_expr) + + updated_null_model = best_region.null_model + candidate_model = ht.selected.append(best_region.drop("null_model")) + updated_vals = { + "null_model": updated_null_model, + "selected": candidate_model, + "selected_nll": best_region.region_nll + ht.selected_nll, + "best_aic": getAIC(updated_null_model, 0) + + getAIC(candidate_model, best_region.nll), + "region": region_expr, + } + curr_vals = {k: ht[k] for k in updated_vals} + curr_vals["region"] = hl.missing(region_expr.dtype) + + found_best = ( + ht.found_best + | hl.is_missing(best_region) + | (updated_vals["best_aic"] >= ht.best_aic) + ) + curr_vals = hl.struct(**curr_vals) + updated_vals = hl.struct(**updated_vals) + ht = ht.annotate( + **hl.if_else(found_best, curr_vals, updated_vals), + found_best=found_best, + ) + + ht = ht.filter(hl.is_defined(ht.region) | (ht.idx == 0)) + ht = ht.checkpoint(hl.utils.new_temp_file(f"forward_round_{round_num}", "ht")) + round_num += 1 + + selected_expr = ht.selected.map(lambda x: x.annotate(is_null=False)) + ht = ht.select( + selected=add_idx_to_array( + selected_expr.append(ht.null_model.annotate(is_null=True)), "region_index" + ) + ) + ht = ht.explode("selected") + ht = ht.select(**ht.selected) + ht = ht.annotate(region_length=hl.len(ht.region)) + ht = ht.explode("region") + chisq_expr = calculate_oe_neq_1_chisq(ht.obs, ht.exp) + ht = ht.annotate( + residue_index=ht.region, + oe_upper=( + hl.qchisqtail(1 - 0.05 / 2, 2 * (ht.obs + 1), lower_tail=True) + / (2 * ht.exp) + ), + chisq=chisq_expr, + p_value=hl.pchisqtail(chisq_expr, 1), + ) + ht = ht.key_by("uniprot_id", "transcript_id", "residue_index").select( + "region_index", "obs", "exp", "oe", "oe_upper", "chisq", "p_value", "is_null" + ) + + return ht + + +def create_missense_viewer_input_ht( + pos_ht: hl.Table, + proemis3d_ht: hl.Table, +) -> hl.Table: + """ + Create missense viewer input Hail Table. + + :param ht: Input Hail Table. + :return: Missense viewer input Hail Table. + """ + pos_ht = pos_ht.key_by( + "uniprot_id", "enst", "gene", "aalength", "cds_len", "strand", "aapos" + ) + pos_ht = pos_ht.select("locus") + pos_ht = pos_ht.collect_by_key("locus") + pos_ht = pos_ht.annotate( + locus=hl.sorted(pos_ht.locus.locus, key=lambda x: x.position)[0] + ) + pos_ht = pos_ht.group_by( + "uniprot_id", "enst", "gene", "aalength", "cds_len", "strand" + ).aggregate(locus_by_aapos=hl.dict(hl.agg.collect((pos_ht.aapos, pos_ht.locus)))) + pos_ht = pos_ht.key_by("uniprot_id", "enst").cache() + + chisq_expr = calculate_oe_neq_1_chisq(proemis3d_ht.obs, proemis3d_ht.exp) + proemis3d_ht = proemis3d_ht.annotate( + chisq=chisq_expr, p_value=hl.pchisqtail(chisq_expr, 1) + ) + + # Key by all fields except 'pos' and collect by key into a field named 'pos'. + proemis3d_ht = proemis3d_ht.key_by( + "uniprot_id", "transcript_id", "region_index", "is_null" + ).collect_by_key("pos") + + # Sort the 'pos' field in ascending order. + proemis3d_ht = proemis3d_ht.annotate( + pos=hl.sorted(proemis3d_ht.pos, key=lambda x: x.residue_index) + ) + + # Annotate with 'start' and 'stop' positions for regions by merging adjacent + # positions. + proemis3d_ht = proemis3d_ht.annotate( + pos=hl.fold( + lambda i, j: hl.if_else( + j.residue_index > (i[-1][1] + 1), + i.append( + ( + j.residue_index, + j.residue_index, + j.obs, + j.exp, + j.oe, + j.oe_upper, + j.chisq, + j.p_value, + ) + ), + i[:-1].append( + ( + i[-1][0], + j.residue_index, + j.obs, + j.exp, + j.oe, + j.oe_upper, + j.chisq, + j.p_value, + ) + ), + ), + [ + ( + proemis3d_ht.pos[0].residue_index, + proemis3d_ht.pos[0].residue_index, + proemis3d_ht.pos[0].obs, + proemis3d_ht.pos[0].exp, + proemis3d_ht.pos[0].oe, + proemis3d_ht.pos[0].oe_upper, + proemis3d_ht.pos[0].chisq, + proemis3d_ht.pos[0].p_value, + ) + ], + proemis3d_ht.pos[1:], + ) + ) + proemis3d_ht = proemis3d_ht.explode("pos") + + # Key by 'gene_id' and transform 'pos' into 'start' and 'stop' fields. + proemis3d_ht = proemis3d_ht.key_by( + "uniprot_id", "transcript_id", "region_index", "is_null" + ) + proemis3d_ht = proemis3d_ht.transmute( + start=proemis3d_ht.pos[0], + stop=proemis3d_ht.pos[1], + obs_mis=proemis3d_ht.pos[2], + exp_mis=proemis3d_ht.pos[3], + obs_exp=proemis3d_ht.pos[4], + oe_upper=proemis3d_ht.pos[5], + chisq=proemis3d_ht.pos[6], + p_value=proemis3d_ht.pos[7], + ) + + # Select fields in preferred order and collect by key into a field named 'regions'. + proemis3d_ht = proemis3d_ht.collect_by_key("regions") + proemis3d_ht = proemis3d_ht.annotate( + **pos_ht[proemis3d_ht.uniprot_id, proemis3d_ht.transcript_id] + ) + proemis3d_ht = proemis3d_ht.annotate( + regions=proemis3d_ht.regions.map( + lambda x: x.select( + chrom=proemis3d_ht.locus_by_aapos[x.start].contig, + start=hl.if_else( + proemis3d_ht.locus_by_aapos[x.start].position + <= proemis3d_ht.locus_by_aapos[x.stop].position, + proemis3d_ht.locus_by_aapos[x.start].position, + proemis3d_ht.locus_by_aapos[x.stop].position, + ), + stop=hl.if_else( + proemis3d_ht.locus_by_aapos[x.start].position + <= proemis3d_ht.locus_by_aapos[x.stop].position, + proemis3d_ht.locus_by_aapos[x.stop].position + 2, + proemis3d_ht.locus_by_aapos[x.start].position + 2, + ), + aa_start=x.start, + aa_stop=x.stop, + obs_mis=x.obs_mis, + exp_mis=x.exp_mis, + obs_exp=x.obs_exp, + oe_upper=x.oe_upper, + region_index=proemis3d_ht.region_index, + is_null=proemis3d_ht.is_null, + chisq_diff_null=x.chisq, + p_value=x.p_value, + ) + ) + ) + proemis3d_ht = proemis3d_ht.group_by("transcript_id", "uniprot_id").aggregate( + gnomad_proemis3d_constraint=hl.struct( + has_no_rmc_evidence=False, + passed_qc=True, + regions=hl.flatten(hl.agg.collect(proemis3d_ht.regions)), + ) + # regions=hl.flatten(hl.agg.collect(ht.regions)) + ) + proemis3d_ht = proemis3d_ht.key_by("transcript_id") + rmc_ht = get_rmc_browser_ht().ht() + rmc_ht = rmc_ht.annotate( + regions=rmc_ht.regions.map( + lambda x: x.select( + chrom=x.start_coordinate.contig, + start=hl.if_else( + x.start_coordinate.position <= x.stop_coordinate.position, + x.start_coordinate.position, + x.stop_coordinate.position, + ), + stop=hl.if_else( + x.start_coordinate.position <= x.stop_coordinate.position, + x.stop_coordinate.position + 2, + x.start_coordinate.position, + ), + aa_start=x.start_aa, + aa_stop=x.stop_aa, + obs_mis=x.obs, + exp_mis=x.exp, + obs_exp=x.oe, + chisq_diff_null=x.chisq, + p_value=x.p, + ) + ) + ) + proemis3d_ht = proemis3d_ht.annotate( + gnomad_regional_missense_constraint=hl.struct( + has_no_rmc_evidence=False, + passed_qc=True, + regions=rmc_ht[proemis3d_ht.transcript_id].regions, + ), + ) + + ht = browser_gene().ht() + ht = ht.select( + "interval", + "gencode_symbol", + "chrom", + "strand", + "start", + "stop", + "xstart", + "xstop", + "exons", + "transcripts", + "reference_genome", + "canonical_transcript_id", + "preferred_transcript_id", + "preferred_transcript_source", + **ht[ht.canonical_transcript_id], + ) + ht = ht.repartition(15, shuffle=True) + + return ht + + +def prioritize_transcripts_and_uniprots( + residue_ht: hl.Table, +) -> hl.Table: + """ + Prioritize and label transcript/uniprot combinations for each gene. + + This function: + 1. De-duplicates both tables by key and selects relevant fields. + 2. Annotates residue HT with gene-level info (canonical, mane_select, cds_length, etc.). + 3. Adds a random number to break ties. + 4. Orders transcripts by gene_id, MANE select, canonical, CDS length, and random. + 5. Assigns an index for prioritization. + 6. Aggregates per gene: + - All transcript/uniprot pairs with their priority. + - The lowest-index (prioritized) uniprot per transcript. + - The lowest-index (prioritized) transcript per gene. + 7. Annotates flags indicating for each row if it’s the “one uniprot per transcript” and/or “one transcript per gene”. + 8. Returns a de-nested table keyed by (uniprot_id, transcript_id). + + :param residue_ht: Hail Table keyed by (uniprot_id, transcript_id). + :return: Annotated and prioritized Hail Table keyed by (uniprot_id, transcript_id). + """ + # Prepare keys and fields + ht = residue_ht.key_by("transcript_id", "uniprot_id") + ht = ht.select("gene_symbol", "gene_id", "canonical", "mane_select", "cds_length") + ht = ht.distinct() + + # Annotate with a random number for tie-breaking + ht = ht.annotate(rand_n=hl.rand_unif(0, 1)) + + # Order and add index + ht = ht.order_by( + "gene_id", + hl.desc(ht.mane_select), + hl.desc(ht.canonical), + hl.desc(ht.cds_length), + "rand_n", + ) + ht = ht.add_index() + + # Aggregate per gene for prioritization + ht = ht.group_by("gene_id").aggregate( + _all_rows=hl.agg.collect_as_set( + hl.struct( + **{ + k: ht[k] + for k in [ + "transcript_id", + "uniprot_id", + "gene_symbol", + "canonical", + "mane_select", + "idx", + ] + } + ) + ), + one_uniprot_per_transcript=hl.agg.group_by( + ht.transcript_id, hl.agg.min(ht.idx) + ), + one_transcript_per_gene=hl.agg.min(ht.idx), + ) + + # Mark priority flags + ht = ht.select( + _all_rows=ht._all_rows.map( + lambda x: x.annotate( + one_uniprot_per_transcript=ht.one_uniprot_per_transcript.get( + x.transcript_id + ) + == x.idx, + one_transcript_per_gene=ht.one_transcript_per_gene == x.idx, + ).drop("idx") + ) + ) + ht = ht.explode("_all_rows") + + return ht.select(**ht._all_rows).key_by("transcript_id", "uniprot_id").cache() + + +def explode_af2_plddt_by_residue(af2_plddt_ht: hl.Table) -> hl.Table: + """ + Explode AlphaFold2 pLDDT array into per-residue rows. + + This function: + 1. Takes an input Hail Table with a `plddt` array field containing per-residue + pLDDT scores. + 2. Enumerates the array to associate each score with a 0-based residue index. + 3. Explodes the table to produce one row per residue. + 4. Extracts the `residue_index` and `plddt` values. + 5. Keys the table by (`uniprot_id`, `residue_index`). + + :param af2_plddt_ht: Input Hail Table containing AlphaFold2 pLDDT scores as an array. + :return: Transformed Hail Table with one row per residue and associated pLDDT score. + """ + ht = af2_plddt_ht.annotate(plddt=hl.enumerate(af2_plddt_ht.plddt)) + ht = ht.explode("plddt") + ht = ht.annotate(residue_index=ht.plddt[0], plddt=ht.plddt[1]) + ht = ht.key_by("uniprot_id", "residue_index") + + return ht + + +def annotate_proemis3d_with_af2_metrics( + proemis3D_ht: hl.Table, + af2_plddt_ht: hl.Table, + af2_pae_ht: hl.Table, + af2_dist_ht: hl.Table, +) -> hl.Table: + """ + Annotate a PROEMIS3D Hail Table with per-residue AlphaFold2 metrics (pLDDT, pAE, dist). + + This function: + 1. Explodes the pLDDT array into (residue_index, score) format. + 2. Keys the pAE and distance matrices by (uniprot_id, aa_index). + 3. Filters PROEMIS3D rows to coding transcripts (ENST). + 4. Annotates each region residue with pLDDT, pAE, and dist. + 5. Aggregates residue-level and region-level metrics. + 6. Outputs both residue-level and region-level structured annotations. + + :param proemis3D_ht: PROEMIS3D region Hail Table keyed by (uniprot_id, transcript_id, + residue_index). + :param af2_plddt_ht: AlphaFold2 pLDDT Hail Table with array of scores. + :param af2_pae_ht: AlphaFold2 predicted aligned error matrix Hail Table. + :param af2_dist_ht: AlphaFold2 distance matrix Hail Table. + :return: Annotated PROEMIS3D Hail Table keyed by (uniprot_id, transcript_id, + residue_index). + """ + # Preprocess inputs. + af2_plddt_ht = explode_af2_plddt_by_residue(af2_plddt_ht) + af2_pae_ht = af2_pae_ht.key_by("uniprot_id", "aa_index").checkpoint( + hl.utils.new_temp_file("af2_pae.keyed", "ht") + ) + af2_dist_ht = af2_dist_ht.key_by("uniprot_id", "aa_index").checkpoint( + hl.utils.new_temp_file("af2_dist.keyed", "ht") + ) + + # Filter transcripts. + proemis3D_ht = proemis3D_ht.filter(proemis3D_ht.transcript_id.startswith("ENST")) + + # Annotate with per-residue metrics. + proemis3D_ht = proemis3D_ht.annotate( + plddt=af2_plddt_ht[proemis3D_ht.uniprot_id, proemis3D_ht.residue_index].plddt, + pae=af2_pae_ht[proemis3D_ht.uniprot_id, proemis3D_ht.residue_index].pae, + dist=af2_dist_ht[proemis3D_ht.uniprot_id, proemis3D_ht.residue_index].dist_mat, + ).checkpoint(hl.utils.new_temp_file("proemis3D.annotated", "ht")) + + # Group by region. + proemis3D_ht = proemis3D_ht.key_by("uniprot_id", "transcript_id", "region_index") + proemis3D_ht = proemis3D_ht.collect_by_key("by_residue") + + # Compute per-residue references + residues_expr = proemis3D_ht.by_residue.map(lambda x: x.residue_index) + by_residue_expr = proemis3D_ht.by_residue.map( + lambda x: x.annotate( + res=hl.array(hl.set(residues_expr).remove(x.residue_index)) + ) + ) + by_residue_expr = by_residue_expr.map( + lambda x: x.annotate( + res=x.res.map(lambda y: hl.abs(x.residue_index - y)), + pae=x.res.map(lambda y: x.pae[y]), + dist=x.res.map(lambda y: x.dist[y]), + ) + ) + + # Region-level aggregates + region_plddt = by_residue_expr.map(lambda x: x.plddt) + region_res = hl.flatten(by_residue_expr.map(lambda x: x.res)) + region_pae = hl.flatten(by_residue_expr.map(lambda x: x.pae)) + region_dist = hl.flatten(by_residue_expr.map(lambda x: x.dist)) + + proemis3D_ht = proemis3D_ht.annotate( + by_residue=by_residue_expr, + region_residues=residues_expr, + region_aa_dist_stats=region_res.aggregate(lambda x: hl.agg.stats(x)).annotate( + median=hl.median(region_res) + ), + alphafold2_info=hl.struct( + region_plddt=region_plddt.aggregate(lambda x: hl.agg.stats(x)).annotate( + median=hl.median(region_plddt) + ), + region_pae_stats=region_pae.aggregate(lambda x: hl.agg.stats(x)).annotate( + median=hl.median(region_pae) + ), + region_dist_stats=region_dist.aggregate(lambda x: hl.agg.stats(x)).annotate( + median=hl.median(region_dist) + ), + ), + ) + proemis3D_ht = proemis3D_ht.checkpoint( + hl.utils.new_temp_file("proemis3D.region_annotated", "ht") + ) + + # Flatten by residue + proemis3D_ht = proemis3D_ht.explode("by_residue") + proemis3D_ht = ( + proemis3D_ht.transmute(**proemis3D_ht.by_residue) + .key_by("uniprot_id", "transcript_id", "residue_index") + .checkpoint(hl.utils.new_temp_file("proemis3D.by_residue.exploded", "ht")) + ) + + # Final structured output + proemis3D_ht = proemis3D_ht.select( + residue_level_annotations=hl.struct( + residue_to_region_aa_dist_stats=proemis3D_ht.res.aggregate( + lambda x: hl.agg.stats(x) + ).annotate(median=hl.median(proemis3D_ht.res)), + alphafold2_info=hl.struct( + residue_plddt=proemis3D_ht.plddt, + residue_to_region_pae_stats=proemis3D_ht.pae.aggregate( + lambda x: hl.agg.stats(x) + ).annotate(median=hl.median(proemis3D_ht.pae)), + residue_to_region_dist_stats=proemis3D_ht.dist.aggregate( + lambda x: hl.agg.stats(x) + ).annotate(median=hl.median(proemis3D_ht.dist)), + ), + ), + region_level_annotations=hl.struct( + region_index=proemis3D_ht.region_index, + region_residues=proemis3D_ht.region_residues, + region_length=proemis3D_ht.region_length, + obs=proemis3D_ht.obs, + exp=proemis3D_ht.exp, + oe=proemis3D_ht.oe, + oe_upper=proemis3D_ht.oe_upper, + oe_ci=oe_confidence_interval(proemis3D_ht.obs, proemis3D_ht.exp), + chisq=proemis3D_ht.chisq, + p_value=proemis3D_ht.p_value, + is_null=proemis3D_ht.is_null, + region_aa_dist_stats=proemis3D_ht.region_aa_dist_stats, + alphafold2_info=proemis3D_ht.alphafold2_info, + ), + ) + + return proemis3D_ht + + +def generate_all_possible_snvs_from_gencode_positions( + transcripts_ht: hl.Table, + translations_ht: hl.Table, + gencode_gtf_ht: hl.Table, + gencode_translations_matched_ht: hl.Table, +) -> hl.Table: + """ + Generate all possible single nucleotide variants (SNVs) from the GENCODE positions Hail Table. + + :param transcripts_ht: GENCODE transcripts Hail Table. + :param translations_ht: GENCODE translations Hail Table. + :param gencode_gtf_ht: GENCODE GTF Hail Table. + :param gencode_translations_matched_ht: GENCODE translations matched Hail Table. + :return: Hail Table with all possible SNVs at each GENCODE position. + """ + gencode_translations_matched_ht = gencode_translations_matched_ht.group_by( + "enst" + ).aggregate( + uniprot_id=hl.agg.collect_as_set(gencode_translations_matched_ht.uniprot_id) + ) + translations_ht = translations_ht.annotate( + uniprot_id=hl.or_else( + gencode_translations_matched_ht[translations_ht.enst].uniprot_id, {"None"} + ) + ) + translations_ht = translations_ht.explode("uniprot_id") + + ht = get_gencode_positions( + transcripts_ht, + translations_ht, + gencode_gtf_ht, + no_filter=True, + ) + ht = ht.annotate( + uniprot_id=hl.if_else( + ht.cds_len_mismatch | ht.cds_len_not_div_by_3, "None", ht.uniprot_id + ) + ) + ht = ht.filter(ht.locus.contig != "chrM") + + nucleotides = hl.set({"A", "T", "C", "G"}) + has_aa_info_expr = ~ht.cds_len_mismatch & ~ht.cds_len_not_div_by_3 + ht = ht.annotate( + aminoacid_ref=hl.or_missing(has_aa_info_expr, ht.sequence[ht.aapos]), + alleles=nucleotides.remove(ht.ref).map(lambda alt: [ht.ref, alt]), + residue_index=hl.or_missing(has_aa_info_expr, ht.aapos), + aminoacid_length=hl.int(ht.aalength), + cds_length=hl.int(ht.cds_len), + transcript_id=ht.enst, + gene_id=ht.ensg, + gene_symbol=ht.gene, + ) + ht = ht.explode("alleles") + ht = ht.key_by("locus", "alleles", "transcript_id", "uniprot_id", "gene_id") + ht = ht.select( + "gene_symbol", + "strand", + "cds_length", + "aminoacid_length", + "residue_index", + "aminoacid_ref", + "cds_len_mismatch", + "cds_len_not_div_by_3", + ) + ht = ht.distinct() + + return ht + + +def make_temp_annotation_ht( + base_ht: hl.Table, + annotation_ht: hl.Table, + keys: List[str] = ["locus", "alleles"], + temp_path_prefix: str = "tmp_annotation_ht", + annotation_name: Optional[str] = None, +) -> hl.Table: + """ + Make a temporary Hail Table with annotations from another Hail Table. + + :param base_ht: Base Hail Table to annotate. + :param annotation_ht: Annotation Hail Table to index. + :param keys: List of keys to index the annotation Hail Table with. + :param temp_path_prefix: Prefix for the temporary file path. + :param annotation_name: Name of the annotation to annotate the base Hail Table with. + :return: Annotated Hail Table. + """ + base_ht = base_ht.select(*[k for k in keys if k not in base_ht.key]) + keys_expr = [base_ht[k] for k in keys] + fields = [f for f in annotation_ht.row_value if f not in base_ht.key] + annotation_expr = annotation_ht.select(*fields) + annotation_expr = annotation_expr.index(*keys_expr) + if annotation_name is not None: + annotation_expr = {annotation_name: annotation_expr} + + base_ht = base_ht.annotate(**annotation_expr) + base_ht = base_ht.checkpoint(hl.utils.new_temp_file(temp_path_prefix, "ht")) + + return base_ht + + +def annotate_snvs_with_variant_level_data(ht: hl.Table) -> hl.Table: + """ + Annotate a per-SNV Hail Table with variant-level annotations from multiple sources. + + Adds the following annotations: + + - context + - gnomad_site + - revel + - cadd + - phylop + - genetics_gym + - autism + - dd_denovo + - dd_denovo_no_transcript + - gnomad_de_novo + - clinvar + - pext_base + - pext_annotation + - rmc + + :param ht: Input Hail Table. + :return: Annotated Hail Table. + """ + annotation_hts = { + n: make_temp_annotation_ht( + ht, + ( + c["ht"].ht() + if "custom_select" not in c + else c["custom_select"](c["ht"].ht()) + ), + keys=c["keys"], + temp_path_prefix=n, + annotation_name=c.get("annotation_name"), + )[ht.key] + for n, c in VARIANT_LEVEL_ANNOTATION_CONFIG.items() + } + annotation_expr = hl.struct() + for t in annotation_hts.values(): + annotation_expr = annotation_expr.annotate(**t) + + ht = ht.annotate(variant_level_annotations=annotation_expr).checkpoint( + hl.utils.new_temp_file("snvs_with_variant_level_data", "ht") + ) + pext_annotation_ht = process_pext_annotation_ht(pext("annotation_level").ht()) + + var_expr = ht.variant_level_annotations.rename({"biotype": "transcript_biotype"}) + var_update_expr = var_expr.annotate( + dd_denovo=var_expr.dd_denovo.annotate( + **get_kaplanis_sig_gene_annotations(ht.gene_symbol) + ), + dd_denovo_no_transcript_match=var_expr.dd_denovo_no_transcript_match.annotate( + **get_kaplanis_sig_gene_annotations(ht.gene_symbol) + ), + annotation_level_pext=pext_annotation_ht[ + ht.locus, ht.alleles, ht.gene_id, var_expr.most_severe_consequence + ], + ) + rearrange_fields = BASE_LEVEL_ANNOTATION_FIELDS + ["residue_alt"] + var_update_expr = var_update_expr.drop(*rearrange_fields) + ht = ht.annotate( + **{k: var_expr[k] for k in rearrange_fields}, + variant_level_annotations=var_update_expr, + ) + + return ht + + +def combine_residue_level_annotations( + ht: hl.Table, + proemis3d_ht: hl.Table, +) -> hl.Table: + """ + Combine residue-level annotations by joining PROEMIS3D regions with COSMIS (multiple sources) and InterPro data. + + Adds the following annotations: + + - interpro + - mtr3d + - cosmis_alphafold + - cosmis_pdb + - cosmis_swiss_model + - proemis3d + + :param ht: Input Hail Table. + :param proemis3d_ht: PROEMIS3D Hail Table. + :return: Annotated Hail Table. + """ + annotation_hts = { + **RESIDUE_LEVEL_ANNOTATION_CONFIG, + "proemis3d": { + "ht": proemis3d_ht, + "keys": ["transcript_id", "uniprot_id", "residue_index"], + "annotation_name": "proemis3d", + }, + } + annotation_hts = { + n: make_temp_annotation_ht( + ht, + ( + c["ht"] + if isinstance(c["ht"], hl.Table) + else ( + c["ht"].ht() + if "custom_select" not in c + else c["custom_select"](c["ht"].ht()) + ) + ), + keys=c["keys"], + temp_path_prefix=n, + annotation_name=c["annotation_name"], + )[ht.key] + for n, c in annotation_hts.items() + } + annotation_expr = hl.struct() + for t in annotation_hts.values(): + annotation_expr = annotation_expr.annotate(**t) + + ht = ht.annotate(**annotation_expr) + ht = ht.transmute( + cosmis=hl.struct( + alphafold=ht.row_value.cosmis_alphafold, + pdb=ht.row_value.cosmis_pdb, + swiss_model=ht.row_value.cosmis_swiss_model, + ) + ) + + return ht + + +def create_per_snv_combined_ht( + ht: hl.Table, + proemis3d_ht: hl.Table, + af2_plddt_ht: hl.Table, + af2_pae_ht: hl.Table, + af2_dist_ht: hl.Table, +) -> hl.Table: + """ + Create a fully annotated per-SNV Hail Table with structured variant-, residue-, and gene-level annotations. + + :param ht: All possible SNVs Hail Table. + :param proemis3d_ht: PROEMIS3D Hail Table. + :param af2_plddt_ht: AlphaFold2 pLDDT Hail Table. + :param af2_pae_ht: AlphaFold2 PAE Hail Table. + :param af2_dist_ht: AlphaFold2 distance matrix Hail Table. + :param partition_intervals: Partition intervals to read annotation Hail Tables with. + :return: Annotated and checkpointed Hail Table. + """ + hl._set_flags(use_new_shuffle="1") + ht = annotate_snvs_with_variant_level_data(ht).naive_coalesce(5000).cache() + hl._set_flags(use_new_shuffle=None) + + ht = ht.annotate(residue_ref=ht.aminoacid_ref) + + base_residue_ht = ( + ht.key_by("transcript_id", "uniprot_id", "residue_index") + .select( + "gene_id", + "gene_symbol", + "canonical", + "mane_select", + "cds_length", + "residue_ref", + "residue_alt", + ) + .distinct() + ).cache() + residue_ht = combine_residue_level_annotations( + base_residue_ht, + annotate_proemis3d_with_af2_metrics( + proemis3d_ht, af2_plddt_ht, af2_pae_ht, af2_dist_ht + ).cache(), + ).cache() + select_uniprot_transcript_ht = prioritize_transcripts_and_uniprots( + base_residue_ht + ).select("one_uniprot_per_transcript", "one_transcript_per_gene") + + hl._set_flags(use_new_shuffle="1") + residue_ht = make_temp_annotation_ht( + ht, + residue_ht, + keys=["transcript_id", "uniprot_id", "residue_index"], + temp_path_prefix="residue", + ).drop("residue_index") + select_uniprot_transcript_ht = make_temp_annotation_ht( + ht, + select_uniprot_transcript_ht, + keys=["transcript_id", "uniprot_id"], + temp_path_prefix="select_uniprot_transcript", + ) + gene_constraint_ht = make_temp_annotation_ht( + ht, + get_temp_processed_constraint_ht().ht(), + keys=["transcript_id"], + temp_path_prefix="gene_constraint", + ) + hl._set_flags(use_new_shuffle=None) + + hi_expr = hl.case() + for n, g in HI_GENE_CATEGORIES.items(): + hi_expr = hi_expr.when(hl.set(g).contains(ht.gene_symbol), n) + hi_expr = hi_expr.or_missing() + + # Structure final output. + ht = ht.select( + "gene_symbol", + "canonical", + "mane_select", + "transcript_biotype", + "most_severe_consequence", + "cds_len_mismatch", + "cds_len_not_div_by_3", + is_phaplo_gene=hl.set(get_phaplo().he()).contains(ht.gene_symbol), + is_ptriplo_gene=hl.set(get_ptriplo().he()).contains(ht.gene_symbol), + is_hi_gene=hl.set(HI_GENES).contains(ht.gene_symbol), + hi_gene_category=hi_expr, + **select_uniprot_transcript_ht[ht.key], + variant_level_annotations=ht.variant_level_annotations, + residue_level_annotations=hl.struct(**residue_ht[ht.key]), + gene_level_annotations=hl.struct( + strand=ht.strand, + cds_length=ht.cds_length, + cds_len_mismatch=ht.cds_len_mismatch, + cds_len_not_div_by_3=ht.cds_len_not_div_by_3, + aminoacid_length=ht.aminoacid_length, + **gene_constraint_ht[ht.key], + ), + ) + + return ht + + +def create_per_residue_ht_from_snv_ht(per_snv_ht: hl.Table) -> hl.Table: + """ + Create a per-residue Hail Table from a fully annotated per-SNV Hail Table. + + This function: + 1. Extracts residue-relevant fields and drops alleles. + 2. Deduplicates rows at the residue level. + 3. Aggregates mean coverage and RMC sets per residue. + 4. Extracts flattened annotations from residue and gene level. + + :param per_snv_ht: Annotated per-SNV Hail Table from `create_per_snv_combined_ht`. + :return: Final aggregated per-residue Hail Table. + """ + keep_fields = [ + *BASE_LEVEL_ANNOTATION_FIELDS, + "is_phaplo_gene", + "is_ptriplo_gene", + "is_hi_gene", + "hi_gene_category", + "cds_len_mismatch", + "cds_len_not_div_by_3", + "one_uniprot_per_transcript", + "one_transcript_per_gene", + ] + + # Extract and deduplicate. + ht = ( + per_snv_ht.select( + *keep_fields, + residue_index=per_snv_ht.residue_level_annotations.residue_index, + exomes_coverage=per_snv_ht.variant_level_annotations.exomes_coverage, + rmc=per_snv_ht.variant_level_annotations.rmc, + residue_level_annotations=per_snv_ht.residue_level_annotations, + gene_level_annotations=per_snv_ht.gene_level_annotations, + ) + .key_by("locus", "transcript_id", "uniprot_id") + .drop("alleles") + ) + + ht = ht.distinct() + ht = ht.checkpoint(hl.utils.new_temp_file("per_residue_dedup", "ht")) + + # Group and aggregate. + ht = ht.group_by("transcript_id", "uniprot_id", "residue_index").aggregate( + **{k: hl.agg.take(ht[k], 1)[0] for k in keep_fields}, + residue_mean_exomes_coverage=hl.struct( + mean=hl.agg.mean(ht.exomes_coverage.mean), + median_approx=hl.agg.mean(ht.exomes_coverage.median_approx), + AN=hl.agg.mean(ht.exomes_coverage.AN), + percent_AN=hl.agg.mean(ht.exomes_coverage.percent_AN), + ), + rmc=hl.agg.collect_as_set(ht.rmc), + residue_level_annotations=hl.agg.take(ht.residue_level_annotations, 1)[0], + gene_level_annotations=hl.agg.take(ht.gene_level_annotations, 1)[0], + ) + ht = ht.checkpoint(hl.utils.new_temp_file("per_residue_agg", "ht")) + + # Flatten and finalize. + ht = ht.select( + *keep_fields, + residue_ref=ht.residue_level_annotations.residue_ref, + residue_mean_exomes_coverage=ht.residue_mean_exomes_coverage, + interpro=ht.residue_level_annotations.interpro, + rmc_regions=ht.rmc, + cosmis=ht.residue_level_annotations.cosmis, + proemis3d=ht.residue_level_annotations.proemis3d, + gene_level_annotations=ht.gene_level_annotations, + ) + + return ht.naive_coalesce(2000) + + +def create_per_proemis3d_region_ht_from_residue_ht(ht: hl.Table) -> hl.Table: + """ + Create a PROEMIS3D region-level Hail Table from a per-residue annotated Hail Table. + + This function: + 1. Extracts region_index from PROEMIS3D residue annotations. + 2. Groups rows by (transcript_id, uniprot_id, region_index). + 3. Aggregates exomes coverage, collects RMC regions, computes per-region pLDDT stats. + 4. Annotates PROEMIS3D region-level alphaFold2 info with region-level pLDDT stats. + + :param ht: Hail Table with residue-level annotations including PROEMIS3D and coverage. + :return: Aggregated region-level Hail Table. + """ + keep_fields = [ + "gene_id", + *BASE_LEVEL_ANNOTATION_FIELDS, + "is_phaplo_gene", + "is_ptriplo_gene", + "is_hi_gene", + "hi_gene_category", + "cds_len_mismatch", + "cds_len_not_div_by_3", + "one_uniprot_per_transcript", + "one_transcript_per_gene", + ] + + ht = ht.annotate(region_index=ht.proemis3d.region_level_annotations.region_index) + + ht = ht.group_by("transcript_id", "uniprot_id", "region_index").aggregate( + **{k: hl.agg.take(ht[k], 1)[0] for k in keep_fields}, + region_mean_exomes_coverage=hl.struct( + mean=hl.agg.mean(ht.residue_mean_exomes_coverage.mean), + median_approx=hl.agg.mean(ht.residue_mean_exomes_coverage.median_approx), + AN=hl.agg.mean(ht.residue_mean_exomes_coverage.AN), + percent_AN=hl.agg.mean(ht.residue_mean_exomes_coverage.percent_AN), + ), + rmc_regions=hl.agg.explode( + lambda x: hl.agg.collect_as_set(x), ht.rmc_regions + ).filter(lambda x: hl.is_defined(x)), + proemis3d=hl.agg.take(ht.proemis3d.region_level_annotations, 1)[0].annotate( + region_plddt_stats=hl.agg.stats( + ht.proemis3d.residue_level_annotations.alphafold2_info.residue_plddt + ).annotate( + median=hl.median( + hl.agg.collect( + ht.proemis3d.residue_level_annotations.alphafold2_info.residue_plddt + ) + ) + ) + ), + gene_level_annotations=hl.agg.take(ht.gene_level_annotations, 1)[0], + ) + + # Move pLDDT stats into alphaFold2 info. + ht = ht.annotate( + proemis3d=ht.proemis3d.annotate( + alphafold2_info=ht.proemis3d.alphafold2_info.annotate( + region_plddt_stats=ht.proemis3d.region_plddt_stats + ) + ).drop("region_plddt_stats") + ) + + return ht.naive_coalesce(100) diff --git a/gnomad_constraint/pipeline/constraint_pipeline.py b/gnomad_constraint/pipeline/constraint_pipeline.py index 77163354..b356d6f6 100644 --- a/gnomad_constraint/pipeline/constraint_pipeline.py +++ b/gnomad_constraint/pipeline/constraint_pipeline.py @@ -330,10 +330,7 @@ def main(args): "Allele number tables are not available for versions prior to v4.0." ) - if coverage_model_type == "logarithmic": - log10_coverage = True - elif coverage_model_type == "linear": - log10_coverage = False + log10_coverage = True if coverage_model_type == "logarithmic" else False # Construct resources with paths for intermediate Tables generated in the pipeline. resources = get_constraint_resources( @@ -518,7 +515,7 @@ def main(args): coverage_ht=training_ht, coverage_expr=training_ht[coverage_metric], weighted=args.use_weights, - pops=pops, + gen_ancs=pops, high_cov_definition=args.high_cov_definition, upper_cov_cutoff=args.upper_cov_cutoff, skip_coverage_model=True if args.skip_coverage_model else False, diff --git a/gnomad_constraint/resources/resource_utils.py b/gnomad_constraint/resources/resource_utils.py index 3f497786..dec14644 100644 --- a/gnomad_constraint/resources/resource_utils.py +++ b/gnomad_constraint/resources/resource_utils.py @@ -135,12 +135,7 @@ def get_methylation_ht(build: str) -> TableResource: if build == "GRCh37": return ref_grch37.methylation_sites elif build == "GRCh38": - methylation_chrx = ref_grch38.methylation_sites_chrx.ht() - methylation_autosomes = ref_grch38.methylation_sites.ht() - methylation_ht = methylation_autosomes.union(methylation_chrx) - tmp_path = get_constraint_root(version=build, test=True) - methylation_ht = methylation_ht.checkpoint(tmp_path, overwrite=True) - return TableResource(path=tmp_path) + return ref_grch38.methylation_sites.ht() else: raise ValueError("Build must be one of 'GRCh37' or 'GRCh38'.") diff --git a/gnomad_constraint/utils/constraint.py b/gnomad_constraint/utils/constraint.py index 07d710b1..6979202a 100644 --- a/gnomad_constraint/utils/constraint.py +++ b/gnomad_constraint/utils/constraint.py @@ -482,6 +482,9 @@ def apply_models( ) exome_ht = exome_ht.filter(exome_ht[coverage_metric] >= low_coverage_filter) + include_canonical_group = False + include_mane_select_group = False + # Add necessary constraint annotations for grouping. if custom_vep_annotation == "worst_csq_by_gene": vep_annotation = "worst_csq_by_gene" @@ -490,7 +493,6 @@ def apply_models( "'mane_select' cannot be set to True when custom_vep_annotation is set" " to 'worst_csq_by_gene'." ) - else: vep_annotation = "transcript_consequences" include_canonical_group = True @@ -572,13 +574,13 @@ def apply_models( cov_corr_expr=cov_corr_expr, possible_variants_expr=poss_expr, cpg_expr=ht.cpg, - pop=pop, + gen_anc=pop, ) ) # Store which downsamplings are obtained for each pop in a # downsampling_meta dictionary. - ds = hl.eval(get_downsampling_freq_indices(ht.freq_meta, pop=pop)) + ds = hl.eval(get_downsampling_freq_indices(ht.freq_meta, gen_anc=pop)) key_names = {key for _, meta_dict in ds for key in meta_dict.keys()} genetic_ancestry_label = "gen_anc" if "gen_anc" in key_names else "pop" downsampling_meta[pop] = [ @@ -993,7 +995,7 @@ def compute_constraint_metrics( ann: oe_aggregation_expr( ht, filter_expr, - pops=() if ann == "mis_pphen" else pops, + gen_ancs=() if ann == "mis_pphen" else pops, exclude_mu_sum=True if ann == "mis_pphen" else False, ) for ann, filter_expr in annotation_dict.items() diff --git a/requirements-dev.in b/requirements-dev.in index 6062fec3..1b025b9c 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -4,3 +4,4 @@ pip-tools autopep8==1.7.0 # This should be kept in sync with the version in .pre-commit-config.yaml pydocstyle[toml]==6.1.1 # This should be kept in sync with the version in .pre-commit-config.yaml pylint +biopython diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a4c845f..0949d2f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,67 +3,72 @@ git+https://github.com/broadinstitute/gnomad_methods@main git+https://github.com/broadinstitute/gnomad_qc@main -# This file is autogenerated by pip-compile with Python 3.12 +# +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile requirements-dev.in # -astroid==2.11.7 +astroid==3.3.11 # via pylint autopep8==1.7.0 # via -r requirements-dev.in +biopython==1.85 + # via -r requirements-dev.in black==24.3.0 # via -r requirements-dev.in -build==0.8.0 +build==1.3.0 # via pip-tools -click==8.1.3 +click==8.2.1 # via # black # pip-tools -dill==0.3.5.1 +dill==0.4.0 # via pylint isort==5.12.0 # via # -r requirements-dev.in # pylint -lazy-object-proxy==1.7.1 - # via astroid mccabe==0.7.0 # via pylint -mypy-extensions==0.4.3 +mypy-extensions==1.1.0 # via black -packaging==23.1 +numpy==2.3.3 + # via biopython +packaging==25.0 # via # black # build -pathspec==0.9.0 +pathspec==0.12.1 # via black -pep517==0.13.0 - # via build -pip-tools==6.8.0 +pip-tools==7.5.0 # via -r requirements-dev.in -platformdirs==2.5.2 +platformdirs==4.4.0 # via # black # pylint -pycodestyle==2.9.1 +pycodestyle==2.14.0 # via autopep8 pydocstyle[toml]==6.1.1 + # via + # -r requirements-dev.in + # pydocstyle +pylint==3.3.8 # via -r requirements-dev.in -pylint==2.14.5 - # via -r requirements-dev.in -snowballstemmer==2.2.0 +pyproject-hooks==1.2.0 + # via + # build + # pip-tools +snowballstemmer==3.0.1 # via pydocstyle toml==0.10.2 # via # autopep8 # pydocstyle -tomlkit==0.11.4 +tomlkit==0.13.3 # via pylint -wheel==0.37.1 +wheel==0.45.1 # via pip-tools -wrapt==1.14.1 - # via astroid # The following packages are considered to be unsafe in a requirements file: # pip