Skip to content
63 changes: 63 additions & 0 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import hail as hl

import gnomad.utils.filtering as filter_utils
from gnomad.utils.gen_stats import to_phred

logging.basicConfig(
Expand Down Expand Up @@ -1795,3 +1796,65 @@ def _agg_by_group(
)

return ht.drop("cols")


def update_structured_annotations(
ht: hl.Table,
annotation_update_exprs: Dict[str, hl.Expression],
annotation_update_label: Optional[str] = None,
) -> hl.Table:
"""
Update highly structured annotations on a Table.

This function recursively updates annotations defined by `annotation_update_exprs`
and if `annotation_update_label` is supplied, it checks if the sample annotations
are different from the input and adds a flag to the Table, indicating which
annotations have been updated for each sample.

:param ht: Input Table with structured annotations to update.
:param annotation_update_exprs: Dictionary of annotations to update, structured as
they are structured on the input `ht`.
:param annotation_update_label: Optional string of the label to use for an
annotation indicating which annotations have been updated. Default is None, so
no annotation is added.
:return: Table with updated annotations and optionally a flag indicating which
annotations were changed.
"""

def _update_struct(
struct_expr: hl.expr.StructExpression,
update_exprs: Union[Dict[str, hl.expr.Expression], hl.expr.Expression],
) -> Tuple[Dict[str, hl.expr.BooleanExpression], Any]:
"""
Update a StructExpression.

:param struct_expr: StructExpression to update.
:param update_exprs: Dictionary of annotations to update.
:return: Tuple of the updated annotations and the updated flag.
"""
if isinstance(update_exprs, dict):
updated_struct_expr = {}
updated_flag_expr = {}
for ann, expr in update_exprs.items():
updated_flag, updated_ann = _update_struct(struct_expr[ann], expr)
updated_flag_expr.update(
{ann + ("." + k if k else ""): v for k, v in updated_flag.items()}
)
updated_struct_expr[ann] = updated_ann
return updated_flag_expr, struct_expr.annotate(**updated_struct_expr)
else:
return {"": update_exprs != struct_expr}, update_exprs

annotation_update_flag, updated_rows = _update_struct(
ht.row_value, annotation_update_exprs
)
if annotation_update_label is not None:
updated_rows = updated_rows.annotate(
**{
annotation_update_label: filter_utils.add_filters_expr(
filters=annotation_update_flag
)
}
)

return ht.annotate(**updated_rows)