diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 46e20dc1c..6ecd8fdfe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,13 +10,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 - - name: Setup Python 3.8 - uses: actions/setup-python@v2 + uses: actions/checkout@v3 + - name: Setup Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Use pip cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: pip-${{ hashFiles('**/requirements*.txt') }} @@ -44,13 +44,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 - - name: Setup Python 3.8 - uses: actions/setup-python@v1 + uses: actions/checkout@v3 + - name: Setup Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Use pip cache - uses: actions/cache@v1 + uses: actions/cache@v3 with: path: ~/.cache/pip key: pip-${{ hashFiles('**/requirements*.txt') }} @@ -98,21 +98,21 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout gnomad_methods - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: gnomad_methods - name: Checkout gnomad_qc - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: submodules: recursive repository: broadinstitute/gnomad_qc path: gnomad_qc - - name: Setup Python 3.8 - uses: actions/setup-python@v2 + - name: Setup Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Use pip cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: pip-${{ hashFiles('gnomad_methods/**/requirements*.txt') }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ae3e0ab05..8527e1156 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Validate version run: | package_version=$(grep 'version' setup.py | sed -E 's|.*([0-9]+\.[0-9]+\.[0-9]+).*|\1|') @@ -19,12 +19,12 @@ jobs: echo "Tag version (${tag_version}) does not match package version (${package_version})" exit 1 fi - - name: Setup Python 3.8 - uses: actions/setup-python@v2 + - name: Setup Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.11 - name: Use pip cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: pip-${{ hashFiles('**/requirements*.txt') }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e14ed66ed..02e9788f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 22.10.0 # This should be kept in sync with the version in requirements-dev.in + rev: 23.7.0 # This should be kept in sync with the version in requirements-dev.in hooks: - id: black language_version: python3 diff --git a/docs/requirements.docs.in b/docs/requirements.docs.in index 72b0d8634..4677facae 100644 --- a/docs/requirements.docs.in +++ b/docs/requirements.docs.in @@ -1,6 +1,6 @@ -myst-parser==0.14.0 +myst-parser requests -Sphinx>=2.1,<3 -sphinx-autodoc-typehints==1.10.3 +Sphinx>=4.2.0,<7.0.0 +sphinx-autodoc-typehints sphinx-rtd-theme -Jinja2==3.0.3 +Jinja2 diff --git a/docs/requirements.docs.txt b/docs/requirements.docs.txt index 628e55c8f..e4d66330c 100644 --- a/docs/requirements.docs.txt +++ b/docs/requirements.docs.txt @@ -1,20 +1,18 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: # # pip-compile docs/requirements.docs.in # alabaster==0.7.13 # via sphinx -attrs==21.4.0 - # via markdown-it-py babel==2.12.1 # via sphinx -certifi==2022.12.7 +certifi==2023.7.22 # via requests -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 # via requests -docutils==0.17.1 +docutils==0.18.1 # via # myst-parser # sphinx @@ -23,28 +21,28 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -jinja2==3.0.3 +jinja2==3.1.2 # via # -r docs/requirements.docs.in # myst-parser # sphinx -markdown-it-py==1.1.0 +markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.2 +markupsafe==2.1.3 # via jinja2 -mdit-py-plugins==0.2.8 +mdit-py-plugins==0.4.0 # via myst-parser -myst-parser==0.14.0 +mdurl==0.1.2 + # via markdown-it-py +myst-parser==2.0.0 # via -r docs/requirements.docs.in -packaging==23.0 +packaging==23.1 # via sphinx -pygments==2.14.0 +pygments==2.16.1 # via sphinx -pytz==2023.3 - # via babel -pyyaml==6.0 +pyyaml==6.0.1 # via myst-parser requests==2.31.0 # via @@ -52,33 +50,35 @@ requests==2.31.0 # sphinx snowballstemmer==2.2.0 # via sphinx -sphinx==2.4.5 +sphinx==6.2.1 # via # -r docs/requirements.docs.in # myst-parser # sphinx-autodoc-typehints # sphinx-rtd-theme + # sphinxcontrib-applehelp + # sphinxcontrib-devhelp + # sphinxcontrib-htmlhelp # sphinxcontrib-jquery -sphinx-autodoc-typehints==1.10.3 + # sphinxcontrib-qthelp + # sphinxcontrib-serializinghtml +sphinx-autodoc-typehints==1.23.0 # via -r docs/requirements.docs.in -sphinx-rtd-theme==1.2.0 +sphinx-rtd-theme==1.2.2 # via -r docs/requirements.docs.in -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==1.0.7 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==1.0.5 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.0.4 # via sphinx sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==1.0.6 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==1.1.8 # via sphinx -urllib3==1.26.15 +urllib3==2.0.4 # via requests - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/gnomad/assessment/validity_checks.py b/gnomad/assessment/validity_checks.py index 6267f4005..b94c67afc 100644 --- a/gnomad/assessment/validity_checks.py +++ b/gnomad/assessment/validity_checks.py @@ -180,9 +180,9 @@ def make_group_sum_expr_dict( else: logger.warning("%s is not in table's info field", field) - annot_dict[ - f"sum{delimiter}{field_prefix}{group}{delimiter}{sum_group}" - ] = hl.sum(sum_group_exprs) + annot_dict[f"sum{delimiter}{field_prefix}{group}{delimiter}{sum_group}"] = ( + hl.sum(sum_group_exprs) + ) # If metric_first_field is True, metric is AC, subset is tgp, sum_group is pop, and group is adj, then the values below are: # check_field_left = "AC-tgp-adj" diff --git a/gnomad/resources/grch38/gnomad.py b/gnomad/resources/grch38/gnomad.py index e1f77feb1..7f5badb37 100644 --- a/gnomad/resources/grch38/gnomad.py +++ b/gnomad/resources/grch38/gnomad.py @@ -1,6 +1,10 @@ # noqa: D100 -from typing import Optional +import json +import logging +from typing import Optional, Union + +import hail as hl from gnomad.resources.resource_utils import ( DataException, @@ -9,6 +13,20 @@ VersionedMatrixTableResource, VersionedTableResource, ) +from gnomad.sample_qc.ancestry import POP_NAMES +from gnomad.utils.annotations import ( + add_gks_va, + add_gks_vrs, + get_gks, + gks_compute_seqloc_digest, +) + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) CURRENT_EXOME_RELEASE = "" CURRENT_GENOME_RELEASE = "3.1.2" @@ -272,6 +290,29 @@ ) +def get_coverage_ht( + coverage_ht: Union[str, hl.Table], data_type: str, coverage_version: str +): + """ + Load a coverage hail table if needed. + + If coverage_ht is 'auto', loads the default coverage table for the + data_type and coverage_version. If it's already a hail table, return it. + Otherwise return None. + + :param coverage_ht: a hail table, or 'auto' (otherwise return None). + :param data_type: a gnomad dataset type, as in 'genomes' or 'exomes' + :param coverage_version: gnomad release version the coverage table is built on + :return: hail table with coverage info, or None + """ + if coverage_ht == "auto": + return hl.read_table(coverage(data_type).versions[coverage_version].path) + elif isinstance(coverage_ht, hl.Table): + return coverage_ht + else: + return None + + def _public_release_ht_path(data_type: str, version: str) -> str: """ Get public release table path. @@ -407,3 +448,194 @@ def release_vcf_path(data_type: str, version: str, contig: str) -> str: contig = f".{contig}" if contig else "" version_prefix = "r" if version.startswith("3.0") else "v" return f"gs://gcp-public-data--gnomad/release/{version}/vcf/{data_type}/gnomad.{data_type}.{version_prefix}{version}.sites{contig}.vcf.bgz" + + +def gnomad_gks( + version: str, + variant: str, + data_type: str = "genomes", + by_ancestry_group: bool = False, + by_sex: bool = False, + vrs_only: bool = False, + custom_ht: hl.Table = None, + coverage_ht: Union[str, hl.Table] = "auto", +) -> dict: + """ + Call get_gks() and return VRS information and frequency information for the specified gnomAD release version and variant. + + :param version: String of version of gnomAD release to use. + :param variant: String of variant to search for (chromosome, position, ref, and alt, separated by '-'). Example for a variant in build GRCh38: "chr5-38258681-C-T". + :param data_type: String of either "exomes" or "genomes" for the type of reads that are desired. + :param by_ancestry_group: Boolean to pass to obtain frequency information for each ancestry group in the desired gnomAD version. + :param by_sex: Boolean to pass if want to return frequency information for each ancestry group split by chromosomal sex. + :param vrs_only: Boolean to pass if only want VRS information returned (will not include allele frequency information). + :param custom_ht: A Hail Table to use instead of what public_release() method would return for the version. + :param coverage_ht: An existing hail.Table object, or 'auto' to automatically lookup coverage ht, or None. + :return: Dictionary containing VRS information (and frequency information split by ancestry groups and sex if desired) for the specified variant. + + """ + # Read public_release table if no custom table provided + if custom_ht: + ht = custom_ht + else: + ht = hl.read_table(public_release(data_type).versions[version].path) + + high_level_version = f"v{version.split('.')[0]}" + + # Read coverage statistics. + + if high_level_version == "v3": + coverage_version = "3.0.1" + else: + raise NotImplementedError( + "gnomad_gks() is currently only implemented for gnomAD v3." + ) + + coverage_ht = get_coverage_ht(coverage_ht, data_type, coverage_version) + + # Retrieve ancestry groups from the imported POPS dictionary. + pops_list = list(POPS[high_level_version]) if by_ancestry_group else None + + # Throw warnings if contradictory arguments passed. + if by_ancestry_group and vrs_only: + logger.warning( + "Both 'vrs_only' and 'by_ancestry_groups' have been specified. Ignoring" + " 'by_ancestry_groups' list and returning only VRS information." + ) + elif by_sex and not by_ancestry_group: + logger.warning( + "Splitting whole database by sex is not yet supported. If using 'by_sex'," + " please also specify 'by_ancestry_group' to stratify by." + ) + + # Call and return get_gks() for chosen arguments. + gks_info = get_gks( + ht=ht, + variant=variant, + label_name="gnomAD", + label_version=version, + coverage_ht=coverage_ht, + ancestry_groups=pops_list, + ancestry_groups_dict=POP_NAMES, + by_sex=by_sex, + vrs_only=vrs_only, + ) + + return gks_info + + +# VRS Annotation needs to be done separately. It needs to compute the sequence location +# digest and this cannot be done in hail. It needs to export the record to JSON, compute +# sequence location digests in python, and then that can be imported to a hail table. +def gnomad_gks_batch( + locus_interval: hl.IntervalExpression, + version: str, + data_type: str = "genomes", + by_ancestry_group: bool = False, + by_sex: bool = False, + vrs_only: bool = False, + custom_ht: hl.Table = None, + coverage_ht: Union[str, hl.Table] = "auto", +): + """ + Perform gnomad GKS annotations on a range of variants at once. + + :param locus_interval: Hail IntervalExpression of locus. e.g. hl.locus_interval('chr1', 1, 50000000, reference_genome="GRCh38") + :param version: String of version of gnomAD release to use. + :param data_type: String of either "exomes" or "genomes" for the type of reads that are desired. + :param by_ancestry_group: Boolean to pass to obtain frequency information for each ancestry group in the desired gnomAD version. + :param by_sex: Boolean to pass if want to return frequency information for each ancestry group split by chromosomal sex. + :param vrs_only: Boolean to pass if only want VRS information returned (will not include allele frequency information). + :param custom_ht: A Hail Table to use instead of what public_release() method would return for the version. + :param coverage_ht: Path of coverage_ht, an existing hail.Table object, or 'auto' to automatically lookup coverage ht. + :return: Dictionary containing VRS information (and frequency information split by ancestry groups and sex if desired) for the specified variant. + """ + # Read public_release table if no custom table provided + if custom_ht: + ht = custom_ht + else: + ht = hl.read_table(public_release(data_type).versions[version].path) + + high_level_version = f"v{version.split('.')[0]}" + + # Read coverage statistics. + + if high_level_version == "v3": + coverage_version = "3.0.1" + else: + raise NotImplementedError( + "gnomad_gks() is currently only implemented for gnomAD v3." + ) + + coverage_ht = get_coverage_ht(coverage_ht, data_type, coverage_version) + + # Retrieve ancestry groups from the imported POPS dictionary. + pops_list = list(POPS[high_level_version]) if by_ancestry_group else None + + # Throw warnings if contradictory arguments passed. + if by_ancestry_group and vrs_only: + logger.warning( + "Both 'vrs_only' and 'by_ancestry_groups' have been specified. Ignoring" + " 'by_ancestry_groups' list and returning only VRS information." + ) + elif by_sex and not by_ancestry_group: + logger.warning( + "Splitting whole database by sex is not yet supported. If using 'by_sex'," + " please also specify 'by_ancestry_group' to stratify by." + ) + + # Call and return add_gks*() for chosen arguments. + # get_gks_va returns the table annotated with .gks_va_freq_dict + # get_gks_va does not fill in the the .focusAllele value of + # .gks_va_freq_dict this is the vrs variant and is mostly just based + # on the values in the variant and info column, but it also needs + # to compute the SequenceLocation digest, which cannot be done in hail + + # Add .vrs and .vrs_json (the JSON string representation of .vrs) + # Omits .location._id + ht_with_gks = add_gks_vrs(ht) + + # If not vrs_only, include the VA freq in various operations below + if not vrs_only: + # Add .gks_va_freq_dict + # Omits .focusAllele + ht_with_gks = add_gks_va( + ht=ht_with_gks, + label_name="gnomAD", + label_version=version, + coverage_ht=coverage_ht, + ancestry_groups=pops_list, + ancestry_groups_dict=POP_NAMES, + by_sex=by_sex, + ) + + filtered = hl.filter_intervals(ht_with_gks, [locus_interval]) + select_cols = {"vrs_json": filtered.vrs_json} + if not vrs_only: + select_cols["gks_va_freq_json"] = hl.json(filtered.gks_va_freq_dict) + annotations = filtered.select(**select_cols).collect() # might be big + outputs = [] + for ann in annotations: + vrs_json = ann.vrs_json + vrs_variant = json.loads(vrs_json) + # Fill in fields ommitted by add_gks_vrs and add_gks_va + vrs_variant = gks_compute_seqloc_digest(vrs_variant) + + out = { + "locus": { + "contig": ann.locus.contig, + "position": ann.locus.position, + "reference_genome": ann.locus.reference_genome.name, + }, + "alleles": ann.alleles, + "gks_vrs_variant": vrs_variant, + } + + if not vrs_only: + va_freq_dict = json.loads(ann.gks_va_freq_json) # Hail Struct as json + va_freq_dict["focusAllele"] = vrs_variant + out["gks_va_freq"] = va_freq_dict + + outputs.append(out) + + return outputs diff --git a/gnomad/resources/grch38/reference_data.py b/gnomad/resources/grch38/reference_data.py index 60765088d..6a648e197 100644 --- a/gnomad/resources/grch38/reference_data.py +++ b/gnomad/resources/grch38/reference_data.py @@ -57,6 +57,22 @@ def _import_dbsnp(**kwargs) -> hl.Table: return dbsnp +def _import_methylation_sites(path) -> hl.Table: + """ + Import methylation data from bed file. + + :param path: Path to bed file containing methylation scores. + :return: Table with methylation data. + """ + ht = hl.import_bed(path, min_partitions=100, reference_genome="GRCh38") + ht = ht.select( + locus=ht.interval.start, + methylation_level=hl.int32(ht.target), + ) + + return ht.key_by("locus").drop("interval") + + def _import_ensembl_interval(path) -> hl.Table: """ Import and parse Ensembl intervals of protein-coding genes to a Hail Table. @@ -296,6 +312,16 @@ def _import_ensembl_interval(path) -> hl.Table: }, ) +# Methylation scores range from 0-15 and are described in Chen et al +# (https://www.biorxiv.org/content/10.1101/2022.03.20.485034v2.full). +methylation_sites = GnomadPublicTableResource( + path="gs://gnomad-public-requester-pays/resources/grch38/methylation_sites/methylation.ht", + import_func=_import_methylation_sites, + import_args={ + "path": "gs://gnomad-public-requester-pays/resources/grch38/methylation_sites/methylation.bed", + }, +) + lcr_intervals = GnomadPublicTableResource( path="gs://gnomad-public-requester-pays/resources/grch38/lcr_intervals/LCRFromHengHg38.ht", import_func=hl.import_locus_intervals, diff --git a/gnomad/resources/resource_utils.py b/gnomad/resources/resource_utils.py index 57b397325..04688edfd 100644 --- a/gnomad/resources/resource_utils.py +++ b/gnomad/resources/resource_utils.py @@ -550,9 +550,9 @@ def _get_path(self) -> str: return self._path relative_path = reduce( - lambda path, bucket: path[5 + len(bucket) :] - if path.startswith(f"gs://{bucket}/") - else path, + lambda path, bucket: ( + path[5 + len(bucket) :] if path.startswith(f"gs://{bucket}/") else path + ), GNOMAD_PUBLIC_BUCKETS, self._path, ) diff --git a/gnomad/sample_qc/ancestry.py b/gnomad/sample_qc/ancestry.py index 90078049f..c426cd34e 100644 --- a/gnomad/sample_qc/ancestry.py +++ b/gnomad/sample_qc/ancestry.py @@ -2,10 +2,16 @@ import logging import random -from typing import Any, Counter, List, Optional, Tuple, Union +from collections import Counter +from typing import Any, Callable, List, Optional, Tuple, Union import hail as hl +import numpy as np +import onnx +import onnxruntime as rt import pandas as pd +from skl2onnx import convert_sklearn +from skl2onnx.common.data_types import FloatTensorType from gnomad.utils.filtering import filter_to_autosomes @@ -117,6 +123,90 @@ def pc_project( return mt.cols().select("scores") +def apply_onnx_classification_model( + data_pd: pd.DataFrame, fit: onnx.ModelProto +) -> Tuple[np.ndarray, pd.DataFrame]: + """ + Apply an ONNX classification model `fit` to a pandas dataframe `data_pd`. + + :param data_pd: Pandas dataframe containing the data to be classified. + :param fit: ONNX model to be applied. + :return: Tuple of classification and probabilities. + """ + if not isinstance(fit, onnx.ModelProto): + raise TypeError("The model supplied is not an onnx model!") + + sess = rt.InferenceSession( + fit.SerializeToString(), providers=["CPUExecutionProvider"] + ) + input_name = sess.get_inputs()[0].name + label_name = sess.get_outputs()[0].name + prob_name = sess.get_outputs()[1].name + classification = sess.run([label_name], {input_name: data_pd.astype(np.float32)})[0] + probs = sess.run([prob_name], {input_name: data_pd.astype(np.float32)})[0] + probs = pd.DataFrame.from_dict(probs) + probs = probs.add_prefix("prob_") + + return classification, probs + + +def apply_sklearn_classification_model( + data_pd: pd.DataFrame, fit: Any +) -> Tuple[np.ndarray, pd.DataFrame]: + """ + Apply an sklearn classification model `fit` to a pandas dataframe `data_pd`. + + :param data_pd: Pandas dataframe containing the data to be classified. + :param fit: Sklearn model to be applied. + :return: Tuple of classification and probabilities. + """ + from sklearn.utils.validation import check_is_fitted + + try: + check_is_fitted(fit) + except TypeError: + raise TypeError("The supplied model is not an sklearn model!") + + classification = fit.predict(data_pd) + probs = fit.predict_proba(data_pd) + probs = pd.DataFrame(probs, columns=[f"prob_{p}" for p in fit.classes_]) + + return classification, probs + + +def convert_sklearn_rf_to_onnx( + fit: Any, target_opset: Optional[int] = None +) -> onnx.ModelProto: + """ + Convert a sklearn random forest model to ONNX. + + :param fit: Sklearn random forest model to be converted. + :param target_opset: An optional target ONNX opset version to convert the model to. + :return: ONNX model. + """ + from sklearn.utils.validation import check_is_fitted + + try: + check_is_fitted(fit) + except TypeError: + raise TypeError("The supplied model is not an sklearn model!") + + initial_type = [("float_input", FloatTensorType([None, fit.n_features_in_]))] + onx = convert_sklearn(fit, initial_types=initial_type, target_opset=target_opset) + + domains = onx.opset_import + opset_version = "" + for dom in domains: + opset_version += f"domain: {dom.domain}, version: {dom.version}\n" + + logger.info( + "sklearn model converted to onnx model with the following opset version: \n%s", + opset_version, + ) + + return onx + + def assign_population_pcs( pop_pca_scores: Union[hl.Table, pd.DataFrame], pc_cols: Union[hl.expr.ArrayExpression, List[int], List[str]], @@ -129,6 +219,10 @@ def assign_population_pcs( output_col: str = "pop", missing_label: str = "oth", pc_expr: Union[hl.expr.ArrayExpression, str] = "scores", + convert_model_func: Optional[Callable[[Any], Any]] = None, + apply_model_func: Callable[ + [pd.DataFrame, Any], Any + ] = apply_sklearn_classification_model, ) -> Tuple[ Union[hl.Table, pd.DataFrame], Any ]: # 2nd element of the tuple should be RandomForestClassifier but we do not want to import sklearn.RandomForestClassifier outside @@ -175,6 +269,12 @@ def assign_population_pcs( smaller than `min_prob`. :param pc_expr: Column storing the list of PCs. Only used if `pc_cols` is a List of integers. Default is scores. + :param convert_model_func: Optional function to convert the model to ONNX format. + Default is no conversion. + :param apply_model_func: Function to apply the model to the data. Default is + `apply_sklearn_classification_model`, which will apply a sklearn classification + model to the data. This default will work if no `fit` is set, or the supplied + `fit` is a sklearn classification model. :return: Hail Table or Pandas Dataframe (depending on input) containing sample IDs and imputed population labels, trained random forest model. """ @@ -224,7 +324,7 @@ def assign_population_pcs( ) pop_pc_pd = pop_pca_scores - # Split training data into subsamples for fitting and evaluating + # Split training data into subsamples for fitting and evaluating. if not fit: train_data = pop_pc_pd.loc[~pop_pc_pd[known_col].isnull()] N = len(train_data) @@ -234,33 +334,34 @@ def assign_population_pcs( fit_samples = [x for x in train_fit["s"]] evaluate_fit = train_data.loc[~train_data["s"].isin(fit_samples)] - # Train RF + # Train RF. training_set_known_labels = train_fit[known_col].values training_set_pcs = train_fit[pc_cols].values evaluation_set_pcs = evaluate_fit[pc_cols].values pop_clf = RandomForestClassifier(n_estimators=n_estimators, random_state=seed) pop_clf.fit(training_set_pcs, training_set_known_labels) - print( - "Random forest feature importances are as follows: {}".format( - pop_clf.feature_importances_ - ) + logger.info( + "Random forest feature importances are as follows: %s", + pop_clf.feature_importances_, ) - # Evaluate RF + # Evaluate RF. predictions = pop_clf.predict(evaluation_set_pcs) error_rate = 1 - sum(evaluate_fit[known_col] == predictions) / float( len(predictions) ) - print("Estimated error rate for RF model is {}".format(error_rate)) + logger.info("Estimated error rate for RF model is %.4f", error_rate) else: pop_clf = fit - # Classify data - pop_pc_pd[output_col] = pop_clf.predict(pop_pc_pd[pc_cols].values) - probs = pop_clf.predict_proba(pop_pc_pd[pc_cols].values) - probs = pd.DataFrame(probs, columns=[f"prob_{p}" for p in pop_clf.classes_]) - pop_pc_pd = pd.concat([pop_pc_pd, probs], axis=1) + # Classify data. + classifications, probs = apply_model_func(pop_pc_pd[pc_cols].values, pop_clf) + + pop_pc_pd[output_col] = classifications + pop_pc_pd = pd.concat( + [pop_pc_pd.reset_index(drop=True), probs.reset_index(drop=True)], axis=1 + ) probs["max"] = probs.max(axis=1) pop_pc_pd.loc[probs["max"] < min_prob, output_col] = missing_label pop_pc_pd = pop_pc_pd.drop(pc_cols, axis="columns") @@ -272,6 +373,9 @@ def assign_population_pcs( ), ) + if convert_model_func is not None: + pop_clf = convert_model_func(pop_clf) + if hail_input: pops_ht = hl.Table.from_pandas(pop_pc_pd, key=list(pop_pca_scores.key)) pops_ht = pops_ht.annotate_globals( diff --git a/gnomad/sample_qc/filtering.py b/gnomad/sample_qc/filtering.py index 663504262..d7aa3d3d7 100644 --- a/gnomad/sample_qc/filtering.py +++ b/gnomad/sample_qc/filtering.py @@ -148,7 +148,9 @@ def compute_stratified_metrics_filter( upper_threshold: float = 4.0, metric_threshold: Optional[Dict[str, Tuple[float, float]]] = None, filter_name: str = "qc_metrics_filters", - comparison_sample_expr: Optional[hl.expr.CollectionExpression] = None, + comparison_sample_expr: Optional[ + Union[hl.expr.BooleanExpression, hl.expr.CollectionExpression] + ] = None, ) -> hl.Table: """ Compute median, MAD, and upper and lower thresholds for each metric used in outlier filtering. @@ -163,10 +165,11 @@ def compute_stratified_metrics_filter( :param metric_threshold: Can be used to specify different (lower, upper) thresholds for one or more metrics. :param filter_name: Name of resulting filters annotation. - :param comparison_sample_expr: Optional CollectionExpression of sample IDs to use - for computation of the metric median, MAD, and upper and lower thresholds to - use for each sample. For instance, this works well with the output of - `determine_nearest_neighbors`. + :param comparison_sample_expr: Optional BooleanExpression or CollectionExpression + of sample IDs to use for computation of the metric median, MAD, and upper and + lower thresholds to use for each sample. For instance, this works well with the + output of `determine_nearest_neighbors` or a boolean expression defining + releasable samples. :return: Table grouped by strata, with upper and lower threshold values computed for each sample QC metric. """ @@ -187,16 +190,26 @@ def compute_stratified_metrics_filter( "_strata": hl.tuple([x[1] for x in strata]), } + sample_explode = False if comparison_sample_expr is not None: - select_expr["_comparison_sample"] = comparison_sample_expr - pre_explode_ht = ht.select(**select_expr) - ht = pre_explode_ht.explode(pre_explode_ht._comparison_sample) - ht = ht.annotate( - _comparison_qc_metrics=ht[ht._comparison_sample]._qc_metrics, - _comparison_strata=ht[ht._comparison_sample]._strata, - ) - metric_ann = "_comparison_qc_metrics" - strata_ann = "_comparison_strata" + if isinstance(comparison_sample_expr, hl.expr.BooleanExpression): + select_expr["_comparison_qc_metrics"] = hl.or_missing( + comparison_sample_expr, qc_metrics + ) + ht = ht.select(**select_expr) + metric_ann = "_comparison_qc_metrics" + strata_ann = "_strata" + else: + sample_explode = True + select_expr["_comparison_sample"] = comparison_sample_expr + pre_explode_ht = ht.select(**select_expr) + ht = pre_explode_ht.explode(pre_explode_ht._comparison_sample) + ht = ht.annotate( + _comparison_qc_metrics=ht[ht._comparison_sample]._qc_metrics, + _comparison_strata=ht[ht._comparison_sample]._strata, + ) + metric_ann = "_comparison_qc_metrics" + strata_ann = "_comparison_strata" else: ht = ht.select(**select_expr) metric_ann = "_qc_metrics" @@ -223,7 +236,7 @@ def compute_stratified_metrics_filter( ) select_expr = {} - if comparison_sample_expr is not None: + if sample_explode: ht = pre_explode_ht.annotate( **ht.group_by(ht.s).aggregate(qc_metrics_stats=agg_expr)[pre_explode_ht.key] ) @@ -238,8 +251,7 @@ def compute_stratified_metrics_filter( **{ f"fail_{metric}": ( ht._qc_metrics[metric] <= metrics_stats_expr[metric].lower - ) - | (ht._qc_metrics[metric] >= metrics_stats_expr[metric].upper) + ) | (ht._qc_metrics[metric] >= metrics_stats_expr[metric].upper) for metric in qc_metrics } ) @@ -255,10 +267,11 @@ def compute_stratified_metrics_filter( if no_strata: ann_expr = {"qc_metrics_stats": ht.qc_metrics_stats[(True,)]} - if comparison_sample_expr is None: - ht = ht.annotate_globals(**ann_expr) - else: + if sample_explode: ht = ht.annotate(**ann_expr) + else: + ht = ht.annotate_globals(**ann_expr) + else: ht = ht.annotate_globals(strata=hl.tuple([x[0] for x in strata])) ht = ht.annotate_globals(qc_metrics=list(qc_metrics.keys())) diff --git a/gnomad/sample_qc/pipeline.py b/gnomad/sample_qc/pipeline.py index c8a687cff..31b0edb3f 100644 --- a/gnomad/sample_qc/pipeline.py +++ b/gnomad/sample_qc/pipeline.py @@ -227,21 +227,27 @@ def get_qc_mt( snv_only=snv_only, adj_only=adj_only, min_af=min_af if min_af is not None else hl.null(hl.tfloat32), - min_callrate=min_callrate - if min_callrate is not None - else hl.null(hl.tfloat32), - inbreeding_coeff_threshold=min_inbreeding_coeff_threshold - if min_inbreeding_coeff_threshold is not None - else hl.null(hl.tfloat32), - min_hardy_weinberg_threshold=min_hardy_weinberg_threshold - if min_hardy_weinberg_threshold is not None - else hl.null(hl.tfloat32), + min_callrate=( + min_callrate if min_callrate is not None else hl.null(hl.tfloat32) + ), + inbreeding_coeff_threshold=( + min_inbreeding_coeff_threshold + if min_inbreeding_coeff_threshold is not None + else hl.null(hl.tfloat32) + ), + min_hardy_weinberg_threshold=( + min_hardy_weinberg_threshold + if min_hardy_weinberg_threshold is not None + else hl.null(hl.tfloat32) + ), apply_hard_filters=apply_hard_filters, ld_r2=ld_r2 if ld_r2 is not None else hl.null(hl.tfloat32), filter_exome_low_coverage_regions=filter_exome_low_coverage_regions, - high_conf_regions=high_conf_regions - if high_conf_regions is not None - else hl.null(hl.tarray(hl.tstr)), + high_conf_regions=( + high_conf_regions + if high_conf_regions is not None + else hl.null(hl.tarray(hl.tstr)) + ), ) ) return qc_mt.annotate_cols(sample_callrate=hl.agg.fraction(hl.is_defined(qc_mt.GT))) diff --git a/gnomad/sample_qc/relatedness.py b/gnomad/sample_qc/relatedness.py index ffe07c2c5..1a0be1669 100644 --- a/gnomad/sample_qc/relatedness.py +++ b/gnomad/sample_qc/relatedness.py @@ -690,21 +690,30 @@ def create_fake_pedigree( exclude_real_probands: bool = False, max_tries: int = 10, real_pedigree: Optional[hl.Pedigree] = None, + sample_list_stratification: Optional[Dict[str, str]] = None, ) -> hl.Pedigree: """ Generate a pedigree made of trios created by sampling 3 random samples in the sample list. - - If `real_pedigree` is given, then children in the resulting fake trios will not include any trio with proband - parents - that are in the real ones. + - If `real_pedigree` is given, then children in the resulting fake trios will not + include any trio with proband - parents that are in the real ones. - Each sample can be used only once as a proband in the resulting trios. - Sex of probands in fake trios is random. - :param n: Number of fake trios desired in the pedigree - :param sample_list: List of samples - :param exclude_real_probands: If set, then fake trios probands cannot be in the real trios probands. - :param max_tries: Maximum number of sampling to try before bailing out (preventing infinite loop if `n` is too large w.r.t. the number of samples) - :param real_pedigree: Optional pedigree to exclude children from - :return: Fake pedigree + :param n: Number of fake trios desired in the pedigree. + :param sample_list: List of samples. + :param exclude_real_probands: If set, then fake trios probands cannot be in the + real trios probands. + :param max_tries: Maximum number of sampling to try before bailing out (preventing + infinite loop if `n` is too large w.r.t. the number of samples). + :param real_pedigree: Optional pedigree to exclude children from. + :param sample_list_stratification: Optional dictionary with samples as keys and + a value that should be used to stratify samples in `sample_list` into groups + that the trio should be picked from. This ensures that each fake trio will + contain samples from only the same stratification. For example, if all samples + within a fake trio should be chosen from the same platform, this can be a + dictionary of sample: platform. + :return: Fake pedigree. """ real_trios = ( {trio.s: trio for trio in real_pedigree.trios} @@ -712,6 +721,18 @@ def create_fake_pedigree( else dict() ) + if sample_list_stratification is not None: + sample_list_stratified = defaultdict(list) + for s in sample_list: + s_strata = sample_list_stratification.get(s) + if s_strata is None: + raise ValueError( + f"Sample {s} not found in 'sample_list_stratification' dict!" + ) + sample_list_stratified[s_strata].append(s) + else: + sample_list_stratified = None + if exclude_real_probands and len(real_trios) == len(set(sample_list)): logger.warning( "All samples are in the real probands list; cannot create any fake" @@ -722,7 +743,13 @@ def create_fake_pedigree( fake_trios = {} tries = 0 while len(fake_trios) < n and tries < max_tries: - s, mat_id, pat_id = random.sample(sample_list, 3) + s = random.choice(sample_list) + if sample_list_stratified is None: + curr_sample_list = sample_list + else: + curr_sample_list = sample_list_stratified[sample_list_stratification[s]] + + mat_id, pat_id = random.sample(curr_sample_list, 2) if ( s in real_trios and ( @@ -1070,7 +1097,7 @@ def _is_dnm( locus.in_autosome(), proband_gt.is_het() & father_gt.is_hom_ref() & mother_gt.is_hom_ref(), ) - return hl.cond( + return hl.if_else( locus.in_autosome_or_par() | (proband_is_female & locus.in_x_nonpar()), proband_gt.is_het() & father_gt.is_hom_ref() & mother_gt.is_hom_ref(), hl.or_missing( diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index cd978f64f..0c3f0e429 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1,12 +1,20 @@ # noqa: D100 +import csv import itertools +import json import logging -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from timeit import default_timer as timer +import ga4gh.core as ga4gh_core +import ga4gh.vrs as ga4gh_vrs import hail as hl +from hail.utils.misc import new_temp_file +import gnomad.utils.filtering as filter_utils from gnomad.utils.gen_stats import to_phred +from gnomad.utils.reference_genome import get_reference_genome logging.basicConfig( format="%(asctime)s (%(name)s %(lineno)s): %(message)s", @@ -33,6 +41,61 @@ "pab_max": (0, 1, 50), } +VRS_CHROM_IDS = { + "GRCh38": { + "chr1": "ga4gh:SQ.Ya6Rs7DHhDeg7YaOSg1EoNi3U_nQ9SvO", + "chr2": "ga4gh:SQ.pnAqCRBrTsUoBghSD1yp_jXWSmlbdh4g", + "chr3": "ga4gh:SQ.Zu7h9AggXxhTaGVsy7h_EZSChSZGcmgX", + "chr4": "ga4gh:SQ.HxuclGHh0XCDuF8x6yQrpHUBL7ZntAHc", + "chr5": "ga4gh:SQ.aUiQCzCPZ2d0csHbMSbh2NzInhonSXwI", + "chr6": "ga4gh:SQ.0iKlIQk2oZLoeOG9P1riRU6hvL5Ux8TV", + "chr7": "ga4gh:SQ.F-LrLMe1SRpfUZHkQmvkVKFEGaoDeHul", + "chr8": "ga4gh:SQ.209Z7zJ-mFypBEWLk4rNC6S_OxY5p7bs", + "chr9": "ga4gh:SQ.KEO-4XBcm1cxeo_DIQ8_ofqGUkp4iZhI", + "chr10": "ga4gh:SQ.ss8r_wB0-b9r44TQTMmVTI92884QvBiB", + "chr11": "ga4gh:SQ.2NkFm8HK88MqeNkCgj78KidCAXgnsfV1", + "chr12": "ga4gh:SQ.6wlJpONE3oNb4D69ULmEXhqyDZ4vwNfl", + "chr13": "ga4gh:SQ._0wi-qoDrvram155UmcSC-zA5ZK4fpLT", + "chr14": "ga4gh:SQ.eK4D2MosgK_ivBkgi6FVPg5UXs1bYESm", + "chr15": "ga4gh:SQ.AsXvWL1-2i5U_buw6_niVIxD6zTbAuS6", + "chr16": "ga4gh:SQ.yC_0RBj3fgBlvgyAuycbzdubtLxq-rE0", + "chr17": "ga4gh:SQ.dLZ15tNO1Ur0IcGjwc3Sdi_0A6Yf4zm7", + "chr18": "ga4gh:SQ.vWwFhJ5lQDMhh-czg06YtlWqu0lvFAZV", + "chr19": "ga4gh:SQ.IIB53T8CNeJJdUqzn9V_JnRtQadwWCbl", + "chr20": "ga4gh:SQ.-A1QmD_MatoqxvgVxBLZTONHz9-c7nQo", + "chr21": "ga4gh:SQ.5ZUqxCmDDgN4xTRbaSjN8LwgZironmB8", + "chr22": "ga4gh:SQ.7B7SHsmchAR0dFcDCuSFjJAo7tX87krQ", + "chrX": "ga4gh:SQ.w0WZEvgJF0zf_P4yyTzjjv9oW1z61HHP", + "chrY": "ga4gh:SQ.8_liLu1aycC0tPQPFmUaGXJLDs5SbPZ5", + }, + "GRCh37": { + "1": "ga4gh:SQ.S_KjnFVz-FE7M0W6yoaUDgYxLPc1jyWU", + "2": "ga4gh:SQ.9KdcA9ZpY1Cpvxvg8bMSLYDUpsX6GDLO", + "3": "ga4gh:SQ.VNBualIltAyi2AI_uXcKU7M9XUOuA7MS", + "4": "ga4gh:SQ.iy7Zfceb5_VGtTQzJ-v5JpPbpeifHD_V", + "5": "ga4gh:SQ.vbjOdMfHJvTjK_nqvFvpaSKhZillW0SX", + "6": "ga4gh:SQ.KqaUhJMW3CDjhoVtBetdEKT1n6hM-7Ek", + "7": "ga4gh:SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86", + "8": "ga4gh:SQ.tTm7wmhz0G4lpt8wPspcNkAD_qiminj6", + "9": "ga4gh:SQ.HBckYGQ4wYG9APHLpjoQ9UUe9v7NxExt", + "10": "ga4gh:SQ.-BOZ8Esn8J88qDwNiSEwUr5425UXdiGX", + "11": "ga4gh:SQ.XXi2_O1ly-CCOi3HP5TypAw7LtC6niFG", + "12": "ga4gh:SQ.105bBysLoDFQHhajooTAUyUkNiZ8LJEH", + "13": "ga4gh:SQ.Ewb9qlgTqN6e_XQiRVYpoUfZJHXeiUfH", + "14": "ga4gh:SQ.5Ji6FGEKfejK1U6BMScqrdKJK8GqmIGf", + "15": "ga4gh:SQ.zIMZb3Ft7RdWa5XYq0PxIlezLY2ccCgt", + "16": "ga4gh:SQ.W6wLoIFOn4G7cjopxPxYNk2lcEqhLQFb", + "17": "ga4gh:SQ.AjWXsI7AkTK35XW9pgd3UbjpC3MAevlz", + "18": "ga4gh:SQ.BTj4BDaaHYoPhD3oY2GdwC_l0uqZ92UD", + "19": "ga4gh:SQ.ItRDD47aMoioDCNW_occY5fWKZBKlxCX", + "20": "ga4gh:SQ.iy_UbUrvECxFRX5LPTH_KPojdlT7BKsf", + "21": "ga4gh:SQ.LpTaNW-hwuY_yARP0rtarCnpCQLkgVCg", + "22": "ga4gh:SQ.XOgHwwR3Upfp5sZYk6ZKzvV25a4RBVu8", + "X": "ga4gh:SQ.v7noePfnNpK8ghYXEqZ9NukMXW7YeNsm", + "Y": "ga4gh:SQ.BT7QyW5iXaX_1PSX-msSGYsqRdMKqkj-", + }, +} + def pop_max_expr( freq: hl.expr.ArrayExpression, @@ -242,22 +305,27 @@ def qual_hist_expr( dp_expr: Optional[hl.expr.NumericExpression] = None, ad_expr: Optional[hl.expr.ArrayNumericExpression] = None, adj_expr: Optional[hl.expr.BooleanExpression] = None, + ab_expr: Optional[hl.expr.NumericExpression] = None, + split_adj_and_raw: bool = False, ) -> hl.expr.StructExpression: """ - Return a struct expression with genotype quality histograms based on the arguments given (dp, gq, ad). + Return a struct expression with genotype quality histograms based on the arguments given (dp, gq, ad, ab). .. note:: - If `gt_expr` is provided, will return histograms for non-reference samples only as well as all samples. - `gt_expr` is required for the allele-balance histogram, as it is only computed on het samples. + - If `ab_expr` is provided, the allele-balance histogram is computed using this expression instead of the ad_expr. - If `adj_expr` is provided, additional histograms are computed using only adj samples. - :param gt_expr: Entry expression containing genotype - :param gq_expr: Entry expression containing genotype quality - :param dp_expr: Entry expression containing depth - :param ad_expr: Entry expression containing allelic depth (bi-allelic here) - :param adj_expr: Entry expression containing adj (high quality) genotype status - :return: Genotype quality histograms expression + :param gt_expr: Entry expression containing genotype. + :param gq_expr: Entry expression containing genotype quality. + :param dp_expr: Entry expression containing depth. + :param ad_expr: Entry expression containing allelic depth (bi-allelic here). + :param adj_expr: Entry expression containing adj (high quality) genotype status. + :param ab_expr: Entry expression containing allele balance (bi-allelic here). + :param split_adj_and_raw: Whether to split the adj and raw histograms into separate fields in the returned struct expr. + :return: Genotype quality histograms expression. """ qual_hists = {} if gq_expr is not None: @@ -278,7 +346,14 @@ def qual_hist_expr( for qual_hist_name, qual_hist_expr in qual_hists.items() }, } - if ad_expr is not None: + ab_hist_msg = "Using the %s to compute allele balance histogram..." + if ab_expr is not None: + logger.info(ab_hist_msg, "ab_expr") + qual_hists["ab_hist_alt"] = hl.agg.filter( + gt_expr.is_het(), hl.agg.hist(ab_expr, 0, 1, 20) + ) + elif ad_expr is not None: + logger.info(ab_hist_msg, "ad_expr") qual_hists["ab_hist_alt"] = hl.agg.filter( gt_expr.is_het(), hl.agg.hist(ad_expr[1] / hl.sum(ad_expr), 0, 1, 20) ) @@ -290,13 +365,17 @@ def qual_hist_expr( } if adj_expr is not None: - qual_hists.update( - { - f"{qual_hist_name}_adj": hl.agg.filter(adj_expr, qual_hist_expr) - for qual_hist_name, qual_hist_expr in qual_hists.items() - } - ) - + adj_qual_hists = { + qual_hist_name: hl.agg.filter(adj_expr, qual_hist_expr) + for qual_hist_name, qual_hist_expr in qual_hists.items() + } + if split_adj_and_raw: + return hl.struct( + raw_qual_hists=hl.struct(**qual_hists), + qual_hists=hl.struct(**adj_qual_hists), + ) + else: + qual_hists.update({f"{k}_adj": v for k, v in adj_qual_hists.items()}) return hl.struct(**qual_hists) @@ -331,280 +410,6 @@ def age_hists_expr( ) -def annotate_freq( - mt: hl.MatrixTable, - sex_expr: Optional[hl.expr.StringExpression] = None, - pop_expr: Optional[hl.expr.StringExpression] = None, - subpop_expr: Optional[hl.expr.StringExpression] = None, - additional_strata_expr: Optional[ - Union[ - List[Dict[str, hl.expr.StringExpression]], - Dict[str, hl.expr.StringExpression], - ] - ] = None, - downsamplings: Optional[List[int]] = None, -) -> hl.MatrixTable: - """ - Annotate `mt` with stratified allele frequencies. - - The output Matrix table will include: - - row annotation `freq` containing the stratified allele frequencies - - global annotation `freq_meta` with metadata - - global annotation `freq_sample_count` with sample count information - - .. note:: - - Currently this only supports bi-allelic sites. - - The input `mt` needs to have the following entry fields: - - GT: a CallExpression containing the genotype - - adj: a BooleanExpression containing whether the genotype is of high quality or not. - - All expressions arguments need to be expression on the input `mt`. - - .. rubric:: `freq` row annotation - - The `freq` row annotation is an Array of Struct, with each Struct containing the following fields: - - - AC: int32 - - AF: float64 - - AN: int32 - - homozygote_count: int32 - - Each element of the array corresponds to a stratification of the data, and the metadata about these annotations is - stored in the globals. - - .. rubric:: Global `freq_meta` metadata annotation - - The global annotation `freq_meta` is added to the input `mt`. It is a list of dict. - Each element of the list contains metadata on a frequency stratification and the index in the list corresponds - to the index of that frequency stratification in the `freq` row annotation. - - .. rubric:: Global `freq_sample_count` annotation - - The global annotation `freq_sample_count` is added to the input `mt`. This is a sample count per sample grouping - defined in the `freq_meta` global annotation. - - .. rubric:: The `downsamplings` parameter - - If the `downsamplings` parameter is used, frequencies will be computed for all samples and by population - (if `pop_expr` is specified) by downsampling the number of samples without replacement to each of the numbers specified in the - `downsamplings` array, provided that there are enough samples in the dataset. - In addition, if `pop_expr` is specified, a downsampling to each of the exact number of samples present in each population is added. - Note that samples are randomly sampled only once, meaning that the lower downsamplings are subsets of the higher ones. - - .. rubric:: The `additional_strata_expr` parameter - - If the `additional_strata_expr` parameter is used, frequencies will be computed for each of the strata dictionaries across all - values. For example, if `additional_strata_expr` is set to `[{'platform': mt.platform}, {'platform':mt.platform, 'pop': mt.pop}, - {'age_bin': mt.age_bin}]`, then frequencies will be computed for each of the values of `mt.platform`, each of the combined values - of `mt.platform` and `mt.pop`, and each of the values of `mt.age_bin`. - - :param mt: Input MatrixTable - :param sex_expr: When specified, frequencies are stratified by sex. If `pop_expr` is also specified, then a pop/sex stratifiction is added. - :param pop_expr: When specified, frequencies are stratified by population. If `sex_expr` is also specified, then a pop/sex stratifiction is added. - :param subpop_expr: When specified, frequencies are stratified by sub-continental population. Note that `pop_expr` is required as well when using this option. - :param additional_strata_expr: When specified, frequencies are stratified by the given additional strata. This can e.g. be used to stratify by platform, platform-pop, platform-pop-sex. - :param downsamplings: When specified, frequencies are computed by downsampling the data to the number of samples given in the list. Note that if `pop_expr` is specified, downsamplings by population is also computed. - :return: MatrixTable with `freq` annotation - """ - if subpop_expr is not None and pop_expr is None: - raise NotImplementedError( - "annotate_freq requires pop_expr when using subpop_expr" - ) - - if additional_strata_expr is None: - additional_strata_expr = [{}] - - if isinstance(additional_strata_expr, dict): - additional_strata_expr = [additional_strata_expr] - - _freq_meta_expr = hl.struct( - **{k: v for d in additional_strata_expr for k, v in d.items()} - ) - if sex_expr is not None: - _freq_meta_expr = _freq_meta_expr.annotate(sex=sex_expr) - if pop_expr is not None: - _freq_meta_expr = _freq_meta_expr.annotate(pop=pop_expr) - if subpop_expr is not None: - _freq_meta_expr = _freq_meta_expr.annotate(subpop=subpop_expr) - - # Annotate cols with provided cuts - mt = mt.annotate_cols(_freq_meta=_freq_meta_expr) - - # Get counters for sex, pop and if set subpop and additional strata - cut_dict = { - cut: hl.agg.filter( - hl.is_defined(mt._freq_meta[cut]), hl.agg.counter(mt._freq_meta[cut]) - ) - for cut in mt._freq_meta - if cut != "subpop" - } - if "subpop" in mt._freq_meta: - cut_dict["subpop"] = hl.agg.filter( - hl.is_defined(mt._freq_meta.pop) & hl.is_defined(mt._freq_meta.subpop), - hl.agg.counter( - hl.struct(subpop=mt._freq_meta.subpop, pop=mt._freq_meta.pop) - ), - ) - - cut_data = mt.aggregate_cols(hl.struct(**cut_dict)) - sample_group_filters = [] - - # Create downsamplings if needed - if downsamplings is not None: - # Add exact pop size downsampling if pops were provided - if cut_data.get("pop"): - downsamplings = list( - set(downsamplings + list(cut_data.get("pop").values())) - ) # Add the pops values if not in yet - downsamplings = sorted( - [x for x in downsamplings if x <= sum(cut_data.get("pop").values())] - ) - logger.info("Found %d downsamplings: %s", len(downsamplings), downsamplings) - - # Shuffle the samples, then create a global index for downsampling - # And a pop-index if pops were provided - downsampling_ht = mt.cols() - downsampling_ht = downsampling_ht.annotate(r=hl.rand_unif(0, 1)) - downsampling_ht = downsampling_ht.order_by(downsampling_ht.r) - scan_expr = {"global_idx": hl.scan.count()} - if cut_data.get("pop"): - scan_expr["pop_idx"] = hl.scan.counter(downsampling_ht._freq_meta.pop).get( - downsampling_ht._freq_meta.pop, 0 - ) - downsampling_ht = downsampling_ht.annotate(**scan_expr) - downsampling_ht = downsampling_ht.key_by("s").select(*scan_expr) - mt = mt.annotate_cols(downsampling=downsampling_ht[mt.s]) - mt = mt.annotate_globals(downsamplings=downsamplings) - - # Create downsampled sample groups - sample_group_filters.extend( - [ - ( - {"downsampling": str(ds), "pop": "global"}, - mt.downsampling.global_idx < ds, - ) - for ds in downsamplings - ] - ) - if cut_data.get("pop"): - sample_group_filters.extend( - [ - ( - {"downsampling": str(ds), "pop": pop}, - (mt.downsampling.pop_idx < ds) & (mt._freq_meta.pop == pop), - ) - for ds in downsamplings - for pop, pop_count in cut_data.get("pop", {}).items() - if ds <= pop_count - ] - ) - - # Build a list of strata filters from the additional strata - additional_strata_filters = [] - for additional_strata in additional_strata_expr: - additional_strata_values = [ - cut_data.get(strata, {}) for strata in additional_strata - ] - additional_strata_combinations = itertools.product(*additional_strata_values) - - additional_strata_filters.extend( - [ - ( - { - strata: str(value) - for strata, value in zip(additional_strata, combination) - }, - hl.all( - list( - mt._freq_meta[strata] == value - for strata, value in zip(additional_strata, combination) - ) - ), - ) - for combination in additional_strata_combinations - ] - ) - - # Add all desired strata, starting with the full set and ending with - # downsamplings (if any) - sample_group_filters = ( - [({}, True)] - + [({"pop": pop}, mt._freq_meta.pop == pop) for pop in cut_data.get("pop", {})] - + [({"sex": sex}, mt._freq_meta.sex == sex) for sex in cut_data.get("sex", {})] - + [ - ( - {"pop": pop, "sex": sex}, - (mt._freq_meta.sex == sex) & (mt._freq_meta.pop == pop), - ) - for sex in cut_data.get("sex", {}) - for pop in cut_data.get("pop", {}) - ] - + [ - ( - {"subpop": subpop.subpop, "pop": subpop.pop}, - (mt._freq_meta.pop == subpop.pop) - & (mt._freq_meta.subpop == subpop.subpop), - ) - for subpop in cut_data.get("subpop", {}) - ] - + additional_strata_filters - + sample_group_filters - ) - - freq_sample_count = mt.aggregate_cols( - [hl.agg.count_where(x[1]) for x in sample_group_filters] - ) - - # Annotate columns with group_membership - mt = mt.annotate_cols(group_membership=[x[1] for x in sample_group_filters]) - - # Create and annotate global expression with meta and sample count information - freq_meta_expr = [ - dict(**sample_group[0], group="adj") for sample_group in sample_group_filters - ] - freq_meta_expr.insert(1, {"group": "raw"}) - freq_sample_count.insert(1, freq_sample_count[0]) - mt = mt.annotate_globals( - freq_meta=freq_meta_expr, - freq_sample_count=freq_sample_count, - ) - - # Create frequency expression array from the sample groups - # Adding sample_group_filters_range_array to reduce memory usage in this array_agg - mt = mt.annotate_rows( - sample_group_filters_range_array=hl.range(len(sample_group_filters)) - ) - freq_expr = hl.agg.array_agg( - lambda i: hl.agg.filter( - mt.group_membership[i] & mt.adj, hl.agg.call_stats(mt.GT, mt.alleles) - ), - mt.sample_group_filters_range_array, - ) - - # Insert raw as the second element of the array - freq_expr = ( - freq_expr[:1] - .extend([hl.agg.call_stats(mt.GT, mt.alleles)]) - .extend(freq_expr[1:]) - ) - - # Select non-ref allele (assumes bi-allelic) - freq_expr = freq_expr.map( - lambda cs: cs.annotate( - AC=cs.AC[1], - AF=cs.AF[ - 1 - ], # TODO This is NA in case AC and AN are 0 -- should we set it to 0? - homozygote_count=cs.homozygote_count[1], - ) - ) - - # Return MT with freq row annotation - return mt.annotate_rows(freq=freq_expr).drop("_freq_meta") - - def get_lowqual_expr( alleles: hl.expr.ArrayExpression, qual_approx_expr: Union[hl.expr.ArrayNumericExpression, hl.expr.NumericExpression], @@ -636,7 +441,7 @@ def get_lowqual_expr( if isinstance(qual_approx_expr, hl.expr.ArrayNumericExpression): return hl.range(1, hl.len(alleles)).map( - lambda ai: hl.cond( + lambda ai: hl.if_else( hl.is_snp(alleles[0], alleles[ai]), qual_approx_expr[ai - 1] < min_snv_qual, qual_approx_expr[ai - 1] < min_indel_qual, @@ -763,7 +568,7 @@ def get_adj_expr( """ return ( (gq_expr >= adj_gq) - & hl.cond(gt_expr.is_haploid(), dp_expr >= haploid_adj_dp, dp_expr >= adj_dp) + & hl.if_else(gt_expr.is_haploid(), dp_expr >= haploid_adj_dp, dp_expr >= adj_dp) & ( hl.case() .when(~gt_expr.is_het(), True) @@ -788,9 +593,21 @@ def annotate_adj( Defaults correspond to gnomAD values. """ + if "GT" not in mt.entry and "LGT" in mt.entry: + logger.warning("No GT field found, using LGT instead.") + gt_expr = mt.LGT + else: + gt_expr = mt.GT + + if "AD" not in mt.entry and "LAD" in mt.entry: + logger.warning("No AD field found, using LAD instead.") + ad_expr = mt.LAD + else: + ad_expr = mt.AD + return mt.annotate_entries( adj=get_adj_expr( - mt.GT, mt.GQ, mt.DP, mt.AD, adj_gq, adj_dp, adj_ab, haploid_adj_dp + gt_expr, mt.GQ, mt.DP, ad_expr, adj_gq, adj_dp, adj_ab, haploid_adj_dp ) ) @@ -801,12 +618,12 @@ def add_variant_type(alt_alleles: hl.expr.ArrayExpression) -> hl.expr.StructExpr alts = alt_alleles[1:] non_star_alleles = hl.filter(lambda a: a != "*", alts) return hl.struct( - variant_type=hl.cond( + variant_type=hl.if_else( hl.all(lambda a: hl.is_snp(ref, a), non_star_alleles), - hl.cond(hl.len(non_star_alleles) > 1, "multi-snv", "snv"), - hl.cond( + hl.if_else(hl.len(non_star_alleles) > 1, "multi-snv", "snv"), + hl.if_else( hl.all(lambda a: hl.is_indel(ref, a), non_star_alleles), - hl.cond(hl.len(non_star_alleles) > 1, "multi-indel", "indel"), + hl.if_else(hl.len(non_star_alleles) > 1, "multi-indel", "indel"), "mixed", ), ), @@ -994,7 +811,7 @@ def fs_from_sb( # Normalize table if counts get too large if normalize: fs_expr = hl.bind( - lambda sb, sb_sum: hl.cond( + lambda sb, sb_sum: hl.if_else( sb_sum <= 2 * min_cell_count, sb, sb.map(lambda x: hl.int(x / (sb_sum / min_cell_count))), @@ -1149,7 +966,7 @@ def region_flag_expr( :return: `region_flag` struct row annotation """ prob_flags_expr = ( - {"non_par": (t.locus.in_x_nonpar() | t.locus.in_y_nonpar())} if non_par else {} + {"non_par": t.locus.in_x_nonpar() | t.locus.in_y_nonpar()} if non_par else {} ) if prob_regions is not None: @@ -1231,3 +1048,1417 @@ def hemi_expr( # mt.GT[0] is alternate allele gt.is_haploid() & (sex_expr == male_str) & (gt[0] == 1), ) + + +def merge_freq_arrays( + farrays: List[hl.expr.ArrayExpression], + fmeta: List[List[Dict[str, str]]], + operation: str = "sum", + set_negatives_to_zero: bool = False, + count_arrays: Optional[Dict[str, List[hl.expr.ArrayExpression]]] = None, +) -> Union[ + Tuple[hl.expr.ArrayExpression, List[Dict[str, int]]], + Tuple[ + hl.expr.ArrayExpression, + List[Dict[str, int]], + Dict[str, List[hl.expr.ArrayExpression]], + ], +]: + """ + Merge a list of frequency arrays based on the supplied `operation`. + + .. warning:: + Arrays must be on the same Table. + + .. note:: + + Arrays do not have to contain the same groupings or order of groupings but + the array indices for a freq array in `farrays` must be the same as its associated + frequency metadata index in `fmeta` i.e., `farrays = [freq1, freq2]` then `fmeta` + must equal `[fmeta1, fmeta2]` where fmeta1 contains the metadata information + for freq1. + + If `operation` is set to "sum", groups in the merged array + will be the union of groupings found within the arrays' metadata and all arrays + with be summed by grouping. If `operation` is set to "diff", the merged array + will contain groups only found in the first array of `fmeta`. Any array containing + any of these groups will have thier values subtracted from the values of the first array. + + :param farrays: List of frequency arrays to merge. First entry in the list is the primary array to which other arrays will be added or subtracted. All arrays must be on the same Table. + :param fmeta: List of frequency metadata for arrays being merged. + :param operation: Merge operation to perform. Options are "sum" and "diff". If "diff" is passed, the first freq array in the list will have the other arrays subtracted from it. + :param set_negatives_to_zero: If True, set negative array values to 0 for AC, AN, AF, and homozygote_count. If False, raise a ValueError. Default is False. + :param count_arrays: Dictionary of Lists of arrays containing counts to merge using the passed operation. Must use the same group indexing as fmeta. Keys are the descriptor names, values are Lists of arrays to merge. Default is None. + :return: Tuple of merged frequency array, frequency metadata list and if `count_arrays` is not None, a dictionary of merged count arrays. + """ + if len(farrays) < 2: + raise ValueError("Must provide at least two frequency arrays to merge!") + if len(farrays) != len(fmeta): + raise ValueError("Length of farrays and fmeta must be equal!") + if operation not in ["sum", "diff"]: + raise ValueError("Operation must be either 'sum' or 'diff'!") + if count_arrays is not None: + for k, count_array in count_arrays.items(): + if len(count_array) != len(fmeta): + raise ValueError( + f"Length of count_array '{k}' and fmeta must be equal!" + ) + + # Create a list where each entry is a dictionary whose key is an aggregation + # group and the value is the corresponding index in the freq array. + fmeta = [hl.dict(hl.enumerate(f).map(lambda x: (x[1], [x[0]]))) for f in fmeta] + all_keys = hl.fold(lambda i, j: (i | j.key_set()), fmeta[0].key_set(), fmeta[1:]) + + # Merge dictionaries in the list into a single dictionary where key is aggregation + # group and the value is a list of the group's index in each of the freq arrays, if + # it exists. For "sum" operation, use keys, aka groups, found in all freq dictionaries. + # For "diff" operations, only use key_set from the first entry. + fmeta = hl.fold( + lambda i, j: hl.dict( + (hl.if_else(operation == "sum", all_keys, i.key_set())).map( + lambda k: ( + k, + i.get(k, [hl.missing(hl.tint32)]).extend( + j.get(k, [hl.missing(hl.tint32)]) + ), + ) + ) + ), + fmeta[0], + fmeta[1:], + ) + + # Create a list of tuples from the dictionary, sorted by the list of indices for + # each aggregation group. + fmeta = hl.sorted(fmeta.items(), key=lambda f: f[1]) + + # Create a list of the aggregation groups, maintaining the sorted order. + new_freq_meta = fmeta.map(lambda x: x[0]) + + # Create array for each aggregation group of arrays containing the group's freq + # values from each freq array. + freq_meta_idx = fmeta.map(lambda x: hl.zip(farrays, x[1]).map(lambda i: i[0][i[1]])) + + def _sum_or_diff_fields( + field_1_expr: str, field_2_expr: str + ) -> hl.expr.Int32Expression: + """ + Sum or subtract fields in call statistics struct. + + :param field_1_expr: First field to sum or diff. + :param field_2_expr: Second field to sum or diff. + :return: Merged field value. + """ + return hl.if_else( + operation == "sum", + hl.or_else(field_1_expr, 0) + hl.or_else(field_2_expr, 0), + hl.or_else(field_1_expr, 0) - hl.or_else(field_2_expr, 0), + ) + + # Iterate through the groups and their freq lists to merge callstats. + callstat_ann = ["AC", "AN", "homozygote_count"] + callstat_ann_af = ["AC", "AF", "AN", "homozygote_count"] + new_freq = freq_meta_idx.map( + lambda x: hl.bind( + lambda y: y.annotate(AF=hl.if_else(y.AN > 0, y.AC / y.AN, 0)).select( + *callstat_ann_af + ), + hl.fold( + lambda i, j: hl.struct( + **{ann: _sum_or_diff_fields(i[ann], j[ann]) for ann in callstat_ann} + ), + x[0].select(*callstat_ann), + x[1:], + ), + ) + ) + # Create count_array_meta_idx using the fmeta then iterate through each group + # in the list of tuples to access each group's entry per array. Sum or diff the + # values for each group across arrays to make a new_counts_array annotation. + if count_arrays: + new_counts_array_dict = {} + for k, count_array in count_arrays.items(): + count_array_meta_idx = fmeta.map( + lambda x: hl.zip(count_array, x[1]).map(lambda i: i[0][i[1]]) + ) + + new_counts_array_dict[k] = count_array_meta_idx.map( + lambda x: hl.fold( + lambda i, j: _sum_or_diff_fields(i, j), + x[0], + x[1:], + ), + ) + # Check and see if any annotation within the merged array is negative. If so, + # raise an error if set_negatives_to_zero is False or set the value to 0 if + # set_negatives_to_zero is True. + if operation == "diff": + negative_value_error_msg = ( + "Negative values found in merged %s array. Review data or set" + " `set_negatives_to_zero` to True to set negative values to 0." + ) + callstat_ann.append("AF") + new_freq = new_freq.map( + lambda x: x.annotate( + **{ + ann: ( + hl.case() + .when(set_negatives_to_zero, hl.max(x[ann], 0)) + .when(x[ann] >= 0, x[ann]) + .or_error(negative_value_error_msg % "freq") + ) + for ann in callstat_ann + } + ) + ) + if count_arrays: + for k, new_counts_array in new_counts_array_dict.items(): + new_counts_array_dict[k] = new_counts_array.map( + lambda x: hl.case() + .when(set_negatives_to_zero, hl.max(x, 0)) + .when(x >= 0, x) + .or_error(negative_value_error_msg % "counts") + ) + + new_freq_meta = hl.eval(new_freq_meta) + if count_arrays: + return new_freq, new_freq_meta, new_counts_array_dict + else: + return new_freq, new_freq_meta + + +def merge_histograms(hists: List[hl.expr.StructExpression]) -> hl.expr.Expression: + """ + Merge a list of histogram annotations. + + This function merges a list of histogram annotations by summing the arrays + in an element-wise fashion. It keeps one 'bin_edge' annotation but merges the + 'bin_freq', 'n_smaller', and 'n_larger' annotations by summing them. + + .. note:: + + Bin edges are assumed to be the same for all histograms. + + :param hists: List of histogram structs to merge. + :return: Merged histogram struct. + """ + return hl.fold( + lambda i, j: hl.struct( + **{ + "bin_edges": i.bin_edges, # Bin edges are the same for all histograms + "bin_freq": hl.zip(i.bin_freq, j.bin_freq).map(lambda x: x[0] + x[1]), + "n_smaller": i.n_smaller + j.n_smaller, + "n_larger": i.n_larger + j.n_larger, + } + ), + hists[0].select("bin_edges", "bin_freq", "n_smaller", "n_larger"), + hists[1:], + ) + + +# Functions used for computing allele frequency. +def annotate_freq( + mt: hl.MatrixTable, + sex_expr: Optional[hl.expr.StringExpression] = None, + pop_expr: Optional[hl.expr.StringExpression] = None, + subpop_expr: Optional[hl.expr.StringExpression] = None, + additional_strata_expr: Optional[ + Union[ + List[Dict[str, hl.expr.StringExpression]], + Dict[str, hl.expr.StringExpression], + ] + ] = None, + downsamplings: Optional[List[int]] = None, + downsampling_expr: Optional[hl.expr.StructExpression] = None, + ds_pop_counts: Optional[Dict[str, int]] = None, + entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, + annotate_mt: bool = True, +) -> Union[hl.Table, hl.MatrixTable]: + """ + Annotate `mt` with stratified allele frequencies. + + The output Matrix table will include: + - row annotation `freq` containing the stratified allele frequencies + - global annotation `freq_meta` with metadata + - global annotation `freq_meta_sample_count` with sample count information + + .. note:: + + Currently this only supports bi-allelic sites. + + The input `mt` needs to have the following entry fields: + - GT: a CallExpression containing the genotype + - adj: a BooleanExpression containing whether the genotype is of high quality + or not. + + All expressions arguments need to be expression on the input `mt`. + + .. rubric:: `freq` row annotation + + The `freq` row annotation is an Array of Structs, with each Struct containing the + following fields: + + - AC: int32 + - AF: float64 + - AN: int32 + - homozygote_count: int32 + + Each element of the array corresponds to a stratification of the data, and the + metadata about these annotations is stored in the globals. + + .. rubric:: Global `freq_meta` metadata annotation + + The global annotation `freq_meta` is added to the input `mt`. It is a list of dict. + Each element of the list contains metadata on a frequency stratification and the + index in the list corresponds to the index of that frequency stratification in the + `freq` row annotation. + + .. rubric:: Global `freq_meta_sample_count` annotation + + The global annotation `freq_meta_sample_count` is added to the input `mt`. This is a + sample count per sample grouping defined in the `freq_meta` global annotation. + + .. rubric:: The `additional_strata_expr` parameter + + If the `additional_strata_expr` parameter is used, frequencies will be computed for + each of the strata dictionaries across all values. For example, if + `additional_strata_expr` is set to `[{'platform': mt.platform}, + {'platform':mt.platform, 'pop': mt.pop}, {'age_bin': mt.age_bin}]`, then + frequencies will be computed for each of the values of `mt.platform`, each of the + combined values of `mt.platform` and `mt.pop`, and each of the values of + `mt.age_bin`. + + .. rubric:: The `downsamplings` parameter + + If the `downsamplings` parameter is used without the `downsampling_expr`, + frequencies will be computed for all samples and by population (if `pop_expr` is + specified) by downsampling the number of samples without replacement to each of the + numbers specified in the `downsamplings` array, provided that there are enough + samples in the dataset. In addition, if `pop_expr` is specified, a downsampling to + each of the exact number of samples present in each population is added. Note that + samples are randomly sampled only once, meaning that the lower downsamplings are + subsets of the higher ones. If the `downsampling_expr` parameter is used with the + `downsamplings` parameter, the `downsamplings` parameter informs the function which + downsampling groups were already created and are to be used in the frequency + calculation. + + .. rubric:: The `downsampling_expr` and `ds_pop_counts` parameters + + If the `downsampling_expr` parameter is used, `downsamplings` must also be set + and frequencies will be computed for all samples and by population (if `pop_expr` + is specified) using the downsampling indices to each of the numbers specified in + the `downsamplings` array. The function expects a 'global_idx', and if `pop_expr` + is used, a 'pop_idx' within the `downsampling_expr` to be used to determine if a + sample belongs within a certain downsampling group, i.e. the index is less than + the group size. `The function `annotate_downsamplings` can be used to to create + the `downsampling_expr`, `downsamplings`, and `ds_pop_counts` expressions. + + .. rubric:: The `entry_agg_funcs` parameter + + If the `entry_agg_funcs` parameter is used, the output MatrixTable will also + contain the annotations specified in the `entry_agg_funcs` parameter. The keys of + the dict are the names of the annotations and the values are tuples of functions. + The first function is used to transform the `mt` entries in some way, and the + second function is used to aggregate the output from the first function. For + example, if `entry_agg_funcs` is set to {'adj_samples': (get_adj_expr, hl.agg.sum)}`, + then the output MatrixTable will contain an annotation `adj_samples` which is an + array of the number of adj samples per strata in each row. + + :param mt: Input MatrixTable + :param sex_expr: When specified, frequencies are stratified by sex. If `pop_expr` + is also specified, then a pop/sex stratifiction is added. + :param pop_expr: When specified, frequencies are stratified by population. If + `sex_expr` is also specified, then a pop/sex stratifiction is added. + :param subpop_expr: When specified, frequencies are stratified by sub-continental + population. Note that `pop_expr` is required as well when using this option. + :param additional_strata_expr: When specified, frequencies are stratified by the + given additional strata. This can e.g. be used to stratify by platform, + platform-pop, platform-pop-sex. + :param downsamplings: When specified, frequencies are computed by downsampling the + data to the number of samples given in the list. Note that if `pop_expr` is + specified, downsamplings by population is also computed. + :param downsampling_expr: When specified, frequencies are computed using the + downsampling indices in the provided StructExpression. Note that if `pop_idx` + is specified within the struct, downsamplings by population is also computed. + :param ds_pop_counts: When specified, frequencies are computed by downsampling the + data to the number of samples per pop in the dict. The key is the population + and the value is the number of samples. + :param entry_agg_funcs: When specified, additional annotations are added to the + output Table/MatrixTable. The keys of the dict are the names of the annotations + and the values are tuples of functions. The first function is used to transform + the `mt` entries in some way, and the second function is used to aggregate the + output from the first function. + :param annotate_mt: Whether to return the full MatrixTable with annotations added + instead of only a Table with `freq` and other annotations. Default is True. + :return: MatrixTable or Table with `freq` annotation. + """ + errors = [] + if downsampling_expr is not None: + if downsamplings is None: + errors.append( + "annotate_freq requires `downsamplings` when using `downsampling_expr`" + ) + if downsampling_expr.get("pop_idx") is not None: + if ds_pop_counts is None: + errors.append( + "annotate_freq requires `ds_pop_counts` when using " + "`downsampling_expr` with pop_idx" + ) + if errors: + raise ValueError("The following errors were found: \n" + "\n".join(errors)) + + # Generate downsamplings and assign downsampling_expr if it is None when + # downsamplings is supplied. + if downsamplings is not None and downsampling_expr is None: + ds_ht = annotate_downsamplings(mt, downsamplings, pop_expr=pop_expr).cols() + downsamplings = hl.eval(ds_ht.downsamplings) + ds_pop_counts = hl.eval(ds_ht.ds_pop_counts) + downsampling_expr = ds_ht[mt.col_key].downsampling + + # Build list of all stratification groups to be used in the frequency calculation. + strata_expr = build_freq_stratification_list( + sex_expr=sex_expr, + pop_expr=pop_expr, + subpop_expr=subpop_expr, + additional_strata_expr=additional_strata_expr, + downsampling_expr=downsampling_expr, + ) + + # Annotate the MT cols with each of the expressions in strata_expr and redefine + # strata_expr based on the column HT with added annotations. + ht = mt.annotate_cols(**{k: v for d in strata_expr for k, v in d.items()}).cols() + strata_expr = [{k: ht[k] for k in d} for d in strata_expr] + + # Annotate HT with a freq_meta global and group membership array for each sample + # indicating whether the sample belongs to the group defined by freq_meta elements. + ht = generate_freq_group_membership_array( + ht, + strata_expr, + downsamplings=downsamplings, + ds_pop_counts=ds_pop_counts, + ) + + freq_ht = compute_freq_by_strata( + mt.annotate_cols(group_membership=ht[mt.col_key].group_membership), + entry_agg_funcs=entry_agg_funcs, + ) + freq_ht = freq_ht.annotate_globals(**ht.index_globals()) + + if annotate_mt: + mt = mt.annotate_rows(**freq_ht[mt.row_key]) + mt = mt.annotate_globals(**freq_ht.index_globals()) + return mt + + else: + return freq_ht + + +def annotate_downsamplings( + t: Union[hl.MatrixTable, hl.Table], + downsamplings: List[int], + pop_expr: Optional[hl.expr.StringExpression] = None, +) -> Union[hl.MatrixTable, hl.Table]: + """ + Annotate MatrixTable or Table with downsampling groups. + + :param t: Input MatrixTable or Table. + :param downsamplings: List of downsampling sizes. + :param pop_expr: Optional expression for population group. When provided, population + sample sizes are added as values to downsamplings. + :return: MatrixTable or Table with downsampling annotations. + """ + if isinstance(t, hl.MatrixTable): + if pop_expr is not None: + ht = t.annotate_cols(pop=pop_expr).cols() + else: + ht = t.cols() + else: + if pop_expr is not None: + ht = t.annotate(pop=pop_expr) + else: + ht = t + + ht = ht.key_by(r=hl.rand_unif(0, 1)) + + # Add a global index for use in computing frequencies, or other aggregate stats on + # the downsamplings. + scan_expr = {"global_idx": hl.scan.count()} + + # If pop_expr is provided, add all pop counts to the downsamplings list. + if pop_expr is not None: + pop_counts = ht.aggregate( + hl.agg.filter(hl.is_defined(ht.pop), hl.agg.counter(ht.pop)) + ) + downsamplings = [x for x in downsamplings if x <= sum(pop_counts.values())] + downsamplings = sorted(set(downsamplings + list(pop_counts.values()))) + # Add an index by pop for use in computing frequencies, or other aggregate stats + # on the downsamplings. + scan_expr["pop_idx"] = hl.scan.counter(ht.pop).get(ht.pop, 0) + else: + pop_counts = None + logger.info("Found %i downsamplings: %s", len(downsamplings), downsamplings) + + ht = ht.annotate(**scan_expr) + ht = ht.key_by("s").select(*scan_expr) + + if isinstance(t, hl.MatrixTable): + t = t.annotate_cols(downsampling=ht[t.s]) + else: + t = t.annotate(downsampling=ht[t.s]) + + t = t.annotate_globals( + downsamplings=downsamplings, + ds_pop_counts=pop_counts, + ) + + return t + + +def build_freq_stratification_list( + sex_expr: Optional[hl.expr.StringExpression] = None, + pop_expr: Optional[hl.expr.StringExpression] = None, + subpop_expr: Optional[hl.expr.StringExpression] = None, + additional_strata_expr: Optional[ + Union[ + List[Dict[str, hl.expr.StringExpression]], + Dict[str, hl.expr.StringExpression], + ] + ] = None, + downsampling_expr: Optional[hl.expr.StructExpression] = None, +) -> List[Dict[str, hl.expr.StringExpression]]: + """ + Build a list of stratification groupings to be used in frequency calculations based on supplied parameters. + + .. note:: + This function is primarily used through `annotate_freq` but can be used + independently if desired. The returned list of stratifications can be passed to + `generate_freq_group_membership_array`. + + :param sex_expr: When specified, the returned list contains a stratification for + sex. If `pop_expr` is also specified, then the returned list also contains a + pop/sex stratification. + :param pop_expr: When specified, the returned list contains a stratification for + population. If `sex_expr` is also specified, then the returned list also + contains a pop/sex stratification. + :param subpop_expr: When specified, the returned list contains a stratification for + sub-continental population. Note that `pop_expr` is required as well when using + this option. + :param additional_strata_expr: When specified, the returned list contains a + stratification for each of the additional strata. This can e.g. be used to + stratify by platform, platform-pop, platform-pop-sex. + :param downsampling_expr: When specified, the returned list contains a + stratification for downsampling. If `pop_expr` is also specified, then the + returned list also contains a downsampling/pop stratification. + :return: List of dictionaries specifying stratification groups where the keys of + each dictionary are strings and the values are corresponding expressions that + define the values to stratify frequency calculations by. + """ + errors = [] + if subpop_expr is not None and pop_expr is None: + errors.append("annotate_freq requires pop_expr when using subpop_expr") + + if downsampling_expr is not None: + if downsampling_expr.get("global_idx") is None: + errors.append( + "annotate_freq requires `downsampling_expr` with key 'global_idx'" + ) + if downsampling_expr.get("pop_idx") is None: + if pop_expr is not None: + errors.append( + "annotate_freq requires `downsampling_expr` with key 'pop_idx' when" + " using `pop_expr`" + ) + else: + if pop_expr is None: + errors.append( + "annotate_freq requires `pop_expr` when using `downsampling_expr` " + "with pop_idx" + ) + + if errors: + raise ValueError("The following errors were found: \n" + "\n".join(errors)) + + # Build list of strata expressions based on supplied parameters. + strata_expr = [] + if pop_expr is not None: + strata_expr.append({"pop": pop_expr}) + if sex_expr is not None: + strata_expr.append({"sex": sex_expr}) + if pop_expr is not None: + strata_expr.append({"pop": pop_expr, "sex": sex_expr}) + if subpop_expr is not None: + strata_expr.append({"pop": pop_expr, "subpop": subpop_expr}) + + # Add downsampling to strata expressions, include pop in the strata if supplied. + if downsampling_expr is not None: + downsampling_strata = {"downsampling": downsampling_expr} + if pop_expr is not None: + downsampling_strata["pop"] = pop_expr + strata_expr.append(downsampling_strata) + + # Add additional strata expressions. + if additional_strata_expr is not None: + if isinstance(additional_strata_expr, dict): + additional_strata_expr = [additional_strata_expr] + strata_expr.extend(additional_strata_expr) + + return strata_expr + + +def generate_freq_group_membership_array( + ht: hl.Table, + strata_expr: List[Dict[str, hl.expr.StringExpression]], + downsamplings: Optional[List[int]] = None, + ds_pop_counts: Optional[Dict[str, int]] = None, +) -> hl.Table: + """ + Generate a Table with a 'group_membership' array for each sample indicating whether the sample belongs to specific stratification groups. + + .. note:: + This function is primarily used through `annotate_freq` but can be used + independently if desired. Please see the `annotate_freq` function for more + complete documentation. + + The following global annotations are added to the returned Table: + - freq_meta: Each element of the list contains metadata on a stratification + group. + - freq_meta_sample_count: sample count per grouping defined in `freq_meta`. + - If downsamplings or ds_pop_counts are specified, they are also added as + global annotations on the returned Table. + + Each sample is annotated with a 'group_membership' array indicating whether the + sample belongs to specific stratification groups. All possible value combinations + are determined for each stratification grouping in the `strata_expr` list. + + :param ht: Input Table that contains Expressions specified by `strata_expr`. + :param strata_expr: List of dictionaries specifying stratification groups where + the keys of each dictionary are strings and the values are corresponding + expressions that define the values to stratify frequency calculations by. + :param downsamplings: List of downsampling values to include in the stratifications. + :param ds_pop_counts: Dictionary of population counts for each downsampling value. + :return: Table with the 'group_membership' array annotation. + """ + errors = [] + ds_in_strata = any("downsampling" in s for s in strata_expr) + global_idx_in_ds_expr = any( + "global_idx" in s["downsampling"] for s in strata_expr if "downsampling" in s + ) + pop_in_strata = any("pop" in s for s in strata_expr) + pop_idx_in_ds_expr = any( + "pop_idx" in s["downsampling"] + for s in strata_expr + if "downsampling" in s and ds_pop_counts is not None + ) + + if downsamplings is not None and not ds_in_strata: + errors.append( + "Strata must contain a downsampling expression when downsamplings" + "are provided." + ) + if downsamplings is not None and not global_idx_in_ds_expr: + errors.append( + "Strata must contain a downsampling expression with 'global_idx' when " + "downsamplings are provided." + ) + if ds_pop_counts is not None and not pop_in_strata: + errors.append( + "Strata must contain a population expression 'pop' when ds_pop_counts " + " are provided." + ) + if ds_pop_counts is not None and not pop_idx_in_ds_expr: + errors.append( + "Strata must contain a downsampling expression with 'pop_idx' when " + "ds_pop_counts are provided." + ) + + if errors: + raise ValueError("The following errors were found: \n" + "\n".join(errors)) + + # Get counters for all strata. + strata_counts = ht.aggregate( + hl.struct( + **{ + k: hl.agg.filter(hl.is_defined(v), hl.agg.counter({k: v})) + for strata in strata_expr + for k, v in strata.items() + } + ) + ) + + # Add all desired strata to sample group filters. + sample_group_filters = [({}, True)] + for strata in strata_expr: + downsampling_expr = strata.get("downsampling") + strata_values = [] + # Add to all downsampling groups, both global and population-specific, to + # strata. + for s in strata: + if s == "downsampling": + v = [("downsampling", d) for d in downsamplings] + else: + v = [(s, k[s]) for k in strata_counts.get(s, {})] + if s == "pop" and downsampling_expr is not None: + v.append(("pop", "global")) + strata_values.append(v) + + # Get all combinations of strata values. + strata_combinations = itertools.product(*strata_values) + # Create sample group filters that are evaluated on each sample for each strata + # combination. Strata combinations are evaluated as a logical AND, e.g. + # {"pop":nfe, "downsampling":1000} or "nfe-10000" creates the filter expression + # pop == nfe AND downsampling pop_idx < 10000. + for combo in strata_combinations: + combo = dict(combo) + ds = combo.get("downsampling") + pop = combo.get("pop") + # If combo contains downsampling, determine the downsampling index + # annotation to use. + downsampling_idx = "global_idx" + if ds is not None: + if pop is not None and pop != "global": + # Don't include population downsamplings where the downsampling is + # larger than the number of samples in the population. + if ds > ds_pop_counts[pop]: + continue + downsampling_idx = "pop_idx" + + # If combo contains downsampling, add downsampling filter expression. + combo_filter_exprs = [] + for s, v in combo.items(): + if s == "downsampling": + combo_filter_exprs.append(downsampling_expr[downsampling_idx] < v) + else: + if s != "pop" or v != "global": + combo_filter_exprs.append(strata[s] == v) + combo = {k: str(v) for k, v in combo.items()} + sample_group_filters.append((combo, hl.all(combo_filter_exprs))) + + n_groups = len(sample_group_filters) + logger.info("number of filters: %i", n_groups) + + # Get sample count per strata group. + freq_meta_sample_count = ht.aggregate( + [hl.agg.count_where(x[1]) for x in sample_group_filters] + ) + + # Annotate columns with group_membership. + ht = ht.select(group_membership=[x[1] for x in sample_group_filters]) + + # Create and annotate global expression with meta and sample count information. + freq_meta = [ + dict(**sample_group[0], group="adj") for sample_group in sample_group_filters + ] + + # Add the "raw" group, representing all samples, to the freq_meta_expr list. + freq_meta.insert(1, {"group": "raw"}) + freq_meta_sample_count.insert(1, freq_meta_sample_count[0]) + + global_expr = { + "freq_meta": freq_meta, + "freq_meta_sample_count": freq_meta_sample_count, + } + + if downsamplings is not None: + global_expr["downsamplings"] = downsamplings + if ds_pop_counts is not None: + global_expr["ds_pop_counts"] = ds_pop_counts + + ht = ht.select_globals(**global_expr) + ht = ht.checkpoint(hl.utils.new_temp_file("group_membership", "ht")) + + return ht + + +def compute_freq_by_strata( + mt: hl.MatrixTable, + entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, +) -> hl.Table: + """ + Compute call statistics and, when passed, entry aggregation function(s) by strata. + + The computed call statistics are AC, AF, AN, and homozygote_count. The entry + aggregation functions are applied to the MatrixTable entries and aggregated. The + MatrixTable must contain a 'group_membership' annotation (like the one added by + `generate_freq_group_membership_array`) that is a list of bools to aggregate the + columns by. + + .. note:: + This function is primarily used through `annotate_freq` but can be used + independently if desired. Please see the `annotate_freq` function for more + complete documentation. + + :param mt: Input MatrixTable. + :param entry_agg_funcs: Optional dict of entry aggregation functions. When + specified, additional annotations are added to the output Table/MatrixTable. + The keys of the dict are the names of the annotations and the values are tuples + of functions. The first function is used to transform the `mt` entries in some + way, and the second function is used to aggregate the output from the first + function. + :return: Table or MatrixTable with allele frequencies by strata. + """ + if entry_agg_funcs is None: + entry_agg_funcs = {} + + n_samples = mt.count_cols() + n_groups = len(mt.group_membership.take(1)[0]) + ht = mt.localize_entries("entries", "cols") + ht = ht.annotate_globals( + indices_by_group=hl.range(n_groups).map( + lambda g_i: hl.range(n_samples).filter( + lambda s_i: ht.cols[s_i].group_membership[g_i] + ) + ) + ) + # Pull out each annotation that will be used in the array aggregation below as its + # own ArrayExpression. This is important to prevent memory issues when performing + # the below array aggregations. + ht = ht.select( + adj_array=ht.entries.map(lambda e: e.adj), + gt_array=ht.entries.map(lambda e: e.GT), + **{ + ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols) + for ann, f in entry_agg_funcs.items() + }, + ) + + def _agg_by_group( + ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args + ) -> hl.expr.ArrayExpression: + """ + Aggregate `agg_expr` by group using the `agg_func` function. + + :param ht: Input Hail Table. + :param agg_func: Aggregation function to apply to `agg_expr`. + :param agg_expr: Expression to aggregate by group. + :param args: Additional arguments to pass to the `agg_func`. + :return: Aggregated array expression. + """ + adj_agg_expr = ht.indices_by_group.map( + lambda s_indices: s_indices.aggregate( + lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args)) + ) + ) + raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args)) + # Create final agg list by inserting the "raw" group, representing all samples, + # into the adj_agg_list. + return adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[1:]) + + freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles) + + # Select non-ref allele (assumes bi-allelic). + freq_expr = freq_expr.map( + lambda cs: cs.annotate( + AC=cs.AC[1], + AF=cs.AF[1], + homozygote_count=cs.homozygote_count[1], + ) + ) + # Add annotations for any supplied entry transform and aggregation functions. + ht = ht.select( + **{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()}, + freq=freq_expr, + ) + + return ht.drop("cols") + + +def update_structured_annotations( + ht: hl.Table, + annotation_update_exprs: Dict[str, hl.Expression], + annotation_update_label: Optional[str] = None, +) -> hl.Table: + """ + Update highly structured annotations on a Table. + + This function recursively updates annotations defined by `annotation_update_exprs` + and if `annotation_update_label` is supplied, it checks if the sample annotations + are different from the input and adds a flag to the Table, indicating which + annotations have been updated for each sample. + + :param ht: Input Table with structured annotations to update. + :param annotation_update_exprs: Dictionary of annotations to update, structured as + they are structured on the input `ht`. + :param annotation_update_label: Optional string of the label to use for an + annotation indicating which annotations have been updated. Default is None, so + no annotation is added. + :return: Table with updated annotations and optionally a flag indicating which + annotations were changed. + """ + + def _update_struct( + struct_expr: hl.expr.StructExpression, + update_exprs: Union[Dict[str, hl.expr.Expression], hl.expr.Expression], + ) -> Tuple[Dict[str, hl.expr.BooleanExpression], Any]: + """ + Update a StructExpression. + + :param struct_expr: StructExpression to update. + :param update_exprs: Dictionary of annotations to update. + :return: Tuple of the updated annotations and the updated flag. + """ + if isinstance(update_exprs, dict): + updated_struct_expr = {} + updated_flag_expr = {} + for ann, expr in update_exprs.items(): + updated_flag, updated_ann = _update_struct(struct_expr[ann], expr) + updated_flag_expr.update( + {ann + ("." + k if k else ""): v for k, v in updated_flag.items()} + ) + updated_struct_expr[ann] = updated_ann + return updated_flag_expr, struct_expr.annotate(**updated_struct_expr) + else: + return {"": update_exprs != struct_expr}, update_exprs + + annotation_update_flag, updated_rows = _update_struct( + ht.row_value, annotation_update_exprs + ) + if annotation_update_label is not None: + updated_rows = updated_rows.annotate( + **{ + annotation_update_label: filter_utils.add_filters_expr( + filters=annotation_update_flag + ) + } + ) + + return ht.annotate(**updated_rows) +def get_gks( + ht: hl.Table, + variant: str, + label_name: str, + label_version: str, + coverage_ht: hl.Table = None, + ancestry_groups: list = None, + ancestry_groups_dict: dict = None, + by_sex: bool = False, + vrs_only: bool = False, +) -> dict: + """ + Filter to a specified variant and return a Python dictionary containing the GA4GH variant report schema. + + :param ht: Hail Table to parse for desired variant. + :param variant: String of variant to search for (chromosome, position, ref, and alt, separated by '-'). Example for a variant in build GRCh38: "chr5-38258681-C-T". + :param label_name: Label name to use within the returned dictionary. Example: "gnomAD". + :param label_version: String listing the version of the HT being used. Example: "3.1.2" . + :param coverage_ht: Hail Table containing coverage statistics, with mean depth stored in "mean" annotation. If None, omit coverage in return. + :param ancestry_groups: List of strings of shortened names of genetic ancestry groups to return results for. Example: ['afr','fin','nfe'] . + :param ancestry_groups_dict: Dict mapping shortened genetic ancestry group names to full names. Example: {'afr':'African/African American'} . + :param by_sex: Boolean to include breakdown of ancestry groups by inferred sex (XX and XY) as well. + :param vrs_only: Boolean to return only the VRS information and no general frequency information. Default is False. + :return: Dictionary containing VRS information (and frequency information split by ancestry groups and sex if desired) for the specified variant. + + """ + # Throw warnings if contradictory arguments passed. + if ancestry_groups and vrs_only: + logger.warning( + "Both 'vrs_only' and 'ancestry_groups' have been specified. Ignoring" + " 'ancestry_groups' list and returning only VRS information." + ) + elif by_sex and not ancestry_groups: + logger.warning( + "Splitting whole database by sex is not yet supported. If using 'by_sex'," + " please also specify 'ancestry_groups' to stratify by." + ) + + # Define variables for variant information. + build_in = get_reference_genome(ht.locus).name + chrom_dict = VRS_CHROM_IDS[build_in] + chr_in, pos_in, ref_in, alt_in = variant.split("-") + + # Filter HT to desired variant. + ht = ht.filter( + ( + ht.locus + == hl.locus(contig=chr_in, pos=int(pos_in), reference_genome=build_in) + ) + & (ht.alleles == [ref_in, alt_in]) + ) + ht = ht.checkpoint(new_temp_file("get_gks", extension="ht")) + # Check to ensure the ht is successfully filtered to 1 variant. + if ht.count() != 1: + raise ValueError( + "Error: can only work with one variant for this code, 0 or multiple" + " returned." + ) + + # Define VRS Attributes that will later be read into the dictionary to be returned. + vrs_id = f"{ht.info.vrs.VRS_Allele_IDs[1].collect()[0]}" + vrs_chrom_id = f"{chrom_dict[chr_in]}" + vrs_start_value = ht.info.vrs.VRS_Starts[1].collect()[0] + vrs_end_value = ht.info.vrs.VRS_Ends[1].collect()[0] + vrs_state_sequence = f"{ht.info.vrs.VRS_States[1].collect()[0]}" + + # Defining the dictionary for VRS information. + vrs_dict = { + "_id": vrs_id, + "location": { + "_id": "to-be-defined", + "interval": { + "end": {"type": "Number", "value": vrs_end_value}, + "start": { + "type": "Number", + "value": vrs_start_value, + }, + "type": "SequenceInterval", + }, + "sequence_id": vrs_chrom_id, + "type": "SequenceLocation", + }, + "state": {"sequence": vrs_state_sequence, "type": "LiteralSequenceExpression"}, + "type": "Allele", + } + + # Set location ID + location_dict = vrs_dict["location"] + location_dict.pop("_id") + location = ga4gh_vrs.models.SequenceLocation(**location_dict) + location_id = ga4gh_core._internal.identifiers.ga4gh_identify(location) + vrs_dict["location"]["_id"] = location_id + + logger.info(vrs_dict) + + # If vrs_only was passed, only return the above dict and stop. + if vrs_only: + return vrs_dict + + # Create a list to then add the dictionaries for frequency reports for + # different ancestry groups to. + list_of_group_info_dicts = [] + + # Define function to return a frequency report dictionary for a given group + def _create_group_dicts( + variant_ht: hl.Table, + group_index: int, + group_id: str, + group_label: str, + group_sex: str = None, + ) -> dict: + """ + Return a dictionary for the frequency information of a given variant for a given subpopulation. + + :param variant_ht: Hail Table with only one row, only containing the desired variant. + :param group_index: Index of frequency within the 'freq' annotation containing the desired group. + :param group_id: String containing variant, genetic ancestry group, and sex (if requested). Example: "chr19-41094895-C-T.afr.XX". + :param group_label: String containing the full name of genetic ancestry group requested. Example: "African/African American". + :param group_sex: String indicating the sex of the group. Example: "XX", or "XY". + :return: Dictionary containing VRS information (and genetic ancestry group if desired) for specified variant. + """ + # Obtain frequency information for the specified variant + group_freq = variant_ht.freq[group_index] + + # Cohort characteristics + characteristics = [] + characteristics.append({"name": "genetic ancestry", "value": group_label}) + if group_sex is not None: + characteristics.append({"name": "biological sex", "value": group_sex}) + + # Dictionary to be returned containing information for a specified group + freq_record = { + "id": f"{variant}.{group_id.upper()}", + "type": "CohortAlleleFrequency", + "label": f"{group_label} Cohort Allele Frequency for {variant}", + "focusAllele": "#/focusAllele", + "focusAlleleCount": group_freq["AC"].collect()[0], + "locusAlleleCount": group_freq["AN"].collect()[0], + "alleleFrequency": group_freq["AF"].collect()[0], + "cohort": {"id": group_id.upper(), "characteristics": characteristics}, + "ancillaryResults": { + "homozygotes": group_freq["homozygote_count"].collect()[0] + }, + } + + return freq_record + + # Iterate through provided groups and generate dictionaries + if ancestry_groups: + for group in ancestry_groups: + key = f"{group}-adj" + index_value = ht.freq_index_dict.get(key) + group_result = _create_group_dicts( + variant_ht=ht, + group_index=index_value, + group_id=group, + group_label=ancestry_groups_dict[group], + ) + + # If specified, stratify group information by sex. + if by_sex: + sex_list = [] + for sex in ["XX", "XY"]: + sex_key = f"{group}-{sex}-adj" + sex_index_value = ht.freq_index_dict.get(sex_key) + sex_label = f"{group}.{sex}" + sex_result = _create_group_dicts( + variant_ht=ht, + group_index=sex_index_value, + group_id=sex_label, + group_label=ancestry_groups_dict[group], + group_sex=sex, + ) + sex_list.append(sex_result) + + group_result["subcohortFrequency"] = sex_list + + list_of_group_info_dicts.append(group_result) + + # Overall frequency, via label 'adj' which is currently stored at + # position #1 (index 0) + overall_freq = ht.freq[0] + + # Final dictionary to be returned + final_freq_dict = { + "id": f"{label_name}{label_version}:{variant}", + "type": "CohortAlleleFrequency", + "label": f"Overall Cohort Allele Frequency for {variant}", + "derivedFrom": { + "id": f"{label_name}{label_version}", + "type": "DataSet", + "label": f"{label_name} v{label_version}", + "version": f"{label_version}", + }, + "focusAllele": vrs_dict, + "focusAlleleCount": overall_freq["AC"].collect()[0], + "locusAlleleCount": overall_freq["AN"].collect()[0], + "alleleFrequency": overall_freq["AF"].collect()[0], + "cohort": {"id": "ALL"}, + "ancillaryResults": { + "homozygotes": overall_freq["homozygote_count"].collect()[0] + }, + } + + # popmax FAF95 + popmax_95 = { + "frequency": ht.popmax.faf95.collect()[0], + "confidenceInterval": 0.95, + "popFreqId": f"{variant}.{ht.popmax.pop.collect()[0].upper()}", + } + final_freq_dict["ancillaryResults"]["popMaxFAF95"] = popmax_95 + + # Read coverage statistics if a table is provdied + if coverage_ht: + coverage_ht = coverage_ht.filter( + coverage_ht.locus + == hl.locus(contig=chr_in, pos=int(pos_in), reference_genome=build_in) + ) + mean_coverage = coverage_ht.mean.collect()[0] + final_freq_dict["ancillaryResults"]["meanDepth"] = mean_coverage + + # If ancestry_groups were passed, add the ancestry group dictionary to the + # final frequency dictionary to be returned. + if ancestry_groups: + final_freq_dict["subcohortFrequency"] = list_of_group_info_dicts + + # Validate that the constructed dictionary will convert to a JSON string. + try: + validated_json = json.dumps(final_freq_dict) + except BaseException: + raise SyntaxError("The dictionary did not convert to a valid JSON") + + # Returns the constructed dictionary. + return final_freq_dict + + +def gks_compute_seqloc_digest(vrs_variant: dict) -> dict: + """ + Compute and set the digest-based id for the sequence location. + + Take a dict of a VRS variant that has a sequence location that does not yet + have the digest-based id computed. Compute it and assign it to .location._id. + + :param vrs_variant: VRS variant dict + :return: VRS variant dict with the location id set to the computed digest-based id + """ + location = vrs_variant["location"] + location.pop("_id") + location_id = ga4gh_core._internal.identifiers.ga4gh_identify( + ga4gh_vrs.models.SequenceLocation(**location) + ) + location["_id"] = location_id + return vrs_variant + + +def gks_compute_seqloc_digest_batch( + ht: hl.Table, + export_tmpfile: str = new_temp_file("gks-seqloc-pre.tsv"), + computed_tmpfile: str = new_temp_file("gks-seqloc-post.tsv"), +): + """ + Compute sequence location digest-based id for a hail variant Table. + + Exports table to tsv, computes SequenceLocation digests, reimports and replaces + the vrs_json field with the result. Input table must have a .vrs field, like the + one added by add_gks_vrs, that can be used to construct ga4gh.vrs models. + + :param ht: hail table with VRS annotation + :param export_tmpfile: file path to export the table to. + :param computed_tmpfile: file path to write the updated rows to, which is then imported as a hail table + :return: a hail table with the VRS annotation updated with the new SequenceLocations + """ + logger.info("Exporting ht to %s", export_tmpfile) + ht.select("vrs_json").export(export_tmpfile, header=True) + + logger.info( + "Computing SequenceLocation digests and writing to %s", computed_tmpfile + ) + start = timer() + counter = 0 + with open(computed_tmpfile, "w", encoding="utf-8") as f_out: + with open(export_tmpfile, "r", encoding="utf-8") as f: + reader = csv.reader(f, delimiter="\t") + header = None + for line in reader: + if header is None: + header = line + f_out.write("\t".join(header)) + f_out.write("\n") + continue + else: + locus, alleles, vrs_json = line + vrs_variant = json.loads(vrs_json) + vrs_variant = gks_compute_seqloc_digest(vrs_variant) + # serialize outputs to JSON and write to TSV + vrs_json = json.dumps(vrs_variant) + alleles = json.dumps(json.loads(alleles)) + f_out.write("\t".join([locus, alleles, vrs_json])) + f_out.write("\n") + counter += 1 + end = timer() + logger.info( + "Computed %s SequenceLocation digests in %s seconds", counter, (end - start) + ) + logger.info("Importing VRS records with computed SequenceLocation digests") + ht_with_location = hl.import_table( + computed_tmpfile, types={"locus": "tstr", "alleles": "tstr", "vrs_json": "tstr"} + ) + ht_with_location_parsed = ht_with_location.annotate( + locus=hl.locus( + contig=ht_with_location.locus.split(":")[0], + pos=hl.int32(ht_with_location.locus.split(":")[1]), + reference_genome="GRCh38", + ), + alleles=hl.parse_json(ht_with_location.alleles, dtype=hl.tarray(hl.tstr)), + ).key_by("locus", "alleles") + + return ht.drop("vrs_json").join(ht_with_location_parsed, how="left") + + +def add_gks_vrs(ht: hl.Table): + """ + Add GKS VRS variant annotation to a hail table. + + Annotates ht with GA4GH GKS VRS structure, except for the variant.location._id, + which must be computed outside Hail. Use gks_compute_seqloc_digest for this. + + ht_out.vrs: Struct of the VRS representation of the variant + ht_out.vrs_json: JSON string representation of the .vrs struct. + """ + build_in = get_reference_genome(ht.locus).name + chr_in = ht.locus.contig + + vrs_chrom_ids_expr = hl.literal(VRS_CHROM_IDS) + chrom_dict = vrs_chrom_ids_expr[build_in] + vrs_id = ht.info.vrs.VRS_Allele_IDs[1] + vrs_chrom_id = chrom_dict[chr_in] + vrs_start_value = ht.info.vrs.VRS_Starts[1] + vrs_end_value = ht.info.vrs.VRS_Ends[1] + vrs_state_sequence = ht.info.vrs.VRS_States[1] + + ht_out = ht.annotate( + vrs=hl.struct( + _id=vrs_id, + type="Allele", + location=hl.struct( + _id="", + type="SequenceLocation", + interval=hl.struct(start=vrs_start_value, end=vrs_end_value), + sequence_id=vrs_chrom_id, + ), + state=hl.struct( + type="LiteralSequenceExpression", sequence=vrs_state_sequence + ), + ) + ) + ht_out = ht_out.annotate(vrs_json=hl.json(ht_out.vrs)) + return ht_out + + +def add_gks_va( + ht: hl.Table, + label_name: str, + label_version: str, + coverage_ht: hl.Table = None, + ancestry_groups: list = None, + ancestry_groups_dict: dict = None, + by_sex: bool = False, +) -> dict: + """ + Add GKS VA annotations to a hail table. + + Annotates the hail table with frequency information conforming to the GKS VA frequency schema. + If ancestry_groups or by_sex is provided, also include subcohort schemas for each cohort. + This annotation is added under the gks_va_freq_dict field of the table. + The focusAllele field is not populated, and must be filled in by the caller. + + :param ht: Hail Table to parse for desired variant. + :param variant: String of variant to search for (chromosome, position, ref, and alt, separated by '-'). Example for a variant in build GRCh38: "chr5-38258681-C-T". + :param label_name: Label name to use within the returned dictionary. Example: "gnomAD". + :param label_version: String listing the version of the HT being used. Example: "3.1.2" . + :param coverage_ht: Hail Table containing coverage statistics, with mean depth stored in "mean" annotation. If None, omit coverage in return. + :param ancestry_groups: List of strings of shortened names of genetic ancestry groups to return results for. Example: ['afr','fin','nfe'] . + :param ancestry_groups_dict: Dict mapping shortened genetic ancestry group names to full names. Example: {'afr':'African/African American'} . + :param by_sex: Boolean to include breakdown of ancestry groups by inferred sex (XX and XY) as well. + :param vrs_only: Boolean to return only the VRS information and no general frequency information. Default is False. + :return: Dictionary containing VRS information (and frequency information split by ancestry groups and sex if desired) for the specified variant. + """ + # Throw warnings if contradictory arguments passed. + if by_sex and not ancestry_groups: + logger.warning( + "Splitting whole database by sex is not yet supported. If using 'by_sex'," + " please also specify 'ancestry_groups' to stratify by." + ) + + ht = ht.annotate( + gnomad_id=hl.format( + "%s-%s-%s-%s", + ht.locus.contig, + ht.locus.position, + ht.alleles[0], + ht.alleles[1], + ) + ) + + # Define function to return a frequency report dictionary for a given group + def _create_group_dicts( + group_index: int, + group_id: str, + group_label: str, + group_sex: str = None, + ) -> dict: + """ + Return a dictionary for the frequency information of a given variant for a given subpopulation. + + :param group_index: Index of frequency within the 'freq' annotation containing the desired group. + :param group_id: String containing variant, genetic ancestry group, and sex (if requested). Example: "chr19-41094895-C-T.afr.XX". + :param group_label: String containing the full name of genetic ancestry group requested. Example: "African/African American". + :param group_sex: String indicating the sex of the group. Example: "XX", or "XY". + :return: Dictionary containing VRS information (and genetic ancestry group if desired) for specified variant. + """ + # Obtain frequency information for the specified variant + group_freq = ht.freq[group_index] + + # Cohort characteristics + characteristics = [] + characteristics.append({"name": "genetic ancestry", "value": group_label}) + if group_sex is not None: + characteristics.append({"name": "biological sex", "value": group_sex}) + + # Dictionary to be returned containing information for a specified group + freq_record = { + "id": hl.format("%s.%s", ht.gnomad_id, group_id.upper()), + "type": "CohortAlleleFrequency", + "label": hl.format( + "%s Cohort Allele Frequency for %s", group_label, ht.gnomad_id + ), + "focusAllele": "#/focusAllele", + "focusAlleleCount": group_freq["AC"], + "locusAlleleCount": group_freq["AN"], + "alleleFrequency": group_freq["AF"], + "cohort": {"id": group_id.upper(), "characteristics": characteristics}, + "ancillaryResults": {"homozygotes": group_freq["homozygote_count"]}, + } + + return freq_record + + # Create a list to then add the dictionaries for frequency reports for + # different ancestry groups to. + list_of_group_info_dicts = [] + + # Iterate through provided groups and generate dictionaries + if ancestry_groups: + for group in ancestry_groups: + key = f"{group}-adj" + index_value = ht.freq_index_dict.get(key) + group_result = _create_group_dicts( + group_index=index_value, + group_id=group, + group_label=ancestry_groups_dict[group], + ) + + # If specified, stratify group information by sex. + if by_sex: + sex_list = [] + for sex in ["XX", "XY"]: + sex_key = f"{group}-{sex}-adj" + sex_index_value = ht.freq_index_dict.get(sex_key) + sex_label = f"{group}.{sex}" + sex_result = _create_group_dicts( + group_index=sex_index_value, + group_id=sex_label, + group_label=ancestry_groups_dict[group], + group_sex=sex, + ) + sex_list.append(sex_result) + + group_result["subcohortFrequency"] = sex_list + + list_of_group_info_dicts.append(group_result) + + # Overall frequency, via label 'adj' which is currently stored at + # position #1 (index 0) + overall_freq = ht.freq[0] + + # Final dictionary to be returned + final_freq_dict = hl.struct( + **{ + "id": hl.format("%s-%s:%s", label_name, label_version, ht.gnomad_id), + "type": "CohortAlleleFrequency", + "label": hl.format("Overall Cohort Allele Frequency for %s", ht.gnomad_id), + "derivedFrom": { + "id": f"{label_name}{label_version}", + "type": "DataSet", + "label": f"{label_name} v{label_version}", + "version": f"{label_version}", + }, + "focusAllele": "", # TODO load from vrs_json table + "focusAlleleCount": overall_freq["AC"], + "locusAlleleCount": overall_freq["AN"], + "alleleFrequency": overall_freq["AF"], + "cohort": {"id": "ALL"}, + } + ) + + ancillaryResults = hl.struct( + homozygotes=overall_freq["homozygote_count"], + popMaxFAF95=hl.struct( + frequency=ht.popmax.faf95, + confidenceInterval=0.95, + popFreqId=hl.format("%s.%s", ht.gnomad_id, ht.popmax.pop.upper()), + ), + ) + + # Read coverage statistics if a table is provided + # NOTE: this is slow, and doing the join outside this function and passing in the joined + # variant ht with the coverage table doesn't help much since the join is resolved dynamically. + # If the mean field was persisted into the variant table it would be faster but this increases + # the table size. + # It could be persisted with something like this, then doing a write out and read back from storage. + # ht_with_cov = ht.annotate( + # meanDepth=coverage_ht[ht.locus].mean + # ) + if coverage_ht is not None: + ancillaryResults = ancillaryResults.annotate( + meanDepth=coverage_ht[ht.locus].mean + ) + + final_freq_dict = final_freq_dict.annotate(ancillaryResults=ancillaryResults) + + # If ancestry_groups were passed, add the ancestry group dictionary to the + # final frequency dictionary to be returned. + if ancestry_groups: + final_freq_dict = final_freq_dict.annotate( + subcohortFrequency=list_of_group_info_dicts + ) + + # Return the hail table with the GKS VA struct added + ht_out = ht.annotate(gks_va_freq_dict=final_freq_dict) + return ht_out diff --git a/gnomad/utils/file_utils.py b/gnomad/utils/file_utils.py index 15cef2ca8..8215dad5b 100644 --- a/gnomad/utils/file_utils.py +++ b/gnomad/utils/file_utils.py @@ -11,7 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import hail as hl -from hailtop.aiogoogle import GoogleStorageAsyncFS +from hailtop.aiocloud.aiogoogle import GoogleStorageAsyncFS from hailtop.aiotools import AsyncFS, LocalAsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.utils import bounded_gather diff --git a/gnomad/utils/filtering.py b/gnomad/utils/filtering.py index 4d78a44ad..101de9c90 100644 --- a/gnomad/utils/filtering.py +++ b/gnomad/utils/filtering.py @@ -3,12 +3,12 @@ import functools import logging import operator -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import hail as hl +import gnomad.utils.annotations as annotate_utils from gnomad.resources.resource_utils import DataException -from gnomad.utils.annotations import annotate_adj from gnomad.utils.reference_genome import get_reference_genome logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s") @@ -19,7 +19,7 @@ def filter_to_adj(mt: hl.MatrixTable) -> hl.MatrixTable: """Filter genotypes to adj criteria.""" if "adj" not in list(mt.entry): - mt = annotate_adj(mt) + mt = annotate_utils.annotate_adj(mt) mt = mt.filter_entries(mt.adj) return mt.drop(mt.adj) @@ -237,7 +237,7 @@ def add_filters_expr( lambda x, y: x.union(y), current_filters, [ - hl.cond(filter_condition, hl.set([filter_name]), hl.empty_set(hl.tstr)) + hl.if_else(filter_condition, hl.set([filter_name]), hl.empty_set(hl.tstr)) for filter_name, filter_condition in filters.items() ], ) @@ -511,3 +511,119 @@ def filter_for_mu( ) return ht + + +def split_vds_by_strata( + vds: hl.vds.VariantDataset, strata_expr: hl.expr.Expression +) -> Dict[str, hl.vds.VariantDataset]: + """ + Split a VDS into multiple VDSs based on `strata_expr`. + + :param vds: Input VDS. + :param strata_expr: Expression on VDS variant_data MT to split on. + :return: Dictionary where strata value is key and VDS is value. + """ + vmt = vds.variant_data + s_by_strata = vmt.aggregate_cols( + hl.agg.group_by(strata_expr, hl.agg.collect_as_set(vmt.s)) + ) + + return { + strata: hl.vds.filter_samples(vds, list(s)) for strata, s in s_by_strata.items() + } + + +def filter_arrays_by_meta( + meta_expr: hl.expr.ArrayExpression, + meta_indexed_exprs: Union[ + Dict[str, hl.expr.ArrayExpression], hl.expr.ArrayExpression + ], + items_to_filter: Union[Dict[str, List[str]], List[str]], + keep: bool = True, + combine_operator: str = "and", +) -> Tuple[ + hl.expr.ArrayExpression, + Union[Dict[str, hl.expr.ArrayExpression], hl.expr.ArrayExpression], +]: + """ + Filter both metadata array expression and meta data indexed expression by `items_to_filter`. + + The `items_to_filter` can be used to filter in the following ways based on + `meta_expr` items: + - By a list of keys, e.g. ["sex", "downsampling"]. + - By specific key: value pairs, e.g. to filter where 'pop' is 'han' or 'papuan' + {"pop": ["han", "papuan"]}, or where 'pop' is 'afr' and/or 'sex' is 'XX' + {"pop": ["afr"], "sex": ["XX"]}. + + The items can be kept or removed from `meta_indexed_expr` and `meta_expr` based on + the value of `keep`. For example if `meta_indexed_exprs` is {'freq': ht.freq, + 'freq_meta_sample_count': ht.index_globals().freq_meta_sample_count} and `meta_expr` + is ht.freq_meta then if `keep` is True, the items specified by `items_to_filter` + such as 'pop' = 'han' will be kept and all other items will be removed from the + ht.freq, ht.freq_meta_sample_count, and ht.freq_meta. `meta_indexed_exprs` can also + be a single array expression such as ht.freq. + + The filtering can also be applied such that all criteria must be met + (`combine_operator` = "and") by the `meta_expr` item in order to be filtered, + or at least one of the specified criteria must be met (`combine_operator` = "or") + by the `meta_expr` item in order to be filtered. + + :param meta_expr: Metadata expression that contains the values of the elements in + `meta_indexed_expr`. The most often used expression is `freq_meta` to index into + a 'freq' array. + :param meta_indexed_expr: Either a Dictionary where the keys are the expression name + and the values are the expressions indexed by the `meta_expr` such as a 'freq' + array or just a single expression indexed by the `meta_expr`. + :param items_to_filter: Items to filter by, either a list or a dictionary. + :param keep: Whether to keep or remove the items specified by `items_to_filter`. + :param combine_operator: Whether to use "and" or "or" to combine the items + specified by `items_to_filter`. + :param meta_based_array_expr: Optional array based on freq meta expression to be filtered. + :return: A Tuple of the filtered metadata expression and a dictionary of metadata + indexed expressions when meta_indexed_expr is a Dictionary or a single filtered + array expression when meta_indexed_expr is a single array expression. + """ + meta_expr = meta_expr.collect(_localize=False)[0] + + if isinstance(meta_indexed_exprs, hl.expr.ArrayExpression): + meta_indexed_exprs = {"_tmp": meta_indexed_exprs} + + if combine_operator == "and": + operator_func = hl.all + elif combine_operator == "or": + operator_func = hl.any + else: + raise ValueError( + "combine_operator must be one of 'and' or 'or', but found" + f" {combine_operator}!" + ) + + if isinstance(items_to_filter, list): + filter_func = lambda m, k: m.contains(k) + items_to_filter = [[k] for k in items_to_filter] + elif isinstance(items_to_filter, dict): + filter_func = lambda m, k: (m.get(k[0], "") == k[1]) + items_to_filter = [ + [(k, v) for v in values] for k, values in items_to_filter.items() + ] + else: + raise TypeError("items_to_filter must be a list or a dictionary!") + + meta_expr = hl.enumerate(meta_expr).filter( + lambda m: hl.bind( + lambda x: hl.if_else(keep, x, ~x), + operator_func( + [hl.any([filter_func(m[1], v) for v in k]) for k in items_to_filter] + ), + ), + ) + + meta_indexed_exprs = { + k: meta_expr.map(lambda x: v[x[0]]) for k, v in meta_indexed_exprs.items() + } + meta_expr = meta_expr.map(lambda x: x[1]) + + if "_tmp" in meta_indexed_exprs: + meta_indexed_exprs = meta_indexed_exprs["_tmp"] + + return meta_expr, meta_indexed_exprs diff --git a/gnomad/utils/gen_stats.py b/gnomad/utils/gen_stats.py index fdc309df1..a25d29cf1 100644 --- a/gnomad/utils/gen_stats.py +++ b/gnomad/utils/gen_stats.py @@ -136,9 +136,9 @@ def add_stats( # If `stdev` is present, then compute it from the variance return agg_stats.select( **{ - metric: agg_stats[metric] - if metric != "stdev" - else hl.sqrt(agg_stats.variance) + metric: ( + agg_stats[metric] if metric != "stdev" else hl.sqrt(agg_stats.variance) + ) for metric in metrics } ) diff --git a/gnomad/utils/release.py b/gnomad/utils/release.py index feb9b1412..61382ab2b 100644 --- a/gnomad/utils/release.py +++ b/gnomad/utils/release.py @@ -1,7 +1,10 @@ # noqa: D100 +import logging from typing import Dict, List, Optional +import hail as hl + from gnomad.resources.grch38.gnomad import ( CURRENT_MAJOR_RELEASE, GROUPS, @@ -9,7 +12,14 @@ SEXES, SUBSETS, ) -from gnomad.utils.vcf import index_globals +from gnomad.utils.vcf import SORT_ORDER, index_globals + +logging.basicConfig( + format="%(asctime)s (%(name)s %(lineno)s): %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def make_faf_index_dict( @@ -92,3 +102,52 @@ def _get_index(label_groups): ) return index_dict + + +def make_freq_index_dict_from_meta( + freq_meta: List[Dict[str, str]], + label_delimiter: str = "_", + sort_order: Optional[List[str]] = SORT_ORDER, +) -> Dict[str, int]: + """ + Create a dictionary for accessing frequency array. + + The dictionary is keyed by the grouping combinations found in the frequency metadata + array, where values are the corresponding 0-based indices for the groupings in the + frequency array. For example, if the `freq_meta` entry [{'pop': 'nfe'}, {'sex': 'XX'}] + corresponds to the 5th entry in the frequency array, the returned dictionary entry + would be {'nfe_XX': 4}. + + :param freq_meta: List of dictionaries containing frequency metadata. + :param label_delimiter: Delimiter to use when joining frequency metadata labels. + :param sort_order: List of frequency metadata labels to use when sorting the dictionary. + :return: Dictionary of frequency metadata. + """ + # Confirm all groups in freq_meta are in sort_order. Warn user if not. + if sort_order is not None: + diff = hl.eval(hl.set(freq_meta.flatmap(lambda i: i.keys()))) - set(sort_order) + if diff: + logger.warning( + "Found unexpected frequency metadata groupings: %s. These groupings" + " are not present in the provided sort_order: %s. These groupings" + " will not be included in the returned dictionary.", + diff, + sort_order, + ) + + index_dict = {} + for i, f in enumerate(hl.eval(freq_meta)): + if sort_order is None or len(set(f.keys()) - set(sort_order)) < 1: + index_dict[ + label_delimiter.join( + [ + f[g] + for g in sorted( + f.keys(), + key=(lambda x: sort_order.index(x)) if sort_order else None, + ) + ] + ) + ] = i + + return index_dict diff --git a/gnomad/utils/sparse_mt.py b/gnomad/utils/sparse_mt.py index 7d1c29b79..286dcbcfa 100644 --- a/gnomad/utils/sparse_mt.py +++ b/gnomad/utils/sparse_mt.py @@ -22,10 +22,19 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -INFO_SUM_AGG_FIELDS = ["QUALapprox"] -INFO_INT32_SUM_AGG_FIELDS = ["VarDP"] -INFO_MEDIAN_AGG_FIELDS = ["ReadPosRankSum", "MQRankSum"] -INFO_ARRAY_SUM_AGG_FIELDS = ["SB", "RAW_MQandDP"] +INFO_AGG_FIELDS = { + "sum_agg_fields": ["QUALapprox"], + "int32_sum_agg_fields": ["VarDP"], + "median_agg_fields": ["ReadPosRankSum", "MQRankSum"], + "array_sum_agg_fields": ["SB", "RAW_MQandDP"], +} + +AS_INFO_AGG_FIELDS = { + "sum_agg_fields": ["AS_QUALapprox", "AS_RAW_MQ"], + "int32_sum_agg_fields": ["AS_VarDP"], + "median_agg_fields": ["AS_RAW_ReadPosRankSum", "AS_RAW_MQRankSum"], + "array_sum_agg_fields": ["AS_SB_TABLE"], +} def compute_last_ref_block_end(mt: hl.MatrixTable) -> hl.Table: @@ -144,38 +153,46 @@ def _get_info_agg_expr( mt: hl.MatrixTable, sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["sum_agg_fields"], int32_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_INT32_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["int32_sum_agg_fields"], median_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_MEDIAN_AGG_FIELDS, + ] = INFO_AGG_FIELDS["median_agg_fields"], array_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.ArrayNumericExpression] - ] = INFO_ARRAY_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["array_sum_agg_fields"], prefix: str = "", + treat_fields_as_allele_specific: bool = False, ) -> Dict[str, hl.expr.Aggregation]: """ Create Aggregators for both site or AS info expression aggregations. .. note:: - - If `SB` is specified in array_sum_agg_fields, it will be aggregated as `AS_SB_TABLE`, according to GATK standard nomenclature. - - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, `median_agg_fields`) are passed as - list of str, then they should correspond to entry fields in `mt` or in `mt.gvcf_info`. - - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in case of a name clash. + - If `SB` is specified in array_sum_agg_fields, it will be aggregated as + `AS_SB_TABLE`, according to GATK standard nomenclature. + - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for + the `MQ` calculation and then dropped according to GATK recommendation. + - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation + and then dropped according to GATK recommendation. + - If the fields to be aggregated (`sum_agg_fields`, `int32_sum_agg_fields`, + `median_agg_fields`) are passed as list of str, then they should correspond + to entry fields in `mt` or in mt.gvcf_info`. + - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in + case of a name clash. :param mt: Input MT :param sum_agg_fields: Fields to aggregate using sum. :param int32_sum_agg_fields: Fields to aggregate using sum using int32. :param median_agg_fields: Fields to aggregate using (approximate) median. - :param median_agg_fields: Fields to aggregate using element-wise summing over an array. + :param array_sum_agg_fields: Fields to aggregate using element-wise summing over an + array. :param prefix: Optional prefix for the fields. Used for adding 'AS_' in the AS case. - - :return: Dictionary of expression names and their corresponding aggregation Expression + :param treat_fields_as_allele_specific: Treat info fields as allele-specific. Defaults to False. + :return: Dictionary of expression names and their corresponding aggregation + Expression. """ def _agg_list_to_dict( @@ -187,7 +204,7 @@ def _agg_list_to_dict( out_fields.update({f: mt[f] for f in fields if f in mt.entry}) - # Check that all fields were found + # Check that all fields were found. missing_fields = [f for f in fields if f not in out_fields] if missing_fields: raise ValueError( @@ -195,9 +212,24 @@ def _agg_list_to_dict( " under mt.gvcf_info: {}".format(",".join(missing_fields)) ) + if treat_fields_as_allele_specific: + # TODO: Change to use hl.vds.local_to_global when fill_value can accept + # missing (error in v0.2.119). + out_fields = { + f: hl.bind( + lambda x: hl.if_else(f == "AS_SB_TABLE", x, x[1:]), + hl.range(hl.len(mt.alleles)).map( + lambda i: hl.or_missing( + mt.LA.contains(i), out_fields[f][mt.LA.index(i)] + ) + ), + ) + for f in fields + } + return out_fields - # Map str to expressions where needed + # Map str to expressions where needed. if isinstance(sum_agg_fields, list): sum_agg_fields = _agg_list_to_dict(mt, sum_agg_fields) @@ -210,33 +242,38 @@ def _agg_list_to_dict( if isinstance(array_sum_agg_fields, list): array_sum_agg_fields = _agg_list_to_dict(mt, array_sum_agg_fields) - # Create aggregators + aggs = [ + (median_agg_fields, lambda x: hl.agg.approx_quantiles(x, 0.5)), + (sum_agg_fields, hl.agg.sum), + (int32_sum_agg_fields, lambda x: hl.int32(hl.agg.sum(x))), + (array_sum_agg_fields, hl.agg.array_sum), + ] + + # Create aggregators. agg_expr = {} + for agg_fields, agg_func in aggs: + for k, expr in agg_fields.items(): + if treat_fields_as_allele_specific: + # If annotation is of the form 'AS_RAW_*_RankSum' it has a histogram + # representation where keys give the per-variant rank sum value to one + # decimal place followed by a comma and the corresponding count for + # that value, so we want to sum the rank sum value (first element). + # Rename annotation in the form 'AS_RAW_*_RankSum' to 'AS_*_RankSum'. + if k.startswith("AS_RAW_") and k.endswith("RankSum"): + agg_expr[f"{prefix}{k.replace('_RAW', '')}"] = hl.agg.array_agg( + lambda x: agg_func(hl.or_missing(hl.is_defined(x), x[0])), expr + ) + else: + agg_expr[f"{prefix}{k}"] = hl.agg.array_agg( + lambda x: agg_func(x), expr + ) + else: + agg_expr[f"{prefix}{k}"] = agg_func(expr) - agg_expr.update( - { - f"{prefix}{k}": hl.agg.approx_quantiles(expr, 0.5) - for k, expr in median_agg_fields.items() - } - ) - agg_expr.update( - {f"{prefix}{k}": hl.agg.sum(expr) for k, expr in sum_agg_fields.items()} - ) - agg_expr.update( - { - f"{prefix}{k}": hl.int32(hl.agg.sum(expr)) - for k, expr in int32_sum_agg_fields.items() - } - ) - agg_expr.update( - { - f"{prefix}{k}": hl.agg.array_agg(lambda x: hl.agg.sum(x), expr) - for k, expr in array_sum_agg_fields.items() - } - ) + if treat_fields_as_allele_specific: + prefix = "AS_" # Handle annotations combinations and casting for specific annotations - # If RAW_MQandDP is in agg_expr or if both MQ_DP and RAW_MQ are, compute MQ instead mq_tuple = None if f"{prefix}RAW_MQandDP" in agg_expr: @@ -246,6 +283,15 @@ def _agg_list_to_dict( *[prefix] * 5, ) mq_tuple = agg_expr.pop(f"{prefix}RAW_MQandDP") + elif "AS_RAW_MQ" in agg_expr and treat_fields_as_allele_specific: + logger.info( + "Computing AS_MQ as sqrt(AS_RAW_MQ[i]/AD[i+1]). " + "Note that AS_MQ will be set to 0 if AS_RAW_MQ == 0." + ) + ad_expr = hl.vds.local_to_global( + mt.LAD, mt.LA, hl.len(mt.alleles), fill_value=0, number="R" + ) + mq_tuple = hl.zip(agg_expr.pop("AS_RAW_MQ"), hl.agg.array_sum(ad_expr[1:])) elif f"{prefix}RAW_MQ" in agg_expr and f"{prefix}MQ_DP" in agg_expr: logger.info( "Computing %sMQ as sqrt(%sRAW_MQ/%sMQ_DP). " @@ -255,9 +301,14 @@ def _agg_list_to_dict( mq_tuple = (agg_expr.pop(f"{prefix}RAW_MQ"), agg_expr.pop(f"{prefix}MQ_DP")) if mq_tuple is not None: - agg_expr[f"{prefix}MQ"] = hl.cond( - mq_tuple[1] > 0, hl.sqrt(mq_tuple[0] / mq_tuple[1]), 0 - ) + if treat_fields_as_allele_specific: + agg_expr[f"{prefix}MQ"] = mq_tuple.map( + lambda x: hl.if_else(x[1] > 0, hl.sqrt(x[0] / x[1]), 0) + ) + else: + agg_expr[f"{prefix}MQ"] = hl.if_else( + mq_tuple[1] > 0, hl.sqrt(mq_tuple[0] / mq_tuple[1]), 0 + ) # If both VarDP and QUALapprox are present, also compute QD. if f"{prefix}VarDP" in agg_expr and f"{prefix}QUALapprox" in agg_expr: @@ -266,15 +317,26 @@ def _agg_list_to_dict( "Note that %sQD will be set to 0 if %sVarDP == 0.", *[prefix] * 5, ) - var_dp = hl.int32(hl.agg.sum(int32_sum_agg_fields["VarDP"])) - agg_expr[f"{prefix}QD"] = hl.cond( - var_dp > 0, agg_expr[f"{prefix}QUALapprox"] / var_dp, 0 - ) + var_dp = agg_expr[f"{prefix}VarDP"] + qual_approx = agg_expr[f"{prefix}QUALapprox"] + if treat_fields_as_allele_specific: + agg_expr[f"{prefix}QD"] = hl.map( + lambda x: hl.if_else(x[1] > 0, x[0] / x[1], 0), + hl.zip(qual_approx, var_dp), + ) + else: + agg_expr[f"{prefix}QD"] = hl.if_else(var_dp > 0, qual_approx / var_dp, 0) - # SB needs to be cast to int32 for FS down the line + # SB needs to be cast to int32 for FS down the line. if f"{prefix}SB" in agg_expr: agg_expr[f"{prefix}SB"] = agg_expr[f"{prefix}SB"].map(lambda x: hl.int32(x)) + # SB needs to be cast to int32 for FS down the line. + if "AS_SB_TABLE" in agg_expr: + agg_expr["AS_SB_TABLE"] = agg_expr["AS_SB_TABLE"].map( + lambda x: x.map(lambda y: hl.int32(y)) + ) + return agg_expr @@ -282,36 +344,52 @@ def get_as_info_expr( mt: hl.MatrixTable, sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["sum_agg_fields"], int32_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_INT32_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["int32_sum_agg_fields"], median_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_MEDIAN_AGG_FIELDS, + ] = INFO_AGG_FIELDS["median_agg_fields"], array_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.ArrayNumericExpression] - ] = INFO_ARRAY_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["array_sum_agg_fields"], alt_alleles_range_array_field: str = "alt_alleles_range_array", + treat_fields_as_allele_specific: bool = False, ) -> hl.expr.StructExpression: """ Return an allele-specific annotation Struct containing typical VCF INFO fields from GVCF INFO fields stored in the MT entries. .. note:: - - If `SB` is specified in array_sum_agg_fields, it will be aggregated as `AS_SB_TABLE`, according to GATK standard nomenclature. - - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, `median_agg_fields`) are passed as list of str, - then they should correspond to entry fields in `mt` or in `mt.gvcf_info`. - - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in case of a name clash. + - If `SB` is specified in array_sum_agg_fields, it will be aggregated as + `AS_SB_TABLE`, according to GATK standard nomenclature. + - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for + the `MQ` calculation and then dropped according to GATK recommendation. + - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation + and then dropped according to GATK recommendation. + - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, + `median_agg_fields`) are passed as list of str, then they should correspond + to entry fields in `mt` or in `mt.gvcf_info`. + - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in + case of a name clash. + - If `treat_fields_as_allele_specific` is False, it's expected that there is a + single value for each entry field to be aggregated. Then when performing the + aggregation per global alternate allele, that value is included in the + aggregation if the global allele is present in the entry's list of local + alleles. If `treat_fields_as_allele_specific` is True, it's expected that + each entry field to be aggregated has one value per local allele, and each + of those is mapped to a global allele for aggregation. :param mt: Input Matrix Table :param sum_agg_fields: Fields to aggregate using sum. :param int32_sum_agg_fields: Fields to aggregate using sum using int32. :param median_agg_fields: Fields to aggregate using (approximate) median. :param array_sum_agg_fields: Fields to aggregate using array sum. - :param alt_alleles_range_array_field: Annotation containing an array of the range of alternate alleles e.g., `hl.range(1, hl.len(mt.alleles))` + :param alt_alleles_range_array_field: Annotation containing an array of the range + of alternate alleles e.g., `hl.range(1, hl.len(mt.alleles))` + :param treat_fields_as_allele_specific: Treat info fields as allele-specific. + Defaults to False. :return: Expression containing the AS info fields """ if "DP" in list(sum_agg_fields) + list(int32_sum_agg_fields): @@ -327,13 +405,10 @@ def get_as_info_expr( int32_sum_agg_fields=int32_sum_agg_fields, median_agg_fields=median_agg_fields, array_sum_agg_fields=array_sum_agg_fields, - prefix="AS_", + prefix="" if treat_fields_as_allele_specific else "AS_", + treat_fields_as_allele_specific=treat_fields_as_allele_specific, ) - # Rename AS_SB to AS_SB_TABLE if present - if "AS_SB" in agg_expr: - agg_expr["AS_SB_TABLE"] = agg_expr.pop("AS_SB") - if alt_alleles_range_array_field not in mt.row or mt[ alt_alleles_range_array_field ].dtype != hl.dtype("array"): @@ -344,29 +419,39 @@ def get_as_info_expr( logger.error(msg) raise ValueError(msg) - # Modify aggregations to aggregate per allele - agg_expr = { - f: hl.agg.array_agg( - lambda ai: hl.agg.filter(mt.LA.contains(ai), expr), - mt[alt_alleles_range_array_field], - ) - for f, expr in agg_expr.items() - } + if not treat_fields_as_allele_specific: + # Modify aggregations to aggregate per allele + agg_expr = { + f: hl.agg.array_agg( + lambda ai: hl.agg.filter(mt.LA.contains(ai), expr), + mt[alt_alleles_range_array_field], + ) + for f, expr in agg_expr.items() + } # Run aggregations info = hl.struct(**agg_expr) - # Add SB Ax2 aggregation logic and FS if SB is present - if "AS_SB_TABLE" in info: - as_sb_table = hl.array( - [ - info.AS_SB_TABLE.filter(lambda x: hl.is_defined(x)).fold( - lambda i, j: i[:2] + j[:2], [0, 0] - ) # ref - ] - ).extend( - info.AS_SB_TABLE.map(lambda x: x[2:]) # each alt - ) + # Add FS and SOR if SB is present. + if "AS_SB_TABLE" in info or "AS_SB" in info: + # Rename AS_SB to AS_SB_TABLE if present and add SB Ax2 aggregation logic. + if "AS_SB" in agg_expr: + if "AS_SB_TABLE" in agg_expr: + logger.warning( + "Both `AS_SB` and `AS_SB_TABLE` were specified for aggregation." + " `AS_SB` will be used for aggregation." + ) + as_sb_table = hl.array( + [ + info.AS_SB.filter(lambda x: hl.is_defined(x)).fold( + lambda i, j: i[:2] + j[:2], [0, 0] + ) # ref + ] + ).extend( + info.AS_SB.map(lambda x: x[2:]) # each alt + ) + else: + as_sb_table = info.AS_SB_TABLE info = info.annotate( AS_SB_TABLE=as_sb_table, AS_FS=hl.range(1, hl.len(mt.alleles)).map( @@ -384,27 +469,31 @@ def get_site_info_expr( mt: hl.MatrixTable, sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["sum_agg_fields"], int32_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_INT32_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["int32_sum_agg_fields"], median_agg_fields: Union[ List[str], Dict[str, hl.expr.NumericExpression] - ] = INFO_MEDIAN_AGG_FIELDS, + ] = INFO_AGG_FIELDS["median_agg_fields"], array_sum_agg_fields: Union[ List[str], Dict[str, hl.expr.ArrayNumericExpression] - ] = INFO_ARRAY_SUM_AGG_FIELDS, + ] = INFO_AGG_FIELDS["array_sum_agg_fields"], ) -> hl.expr.StructExpression: """ Create a site-level annotation Struct aggregating typical VCF INFO fields from GVCF INFO fields stored in the MT entries. .. note:: - - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation and then dropped according to GATK recommendation. - - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, `median_agg_fields`) are passed as - list of str, then they should correspond to entry fields in `mt` or in `mt.gvcf_info`. - - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in case of a name clash. + - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for + the `MQ` calculation and then dropped according to GATK recommendation. + - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation + and then dropped according to GATK recommendation. + - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, + `median_agg_fields`) are passed as list of str, then they should correspond + to entry fields in `mt` or in `mt.gvcf_info`. + - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in + case of a name clash. :param mt: Input Matrix Table :param sum_agg_fields: Fields to aggregate using sum. @@ -427,7 +516,7 @@ def get_site_info_expr( ) # Add FS and SOR if SB is present - # This is done outside of _get_info_agg_expr as the behavior is different + # This is done outside _get_info_agg_expr as the behavior is different # in site vs allele-specific versions if "SB" in agg_expr: agg_expr["FS"] = fs_from_sb(agg_expr["SB"]) @@ -449,7 +538,10 @@ def get_site_info_expr( def default_compute_info( mt: hl.MatrixTable, site_annotations: bool = False, - n_partitions: int = 5000, + as_annotations: bool = False, + # Set to True by default to prevent a breaking change. + quasi_as_annotations: bool = True, + n_partitions: Optional[int] = 5000, lowqual_indel_phred_het_prior: int = 40, ac_filter_groups: Optional[Dict[str, hl.Expression]] = None, ) -> hl.Table: @@ -458,11 +550,21 @@ def default_compute_info( .. note:: - This table doesn't split multi-allelic sites. + - This table doesn't split multi-allelic sites. + - At least one of `site_annotations`, `as_annotations` or `quasi_as_annotations` + must be True. :param mt: Input MatrixTable. Note that this table should be filtered to nonref sites. - :param site_annotations: Whether to also generate site level info fields. Default is False. - :param n_partitions: Number of desired partitions for output Table. Default is 5000. + :param site_annotations: Whether to generate site level info fields. Default is False. + :param as_annotations: Whether to generate allele-specific info fields using + allele-specific annotations in gvcf_info. Default is False. + :param quasi_as_annotations: Whether to generate allele-specific info fields using + non-allele-specific annotations in gvcf_info, but performing per allele + aggregations. This method can be used in cases where genotype data doesn't + contain allele-specific annotations to approximate allele-specific annotations. + Default is True. + :param n_partitions: Optional number of desired partitions for output Table. If + specified, naive_coalesce is performed. Default is 5000. :param lowqual_indel_phred_het_prior: Phred-scaled prior for a het genotype at a site with a low quality indel. Default is 40. We use 1/10k bases (phred=40) to be more consistent with the filtering used by Broad's Data Sciences Platform @@ -472,26 +574,52 @@ def default_compute_info( :return: Table with info fields :rtype: Table """ + if not site_annotations and not as_annotations and not quasi_as_annotations: + raise ValueError( + "At least one of `site_annotations`, `as_annotations`, or " + "`quasi_as_annotations` must be True!" + ) + # Add a temporary annotation for allele count groupings. ac_filter_groups = {"": True, **(ac_filter_groups or {})} mt = mt.annotate_cols(_ac_filter_groups=ac_filter_groups) - # Move gvcf info entries out from nested struct + # Move gvcf info entries out from nested struct. mt = mt.transmute_entries(**mt.gvcf_info) # Adding alt_alleles_range_array as a required annotation for - # get_as_info_expr to reduce memory usage + # get_as_info_expr to reduce memory usage. mt = mt.annotate_rows(alt_alleles_range_array=hl.range(1, hl.len(mt.alleles))) - # Compute AS info expr - info_expr = get_as_info_expr(mt) + info_expr = None + quasi_info_expr = None + + # Compute quasi-AS info expr. + if quasi_as_annotations: + info_expr = get_as_info_expr(mt) + + # Compute AS info expr using gvcf_info allele specific annotations. + if as_annotations: + if info_expr is not None: + quasi_info_expr = info_expr + info_expr = get_as_info_expr( + mt, + **AS_INFO_AGG_FIELDS, + treat_fields_as_allele_specific=True, + ) + + if info_expr is not None: + # Add allele specific pab_max + info_expr = info_expr.annotate( + AS_pab_max=pab_max_expr(mt.LGT, mt.LAD, mt.LA, hl.len(mt.alleles)) + ) - # Add allele specific pab_max - info_expr = info_expr.annotate( - AS_pab_max=pab_max_expr(mt.LGT, mt.LAD, mt.LA, hl.len(mt.alleles)) - ) if site_annotations: - info_expr = info_expr.annotate(**get_site_info_expr(mt)) + site_expr = get_site_info_expr(mt) + if info_expr is None: + info_expr = site_expr + else: + info_expr = info_expr.annotate(**site_expr) # Add 'AC' and 'AC_raw' for each allele count filter group requested. # First compute ACs for each non-ref allele, grouped by adj. @@ -518,18 +646,22 @@ def default_compute_info( # 'AC_raw' as the sum of adj and non-adj groups info_expr = info_expr.annotate( **{ - f"AC{'_'+f if f else f}_raw": grp.map( + f"AC{'_' + f if f else f}_raw": grp.map( lambda i: hl.int32(i.get(True, 0) + i.get(False, 0)) ) for f, grp in grp_ac_expr.items() }, **{ - f"AC{'_'+f if f else f}": grp.map(lambda i: hl.int32(i.get(True, 0))) + f"AC{'_' + f if f else f}": grp.map(lambda i: hl.int32(i.get(True, 0))) for f, grp in grp_ac_expr.items() }, ) - info_ht = mt.select_rows(info=info_expr).rows() + ann_expr = {"info": info_expr} + if quasi_info_expr is not None: + ann_expr["quasi_info"] = quasi_info_expr + + info_ht = mt.select_rows(**ann_expr).rows() # Add AS lowqual flag info_ht = info_ht.annotate( @@ -550,7 +682,10 @@ def default_compute_info( ) ) - return info_ht.naive_coalesce(n_partitions) + if n_partitions is not None: + info_ht = info_ht.naive_coalesce(n_partitions) + + return info_ht def split_info_annotation( @@ -724,21 +859,22 @@ def get_chr_dp_ann(chrom: str) -> hl.Table: f"{chrom}_mean_dp": hl.agg.filter( chr_mt.LGT.is_non_ref(), hl.agg.sum(chr_mt.DP), - ) - / hl.agg.filter(chr_mt.LGT.is_non_ref(), hl.agg.count()) + ) / hl.agg.filter(chr_mt.LGT.is_non_ref(), hl.agg.count()) } ).cols() else: return chr_mt.select_cols( **{ - f"{chrom}_mean_dp": hl.agg.sum( - hl.if_else( - chr_mt.LGT.is_hom_ref(), - chr_mt.DP * (1 + chr_mt.END - chr_mt.locus.position), - chr_mt.DP, + f"{chrom}_mean_dp": ( + hl.agg.sum( + hl.if_else( + chr_mt.LGT.is_hom_ref(), + chr_mt.DP * (1 + chr_mt.END - chr_mt.locus.position), + chr_mt.DP, + ) ) + / contig_size ) - / contig_size } ).cols() @@ -753,10 +889,12 @@ def get_chr_dp_ann(chrom: str) -> hl.Table: return ht.annotate( **{ - f"{chr_x}_ploidy": ht[f"{chr_x}_mean_dp"] - / (ht[f"{normalization_contig}_mean_dp"] / 2), - f"{chr_y}_ploidy": ht[f"{chr_y}_mean_dp"] - / (ht[f"{normalization_contig}_mean_dp"] / 2), + f"{chr_x}_ploidy": ht[f"{chr_x}_mean_dp"] / ( + ht[f"{normalization_contig}_mean_dp"] / 2 + ), + f"{chr_y}_ploidy": ht[f"{chr_y}_mean_dp"] / ( + ht[f"{normalization_contig}_mean_dp"] / 2 + ), } ) @@ -845,7 +983,7 @@ def compute_coverage_stats( # Annotate rows now return mt.select_rows( - mean=hl.cond(hl.is_nan(mean_expr), 0, mean_expr), + mean=hl.if_else(hl.is_nan(mean_expr), 0, mean_expr), median_approx=hl.or_else(hl.agg.approx_median(hl.or_else(mt.DP, 0)), 0), total_DP=hl.agg.sum(mt.DP), **{ diff --git a/gnomad/utils/vep.py b/gnomad/utils/vep.py index 47741fbe4..8b006f18e 100644 --- a/gnomad/utils/vep.py +++ b/gnomad/utils/vep.py @@ -15,6 +15,12 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +VEP_VERSIONS = ["101", "105"] +CURRENT_VEP_VERSION = VEP_VERSIONS[-1] +""" +Versions of VEP used in gnomAD data, the latest version is 105. +""" + # Note that this is the current as of v81 with some included for backwards # compatibility (VEP <= 75) CSQ_CODING_HIGH_IMPACT = [ @@ -85,12 +91,18 @@ Constant that contains the local path to the VEP config file """ -VEP_CSQ_FIELDS = "Allele|Consequence|IMPACT|SYMBOL|Gene|Feature_type|Feature|BIOTYPE|EXON|INTRON|HGVSc|HGVSp|cDNA_position|CDS_position|Protein_position|Amino_acids|Codons|ALLELE_NUM|DISTANCE|STRAND|VARIANT_CLASS|MINIMISED|SYMBOL_SOURCE|HGNC_ID|CANONICAL|TSL|APPRIS|CCDS|ENSP|SWISSPROT|TREMBL|UNIPARC|GENE_PHENO|SIFT|PolyPhen|DOMAINS|HGVS_OFFSET|MOTIF_NAME|MOTIF_POS|HIGH_INF_POS|MOTIF_SCORE_CHANGE|LoF|LoF_filter|LoF_flags|LoF_info" +VEP_CSQ_FIELDS = { + "101": "Allele|Consequence|IMPACT|SYMBOL|Gene|Feature_type|Feature|BIOTYPE|EXON|INTRON|HGVSc|HGVSp|cDNA_position|CDS_position|Protein_position|Amino_acids|Codons|ALLELE_NUM|DISTANCE|STRAND|VARIANT_CLASS|MINIMISED|SYMBOL_SOURCE|HGNC_ID|CANONICAL|TSL|APPRIS|CCDS|ENSP|SWISSPROT|TREMBL|UNIPARC|GENE_PHENO|SIFT|PolyPhen|DOMAINS|HGVS_OFFSET|MOTIF_NAME|MOTIF_POS|HIGH_INF_POS|MOTIF_SCORE_CHANGE|LoF|LoF_filter|LoF_flags|LoF_info", + "105": "Allele|Consequence|IMPACT|SYMBOL|Gene|Feature_type|Feature|BIOTYPE|EXON|INTRON|HGVSc|HGVSp|cDNA_position|CDS_position|Protein_position|Amino_acids|Codons|ALLELE_NUM|DISTANCE|STRAND|FLAGS|VARIANT_CLASS|SYMBOL_SOURCE|HGNC_ID|CANONICAL|MANE_SELECT|MANE_PLUS_CLINICAL|TSL|APPRIS|CCDS|ENSP|UNIPROT_ISOFORM|SOURCE|SIFT|PolyPhen|DOMAINS|miRNA|HGVS_OFFSET|PUBMED|MOTIF_NAME|MOTIF_POS|HIGH_INF_POS|MOTIF_SCORE_CHANGE|TRANSCRIPTION_FACTORS|LoF|LoF_filter|LoF_flags|LoF_info", +} """ -Constant that defines the order of VEP annotations used in VCF export. +Constant that defines the order of VEP annotations used in VCF export, currently stored in a dictionary with the VEP version as the key. """ -VEP_CSQ_HEADER = f"Consequence annotations from Ensembl VEP. Format: {VEP_CSQ_FIELDS}" +VEP_CSQ_HEADER = ( + "Consequence annotations from Ensembl VEP. Format:" + f" {VEP_CSQ_FIELDS[CURRENT_VEP_VERSION]}" +) """ Constant that contains description for VEP used in VCF export. """ @@ -388,7 +400,8 @@ def filter_vep_to_synonymous_variants( def vep_struct_to_csq( - vep_expr: hl.expr.StructExpression, csq_fields: str = VEP_CSQ_FIELDS + vep_expr: hl.expr.StructExpression, + csq_fields: str = VEP_CSQ_FIELDS[CURRENT_VEP_VERSION], ) -> hl.expr.ArrayExpression: """ Given a VEP Struct, returns and array of VEP VCF CSQ strings (one per consequence in the struct). @@ -403,7 +416,7 @@ def vep_struct_to_csq( hl.str(), so it may differ from their usual VEP CSQ representation. :param vep_expr: The input VEP Struct - :param csq_fields: The | delimited list of fields to include in the CSQ (in that order) + :param csq_fields: The | delimited list of fields to include in the CSQ (in that order), default is the CSQ fields of the CURRENT_VEP_VERSION. :return: The corresponding CSQ strings """ _csq_fields = [f.lower() for f in csq_fields.split("|")] @@ -423,11 +436,15 @@ def get_csq_from_struct( "feature": ( element.transcript_id if "transcript_id" in element - else element.regulatory_feature_id - if "regulatory_feature_id" in element - else element.motif_feature_id - if "motif_feature_id" in element - else "" + else ( + element.regulatory_feature_id + if "regulatory_feature_id" in element + else ( + element.motif_feature_id + if "motif_feature_id" in element + else "" + ) + ) ), "variant_class": vep_expr.variant_class, } @@ -437,37 +454,40 @@ def get_csq_from_struct( if feature_type == "Transcript": fields.update( { - "canonical": hl.cond(element.canonical == 1, "YES", ""), + "canonical": hl.if_else(element.canonical == 1, "YES", ""), "ensp": element.protein_id, "gene": element.gene_id, "symbol": element.gene_symbol, "symbol_source": element.gene_symbol_source, - "cdna_position": hl.str(element.cdna_start) - + hl.cond( + "cdna_position": hl.str(element.cdna_start) + hl.if_else( element.cdna_start == element.cdna_end, "", "-" + hl.str(element.cdna_end), ), - "cds_position": hl.str(element.cds_start) - + hl.cond( + "cds_position": hl.str(element.cds_start) + hl.if_else( element.cds_start == element.cds_end, "", "-" + hl.str(element.cds_end), ), - "protein_position": hl.str(element.protein_start) - + hl.cond( + "mirna": hl.delimit(element.mirna, "&"), + "protein_position": hl.str(element.protein_start) + hl.if_else( element.protein_start == element.protein_end, "", "-" + hl.str(element.protein_end), ), - "sift": element.sift_prediction - + "(" - + hl.format("%.3f", element.sift_score) - + ")", - "polyphen": element.polyphen_prediction - + "(" - + hl.format("%.3f", element.polyphen_score) - + ")", + "uniprot_isoform": hl.delimit(element.uniprot_isoform, "&"), + "sift": ( + element.sift_prediction + + "(" + + hl.format("%.3f", element.sift_score) + + ")" + ), + "polyphen": ( + element.polyphen_prediction + + "(" + + hl.format("%.3f", element.polyphen_score) + + ")" + ), "domains": hl.delimit( element.domains.map(lambda d: d.db + ":" + d.name), "&" ), @@ -475,6 +495,9 @@ def get_csq_from_struct( ) elif feature_type == "MotifFeature": fields["motif_score_change"] = hl.format("%.3f", element.motif_score_change) + fields["transcription_factors"] = hl.delimit( + element.transcription_factors, "&" + ) return hl.delimit( [hl.or_else(hl.str(fields.get(f, "")), "") for f in _csq_fields], "|" diff --git a/gnomad/variant_qc/evaluation.py b/gnomad/variant_qc/evaluation.py index e9cbfd0e1..b62ac3457 100644 --- a/gnomad/variant_qc/evaluation.py +++ b/gnomad/variant_qc/evaluation.py @@ -66,7 +66,7 @@ def compute_ranked_bin( if compute_snv_indel_separately: # For each bin, add a SNV / indel stratification bin_expr = { - f"{bin_id}_{snv}": (bin_expr & snv_expr) + f"{bin_id}_{snv}": bin_expr & snv_expr for bin_id, bin_expr in bin_expr.items() for snv, snv_expr in [ ("snv", hl.is_snp(ht.alleles[0], ht.alleles[1])), diff --git a/gnomad/variant_qc/random_forest.py b/gnomad/variant_qc/random_forest.py index 05179ce70..ce00e5a68 100644 --- a/gnomad/variant_qc/random_forest.py +++ b/gnomad/variant_qc/random_forest.py @@ -45,7 +45,7 @@ def run_rf_test( ) mt = mt.annotate_rows( - label=hl.cond(mt["feature1"] & (mt["feature2"] > 0), "TP", "FP") + label=hl.if_else(mt["feature1"] & (mt["feature2"] > 0), "TP", "FP") ) ht = mt.rows() diff --git a/pyproject.toml b/pyproject.toml index 66a8b1ec2..0114d39eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -target-version = ['py36', 'py37', 'py38'] +target-version = ['py39', 'py310', 'py311'] preview = true [tool.isort] profile = "black" diff --git a/requirements-dev.in b/requirements-dev.in index f43bc5047..b8d3e217e 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,7 +1,7 @@ # Used to create the requirements-dev.txt file # Note: `pip-compile requirements-dev.in` needs to be run before changes in requirements-dev.in are reflected in # requirements-dev.txt, which is used for GitHub Actions -black==22.10.0 # This should be kept in sync with the version in .pre-commit-config.yaml +black==23.7.0 # This should be kept in sync with the version in .pre-commit-config.yaml isort==5.12.0 # This should be kept in sync with the version in .pre-commit-config.yaml autopep8==2.0.2 # This should be kept in sync with the version in .pre-commit-config.yaml pre-commit diff --git a/requirements-dev.txt b/requirements-dev.txt index ff1a4d764..77146ebd2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,91 +1,74 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: # # pip-compile requirements-dev.in # -astroid==2.15.1 +astroid==2.15.6 # via pylint -attrs==21.4.0 - # via pytest autopep8==2.0.2 # via -r requirements-dev.in -backports-entry-points-selectable==1.1.1 - # via virtualenv -black==22.10.0 +black==23.7.0 # via -r requirements-dev.in -cfgv==3.3.1 +cfgv==3.4.0 # via pre-commit -click==8.0.3 +click==8.1.6 # via black -dill==0.3.6 +dill==0.3.7 # via pylint -distlib==0.3.3 +distlib==0.3.7 # via virtualenv -filelock==3.4.0 +filelock==3.12.2 # via virtualenv -identify==2.4.0 +identify==2.5.26 # via pre-commit -iniconfig==1.1.1 +iniconfig==2.0.0 # via pytest isort==5.12.0 # via # -r requirements-dev.in # pylint -lazy-object-proxy==1.6.0 +lazy-object-proxy==1.9.0 # via astroid -mccabe==0.6.1 +mccabe==0.7.0 # via pylint -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via black -nodeenv==1.6.0 +nodeenv==1.8.0 # via pre-commit -packaging==23.0 - # via pytest -pathspec==0.9.0 +packaging==23.1 + # via + # black + # pytest +pathspec==0.11.2 # via black -platformdirs==2.4.0 +platformdirs==3.10.0 # via # black # pylint # virtualenv -pluggy==1.0.0 +pluggy==1.2.0 # via pytest -pre-commit==2.15.0 +pre-commit==3.3.3 # via -r requirements-dev.in -py==1.11.0 - # via pytest -pycodestyle==2.10.0 +pycodestyle==2.11.0 # via autopep8 pydocstyle==6.3.0 # via -r requirements-dev.in -pylint==2.17.1 +pylint==2.17.5 # via -r requirements-dev.in -pytest==6.2.5 +pytest==7.4.0 # via -r requirements-dev.in -pyyaml==6.0 +pyyaml==6.0.1 # via pre-commit -six==1.16.0 - # via virtualenv snowballstemmer==2.2.0 # via pydocstyle -toml==0.10.2 - # via - # pre-commit - # pytest -tomli==2.0.1 - # via - # autopep8 - # black - # pylint -tomlkit==0.11.7 +tomlkit==0.12.1 # via pylint -typing-extensions==4.0.0 - # via - # astroid - # black - # pylint -virtualenv==20.10.0 +virtualenv==20.24.3 # via pre-commit -wrapt==1.13.3 +wrapt==1.15.0 # via astroid + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.txt b/requirements.txt index 60d326875..97bc9a199 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ annoy +ga4gh.vrs[extras] hail hdbscan ipywidgets networkx +onnx +onnxruntime scikit-learn +skl2onnx slackclient==2.5.0 diff --git a/tests/resources/test_resource_utils.py b/tests/resources/test_resource_utils.py index fd3bc0320..fac02b9f3 100644 --- a/tests/resources/test_resource_utils.py +++ b/tests/resources/test_resource_utils.py @@ -188,17 +188,20 @@ def test_default_source_from_environment_overrides_cloud_spark_provider(self): Make sure the environment variables is used over the one for the current cloud Spark provider. """ - with patch( - "hail.utils.guess_cloud_spark_provider", - return_value="hdinsight", - create=True, - ), patch.dict( - os.environ, - { - "GNOMAD_DEFAULT_PUBLIC_RESOURCE_SOURCE": ( - "gs://my-bucket/gnomad-resources" - ) - }, + with ( + patch( + "hail.utils.guess_cloud_spark_provider", + return_value="hdinsight", + create=True, + ), + patch.dict( + os.environ, + { + "GNOMAD_DEFAULT_PUBLIC_RESOURCE_SOURCE": ( + "gs://my-bucket/gnomad-resources" + ) + }, + ), ): assert ( get_default_public_resource_source()