Skip to content

Commit 7c5b819

Browse files
committed
[RF] Handle RooProdPdf integrals in CPU backend without hidden caches
1 parent 71771d6 commit 7c5b819

File tree

4 files changed

+60
-25
lines changed

4 files changed

+60
-25
lines changed

roofit/roofitcore/inc/RooProdPdf.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ class RooProdPdf : public RooAbsPdf {
102102

103103
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;
104104

105+
std::unique_ptr<RooAbsArg> compileToFixedProdPdf(RooArgSet const &normSet, RooArgSet const *intSet,
106+
const char *rangeName, RooFit::Detail::CompileContext &ctx,
107+
bool compileServers) const;
108+
105109
// The cache object. Internal, do not use.
106110
class CacheElem final : public RooAbsCacheElement {
107111
public:
@@ -211,7 +215,8 @@ namespace RooFit::Detail {
211215
/// normalization set.
212216
class RooFixedProdPdf : public RooAbsPdf {
213217
public:
214-
RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet);
218+
RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet, RooArgSet const *intSet,
219+
const char *rangeName);
215220
RooFixedProdPdf(const RooFixedProdPdf &other, const char *name = nullptr);
216221

217222
inline TObject *clone(const char *newname) const override { return new RooFixedProdPdf(*this, newname); }
@@ -269,10 +274,12 @@ class RooFixedProdPdf : public RooAbsPdf {
269274
private:
270275
double evaluate() const override;
271276

277+
RooArgSet _intSet;
272278
RooArgSet _normSet;
273279
RooSetProxy _servers;
274280
std::unique_ptr<RooProdPdf> _prodPdf;
275281
bool _isRearranged = false;
282+
std::string _rangeName;
276283

277284
ClassDefOverride(RooFit::Detail::RooFixedProdPdf, 0);
278285
};

roofit/roofitcore/src/RooProdPdf.cxx

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,34 +2147,53 @@ RooProdPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileC
21472147
}
21482148
}
21492149

2150+
return compileToFixedProdPdf(normSet, nullptr, nullptr, ctx, true);
2151+
}
2152+
2153+
std::unique_ptr<RooAbsArg> RooProdPdf::compileToFixedProdPdf(RooArgSet const &normSet, RooArgSet const *intSet,
2154+
const char *rangeName, RooFit::Detail::CompileContext &ctx,
2155+
bool compileServers) const
2156+
{
21502157
std::unique_ptr<RooProdPdf> prodPdfClone{static_cast<RooProdPdf *>(this->Clone())};
21512158
ctx.markAsCompiled(*prodPdfClone);
21522159

2153-
for (const auto server : prodPdfClone->servers()) {
2154-
auto nsetForServer = fillNormSetForServer(normSet, *server);
2155-
RooArgSet const &nset = nsetForServer ? *nsetForServer : normSet;
2160+
if (compileServers)
2161+
for (const auto server : prodPdfClone->servers()) {
2162+
auto nsetForServer = fillNormSetForServer(normSet, *server);
2163+
RooArgSet const &nset = nsetForServer ? *nsetForServer : normSet;
21562164

2157-
RooArgSet depList;
2158-
server->getObservables(&nset, depList);
2165+
RooArgSet depList;
2166+
server->getObservables(&nset, depList);
21592167

2160-
ctx.compileServer(*server, *prodPdfClone, depList);
2161-
}
2168+
ctx.compileServer(*server, *prodPdfClone, depList);
2169+
}
21622170

2163-
auto fixedProdPdf = std::make_unique<RooFit::Detail::RooFixedProdPdf>(std::move(prodPdfClone), normSet);
2164-
ctx.markAsCompiled(*fixedProdPdf);
2171+
auto fixedProdPdf =
2172+
std::make_unique<RooFit::Detail::RooFixedProdPdf>(std::move(prodPdfClone), normSet, intSet, rangeName);
2173+
if (compileServers)
2174+
ctx.markAsCompiled(*fixedProdPdf);
2175+
else
2176+
ctx.compileServers(*fixedProdPdf, normSet);
21652177

21662178
return fixedProdPdf;
21672179
}
21682180

21692181
namespace RooFit::Detail {
21702182

2171-
RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet)
2183+
RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSet const &normSet,
2184+
RooArgSet const *intSet, const char *rangeName)
21722185
: RooAbsPdf(prodPdf->GetName(), prodPdf->GetTitle()),
21732186
_normSet{normSet},
21742187
_servers("!servers", "List of servers", this),
2175-
_prodPdf{std::move(prodPdf)}
2188+
_prodPdf{std::move(prodPdf)},
2189+
_rangeName{rangeName ? rangeName : ""}
21762190
{
2177-
auto cache = _prodPdf->createCacheElem(&_normSet, nullptr);
2191+
if (intSet) {
2192+
_intSet.add(*intSet);
2193+
}
2194+
2195+
RooArgSet const *iset = _intSet.empty() ? nullptr : &_intSet;
2196+
auto cache = _prodPdf->createCacheElem(&_normSet, iset, _rangeName.empty() ? nullptr : _rangeName.c_str());
21782197
_isRearranged = cache->_isRearranged;
21792198

21802199
// The actual servers for a given normalization set depend on whether the
@@ -2190,25 +2209,27 @@ RooFixedProdPdf::RooFixedProdPdf(std::unique_ptr<RooProdPdf> &&prodPdf, RooArgSe
21902209
// We don't want to carry the full cache object around, so we let it go out
21912210
// of scope and transfer the ownership of the args that we actually need.
21922211
cache->_ownedList.releaseOwnership();
2193-
std::vector<std::unique_ptr<RooAbsArg>> owned;
2194-
for (RooAbsArg *arg : cache->_ownedList) {
2195-
owned.emplace_back(arg);
2196-
}
2212+
// std::vector<std::unique_ptr<RooAbsArg>> owned;
2213+
// for (RooAbsArg *arg : cache->_ownedList) {
2214+
// owned.emplace_back(arg);
2215+
//}
21972216
for (RooAbsArg *arg : cache->_partList) {
21982217
_servers.add(*arg);
2199-
auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); });
2200-
if (found != owned.end()) {
2201-
addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)]));
2202-
}
2218+
// auto found = std::find_if(owned.begin(), owned.end(), [&](auto const &ptr) { return arg == ptr.get(); });
2219+
// if (found != owned.end()) {
2220+
// addOwnedComponents(std::move(owned[std::distance(owned.begin(), found)]));
2221+
//}
22032222
}
22042223
}
22052224

22062225
RooFixedProdPdf::RooFixedProdPdf(const RooFixedProdPdf &other, const char *name)
22072226
: RooAbsPdf(other, name),
2227+
_intSet{other._intSet},
22082228
_normSet{other._normSet},
22092229
_servers("!servers", this, other._servers),
22102230
_prodPdf{static_cast<RooProdPdf *>(other._prodPdf->Clone())},
2211-
_isRearranged{other._isRearranged}
2231+
_isRearranged{other._isRearranged},
2232+
_rangeName{other._rangeName}
22122233
{
22132234
}
22142235

roofit/roofitcore/src/RooProjectedPdf.cxx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,9 @@ RooProjectedPdf::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::Com
286286

287287
auto newArgPdf = std::make_unique<RooWrapperPdf>(namePdf.c_str(), namePdf.c_str(), *newArg);
288288

289-
ctx.markAsCompiled(*newArg);
290-
ctx.markAsCompiled(*newArgPdf);
289+
ctx.compileServers(*newArgPdf, normSet);
291290

292-
newArgPdf->addOwnedComponents(std::move(newArg));
291+
ctx.markAsCompiled(*newArgPdf);
293292

294293
return newArgPdf;
295294
}

roofit/roofitcore/src/RooRealIntegral.cxx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ integration is performed in the various implementations of the RooAbsIntegrator
4040
#include <RooNameReg.h>
4141
#include <RooNumIntConfig.h>
4242
#include <RooNumIntFactory.h>
43+
#include <RooProdPdf.h>
4344
#include <RooRealBinding.h>
4445
#include <RooSuperCategory.h>
4546
#include <RooTrace.h>
@@ -1114,6 +1115,13 @@ Int_t RooRealIntegral::getCacheAllNumeric()
11141115
std::unique_ptr<RooAbsArg>
11151116
RooRealIntegral::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext &ctx) const
11161117
{
1118+
if (auto *prodPdf = dynamic_cast<RooProdPdf *>(&*_function)) {
1119+
getVal();
1120+
auto out = prodPdf->compileToFixedProdPdf(_funcNormSet ? *_funcNormSet : normSet, &_anaList,
1121+
RooNameReg::str(_rangeName), ctx, false);
1122+
out->SetName(GetName());
1123+
return out;
1124+
}
11171125
return RooAbsReal::compileForNormSet(_funcNormSet ? *_funcNormSet : normSet, ctx);
11181126
}
11191127

0 commit comments

Comments
 (0)