Skip to content

Commit c9ea50c

Browse files
committed
Extract strata aggregation into it's own function and use in compute_freq_by_strata
1 parent b9f65e1 commit c9ea50c

File tree

1 file changed

+179
-66
lines changed

1 file changed

+179
-66
lines changed

gnomad/utils/annotations.py

Lines changed: 179 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,25 +1861,19 @@ def generate_freq_group_membership_array(
18611861
return ht
18621862

18631863

1864-
def compute_freq_by_strata(
1864+
def agg_by_strata(
18651865
mt: hl.MatrixTable,
18661866
entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None,
18671867
select_fields: Optional[List[str]] = None,
1868-
group_membership_includes_raw_group: bool = True,
1868+
group_membership_ht: Optional[hl.Table] = None,
18691869
) -> hl.Table:
18701870
"""
1871-
Compute call statistics and, when passed, entry aggregation function(s) by strata.
1872-
1873-
The computed call statistics are AC, AF, AN, and homozygote_count. The entry
1874-
aggregation functions are applied to the MatrixTable entries and aggregated. The
1875-
MatrixTable must contain a 'group_membership' annotation (like the one added by
1876-
`generate_freq_group_membership_array`) that is a list of bools to aggregate the
1877-
columns by.
1871+
Get row expression for annotations of each entry aggregation function(s) by strata.
18781872
1879-
.. note::
1880-
This function is primarily used through `annotate_freq` but can be used
1881-
independently if desired. Please see the `annotate_freq` function for more
1882-
complete documentation.
1873+
The entry aggregation functions are applied to the MatrixTable entries and
1874+
aggregated. If no `group_membership_ht` (like the one returned by
1875+
`generate_freq_group_membership_array`) is supplied, `mt` must contain a
1876+
'group_membership' annotation that is a list of bools to aggregate the columns by.
18831877
18841878
:param mt: Input MatrixTable.
18851879
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
@@ -1890,15 +1884,9 @@ def compute_freq_by_strata(
18901884
function.
18911885
:param select_fields: Optional list of row fields from `mt` to keep on the output
18921886
Table.
1893-
:param group_membership_includes_raw_group: Whether the 'group_membership'
1894-
annotation includes an entry for the 'raw' group, representing all samples. If
1895-
False, the 'raw' group is inserted as the second element in all added
1896-
annotations using the same 'group_membership', resulting
1897-
in array lengths of 'group_membership'+1. If True, the second element of each
1898-
added annotation is still the 'raw' group, but the group membership is
1899-
determined by the values in the second element of 'group_membership', and the
1900-
output annotations will be the same length as 'group_membership'. Default is
1901-
True.
1887+
:param group_membership_ht: Optional Table containing group membership annotations
1888+
to stratify the coverage stats by. If not provided, the 'group_membership'
1889+
annotation is expected to be present on `mt`.
19021890
:return: Table or MatrixTable with allele frequencies by strata.
19031891
"""
19041892
if entry_agg_funcs is None:
@@ -1907,79 +1895,204 @@ def compute_freq_by_strata(
19071895
select_fields = []
19081896

19091897
n_samples = mt.count_cols()
1910-
n_groups = len(mt.group_membership.take(1)[0])
1911-
ht = mt.localize_entries("entries", "cols")
1912-
ht = ht.annotate_globals(
1913-
indices_by_group=hl.range(n_groups).map(
1914-
lambda g_i: hl.range(n_samples).filter(
1915-
lambda s_i: ht.cols[s_i].group_membership[g_i]
1898+
global_expr = {}
1899+
if "adj_group" in mt.index_globals():
1900+
global_expr["adj_group"] = mt.index_globals().adj_group
1901+
logger.info("Using the 'adj_group' global annotation found on the input MT.")
1902+
1903+
if group_membership_ht is None and "group_membership" not in mt.col:
1904+
raise ValueError(
1905+
"The 'group_membership' annotation is not found in the input MatrixTable "
1906+
"and 'group_membership_ht' is not specified."
1907+
)
1908+
elif group_membership_ht is None:
1909+
logger.info(
1910+
"'group_membership_ht' is not specified, using sample stratification "
1911+
"indicated by the 'group_membership' annotation on mt."
1912+
)
1913+
n_groups = len(mt.group_membership.take(1)[0])
1914+
else:
1915+
logger.info(
1916+
"'group_membership_ht' is specified, using sample stratification indicated "
1917+
"by its 'group_membership' annotation."
1918+
)
1919+
n_groups = len(group_membership_ht.group_membership.take(1)[0])
1920+
mt = mt.annotate_cols(
1921+
group_membership=group_membership_ht[mt.col_key].group_membership
1922+
)
1923+
if "adj_group" not in global_expr:
1924+
if "adj_group" in group_membership_ht.index_globals():
1925+
global_expr["adj_group"] = mt.index_globals().adj_group
1926+
logger.info(
1927+
"Using the 'adj_group' global annotation on 'group_membership_ht'."
1928+
)
1929+
elif "freq_meta" in group_membership_ht.index_globals():
1930+
logger.info(
1931+
"The 'freq_meta' global annotation is found in "
1932+
"'group_membership_ht', using it to determine the adj filtered "
1933+
"stratification groups."
1934+
)
1935+
freq_meta = group_membership_ht.index_globals().freq_meta
1936+
1937+
global_expr["adj_group"] = freq_meta.map(
1938+
lambda x: x.get("group", "NA") == "adj"
19161939
)
1940+
1941+
if "adj_group" not in global_expr:
1942+
global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False)
1943+
1944+
n_adj_group = hl.eval(hl.len(global_expr["adj_group"]))
1945+
if hl.eval(hl.len(global_expr["adj_group"])) != n_groups:
1946+
raise ValueError(
1947+
f"The number of elements in the 'adj_group' ({n_adj_group}) global "
1948+
"annotation does not match the number of elements in the "
1949+
f"'group_membership' annotation ({n_groups})!",
1950+
)
1951+
1952+
# Keep only the entries needed for the aggregation functions.
1953+
select_expr = {}
1954+
if hl.eval(hl.any(global_expr["adj_group"])):
1955+
select_expr["adj"] = mt.adj
1956+
1957+
select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()})
1958+
mt = mt.select_entries(**select_expr)
1959+
1960+
# Convert MT to HT with a row annotation that is an array of all samples entries
1961+
# for that variant.
1962+
ht = mt.localize_entries("entries", "cols")
1963+
1964+
# For each stratification group in group_membership, determine the indices of the
1965+
# samples that belong to that group.
1966+
global_expr["indices_by_group"] = hl.range(n_groups).map(
1967+
lambda g_i: hl.range(n_samples).filter(
1968+
lambda s_i: ht.cols[s_i].group_membership[g_i]
19171969
)
19181970
)
1971+
ht = ht.annotate_globals(**global_expr)
1972+
19191973
# Pull out each annotation that will be used in the array aggregation below as its
19201974
# own ArrayExpression. This is important to prevent memory issues when performing
19211975
# the below array aggregations.
19221976
ht = ht.select(
1923-
*select_fields,
1924-
adj_array=ht.entries.map(lambda e: e.adj),
1925-
gt_array=ht.entries.map(lambda e: e.GT),
19261977
**{
1927-
ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols)
1928-
for ann, f in entry_agg_funcs.items()
1929-
},
1978+
ann: ht.entries.map(lambda e: e[ann])
1979+
for ann in select_fields + list(select_expr.keys())
1980+
}
19301981
)
19311982

19321983
def _agg_by_group(
1933-
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args
1984+
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression
19341985
) -> hl.expr.ArrayExpression:
19351986
"""
19361987
Aggregate `agg_expr` by group using the `agg_func` function.
19371988
19381989
:param ht: Input Hail Table.
1939-
:param agg_func: Aggregation function to apply to `agg_expr`.
1940-
:param agg_expr: Expression to aggregate by group.
1941-
:param args: Additional arguments to pass to the `agg_func`.
1990+
:param agg_func: Aggregation function to apply to `ann_expr`.
1991+
:param ann_expr: Expression to aggregate by group.
19421992
:return: Aggregated array expression.
19431993
"""
1944-
adj_agg_expr = ht.indices_by_group.map(
1945-
lambda s_indices: s_indices.aggregate(
1946-
lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args))
1947-
)
1948-
)
1949-
# Create final agg list by inserting or changing the "raw" group,
1950-
# representing all samples, in the adj_agg_list.
1951-
raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args))
1952-
if group_membership_includes_raw_group:
1953-
extend_idx = 2
1954-
else:
1955-
extend_idx = 1
1956-
1957-
adj_agg_expr = (
1958-
adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[extend_idx:])
1994+
return hl.map(
1995+
lambda s_indices, adj: s_indices.aggregate(
1996+
lambda i: hl.if_else(
1997+
adj,
1998+
hl.agg.filter(ht.adj[i], agg_func(ann_expr[i])),
1999+
agg_func(ann_expr[i]),
2000+
)
2001+
),
2002+
ht.indices_by_group,
2003+
ht.adj_group,
19592004
)
19602005

1961-
return adj_agg_expr
1962-
1963-
freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles)
1964-
1965-
# Select non-ref allele (assumes bi-allelic).
1966-
freq_expr = freq_expr.map(
1967-
lambda cs: cs.annotate(
1968-
AC=cs.AC[1],
1969-
AF=cs.AF[1],
1970-
homozygote_count=cs.homozygote_count[1],
1971-
)
1972-
)
19732006
# Add annotations for any supplied entry transform and aggregation functions.
19742007
ht = ht.select(
19752008
*select_fields,
19762009
**{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()},
1977-
freq=freq_expr,
19782010
)
19792011

19802012
return ht.drop("cols")
19812013

19822014

2015+
def compute_freq_by_strata(
2016+
mt: hl.MatrixTable,
2017+
entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None,
2018+
select_fields: Optional[List[str]] = None,
2019+
group_membership_includes_raw_group: bool = True,
2020+
) -> hl.Table:
2021+
"""
2022+
Compute call statistics and, when passed, entry aggregation function(s) by strata.
2023+
2024+
The computed call statistics are AC, AF, AN, and homozygote_count. The entry
2025+
aggregation functions are applied to the MatrixTable entries and aggregated. The
2026+
MatrixTable must contain a 'group_membership' annotation (like the one added by
2027+
`generate_freq_group_membership_array`) that is a list of bools to aggregate the
2028+
columns by.
2029+
2030+
.. note::
2031+
This function is primarily used through `annotate_freq` but can be used
2032+
independently if desired. Please see the `annotate_freq` function for more
2033+
complete documentation.
2034+
2035+
:param mt: Input MatrixTable.
2036+
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
2037+
specified, additional annotations are added to the output Table/MatrixTable.
2038+
The keys of the dict are the names of the annotations and the values are tuples
2039+
of functions. The first function is used to transform the `mt` entries in some
2040+
way, and the second function is used to aggregate the output from the first
2041+
function.
2042+
:param select_fields: Optional list of row fields from `mt` to keep on the output
2043+
Table.
2044+
:param group_membership_includes_raw_group: Whether the 'group_membership'
2045+
annotation includes an entry for the 'raw' group, representing all samples. If
2046+
False, the 'raw' group is inserted as the second element in all added
2047+
annotations using the same 'group_membership', resulting
2048+
in array lengths of 'group_membership'+1. If True, the second element of each
2049+
added annotation is still the 'raw' group, but the group membership is
2050+
determined by the values in the second element of 'group_membership', and the
2051+
output annotations will be the same length as 'group_membership'. Default is
2052+
True.
2053+
:return: Table or MatrixTable with allele frequencies by strata.
2054+
"""
2055+
if not group_membership_includes_raw_group:
2056+
# Add the 'raw' group to the 'group_membership' annotation.
2057+
mt = mt.annotate_cols(
2058+
group_membership=hl.array([mt.group_membership[0]]).extend(
2059+
mt.group_membership
2060+
)
2061+
)
2062+
2063+
# Add adj_group global annotation indicating that the second element in
2064+
# group_membership is 'raw' and all others are 'adj'.
2065+
mt = mt.annotate_globals(
2066+
adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1)
2067+
)
2068+
2069+
if entry_agg_funcs is None:
2070+
entry_agg_funcs = {}
2071+
2072+
def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression:
2073+
"""
2074+
Get struct expression with call statistics.
2075+
2076+
:param gt_expr: CallExpression to compute call statistics on.
2077+
:return: StructExpression with call statistics.
2078+
"""
2079+
# Get the source Table for the CallExpression to grab alleles.
2080+
ht = gt_expr._indices.source
2081+
freq_expr = hl.agg.call_stats(gt_expr, ht.alleles)
2082+
# Select non-ref allele (assumes bi-allelic).
2083+
freq_expr = freq_expr.annotate(
2084+
AC=freq_expr.AC[1],
2085+
AF=freq_expr.AF[1],
2086+
homozygote_count=freq_expr.homozygote_count[1],
2087+
)
2088+
2089+
return freq_expr
2090+
2091+
entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr)
2092+
2093+
return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group")
2094+
2095+
19832096
def update_structured_annotations(
19842097
ht: hl.Table,
19852098
annotation_update_exprs: Dict[str, hl.Expression],

0 commit comments

Comments
 (0)