@@ -1861,25 +1861,19 @@ def generate_freq_group_membership_array(
18611861 return ht
18621862
18631863
1864- def compute_freq_by_strata (
1864+ def agg_by_strata (
18651865 mt : hl .MatrixTable ,
18661866 entry_agg_funcs : Optional [Dict [str , Tuple [Callable , Callable ]]] = None ,
18671867 select_fields : Optional [List [str ]] = None ,
1868- group_membership_includes_raw_group : bool = True ,
1868+ group_membership_ht : Optional [ hl . Table ] = None ,
18691869) -> hl .Table :
18701870 """
1871- Compute call statistics and, when passed, entry aggregation function(s) by strata.
1872-
1873- The computed call statistics are AC, AF, AN, and homozygote_count. The entry
1874- aggregation functions are applied to the MatrixTable entries and aggregated. The
1875- MatrixTable must contain a 'group_membership' annotation (like the one added by
1876- `generate_freq_group_membership_array`) that is a list of bools to aggregate the
1877- columns by.
1871+ Get row expression for annotations of each entry aggregation function(s) by strata.
18781872
1879- .. note::
1880- This function is primarily used through `annotate_freq` but can be used
1881- independently if desired. Please see the `annotate_freq` function for more
1882- complete documentation .
1873+ The entry aggregation functions are applied to the MatrixTable entries and
1874+ aggregated. If no `group_membership_ht` (like the one returned by
1875+ `generate_freq_group_membership_array`) is supplied, `mt` must contain a
1876+ 'group_membership' annotation that is a list of bools to aggregate the columns by .
18831877
18841878 :param mt: Input MatrixTable.
18851879 :param entry_agg_funcs: Optional dict of entry aggregation functions. When
@@ -1890,15 +1884,9 @@ def compute_freq_by_strata(
18901884 function.
18911885 :param select_fields: Optional list of row fields from `mt` to keep on the output
18921886 Table.
1893- :param group_membership_includes_raw_group: Whether the 'group_membership'
1894- annotation includes an entry for the 'raw' group, representing all samples. If
1895- False, the 'raw' group is inserted as the second element in all added
1896- annotations using the same 'group_membership', resulting
1897- in array lengths of 'group_membership'+1. If True, the second element of each
1898- added annotation is still the 'raw' group, but the group membership is
1899- determined by the values in the second element of 'group_membership', and the
1900- output annotations will be the same length as 'group_membership'. Default is
1901- True.
1887+ :param group_membership_ht: Optional Table containing group membership annotations
1888+ to stratify the coverage stats by. If not provided, the 'group_membership'
1889+ annotation is expected to be present on `mt`.
19021890 :return: Table or MatrixTable with allele frequencies by strata.
19031891 """
19041892 if entry_agg_funcs is None :
@@ -1907,79 +1895,204 @@ def compute_freq_by_strata(
19071895 select_fields = []
19081896
19091897 n_samples = mt .count_cols ()
1910- n_groups = len (mt .group_membership .take (1 )[0 ])
1911- ht = mt .localize_entries ("entries" , "cols" )
1912- ht = ht .annotate_globals (
1913- indices_by_group = hl .range (n_groups ).map (
1914- lambda g_i : hl .range (n_samples ).filter (
1915- lambda s_i : ht .cols [s_i ].group_membership [g_i ]
1898+ global_expr = {}
1899+ if "adj_group" in mt .index_globals ():
1900+ global_expr ["adj_group" ] = mt .index_globals ().adj_group
1901+ logger .info ("Using the 'adj_group' global annotation found on the input MT." )
1902+
1903+ if group_membership_ht is None and "group_membership" not in mt .col :
1904+ raise ValueError (
1905+ "The 'group_membership' annotation is not found in the input MatrixTable "
1906+ "and 'group_membership_ht' is not specified."
1907+ )
1908+ elif group_membership_ht is None :
1909+ logger .info (
1910+ "'group_membership_ht' is not specified, using sample stratification "
1911+ "indicated by the 'group_membership' annotation on mt."
1912+ )
1913+ n_groups = len (mt .group_membership .take (1 )[0 ])
1914+ else :
1915+ logger .info (
1916+ "'group_membership_ht' is specified, using sample stratification indicated "
1917+ "by its 'group_membership' annotation."
1918+ )
1919+ n_groups = len (group_membership_ht .group_membership .take (1 )[0 ])
1920+ mt = mt .annotate_cols (
1921+ group_membership = group_membership_ht [mt .col_key ].group_membership
1922+ )
1923+ if "adj_group" not in global_expr :
1924+ if "adj_group" in group_membership_ht .index_globals ():
1925+ global_expr ["adj_group" ] = mt .index_globals ().adj_group
1926+ logger .info (
1927+ "Using the 'adj_group' global annotation on 'group_membership_ht'."
1928+ )
1929+ elif "freq_meta" in group_membership_ht .index_globals ():
1930+ logger .info (
1931+ "The 'freq_meta' global annotation is found in "
1932+ "'group_membership_ht', using it to determine the adj filtered "
1933+ "stratification groups."
1934+ )
1935+ freq_meta = group_membership_ht .index_globals ().freq_meta
1936+
1937+ global_expr ["adj_group" ] = freq_meta .map (
1938+ lambda x : x .get ("group" , "NA" ) == "adj"
19161939 )
1940+
1941+ if "adj_group" not in global_expr :
1942+ global_expr ["adj_group" ] = hl .range (n_groups ).map (lambda x : False )
1943+
1944+ n_adj_group = hl .eval (hl .len (global_expr ["adj_group" ]))
1945+ if hl .eval (hl .len (global_expr ["adj_group" ])) != n_groups :
1946+ raise ValueError (
1947+ f"The number of elements in the 'adj_group' ({ n_adj_group } ) global "
1948+ "annotation does not match the number of elements in the "
1949+ f"'group_membership' annotation ({ n_groups } )!" ,
1950+ )
1951+
1952+ # Keep only the entries needed for the aggregation functions.
1953+ select_expr = {}
1954+ if hl .eval (hl .any (global_expr ["adj_group" ])):
1955+ select_expr ["adj" ] = mt .adj
1956+
1957+ select_expr .update (** {ann : f [0 ](mt ) for ann , f in entry_agg_funcs .items ()})
1958+ mt = mt .select_entries (** select_expr )
1959+
1960+ # Convert MT to HT with a row annotation that is an array of all samples entries
1961+ # for that variant.
1962+ ht = mt .localize_entries ("entries" , "cols" )
1963+
1964+ # For each stratification group in group_membership, determine the indices of the
1965+ # samples that belong to that group.
1966+ global_expr ["indices_by_group" ] = hl .range (n_groups ).map (
1967+ lambda g_i : hl .range (n_samples ).filter (
1968+ lambda s_i : ht .cols [s_i ].group_membership [g_i ]
19171969 )
19181970 )
1971+ ht = ht .annotate_globals (** global_expr )
1972+
19191973 # Pull out each annotation that will be used in the array aggregation below as its
19201974 # own ArrayExpression. This is important to prevent memory issues when performing
19211975 # the below array aggregations.
19221976 ht = ht .select (
1923- * select_fields ,
1924- adj_array = ht .entries .map (lambda e : e .adj ),
1925- gt_array = ht .entries .map (lambda e : e .GT ),
19261977 ** {
1927- ann : hl . map (lambda e , s : f [ 0 ]( e , s ), ht . entries , ht . cols )
1928- for ann , f in entry_agg_funcs . items ( )
1929- },
1978+ ann : ht . entries . map (lambda e : e [ ann ] )
1979+ for ann in select_fields + list ( select_expr . keys () )
1980+ }
19301981 )
19311982
19321983 def _agg_by_group (
1933- ht : hl .Table , agg_func : Callable , ann_expr : hl .expr .ArrayExpression , * args
1984+ ht : hl .Table , agg_func : Callable , ann_expr : hl .expr .ArrayExpression
19341985 ) -> hl .expr .ArrayExpression :
19351986 """
19361987 Aggregate `agg_expr` by group using the `agg_func` function.
19371988
19381989 :param ht: Input Hail Table.
1939- :param agg_func: Aggregation function to apply to `agg_expr`.
1940- :param agg_expr: Expression to aggregate by group.
1941- :param args: Additional arguments to pass to the `agg_func`.
1990+ :param agg_func: Aggregation function to apply to `ann_expr`.
1991+ :param ann_expr: Expression to aggregate by group.
19421992 :return: Aggregated array expression.
19431993 """
1944- adj_agg_expr = ht .indices_by_group .map (
1945- lambda s_indices : s_indices .aggregate (
1946- lambda i : hl .agg .filter (ht .adj_array [i ], agg_func (ann_expr [i ], * args ))
1947- )
1948- )
1949- # Create final agg list by inserting or changing the "raw" group,
1950- # representing all samples, in the adj_agg_list.
1951- raw_agg_expr = ann_expr .aggregate (lambda x : agg_func (x , * args ))
1952- if group_membership_includes_raw_group :
1953- extend_idx = 2
1954- else :
1955- extend_idx = 1
1956-
1957- adj_agg_expr = (
1958- adj_agg_expr [:1 ].append (raw_agg_expr ).extend (adj_agg_expr [extend_idx :])
1994+ return hl .map (
1995+ lambda s_indices , adj : s_indices .aggregate (
1996+ lambda i : hl .if_else (
1997+ adj ,
1998+ hl .agg .filter (ht .adj [i ], agg_func (ann_expr [i ])),
1999+ agg_func (ann_expr [i ]),
2000+ )
2001+ ),
2002+ ht .indices_by_group ,
2003+ ht .adj_group ,
19592004 )
19602005
1961- return adj_agg_expr
1962-
1963- freq_expr = _agg_by_group (ht , hl .agg .call_stats , ht .gt_array , ht .alleles )
1964-
1965- # Select non-ref allele (assumes bi-allelic).
1966- freq_expr = freq_expr .map (
1967- lambda cs : cs .annotate (
1968- AC = cs .AC [1 ],
1969- AF = cs .AF [1 ],
1970- homozygote_count = cs .homozygote_count [1 ],
1971- )
1972- )
19732006 # Add annotations for any supplied entry transform and aggregation functions.
19742007 ht = ht .select (
19752008 * select_fields ,
19762009 ** {ann : _agg_by_group (ht , f [1 ], ht [ann ]) for ann , f in entry_agg_funcs .items ()},
1977- freq = freq_expr ,
19782010 )
19792011
19802012 return ht .drop ("cols" )
19812013
19822014
2015+ def compute_freq_by_strata (
2016+ mt : hl .MatrixTable ,
2017+ entry_agg_funcs : Optional [Dict [str , Tuple [Callable , Callable ]]] = None ,
2018+ select_fields : Optional [List [str ]] = None ,
2019+ group_membership_includes_raw_group : bool = True ,
2020+ ) -> hl .Table :
2021+ """
2022+ Compute call statistics and, when passed, entry aggregation function(s) by strata.
2023+
2024+ The computed call statistics are AC, AF, AN, and homozygote_count. The entry
2025+ aggregation functions are applied to the MatrixTable entries and aggregated. The
2026+ MatrixTable must contain a 'group_membership' annotation (like the one added by
2027+ `generate_freq_group_membership_array`) that is a list of bools to aggregate the
2028+ columns by.
2029+
2030+ .. note::
2031+ This function is primarily used through `annotate_freq` but can be used
2032+ independently if desired. Please see the `annotate_freq` function for more
2033+ complete documentation.
2034+
2035+ :param mt: Input MatrixTable.
2036+ :param entry_agg_funcs: Optional dict of entry aggregation functions. When
2037+ specified, additional annotations are added to the output Table/MatrixTable.
2038+ The keys of the dict are the names of the annotations and the values are tuples
2039+ of functions. The first function is used to transform the `mt` entries in some
2040+ way, and the second function is used to aggregate the output from the first
2041+ function.
2042+ :param select_fields: Optional list of row fields from `mt` to keep on the output
2043+ Table.
2044+ :param group_membership_includes_raw_group: Whether the 'group_membership'
2045+ annotation includes an entry for the 'raw' group, representing all samples. If
2046+ False, the 'raw' group is inserted as the second element in all added
2047+ annotations using the same 'group_membership', resulting
2048+ in array lengths of 'group_membership'+1. If True, the second element of each
2049+ added annotation is still the 'raw' group, but the group membership is
2050+ determined by the values in the second element of 'group_membership', and the
2051+ output annotations will be the same length as 'group_membership'. Default is
2052+ True.
2053+ :return: Table or MatrixTable with allele frequencies by strata.
2054+ """
2055+ if not group_membership_includes_raw_group :
2056+ # Add the 'raw' group to the 'group_membership' annotation.
2057+ mt = mt .annotate_cols (
2058+ group_membership = hl .array ([mt .group_membership [0 ]]).extend (
2059+ mt .group_membership
2060+ )
2061+ )
2062+
2063+ # Add adj_group global annotation indicating that the second element in
2064+ # group_membership is 'raw' and all others are 'adj'.
2065+ mt = mt .annotate_globals (
2066+ adj_group = hl .range (hl .len (mt .group_membership .take (1 )[0 ])).map (lambda x : x != 1 )
2067+ )
2068+
2069+ if entry_agg_funcs is None :
2070+ entry_agg_funcs = {}
2071+
2072+ def _get_freq_expr (gt_expr : hl .expr .CallExpression ) -> hl .expr .StructExpression :
2073+ """
2074+ Get struct expression with call statistics.
2075+
2076+ :param gt_expr: CallExpression to compute call statistics on.
2077+ :return: StructExpression with call statistics.
2078+ """
2079+ # Get the source Table for the CallExpression to grab alleles.
2080+ ht = gt_expr ._indices .source
2081+ freq_expr = hl .agg .call_stats (gt_expr , ht .alleles )
2082+ # Select non-ref allele (assumes bi-allelic).
2083+ freq_expr = freq_expr .annotate (
2084+ AC = freq_expr .AC [1 ],
2085+ AF = freq_expr .AF [1 ],
2086+ homozygote_count = freq_expr .homozygote_count [1 ],
2087+ )
2088+
2089+ return freq_expr
2090+
2091+ entry_agg_funcs ["freq" ] = (lambda x : x .GT , _get_freq_expr )
2092+
2093+ return agg_by_strata (mt , entry_agg_funcs , select_fields ).drop ("adj_group" )
2094+
2095+
19832096def update_structured_annotations (
19842097 ht : hl .Table ,
19852098 annotation_update_exprs : Dict [str , hl .Expression ],
0 commit comments