-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
Currently, the DynamicPPL.float_type_with_fallback function is used in at least two places:
- When deciding the type of logp based on the type of parameters, inside
unflatten
Lines 368 to 385 in 9a2607b
| function unflatten(vi::VarInfo, x::AbstractVector) | |
| md = unflatten_metadata(vi.metadata, x) | |
| # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is | |
| # a gradient type of some AD backend. | |
| # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! | |
| # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but | |
| # the accumulators in the VarInfo are plain floats, we error since we can't change the | |
| # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here | |
| # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just | |
| # plain ugly and hacky. | |
| # The below line is finicky for type stability. For instance, assigning the eltype to | |
| # convert to into an intermediate variable makes this unstable (constant propagation) | |
| # fails. Take care when editing. | |
| accs = map( | |
| acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) | |
| ) | |
| return VarInfo(md, accs) | |
| end |
- Inside the compiler, when handling arguments like
::Type{T}=Vector{Float64}. (Although, see also Do we still need parametric types in model definition for autodiff? #823)
Lines 747 to 753 in 9a2607b
| get_matching_type(_, ::Type{T}) where {T} = T | |
| function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) | |
| return Union{Missing,float_type_with_fallback(eltype(vi))} | |
| end | |
| function get_matching_type(vi, ::Type{<:AbstractFloat}) | |
| return float_type_with_fallback(eltype(vi)) | |
| end |
I am fairly sure that these uses are disjoint and although it just so happens that the same function works in both cases, I think they should be separated.
For example, in case (2), if you have ::Type{T}=Vector{Real} and you are running ForwardDiff AD, you probably want to promote it to Dual{...,Real}, not Dual{...,Float64} as is currently the case.
In contrast in case (1) we don't want Real params to cause log probs to be Real.
Metadata
Metadata
Assignees
Labels
No labels