Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions roofit/codegen/src/CodegenImpl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R

void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
{
if (arg.cache()._isRearranged) {
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen));
if (arg.isRearranged()) {
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.rearrangedNum(), *arg.rearrangedDen()));
} else {
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), arg.cache()._partList, arg.cache()._partList.size()));
ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), *arg.partList(), arg.partList()->size()));
}
}

Expand Down
35 changes: 18 additions & 17 deletions roofit/roofitcore/inc/RooProdPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ class RooProdPdf : public RooAbsPdf {
void rearrangeProduct(CacheElem&) const;
std::unique_ptr<RooAbsReal> specializeIntegral(RooAbsReal& orig, const char* targetRangeName) const ;
std::unique_ptr<RooAbsReal> specializeRatio(RooFormulaVar& input, const char* targetRangeName) const ;
double calculate(const RooProdPdf::CacheElem& cache, bool verbose=false) const ;
void doEvalImpl(RooAbsArg const* caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &) const;

double calculate(const RooProdPdf::CacheElem &cache, bool verbose = false) const;

friend class RooProdGenContext ;
friend class RooFit::Detail::RooFixedProdPdf ;
Expand All @@ -202,15 +200,10 @@ class RooProdPdf : public RooAbsPdf {
bool _selfNorm = true; ///< Is self-normalized
RooArgSet _defNormSet ; ///< Default normalization set

private:



ClassDefOverride(RooProdPdf,6) // PDF representing a product of PDFs
};

namespace RooFit {
namespace Detail {
namespace RooFit::Detail {

/// A RooProdPdf with a fixed normalization set can be replaced by this class.
/// Its purpose is to provide the right client-server interface for the
Expand All @@ -227,7 +220,7 @@ class RooFixedProdPdf : public RooAbsPdf {

inline bool canComputeBatchWithCuda() const override { return true; }

inline void doEval(RooFit::EvalContext &ctx) const override { _prodPdf->doEvalImpl(this, *_cache, ctx); }
void doEval(RooFit::EvalContext &ctx) const override;

inline ExtendMode extendMode() const override { return _prodPdf->extendMode(); }
inline double expectedEvents(const RooArgSet * /*nset*/) const override
Expand Down Expand Up @@ -260,22 +253,30 @@ class RooFixedProdPdf : public RooAbsPdf {
return _prodPdf->analyticalIntegral(code, rangeName);
}

RooProdPdf::CacheElem const &cache() const { return *_cache; }
bool isRearranged() const { return _isRearranged; }

private:
void initialize();
RooAbsReal const *rearrangedNum() const
{
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[0]) : nullptr;
}
RooAbsReal const *rearrangedDen() const
{
return _isRearranged ? static_cast<RooAbsReal const *>(_servers[1]) : nullptr;
}

RooArgSet const *partList() const { return !_isRearranged ? static_cast<RooArgSet const *>(&_servers) : nullptr; }

inline double evaluate() const override { return _prodPdf->calculate(*_cache); }
private:
double evaluate() const override;

RooArgSet _normSet;
std::unique_ptr<RooProdPdf::CacheElem> _cache;
RooSetProxy _servers;
std::unique_ptr<RooProdPdf> _prodPdf;
bool _isRearranged = false;

ClassDefOverride(RooFit::Detail::RooFixedProdPdf, 0);
};

} // namespace Detail
} // namespace RooFit
} // namespace RooFit::Detail

#endif
102 changes: 61 additions & 41 deletions roofit/roofitcore/src/RooProdPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -409,26 +409,6 @@ double RooProdPdf::calculate(const RooProdPdf::CacheElem& cache, bool /*verbose*
}
}

////////////////////////////////////////////////////////////////////////////////
/// Evaluate product of PDFs in batch mode.
void RooProdPdf::doEvalImpl(RooAbsArg const *caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &ctx) const
{
if (cache._isRearranged) {
auto numerator = ctx.at(cache._rearrangedNum.get());
auto denominator = ctx.at(cache._rearrangedDen.get());
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
} else {
std::vector<std::span<const double>> factors;
factors.reserve(cache._partList.size());
for (const RooAbsArg *i : cache._partList) {
auto span = ctx.at(i);
factors.push_back(span);
}
std::array<double, 1> special{static_cast<double>(factors.size())};
RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
}
}

namespace {

template<class T>
Expand Down Expand Up @@ -2185,44 +2165,84 @@ RooProdPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileC
return fixedProdPdf;
}

namespace RooFit {
namespace Detail {
namespace RooFit::Detail {

RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet)
: RooAbsPdf(prodPdf->GetName(), prodPdf->GetTitle()),
_normSet{normSet},
_servers("!servers", "List of servers", this),
_prodPdf{std::move(prodPdf)}
{
initialize();
auto cache = _prodPdf->createCacheElem(&_normSet, nullptr);
_isRearranged = cache->_isRearranged;

// The actual servers for a given normalization set depend on whether the
// cache is rearranged or not. See RooProdPdf::calculate to see
// which args in the cache are used directly.
if (_isRearranged) {
_servers.add(*cache->_rearrangedNum);
_servers.add(*cache->_rearrangedDen);
addOwnedComponents(std::move(cache->_rearrangedNum));
addOwnedComponents(std::move(cache->_rearrangedDen));
return;
}
// We don't want to carry the full cache object around, so we let it go out
// of scope and transfer the ownership of the args that we actually need.
cache->_ownedList.releaseOwnership();
std::vector<std::unique_ptr<RooAbsArg>> owned;
for (RooAbsArg *arg : cache->_ownedList) {
owned.emplace_back(arg);
}
for (RooAbsArg *arg : cache->_partList) {
_servers.add(*arg);
auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); });
if (found != owned.end()) {
addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)]));
}
}
}

RooFixedProdPdf::RooFixedProdPdf(const RooFixedProdPdf &other, const char *name)
: RooAbsPdf(other, name),
_normSet{other._normSet},
_servers("!servers", "List of servers", this),
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())}
_servers("!servers", this, other._servers),
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())},
_isRearranged{other._isRearranged}
{
}

////////////////////////////////////////////////////////////////////////////////
/// Evaluate product of PDFs in batch mode.

void RooFixedProdPdf::doEval(RooFit::EvalContext &ctx) const
{
initialize();
if (_isRearranged) {
auto numerator = ctx.at(rearrangedNum());
auto denominator = ctx.at(rearrangedDen());
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator});
return;
}
std::vector<std::span<const double>> factors;
factors.reserve(partList()->size());
for (const RooAbsArg *arg : *partList()) {
auto span = ctx.at(arg);
factors.push_back(span);
}
std::array<double, 1> special{static_cast<double>(factors.size())};
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::ProdPdf, ctx.output(), factors, special);
}

void RooFixedProdPdf::initialize()
double RooFixedProdPdf::evaluate() const
{
_cache = _prodPdf->createCacheElem(&_normSet, nullptr);
auto &cache = *_cache;
if (_isRearranged) {
return rearrangedNum()->getVal() / rearrangedDen()->getVal();
}
double value = 1.0;

// The actual servers for a given normalization set depend on whether the
// cache is rearranged or not. See RooProdPdf::calculateBatch to see
// which args in the cache are used directly.
if (cache._isRearranged) {
_servers.add(*cache._rearrangedNum);
_servers.add(*cache._rearrangedDen);
} else {
for (std::size_t i = 0; i < cache._partList.size(); ++i) {
_servers.add(cache._partList[i]);
}
for (auto *arg : static_range_cast<RooAbsReal *>(*partList())) {
value *= arg->getVal();
}
return value;
}

} // namespace Detail
} // namespace RooFit
} // namespace RooFit::Detail
Loading