- 
                Notifications
    
You must be signed in to change notification settings  - Fork 2.1k
 
          Decouple convergence checking from SamplerReport
          #6453
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
  
    Decouple convergence checking from SamplerReport
  
  #6453
              Conversation
          Codecov Report
 Additional details and impacted files@@            Coverage Diff             @@
##             main    #6453      +/-   ##
==========================================
+ Coverage   94.78%   94.79%   +0.01%     
==========================================
  Files         148      148              
  Lines       27678    27678              
==========================================
+ Hits        26234    26238       +4     
+ Misses       1444     1440       -4     
  | 
    
c6dbdbb    to
    f419ed3      
    Compare
  
    f419ed3    to
    6c2f7f2      
    Compare
  
    The goal was to uncouple sampling functions from `MultiTrace` and `SamplerReport`. Some calls to `SamplerReport._log_summary()` were unnecessary because `MultiTrace._add_warnings()` was never called inbetween instantiation and `_log_summary()`, therefore the traces never contained warnings. Running convergence checks and logging the warnings can also be done without needing `MultiTrace` or `SamplerReport` instances/methods.
6c2f7f2    to
    49f5263      
    Compare
  
    * Specify covariant input types in `StatsBijection`. * Annotate `_choose_chains` to be independent of `BaseTrace` type.
| 
           I don't think I am qualified to review this  | 
    
          
 I should have added comments to the diff earlier.. GitHub suggested you because you edited the SMC code? Who else is familiar with it?  | 
    
| S = TypeVar("S", bound=Sized) | ||
| 
               | 
          ||
| 
               | 
          ||
| def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized.
| f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " | ||
| f"took {t_sampling:.0f} seconds." | ||
| ) | ||
| mtrace.report._log_summary() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inbetween the line 574 mtrace = MultiTrace(traces)[:length] where the MultiTrace was created, no warnings were added to mtrace.
Therefore, there are no warnings to log and the _log_summary() call can safely be removed.
| warnings.warn( | ||
| "The number of samples is too small to check convergence reliably.", | ||
| stacklevel=2, | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now checked by run_convergence_checks, just like it already checked for a minimum number of chains
| multitrace = MultiTrace(traces) | ||
| multitrace._report._log_summary() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too: The multitrace can not have warnings that would be printed by _log_summary() because none were added here or in its __init__
| if idata is None: | ||
| idata = to_inference_data(trace, log_likelihood=False) | ||
| warns = run_convergence_checks(idata, model) | ||
| trace.report._add_warnings(warns) | ||
| log_warnings(warns) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces the _compute_convergence_checks function and makes the trace.report be a dead end that can easily be removed in the future
Remember from other changes:
- "number of samples is too small" warning now done by 
run_convergence_checks report._add_warningswas done insidereport._run_convergence_checkstrace.report._log_summary()internally calledlog_warnings()
| """Map between a `list` of stats to `dict` of stats.""" | ||
| 
               | 
          ||
| def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: | ||
| def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing rule of thumb: Generic input types, exact output types.
          
 I have only modified some docstrings 😅. @aloctavodia is the best choice I think  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The goal was to uncouple sampling functions from
MultiTraceandSamplerReport.Some calls to
SamplerReport._log_summary()were unnecessary becauseMultiTrace._add_warnings()was never called inbetween instantiation and_log_summary(), therefore the traces never contained warnings.Running convergence checks and logging the warnings can also be done without needing
MultiTraceorSamplerReportinstances/methods.Checklist
Minor changes
"The number of samples is too small to check convergence reliably."warning is now anINFOlevel log message instead of aWarning.SamplerReport._log_summary()andSamplerReport._run_convergence_checksmethods were removed.Maintenance
MultiTraceorSamplerReportto compute/log warnings.