@@ -293,7 +293,9 @@ def _get_criteria(i: hl.expr.Int32Expression) -> hl.expr.Int32Expression:
293293
294294
295295def annotate_mutation_type (
296- t : Union [hl .MatrixTable , hl .Table ]
296+ t : Union [hl .MatrixTable , hl .Table ],
297+ context_length : Optional [int ] = None ,
298+ num_scan_context_length : Optional [int ] = 100 ,
297299) -> Union [hl .MatrixTable , hl .Table ]:
298300 """
299301 Annotate mutation types.
@@ -310,20 +312,36 @@ def annotate_mutation_type(
310312 used in this repo to indicate a variant's multiallelic and SNP/indel status.
311313
312314 :param t: Input Table or MatrixTable.
315+ :param context_length: Length of the 'context' annotation in 't'. If this is not
316+ specified, the value will be determined by examining the first
317+ `num_scan_context_length` values of the 'context' annotation. Default is None.
318+ :param num_scan_context_length: Number of values in the 'context' annotation to use
319+ for determining `context_length` if it is not specified. If set to None, all
320+ values in 'context' will be used. Default is 100.
313321 :return: Table with mutation type annotations added.
314322 """
315- # Determine the context length by collecting all the context lengths.
316- context_lengths = list (
317- filter (None , t .aggregate (hl .agg .collect_as_set (hl .len (t .context ))))
318- )
319- if len (context_lengths ) > 1 :
320- raise ValueError (
321- "More than one length was found among the first 100 'context' values."
322- " Length of 'context' should be consistent."
323- )
324- else :
325- context_length = context_lengths [0 ]
326- logger .info ("Detected a length of %d for context length" , context_length )
323+ if context_length is None :
324+ # Determine the context length by collecting all the context lengths.
325+ if num_scan_context_length is None :
326+ context_lengths = t .aggregate (hl .agg .collect_as_set (hl .len (t .context )))
327+ msg = "all"
328+ else :
329+ context_lengths = hl .len (t .context ).take (num_scan_context_length )
330+ msg = f"the first { num_scan_context_length } "
331+ context_lengths = list (filter (None , set (context_lengths )))
332+ if len (context_lengths ) > 1 :
333+ raise ValueError (
334+ f"More than one length was found among { msg } 'context' values. Length "
335+ "of 'context' should be consistent." ,
336+ )
337+ else :
338+ context_length = context_lengths [0 ]
339+ logger .info (
340+ "Detected a length of %d for context length using %s 'context' values." ,
341+ context_length ,
342+ msg ,
343+ )
344+
327345 # Determine the middle index of the context annotation.
328346 if context_length == 3 :
329347 mid_index = 1
@@ -348,10 +366,15 @@ def annotate_mutation_type(
348366 else :
349367 t = t .annotate (transition = transition_expr , cpg = cpg_expr )
350368 mutation_type_expr = (
351- hl .case ()
352- .when (t .cpg , "CpG" )
353- .when (t .transition , "non-CpG transition" )
354- .default ("transversion" )
369+ hl .switch (hl .len (t .context ))
370+ .when (
371+ context_length ,
372+ hl .case ()
373+ .when (t .cpg , "CpG" )
374+ .when (t .transition , "non-CpG transition" )
375+ .default ("transversion" ),
376+ )
377+ .or_error ("Found 'context' value with unexpected context length!" )
355378 )
356379 mutation_type_model_expr = hl .if_else (t .cpg , t .context , "non-CpG" )
357380 if isinstance (t , hl .MatrixTable ):
0 commit comments