diff --git a/roofit/codegen/src/CodegenImpl.cxx b/roofit/codegen/src/CodegenImpl.cxx index ca79edb1b1839..9e9742b859e5c 100644 --- a/roofit/codegen/src/CodegenImpl.cxx +++ b/roofit/codegen/src/CodegenImpl.cxx @@ -137,10 +137,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx) { - if (arg.cache()._isRearranged) { - ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen)); + if (arg.isRearranged()) { + ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.rearrangedNum(), *arg.rearrangedDen())); } else { - ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), arg.cache()._partList, arg.cache()._partList.size())); + ctx.addResult(&arg, ctx.buildCall(mathFunc("product"), *arg.partList(), arg.partList()->size())); } } diff --git a/roofit/roofitcore/inc/RooProdPdf.h b/roofit/roofitcore/inc/RooProdPdf.h index f6b87144130e5..627285f752e86 100644 --- a/roofit/roofitcore/inc/RooProdPdf.h +++ b/roofit/roofitcore/inc/RooProdPdf.h @@ -177,9 +177,7 @@ class RooProdPdf : public RooAbsPdf { void rearrangeProduct(CacheElem&) const; std::unique_ptr specializeIntegral(RooAbsReal& orig, const char* targetRangeName) const ; std::unique_ptr specializeRatio(RooFormulaVar& input, const char* targetRangeName) const ; - double calculate(const RooProdPdf::CacheElem& cache, bool verbose=false) const ; - void doEvalImpl(RooAbsArg const* caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &) const; - + double calculate(const RooProdPdf::CacheElem &cache, bool verbose = false) const; friend class RooProdGenContext ; friend class RooFit::Detail::RooFixedProdPdf ; @@ -202,15 +200,10 @@ class RooProdPdf : public RooAbsPdf { bool _selfNorm = true; ///< Is self-normalized RooArgSet _defNormSet ; ///< Default normalization set -private: - - - ClassDefOverride(RooProdPdf,6) // PDF representing a product of PDFs }; -namespace RooFit { -namespace Detail { +namespace RooFit::Detail { /// A RooProdPdf with a fixed normalization set can be replaced by this class. /// Its purpose is to provide the right client-server interface for the @@ -227,7 +220,7 @@ class RooFixedProdPdf : public RooAbsPdf { inline bool canComputeBatchWithCuda() const override { return true; } - inline void doEval(RooFit::EvalContext &ctx) const override { _prodPdf->doEvalImpl(this, *_cache, ctx); } + void doEval(RooFit::EvalContext &ctx) const override; inline ExtendMode extendMode() const override { return _prodPdf->extendMode(); } inline double expectedEvents(const RooArgSet * /*nset*/) const override @@ -260,22 +253,30 @@ class RooFixedProdPdf : public RooAbsPdf { return _prodPdf->analyticalIntegral(code, rangeName); } - RooProdPdf::CacheElem const &cache() const { return *_cache; } + bool isRearranged() const { return _isRearranged; } -private: - void initialize(); + RooAbsReal const *rearrangedNum() const + { + return _isRearranged ? static_cast(_servers[0]) : nullptr; + } + RooAbsReal const *rearrangedDen() const + { + return _isRearranged ? static_cast(_servers[1]) : nullptr; + } + + RooArgSet const *partList() const { return !_isRearranged ? static_cast(&_servers) : nullptr; } - inline double evaluate() const override { return _prodPdf->calculate(*_cache); } +private: + double evaluate() const override; RooArgSet _normSet; - std::unique_ptr _cache; RooSetProxy _servers; std::unique_ptr _prodPdf; + bool _isRearranged = false; ClassDefOverride(RooFit::Detail::RooFixedProdPdf, 0); }; -} // namespace Detail -} // namespace RooFit +} // namespace RooFit::Detail #endif diff --git a/roofit/roofitcore/src/RooProdPdf.cxx b/roofit/roofitcore/src/RooProdPdf.cxx index edc9092f0e561..4fb01e3466d4e 100644 --- a/roofit/roofitcore/src/RooProdPdf.cxx +++ b/roofit/roofitcore/src/RooProdPdf.cxx @@ -409,26 +409,6 @@ double RooProdPdf::calculate(const RooProdPdf::CacheElem& cache, bool /*verbose* } } -//////////////////////////////////////////////////////////////////////////////// -/// Evaluate product of PDFs in batch mode. -void RooProdPdf::doEvalImpl(RooAbsArg const *caller, const RooProdPdf::CacheElem &cache, RooFit::EvalContext &ctx) const -{ - if (cache._isRearranged) { - auto numerator = ctx.at(cache._rearrangedNum.get()); - auto denominator = ctx.at(cache._rearrangedDen.get()); - RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator}); - } else { - std::vector> factors; - factors.reserve(cache._partList.size()); - for (const RooAbsArg *i : cache._partList) { - auto span = ctx.at(i); - factors.push_back(span); - } - std::array special{static_cast(factors.size())}; - RooBatchCompute::compute(ctx.config(caller), RooBatchCompute::ProdPdf, ctx.output(), factors, special); - } -} - namespace { template @@ -2185,8 +2165,7 @@ RooProdPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileC return fixedProdPdf; } -namespace RooFit { -namespace Detail { +namespace RooFit::Detail { RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr &&prodPdf, RooArgSet const &normSet) : RooAbsPdf(prodPdf->GetName(), prodPdf->GetTitle()), @@ -2194,35 +2173,76 @@ RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr &&prodPdf, RooArgSe _servers("!servers", "List of servers", this), _prodPdf{std::move(prodPdf)} { - initialize(); + auto cache = _prodPdf->createCacheElem(&_normSet, nullptr); + _isRearranged = cache->_isRearranged; + + // The actual servers for a given normalization set depend on whether the + // cache is rearranged or not. See RooProdPdf::calculate to see + // which args in the cache are used directly. + if (_isRearranged) { + _servers.add(*cache->_rearrangedNum); + _servers.add(*cache->_rearrangedDen); + addOwnedComponents(std::move(cache->_rearrangedNum)); + addOwnedComponents(std::move(cache->_rearrangedDen)); + return; + } + // We don't want to carry the full cache object around, so we let it go out + // of scope and transfer the ownership of the args that we actually need. + cache->_ownedList.releaseOwnership(); + std::vector> owned; + for (RooAbsArg *arg : cache->_ownedList) { + owned.emplace_back(arg); + } + for (RooAbsArg *arg : cache->_partList) { + _servers.add(*arg); + auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); }); + if (found != owned.end()) { + addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)])); + } + } } RooFixedProdPdf::RooFixedProdPdf(const RooFixedProdPdf &other, const char *name) : RooAbsPdf(other, name), _normSet{other._normSet}, - _servers("!servers", "List of servers", this), - _prodPdf{static_cast(other._prodPdf->Clone())} + _servers("!servers", this, other._servers), + _prodPdf{static_cast(other._prodPdf->Clone())}, + _isRearranged{other._isRearranged} +{ +} + +//////////////////////////////////////////////////////////////////////////////// +/// Evaluate product of PDFs in batch mode. + +void RooFixedProdPdf::doEval(RooFit::EvalContext &ctx) const { - initialize(); + if (_isRearranged) { + auto numerator = ctx.at(rearrangedNum()); + auto denominator = ctx.at(rearrangedDen()); + RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Ratio, ctx.output(), {numerator, denominator}); + return; + } + std::vector> factors; + factors.reserve(partList()->size()); + for (const RooAbsArg *arg : *partList()) { + auto span = ctx.at(arg); + factors.push_back(span); + } + std::array special{static_cast(factors.size())}; + RooBatchCompute::compute(ctx.config(this), RooBatchCompute::ProdPdf, ctx.output(), factors, special); } -void RooFixedProdPdf::initialize() +double RooFixedProdPdf::evaluate() const { - _cache = _prodPdf->createCacheElem(&_normSet, nullptr); - auto &cache = *_cache; + if (_isRearranged) { + return rearrangedNum()->getVal() / rearrangedDen()->getVal(); + } + double value = 1.0; - // The actual servers for a given normalization set depend on whether the - // cache is rearranged or not. See RooProdPdf::calculateBatch to see - // which args in the cache are used directly. - if (cache._isRearranged) { - _servers.add(*cache._rearrangedNum); - _servers.add(*cache._rearrangedDen); - } else { - for (std::size_t i = 0; i < cache._partList.size(); ++i) { - _servers.add(cache._partList[i]); - } + for (auto *arg : static_range_cast(*partList())) { + value *= arg->getVal(); } + return value; } -} // namespace Detail -} // namespace RooFit +} // namespace RooFit::Detail