diff --git a/roofit/hs3/src/RooJSONFactoryWSTool.cxx b/roofit/hs3/src/RooJSONFactoryWSTool.cxx index 20d87473e408a..d83fc4cf5fc87 100644 --- a/roofit/hs3/src/RooJSONFactoryWSTool.cxx +++ b/roofit/hs3/src/RooJSONFactoryWSTool.cxx @@ -307,6 +307,114 @@ void importAttributes(RooAbsArg *arg, JSONNode const &node) } } +void addIfPresent(RooArgSet &out, RooArgSet const *args) +{ + if (args) { + out.add(*args, true); + } +} + +void collectParameterStepWidthCandidatesFromModelConfigs(RooWorkspace const &workspace, RooArgSet &candidates, + RooArgSet &excluded) +{ + for (TObject *obj : workspace.allGenericObjects()) { + auto const *mc = dynamic_cast(obj); + if (!mc) { + continue; + } + + addIfPresent(candidates, mc->GetParametersOfInterest()); + addIfPresent(candidates, mc->GetNuisanceParameters()); + + addIfPresent(excluded, mc->GetObservables()); + addIfPresent(excluded, mc->GetGlobalObservables()); + addIfPresent(excluded, mc->GetConditionalObservables()); + } +} + +void collectParameterStepWidthCandidatesFromPdfs(std::vector const &pdfs, + std::vector const &data, RooArgSet &candidates, + RooArgSet &excluded) +{ + for (RooAbsPdf const *pdf : pdfs) { + RooArgSet observables; + for (RooAbsData const *dataset : data) { + std::unique_ptr pdfObs{pdf->getObservables(*dataset->get())}; + observables.add(*pdfObs, true); + } + + if (observables.empty()) { + continue; + } + + RooArgSet params; + pdf->getParameters(&observables, params); + candidates.add(params, true); + excluded.add(observables, true); + } +} + +void exportParameterStepWidths(RooWorkspace const &workspace, std::vector const &pdfs, + std::vector const &data, JSONNode &rootnode) +{ + RooArgSet candidates; + RooArgSet excluded; + + collectParameterStepWidthCandidatesFromModelConfigs(workspace, candidates, excluded); + collectParameterStepWidthCandidatesFromPdfs(pdfs, data, candidates, excluded); + + candidates.sort(); + + JSONNode *parameterStepWidthsNode = nullptr; + for (RooAbsArg *arg : candidates) { + if (excluded.find(*arg)) { + continue; + } + + auto *var = dynamic_cast(arg); + if (!var || !var->hasError()) { + continue; + } + + if (!parameterStepWidthsNode) { + parameterStepWidthsNode = &rootnode["misc"]["minimization"]["parameter_stepwidths"].set_seq(); + } + + JSONNode &stepWidthNode = RooJSONFactoryWSTool::appendNamedChild(*parameterStepWidthsNode, var->GetName()); + stepWidthNode["step_width"] << var->getError(); + } +} + +void importParameterStepWidths(RooWorkspace &workspace, JSONNode const &rootnode) +{ + auto const *parameterStepWidthsNode = rootnode.find("misc", "minimization", "parameter_stepwidths"); + if (!parameterStepWidthsNode) { + return; + } + if (!parameterStepWidthsNode->is_seq()) { + RooJSONFactoryWSTool::warning("RooFitHS3: misc.minimization.parameter_stepwidths is not a sequence, skipping."); + return; + } + + for (JSONNode const &stepWidthNode : parameterStepWidthsNode->children()) { + if (!stepWidthNode.is_map() || !stepWidthNode.has_child("name") || !stepWidthNode.has_child("step_width")) { + RooJSONFactoryWSTool::warning("RooFitHS3: skipping malformed parameter_stepwidths entry."); + continue; + } + + const std::string name = RooJSONFactoryWSTool::name(stepWidthNode); + RooAbsArg *arg = workspace.arg(name); + auto *var = dynamic_cast(arg); + if (!var) { + RooJSONFactoryWSTool::warning( + "RooFitHS3: skipping parameter_stepwidths entry for unknown or non-real variable '" + name + "'."); + continue; + } + + var->setError(stepWidthNode.find("step_width")->val_double()); + } +} + // RooWSFactoryTool expression handling std::string generate(const RooFit::JSONIO::ImportExpression &ex, const JSONNode &p, RooJSONFactoryWSTool *tool) { @@ -594,7 +702,7 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons for (const auto &d : datasets) { if (d->GetName() == nameNode.val()) { found = true; - observables.add(*d->get()); + observables.add(*d->get(), true); } } if (nameNode.val() != "0" && !found) @@ -758,7 +866,7 @@ void combineDatasets(const JSONNode &rootnode, std::vectorGetName() == componentName; }); if (!component) RooJSONFactoryWSTool::error("unable to obtain component matching component name '" + componentName + "'"); - allVars.add(*component->get()); + allVars.add(*component->get(), true); dsMap.insert({labels[iChannel], std::move(component)}); indexCat.defineType(labels[iChannel], indices[iChannel]); } @@ -1789,7 +1897,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooFit::M nllNode["data"].set_seq(); if (dataComponents) { - auto simPdf = static_cast(pdf); + auto simPdf = dynamic_cast(pdf); if (simPdf) { for (auto const &item : simPdf->indexCat()) { const auto &dataComp = dataComponents->find(item.first); @@ -1926,6 +2034,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n) } } + exportParameterStepWidths(_workspace, allpdfs, alldata, n); + for (auto *snsh : static_range_cast(_workspace.getSnapshots())) { RooArgSet snapshotSorted; // We only want to add the variables that actually got exported and skip @@ -2251,10 +2361,12 @@ bool RooJSONFactoryWSTool::importJSON(std::istream &is) { // import a JSON file to the workspace std::unique_ptr tree = JSONTree::create(is); - this->importAllNodes(tree->rootnode()); + JSONNode const &rootnode = tree->rootnode(); + this->importAllNodes(rootnode); if (this->workspace()->getSnapshot("default_values")) { this->workspace()->loadSnapshot("default_values"); } + importParameterStepWidths(*this->workspace(), rootnode); return true; } @@ -2287,7 +2399,9 @@ bool RooJSONFactoryWSTool::importYML(std::istream &is) { // import a YML file to the workspace std::unique_ptr tree = JSONTree::create(is); - this->importAllNodes(tree->rootnode()); + JSONNode const &rootnode = tree->rootnode(); + this->importAllNodes(rootnode); + importParameterStepWidths(*this->workspace(), rootnode); return true; } diff --git a/roofit/hs3/test/testRooFitHS3.cxx b/roofit/hs3/test/testRooFitHS3.cxx index 86481361a5812..d949326ccd43f 100644 --- a/roofit/hs3/test/testRooFitHS3.cxx +++ b/roofit/hs3/test/testRooFitHS3.cxx @@ -132,6 +132,20 @@ int validate(RooAbsArg const &arg, bool exact = true) return validate(ws, arg.GetName(), exact); } +std::string parameterStepWidthsNode(std::string const &json) +{ + const std::string key = "\"parameter_stepwidths\":["; + const auto begin = json.find(key); + if (begin == std::string::npos) { + return ""; + } + const auto end = json.find("]", begin); + if (end == std::string::npos) { + return ""; + } + return json.substr(begin, end - begin + 1); +} + } // namespace // Test that the IO of attributes and string attributes works. @@ -165,6 +179,128 @@ TEST(RooFitHS3, AttributesIO) EXPECT_STREQ(pdf.getStringAttribute("key1"), nullptr) << "unexpected string attribute found!"; } +TEST(RooFitHS3, ParameterStepWidthsModelConfigRoundTrip) +{ + RooWorkspace ws1{"workspace"}; + ws1.factory("Gaussian::sig(x[-5, 5], mu[0, -10, 10], sigma[1, 0.1, 10])"); + ws1.factory("Polynomial::bkg(x, {theta[0, -1, 1]})"); + ws1.factory("SUM::model(fsig[0.5, 0, 1] * sig, bkg)"); + + RooRealVar &x = *ws1.var("x"); + RooDataSet data{"data", "data", RooArgSet{x}}; + for (double val : {-1.0, 0.5, 1.5}) { + x.setVal(val); + data.add(RooArgSet{x}); + } + ws1.import(data); + + ws1.var("x")->setError(9.0); + ws1.var("mu")->setError(0.12); + ws1.var("theta")->setError(0.33); + ws1.var("sigma")->setError(0.20); + ws1.var("sigma")->setAsymError(-0.18, 0.25); + + RooFit::ModelConfig mc{"mc", &ws1}; + mc.SetPdf(*ws1.pdf("model")); + mc.SetObservables("x"); + mc.SetParametersOfInterest("mu"); + mc.SetNuisanceParameters("sigma"); + ws1.import(mc); + + const std::string json = RooJSONFactoryWSTool{ws1}.exportJSONtoString(); + const std::string parameterStepWidths = parameterStepWidthsNode(json); + ASSERT_FALSE(parameterStepWidths.empty()) << json; + EXPECT_NE(parameterStepWidths.find("\"name\":\"mu\""), std::string::npos) << parameterStepWidths; + EXPECT_NE(parameterStepWidths.find("\"name\":\"sigma\""), std::string::npos) << parameterStepWidths; + EXPECT_NE(parameterStepWidths.find("\"name\":\"theta\""), std::string::npos) << parameterStepWidths; + EXPECT_NE(parameterStepWidths.find("\"step_width\":0.12"), std::string::npos) << parameterStepWidths; + EXPECT_NE(parameterStepWidths.find("\"step_width\":0.2"), std::string::npos) << parameterStepWidths; + EXPECT_EQ(parameterStepWidths.find("\"error_lo\""), std::string::npos) << parameterStepWidths; + EXPECT_EQ(parameterStepWidths.find("\"error_hi\""), std::string::npos) << parameterStepWidths; + EXPECT_EQ(parameterStepWidths.find("\"name\":\"x\""), std::string::npos) << parameterStepWidths; + + RooWorkspace ws2{"workspace2"}; + ASSERT_TRUE(RooJSONFactoryWSTool{ws2}.importJSONfromString(json)); + + ASSERT_NE(ws2.var("mu"), nullptr); + ASSERT_NE(ws2.var("theta"), nullptr); + ASSERT_NE(ws2.var("sigma"), nullptr); + ASSERT_NE(ws2.var("x"), nullptr); + EXPECT_TRUE(ws2.var("mu")->hasError()); + EXPECT_DOUBLE_EQ(ws2.var("mu")->getError(), 0.12); + EXPECT_TRUE(ws2.var("theta")->hasError()); + EXPECT_DOUBLE_EQ(ws2.var("theta")->getError(), 0.33); + EXPECT_TRUE(ws2.var("sigma")->hasError()); + EXPECT_DOUBLE_EQ(ws2.var("sigma")->getError(), 0.20); + EXPECT_FALSE(ws2.var("sigma")->hasAsymError()); + EXPECT_FALSE(ws2.var("x")->hasError()); +} + +TEST(RooFitHS3, ParameterStepWidthsFallbackExcludesDataAxes) +{ + RooWorkspace ws1{"workspace"}; + ws1.factory("Gaussian::model(x[-5, 5], mu[0, -10, 10], sigma[1, 0.1, 10])"); + + RooRealVar &x = *ws1.var("x"); + RooDataSet data{"data", "data", RooArgSet{x}}; + for (double val : {-1.0, 0.5, 1.5}) { + x.setVal(val); + data.add(RooArgSet{x}); + } + ws1.import(data); + + ws1.var("x")->setError(9.0); + ws1.var("mu")->setError(0.12); + ws1.var("sigma")->setError(0.20); + + const std::string json = RooJSONFactoryWSTool{ws1}.exportJSONtoString(); + const std::string parameterStepWidths = parameterStepWidthsNode(json); + ASSERT_FALSE(parameterStepWidths.empty()) << json; + EXPECT_NE(parameterStepWidths.find("\"name\":\"mu\""), std::string::npos) << parameterStepWidths; + EXPECT_NE(parameterStepWidths.find("\"name\":\"sigma\""), std::string::npos) << parameterStepWidths; + EXPECT_EQ(parameterStepWidths.find("\"name\":\"x\""), std::string::npos) << parameterStepWidths; + + RooWorkspace ws2{"workspace2"}; + ASSERT_TRUE(RooJSONFactoryWSTool{ws2}.importJSONfromString(json)); + + ASSERT_NE(ws2.var("mu"), nullptr); + ASSERT_NE(ws2.var("sigma"), nullptr); + ASSERT_NE(ws2.var("x"), nullptr); + EXPECT_DOUBLE_EQ(ws2.var("mu")->getError(), 0.12); + EXPECT_DOUBLE_EQ(ws2.var("sigma")->getError(), 0.20); + EXPECT_FALSE(ws2.var("x")->hasError()); +} + +TEST(RooFitHS3, ParameterStepWidthsImportAfterDefaultSnapshot) +{ + const std::string json = R"({ + "metadata": {"hs3_version": "0.1.90"}, + "parameter_points": [ + { + "name": "default_values", + "parameters": [ + {"name": "mu", "value": 0.0, "err": 0.01} + ] + } + ], + "misc": { + "minimization": { + "parameter_stepwidths": [ + {"name": "mu", "step_width": 0.42}, + {"name": "missing", "step_width": 1.0} + ] + } + } + })"; + + RooWorkspace ws{"workspace"}; + ASSERT_TRUE(RooJSONFactoryWSTool{ws}.importJSONfromString(json)); + + ASSERT_NE(ws.var("mu"), nullptr); + EXPECT_TRUE(ws.var("mu")->hasError()); + EXPECT_DOUBLE_EQ(ws.var("mu")->getError(), 0.42); +} + TEST(RooFitHS3, RooAddPdf) { int status = validate({"Gaussian::sig(x[5.20, 5.30], sigmean[5.28, 5.20, 5.30], sigwidth[0.0027, 0.001, 1.])",