Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions gnomad/utils/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression:


def annotate_mutation_type(
t: Union[hl.MatrixTable, hl.Table]
t: Union[hl.MatrixTable, hl.Table],
context_length: Optional[int] = None,
num_scan_context_length: Optional[int] = 100,
) -> Union[hl.MatrixTable, hl.Table]:
"""
Annotate mutation types.
Expand All @@ -310,20 +312,36 @@ def annotate_mutation_type(
used in this repo to indicate a variant's multiallelic and SNP/indel status.

:param t: Input Table or MatrixTable.
:param context_length: Length of the 'context' annotation in 't'. If this is not
specified, the value will be determined by examining the first
`num_scan_context_length` values of the 'context' annotation. Default is None.
:param num_scan_context_length: Number of values in the 'context' annotation to use
for determining `context_length` if it is not specified. If set to None, all
values in 'context' will be used. Default is 100.
:return: Table with mutation type annotations added.
"""
# Determine the context length by collecting all the context lengths.
context_lengths = list(
filter(None, t.aggregate(hl.agg.collect_as_set(hl.len(t.context))))
)
if len(context_lengths) > 1:
raise ValueError(
"More than one length was found among the first 100 'context' values."
" Length of 'context' should be consistent."
)
else:
context_length = context_lengths[0]
logger.info("Detected a length of %d for context length", context_length)
if context_length is None:
# Determine the context length by collecting all the context lengths.
if num_scan_context_length is None:
context_lengths = t.aggregate(hl.agg.collect_as_set(hl.len(t.context)))
msg = "all"
else:
context_lengths = hl.len(t.context).take(num_scan_context_length)
msg = f"the first {num_scan_context_length}"
context_lengths = list(filter(None, set(context_lengths)))
if len(context_lengths) > 1:
raise ValueError(
f"More than one length was found among {msg} 'context' values. Length "
"of 'context' should be consistent.",
)
else:
context_length = context_lengths[0]
logger.info(
"Detected a length of %d for context length using %s 'context' values.",
context_length,
msg,
)

# Determine the middle index of the context annotation.
if context_length == 3:
mid_index = 1
Expand All @@ -348,10 +366,15 @@ def annotate_mutation_type(
else:
t = t.annotate(transition=transition_expr, cpg=cpg_expr)
mutation_type_expr = (
hl.case()
.when(t.cpg, "CpG")
.when(t.transition, "non-CpG transition")
.default("transversion")
hl.switch(hl.len(t.context))
.when(
context_length,
hl.case()
.when(t.cpg, "CpG")
.when(t.transition, "non-CpG transition")
.default("transversion"),
)
.or_error("Found 'context' value with unexpected context length!")
)
mutation_type_model_expr = hl.if_else(t.cpg, t.context, "non-CpG")
if isinstance(t, hl.MatrixTable):
Expand Down