Skip to content

Complex broadcasting AD gives nothing when using CUDA #1215

@pawbz

Description

@pawbz

The following snippet works well on CPU i.e., it gives the correct gradient but fails on GPU.

using Zygote
y=complex.([4,1])
x=complex.([3,2])
function f1215(x, y) 
    x = 2 .* x
    return sum(abs2.(x .- y))
end 
gs = gradient(()-> f1215(x,y), Zygote.Params([x]))
gs[x] # returns nothing when x and y are on GPU

using CUDA
x = cu(x)
y = cu(y)

[Edited not to need Flux]

Metadata

Metadata

Assignees

No one assigned

    Labels

    CUDAAll things GPU

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions