diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index e2fd30e0b..e71ed73a6 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -6,6 +6,7 @@ import hail as hl +import gnomad.utils.filtering as filter_utils from gnomad.utils.gen_stats import to_phred logging.basicConfig( @@ -1795,3 +1796,65 @@ def _agg_by_group( ) return ht.drop("cols") + + +def update_structured_annotations( + ht: hl.Table, + annotation_update_exprs: Dict[str, hl.Expression], + annotation_update_label: Optional[str] = None, +) -> hl.Table: + """ + Update highly structured annotations on a Table. + + This function recursively updates annotations defined by `annotation_update_exprs` + and if `annotation_update_label` is supplied, it checks if the sample annotations + are different from the input and adds a flag to the Table, indicating which + annotations have been updated for each sample. + + :param ht: Input Table with structured annotations to update. + :param annotation_update_exprs: Dictionary of annotations to update, structured as + they are structured on the input `ht`. + :param annotation_update_label: Optional string of the label to use for an + annotation indicating which annotations have been updated. Default is None, so + no annotation is added. + :return: Table with updated annotations and optionally a flag indicating which + annotations were changed. + """ + + def _update_struct( + struct_expr: hl.expr.StructExpression, + update_exprs: Union[Dict[str, hl.expr.Expression], hl.expr.Expression], + ) -> Tuple[Dict[str, hl.expr.BooleanExpression], Any]: + """ + Update a StructExpression. + + :param struct_expr: StructExpression to update. + :param update_exprs: Dictionary of annotations to update. + :return: Tuple of the updated annotations and the updated flag. + """ + if isinstance(update_exprs, dict): + updated_struct_expr = {} + updated_flag_expr = {} + for ann, expr in update_exprs.items(): + updated_flag, updated_ann = _update_struct(struct_expr[ann], expr) + updated_flag_expr.update( + {ann + ("." + k if k else ""): v for k, v in updated_flag.items()} + ) + updated_struct_expr[ann] = updated_ann + return updated_flag_expr, struct_expr.annotate(**updated_struct_expr) + else: + return {"": update_exprs != struct_expr}, update_exprs + + annotation_update_flag, updated_rows = _update_struct( + ht.row_value, annotation_update_exprs + ) + if annotation_update_label is not None: + updated_rows = updated_rows.annotate( + **{ + annotation_update_label: filter_utils.add_filters_expr( + filters=annotation_update_flag + ) + } + ) + + return ht.annotate(**updated_rows)