%%cpp -d #include #include #include #include #include "TChain.h" #include "TFile.h" #include "TTree.h" #include "TString.h" #include "TObjString.h" #include "TSystem.h" #include "TROOT.h" #include "TMVA/CrossValidation.h" #include "TMVA/DataLoader.h" #include "TMVA/Factory.h" #include "TMVA/Tools.h" #include "TMVA/TMVAGui.h" %%cpp -d TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100) { TRandom3 rng(seed); Float_t x = 0; Float_t y = 0; UInt_t eventID = 0; TTree *data = new TTree(); data->Branch("x", &x, "x/F"); data->Branch("y", &y, "y/F"); data->Branch("eventID", &eventID, "eventID/I"); for (Int_t n = 0; n < nPoints; ++n) { x = rng.Gaus(offset, scale); y = rng.Gaus(offset, scale); // For our simple example it is enough that the id's are uniformly // distributed and independent of the data. ++eventID; data->Fill(); } // Important: Disconnects the tree from the memory locations of x and y. data->ResetBranchAddresses(); return data; } bool useRandomSplitting = false; TMVA::Tools::Instance(); TTree *sigTree = genTree(1000, 1.0, 1.0, 100); TTree *bkgTree = genTree(1000, -1.0, 1.0, 101); TString outfileName("TMVACV.root"); TFile *outputFile = TFile::Open(outfileName, "RECREATE"); TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetcv"); dataloader->AddVariable("x", 'F'); dataloader->AddVariable("y", 'F'); dataloader->AddSpectator("eventID", 'I'); dataloader->AddSignalTree(sigTree, 1.0); dataloader->AddBackgroundTree(bkgTree, 1.0); dataloader->PrepareTrainingAndTestTree("", "", "nTest_Signal=1" ":nTest_Background=1" ":SplitMode=Random" ":NormMode=NumEvents" ":!V"); UInt_t numFolds = 2; TString analysisType = "Classification"; TString splitType = (useRandomSplitting) ? "Random" : "Deterministic"; TString splitExpr = (!useRandomSplitting) ? "int(fabs([eventID]))%int([NumFolds])" : ""; TString cvOptions = Form("!V" ":!Silent" ":ModelPersistence" ":AnalysisType=%s" ":SplitType=%s" ":NumFolds=%i" ":SplitExpr=%s", analysisType.Data(), splitType.Data(), numFolds, splitExpr.Data()); TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions}; cv.BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad" ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20" ":MaxDepth=2"); cv.BookMethod(TMVA::Types::kFisher, "Fisher", "!H:!V:Fisher:VarTransform=None"); cv.Evaluate(); size_t iMethod = 0; for (auto && result : cv.GetResults()) { std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue("MethodName") << std::endl; for (UInt_t iFold = 0; iFoldClose(); std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl; std::cout << "==> TMVACrossValidation is done!" << std::endl; if (!gROOT->IsBatch()) { // Draw cv-specific graphs cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDTG"); cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for Fisher"); // You can also use the classical gui TMVA::TMVAGui(outfileName); } return 0;