Skip to content

Decouple usage of float_type_with_fallback #1090

@penelopeysm

Description

@penelopeysm

Currently, the DynamicPPL.float_type_with_fallback function is used in at least two places:

  1. When deciding the type of logp based on the type of parameters, inside unflatten

function unflatten(vi::VarInfo, x::AbstractVector)
md = unflatten_metadata(vi.metadata, x)
# Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is
# a gradient type of some AD backend.
# TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!!
# for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but
# the accumulators in the VarInfo are plain floats, we error since we can't change the
# element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
# plain ugly and hacky.
# The below line is finicky for type stability. For instance, assigning the eltype to
# convert to into an intermediate variable makes this unstable (constant propagation)
# fails. Take care when editing.
accs = map(
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi))
)
return VarInfo(md, accs)
end

  1. Inside the compiler, when handling arguments like ::Type{T}=Vector{Float64}. (Although, see also Do we still need parametric types in model definition for autodiff? #823)

get_matching_type(_, ::Type{T}) where {T} = T
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(eltype(vi))}
end
function get_matching_type(vi, ::Type{<:AbstractFloat})
return float_type_with_fallback(eltype(vi))
end

I am fairly sure that these uses are disjoint and although it just so happens that the same function works in both cases, I think they should be separated.

For example, in case (2), if you have ::Type{T}=Vector{Real} and you are running ForwardDiff AD, you probably want to promote it to Dual{...,Real}, not Dual{...,Float64} as is currently the case.

In contrast in case (1) we don't want Real params to cause log probs to be Real.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions