@@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr})
1818    vn =  gensym (:vn )
1919
2020    return  quote 
21-         let  $ vn =  $ (varname (expr))
21+         let  $ vn =  $ (AbstractPPL . drop_escape ( varname (expr) ))
2222            if  $ (DynamicPPL. contextual_isassumption)(__context__, $ vn)
2323                #  Considered an assumption by `__context__` which means either:
2424                #  1. We hit the default implementation, e.g. using `DefaultContext`,
@@ -133,17 +133,17 @@ variables.
133133
134134# Example 
135135```jldoctest; setup=:(using Distributions, LinearAlgebra) 
136- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string( vns[end])  
137- " x[:,2]" 
136+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] 
137+ x[:,2] 
138138
139- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:] )); string( vns[end])  
140- "x[:][ 1,2]" 
139+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] 
140+ x[ 1,2]
141141
142- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3 ), @varname(x[1 ])); string( vns[end])  
143- "x[1][3]" 
142+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2 ), @varname(x[: ])); vns[end] 
143+ x[:][1,2] 
144144
145- julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2,  3), @varname(x)); string( vns[end])  
146- " x[1,2,3]" 
145+ julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1] )); vns[end] 
146+ x[1][3]  
147147``` 
148148""" 
149149unwrap_right_left_vns (right, left, vns) =  right, left, vns
@@ -158,7 +158,7 @@ function unwrap_right_left_vns(
158158    #  for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
159159    #  and we therefore add the `Colon()` below.
160160    vns =  map (axes (left, 2 )) do  i
161-         return  VarName (vn, (vn . indexing ... , ( Colon (), i) ))
161+         return  vn  ∘  Setfield . IndexLens (( Colon (), i))
162162    end 
163163    return  unwrap_right_left_vns (right, left, vns)
164164end 
@@ -168,7 +168,7 @@ function unwrap_right_left_vns(
168168    vn:: VarName ,
169169)
170170    vns =  map (CartesianIndices (left)) do  i
171-         return  VarName (vn, (vn . indexing ... ,  Tuple (i) ))
171+         return  vn  ∘  Setfield . IndexLens ( Tuple (i))
172172    end 
173173    return  unwrap_right_left_vns (right, left, vns)
174174end 
@@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
317317    #  Do not touch interpolated expressions
318318    expr. head ===  :$  &&  return  expr. args[1 ]
319319
320+     #  Do we don't want escaped expressions because we unfortunately
321+     #  escape the entire body afterwards.
322+     Meta. isexpr (expr, :escape ) &&  return  generate_mainbody (mod, found, expr. args[1 ], warn)
323+ 
320324    #  If it's a macro, we expand it
321325    if  Meta. isexpr (expr, :macrocall )
322326        return  generate_mainbody! (mod, found, macroexpand (mod, expr; recursive= true ), warn)
@@ -349,38 +353,36 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
349353    return  Expr (expr. head, map (x ->  generate_mainbody! (mod, found, x, warn), expr. args)... )
350354end 
351355
356+ function  generate_tilde_literal (left, right)
357+     #  If the LHS is a literal, it is always an observation
358+     return  quote 
359+         $ (DynamicPPL. tilde_observe!)(
360+             __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
361+         )
362+     end 
363+ end 
364+ 
352365""" 
353366    generate_tilde(left, right) 
354367
355368Generate an `observe` expression for data variables and `assume` expression for parameter 
356369variables. 
357370""" 
358371function  generate_tilde (left, right)
359-     #  If the LHS is a literal, it is always an observation
360-     if  isliteral (left)
361-         return  quote 
362-             $ (DynamicPPL. tilde_observe!)(
363-                 __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
364-             )
365-         end 
366-     end 
372+     isliteral (left) &&  return  generate_tilde_literal (left, right)
367373
368374    #  Otherwise it is determined by the model or its value,
369375    #  if the LHS represents an observation
370-     @gensym  vn inds isassumption
376+     @gensym  vn isassumption
377+ 
378+     #  HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379+     #  that in DynamicPPL we the entire function body. Instead we should be
380+     #  more selective with our escape. Until that's the case, we remove them all.
371381    return  quote 
372-         $ vn =  $ (varname (left))
373-         $ inds =  $ (vinds (left))
382+         $ vn =  $ (AbstractPPL. drop_escape (varname (left)))
374383        $ isassumption =  $ (DynamicPPL. isassumption (left))
375384        if  $ isassumption
376-             $ left =  $ (DynamicPPL. tilde_assume!)(
377-                 __context__,
378-                 $ (DynamicPPL. unwrap_right_vn)(
379-                     $ (DynamicPPL. check_tilde_rhs)($ right), $ vn
380-                 ). .. ,
381-                 $ inds,
382-                 __varinfo__,
383-             )
385+             $ (generate_tilde_assume (left, right, vn))
384386        else 
385387            #  If `vn` is not in `argnames`, we need to make sure that the variable is defined.
386388            if  ! $ (DynamicPPL. inargnames)($ vn, __model__)
@@ -392,44 +394,46 @@ function generate_tilde(left, right)
392394                $ (DynamicPPL. check_tilde_rhs)($ right),
393395                $ (maybe_view (left)),
394396                $ vn,
395-                 $ inds,
396397                __varinfo__,
397398            )
398399        end 
399400    end 
400401end 
401402
403+ function  generate_tilde_assume (left, right, vn)
404+     expr =  :(
405+         $ left =  $ (DynamicPPL. tilde_assume!)(
406+             __context__,
407+             $ (DynamicPPL. unwrap_right_vn)($ (DynamicPPL. check_tilde_rhs)($ right), $ vn). .. ,
408+             __varinfo__,
409+         )
410+     )
411+ 
412+     return  if  left isa  Expr
413+         AbstractPPL. drop_escape (
414+             Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415+         )
416+     else 
417+         return  expr
418+     end 
419+ end 
420+ 
402421""" 
403422    generate_dot_tilde(left, right) 
404423
405424Generate the expression that replaces `left .~ right` in the model body. 
406425""" 
407426function  generate_dot_tilde (left, right)
408-     #  If the LHS is a literal, it is always an observation
409-     if  isliteral (left)
410-         return  quote 
411-             $ (DynamicPPL. dot_tilde_observe!)(
412-                 __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
413-             )
414-         end 
415-     end 
427+     isliteral (left) &&  return  generate_tilde_literal (left, right)
416428
417429    #  Otherwise it is determined by the model or its value,
418430    #  if the LHS represents an observation
419-     @gensym  vn inds  isassumption
431+     @gensym  vn isassumption
420432    return  quote 
421-         $ vn =  $ (varname (left))
422-         $ inds =  $ (vinds (left))
433+         $ vn =  $ (AbstractPPL. drop_escape (varname (left)))
423434        $ isassumption =  $ (DynamicPPL. isassumption (left))
424435        if  $ isassumption
425-             $ left .=  $ (DynamicPPL. dot_tilde_assume!)(
426-                 __context__,
427-                 $ (DynamicPPL. unwrap_right_left_vns)(
428-                     $ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
429-                 ). .. ,
430-                 $ inds,
431-                 __varinfo__,
432-             )
436+             $ (generate_dot_tilde_assume (left, right, vn))
433437        else 
434438            #  If `vn` is not in `argnames`, we need to make sure that the variable is defined.
435439            if  ! $ (DynamicPPL. inargnames)($ vn, __model__)
@@ -441,13 +445,27 @@ function generate_dot_tilde(left, right)
441445                $ (DynamicPPL. check_tilde_rhs)($ right),
442446                $ (maybe_view (left)),
443447                $ vn,
444-                 $ inds,
445448                __varinfo__,
446449            )
447450        end 
448451    end 
449452end 
450453
454+ function  generate_dot_tilde_assume (left, right, vn)
455+     #  We don't need to use `Setfield.@set` here since
456+     #  `.=` is always going to be inplace + needs `left` to
457+     #  be something that supports `.=`.
458+     return  :(
459+         $ left .=  $ (DynamicPPL. dot_tilde_assume!)(
460+             __context__,
461+             $ (DynamicPPL. unwrap_right_left_vns)(
462+                 $ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
463+             ). .. ,
464+             __varinfo__,
465+         )
466+     )
467+ end 
468+ 
451469const  FloatOrArrayType =  Type{<: Union{AbstractFloat,AbstractArray} }
452470hasmissing (T:: Type{<:AbstractArray{TA}} ) where  {TA<: AbstractArray } =  hasmissing (TA)
453471hasmissing (T:: Type{<:AbstractArray{>:Missing}} ) =  true 
0 commit comments