using namespace TMVA::Experimental; %%cpp -d void train(const std::string &filename) { // Create factory auto output = TFile::Open("TMVARR.root", "RECREATE"); auto factory = new TMVA::Factory("tmva003", output, "!V:!DrawProgressBar:AnalysisType=Classification"); // Open trees with signal and background events auto data = TFile::Open(filename.c_str()); auto signal = (TTree *)data->Get("TreeS"); auto background = (TTree *)data->Get("TreeB"); // Add variables and register the trees with the dataloader auto dataloader = new TMVA::DataLoader("tmva003_BDT"); const std::vector variables = {"var1", "var2", "var3", "var4"}; for (const auto &var : variables) { dataloader->AddVariable(var); } dataloader->AddSignalTree(signal, 1.0); dataloader->AddBackgroundTree(background, 1.0); dataloader->PrepareTrainingAndTestTree("", ""); // Train a TMVA method factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2"); factory->TrainAllMethods(); } const std::string filename = "http://root.cern/files/tmva_class_example.root"; train(filename); RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml"); auto variables = model.GetVariableNames(); auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5}); std::cout << "Single-event inference: " << prediction[0] << "\n\n"; ROOT::RDataFrame df("TreeS", filename); auto df2 = df.Range(3); // Read only a small subset of the dataset auto x = AsTensor(df2, variables); auto y = model.Compute(x); std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n"; std::cout << "Prediction performed on multiple events: " << y << "\n\n"; auto make_histo = [&](const std::string &treename) { ROOT::RDataFrame df(treename, filename); auto df2 = df.Define("y", Compute<4, float>(model), variables); return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y"); }; auto sig = make_histo("TreeS"); auto bkg = make_histo("TreeB"); gStyle->SetOptStat(0); auto c = new TCanvas("", "", 800, 800); sig->SetLineColor(kRed); bkg->SetLineColor(kBlue); sig->SetLineWidth(2); bkg->SetLineWidth(2); bkg->Draw("HIST"); sig->Draw("HIST SAME"); TLegend legend(0.7, 0.7, 0.89, 0.89); legend.SetBorderSize(0); legend.AddEntry("TreeS", "Signal", "l"); legend.AddEntry("TreeB", "Background", "l"); legend.Draw(); c->DrawClone();