Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 50 additions & 34 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,58 @@
#############################################################################

const _JuMPTypes = Union{AbstractJuMPScalar, NonlinearExpression}
const Constant = Union{Number, UniformScaling}
_float(x::Number) = convert(Float64, x)
_float(J::UniformScaling) = _float(J.λ)

# Overloads
#
# Different objects that must all interact:
# 1. Number
# 1. Constant
# 2. AbstractVariableRef
# 4. GenericAffExpr
# 5. GenericQuadExpr

# Number
# Number--Number obviously already taken care of!
# Number--VariableRef
Base.:+(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(convert(Float64, lhs), rhs => 1.0)
Base.:-(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(convert(Float64, lhs), rhs => -1.0)
Base.:*(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(0.0, rhs => convert(Float64,lhs))
# Number--GenericAffExpr
function Base.:+(lhs::Number, rhs::GenericAffExpr)
# Constant
# Constant--Constant obviously already taken care of!
# Constant--VariableRef
Base.:+(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(_float(lhs), rhs => 1.0)
Base.:-(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(_float(lhs), rhs => -1.0)
Base.:*(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(0.0, rhs => _float(lhs))
# Constant--GenericAffExpr
function Base.:+(lhs::Constant, rhs::GenericAffExpr)
result = copy(rhs)
result.constant += lhs
return result
end
function Base.:-(lhs::Number, rhs::GenericAffExpr)
function Base.:-(lhs::Constant, rhs::GenericAffExpr)
result = -rhs
result.constant += lhs
return result
end
Base.:*(lhs::Number, rhs::GenericAffExpr) = map_coefficients(c -> lhs * c, rhs)
# Number--QuadExpr
Base.:+(lhs::Number, rhs::GenericQuadExpr) = GenericQuadExpr(lhs+rhs.aff, copy(rhs.terms))
function Base.:-(lhs::Number, rhs::GenericQuadExpr)
function Base.:*(lhs::Constant, rhs::GenericAffExpr)
f = _float(lhs)
return map_coefficients(c -> f * c, rhs)
end
# Constant--QuadExpr
Base.:+(lhs::Constant, rhs::GenericQuadExpr) = GenericQuadExpr(lhs+rhs.aff, copy(rhs.terms))
function Base.:-(lhs::Constant, rhs::GenericQuadExpr)
result = -rhs
result.aff.constant += lhs
return result
end
Base.:*(lhs::Number, rhs::GenericQuadExpr) = map_coefficients(c -> lhs * c, rhs)
Base.:*(lhs::Constant, rhs::GenericQuadExpr) = map_coefficients(c -> lhs * c, rhs)

# AbstractVariableRef (or, AbstractJuMPScalar)
# TODO: What is the role of AbstractJuMPScalar??
Base.:+(lhs::AbstractJuMPScalar) = lhs
Base.:-(lhs::AbstractVariableRef) = GenericAffExpr(0.0, lhs => -1.0)
Base.:*(lhs::AbstractJuMPScalar) = lhs # make this more generic so extensions don't have to define unary multiplication for our macros
# AbstractVariableRef--Number
Base.:+(lhs::AbstractVariableRef, rhs::Number) = (+)( rhs,lhs)
Base.:-(lhs::AbstractVariableRef, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::AbstractVariableRef, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::AbstractVariableRef, rhs::Number) = (*)(1.0/rhs,lhs)
# AbstractVariableRef--Constant
Base.:+(lhs::AbstractVariableRef, rhs::Constant) = (+)( rhs,lhs)
Base.:-(lhs::AbstractVariableRef, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::AbstractVariableRef, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::AbstractVariableRef, rhs::Constant) = (*)(1.0/rhs,lhs)
# AbstractVariableRef--AbstractVariableRef
Base.:+(lhs::V, rhs::V) where {V <: AbstractVariableRef} = GenericAffExpr(0.0, lhs => 1.0, rhs => 1.0)
Base.:-(lhs::V, rhs::V) where {V <: AbstractVariableRef} = GenericAffExpr(0.0, lhs => 1.0, rhs => -1.0)
Expand Down Expand Up @@ -111,11 +117,11 @@ end
# GenericAffExpr
Base.:+(lhs::GenericAffExpr) = lhs
Base.:-(lhs::GenericAffExpr) = map_coefficients(-, lhs)
# GenericAffExpr--Number
Base.:+(lhs::GenericAffExpr, rhs::Number) = (+)(+rhs,lhs)
Base.:-(lhs::GenericAffExpr, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::GenericAffExpr, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::GenericAffExpr, rhs::Number) = map_coefficients(c -> c/rhs, lhs)
# GenericAffExpr--Constant
Base.:+(lhs::GenericAffExpr, rhs::Constant) = (+)(rhs,lhs)
Base.:-(lhs::GenericAffExpr, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::GenericAffExpr, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::GenericAffExpr, rhs::Constant) = map_coefficients(c -> c/rhs, lhs)
function Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Integer)
if rhs == 2
return lhs*lhs
Expand All @@ -127,7 +133,7 @@ function Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Integer)
error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
end
end
Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Number) = error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Constant) = error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
# GenericAffExpr--AbstractVariableRef
function Base.:+(lhs::GenericAffExpr{C, V}, rhs::V) where {C, V <: AbstractVariableRef}
return add_to_expression!(copy(lhs), one(C), rhs)
Expand Down Expand Up @@ -187,11 +193,11 @@ end
# GenericQuadExpr
Base.:+(lhs::GenericQuadExpr) = lhs
Base.:-(lhs::GenericQuadExpr) = map_coefficients(-, lhs)
# GenericQuadExpr--Number
Base.:+(lhs::GenericQuadExpr, rhs::Number) = (+)(+rhs,lhs)
Base.:-(lhs::GenericQuadExpr, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::GenericQuadExpr, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::GenericQuadExpr, rhs::Number) = (*)(inv(rhs),lhs)
# GenericQuadExpr--Constant
Base.:+(lhs::GenericQuadExpr, rhs::Constant) = (+)(+rhs,lhs)
Base.:-(lhs::GenericQuadExpr, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::GenericQuadExpr, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::GenericQuadExpr, rhs::Constant) = (*)(inv(rhs),lhs)
# GenericQuadExpr--AbstractVariableRef
Base.:+(q::GenericQuadExpr, v::AbstractVariableRef) = GenericQuadExpr(q.aff+v, copy(q.terms))
Base.:-(q::GenericQuadExpr, v::AbstractVariableRef) = GenericQuadExpr(q.aff-v, copy(q.terms))
Expand Down Expand Up @@ -271,8 +277,8 @@ end
# for scalars, so instead of defining them one-by-one, we will
# fallback to the multiplication operator
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::_JuMPTypes) = lhs*rhs
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::Number) = lhs*rhs
LinearAlgebra.dot(lhs::Number, rhs::_JuMPTypes) = lhs*rhs
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::Constant) = lhs*rhs
LinearAlgebra.dot(lhs::Constant, rhs::_JuMPTypes) = lhs*rhs

LinearAlgebra.dot(lhs::AbstractVector{T}, rhs::AbstractVector{S}) where {T <: _JuMPTypes, S <: _JuMPTypes} = _dot(lhs,rhs)
LinearAlgebra.dot(lhs::AbstractVector{T}, rhs::AbstractVector{S}) where {T <: _JuMPTypes, S} = _dot(lhs,rhs)
Expand Down Expand Up @@ -606,6 +612,16 @@ function Base.:-(x::AbstractArray{T}) where {T <: _JuMPTypes}
return ret
end

# Fix https://github.com/JuliaLang/julia/issues/32374 as done in
# https://github.com/JuliaLang/julia/pull/32375. This hack should
# be removed once we drop Julia v1.0.
function Base.:-(A::Symmetric{<:JuMP.AbstractVariableRef})
return Symmetric(-A.data, LinearAlgebra.sym_uplo(A.uplo))
end
function Base.:-(A::Hermitian{<:JuMP.AbstractVariableRef})
return Hermitian(-A.data, LinearAlgebra.sym_uplo(A.uplo))
end

###############################################################################
# nonlinear function fallbacks for JuMP built-in types
###############################################################################
Expand All @@ -625,6 +641,6 @@ Base.:*(lhs::GenericQuadExpr, rhs::GenericQuadExpr) =
Base.:*(::S, ::T) where {T <: GenericQuadExpr,
S <: Union{AbstractVariableRef, GenericAffExpr, GenericQuadExpr}} =
error( "*(::$S,::$T) is not defined. $op_hint")
Base.:/(::S, ::T) where {S <: Union{Number, AbstractVariableRef, GenericAffExpr, GenericQuadExpr},
Base.:/(::S, ::T) where {S <: Union{Constant, AbstractVariableRef, GenericAffExpr, GenericQuadExpr},
T <: Union{AbstractVariableRef, GenericAffExpr, GenericQuadExpr}} =
error( "/(::$S,::$T) is not defined. $op_hint")
Loading