#include <sstream>
#include "RVersion.h"
#if ROOT_VERSION_CODE < ROOT_VERSION(6,07,00)
#define private public
#define protected public
#include "TMVA/Factory.h"
#include "TMVA/DataSetInfo.h"
#undef private
#undef protected
#else
#define private public
#define protected public
#include "TMVA/Factory.h"
#include "TMVA/DataSetInfo.h"
#include "TMVA/DataLoader.h"
#undef private
#undef protected
#endif
#include "QFramework/TQMVA.h"
#include "QFramework/TQIterator.h"
#include "TFile.h"
#include "QFramework/TQUtils.h"
#include "QFramework/TQLibrary.h"
ClassImp(TQMVA)
TQMVA::TQMVA() :
TQNamedTaggable("TQMVA")
{
init();
}
TQMVA::TQMVA(const TString& name_) :
TQNamedTaggable(TQStringUtils::replace(name_,"-","_"))
{
init();
}
TQMVA::TQMVA(TQSampleFolder* sf) :
TQNamedTaggable("TQMVA"),
fSampleFolder(sf)
{
init();
}
TQMVA::TQMVA(const TString& name_, TQSampleFolder* sf) :
TQNamedTaggable(TQStringUtils::replace(name_,"-","_")),
fSampleFolder(sf)
{
init();
}
void TQMVA::init(){
this->SetTitle("TQMVA");
#ifndef LEGACY_INTERFACE_PRE_607
this->fDataLoader = new TMVA::DataLoader("TQMVADataLoader");
#endif
}
TQMVA::~TQMVA(){
#ifndef LEGACY_INTERFACE_PRE_607
delete this->fDataLoader;
#endif
}
void TQMVA::printListOfVariables() const {
for(size_t i=0; i<this->fNames.size(); i++){
std::cout << this->fNames[i] << "\t" << this->fExpressions[i] << std::endl;
}
}
TString TQMVA::getVariableExpression(const TString& var){
for(size_t i=0; i<this->fExpressions.size(); i++){
if(TQStringUtils::equal(var,this->fNames[i])){
return this->fExpressions[i];
}
}
return TQStringUtils::emptyString;
}
TString TQMVA::getVariableExpression(const char* var){
return this->getVariableExpression((TString)(var));
}
void TQMVA::setBaseCut(TQCut* cut){
this->fBaseCut = cut;
}
TQCut* TQMVA::getBaseCut(){
return this->fBaseCut;
}
TQCut* TQMVA::useCut(const TString& name){
this->fActiveCut = this->fBaseCut->getCut(name);
return this->fActiveCut;
}
TMVA::Factory* TQMVA::getFactory(){
return this->fMVA;
}
void TQMVA::setFactory(TMVA::Factory* mva){
this->fMVA = mva;
}
TQSampleFolder* TQMVA::getSampleFolder(){
return this->fSampleFolder;
}
void TQMVA::setSampleFolder(TQSampleFolder* sf){
this->fSampleFolder = sf;
}
void TQMVA::deleteFactory(){
if(this->fMVA) delete this->fMVA;
}
void TQMVA::closeOutputFile(){
if(this->fOutputFile){
this->fOutputFile->Close();
delete this->fOutputFile;
this->fOutputFile = NULL;
}
}
bool TQMVA::createFactory(const TString& filename, const TString& options){
if(!TQUtils::ensureDirectoryForFile(filename)){
TQLibrary::ERRORclass("unable to access directory for file '%s'",filename.Data());
return false;
}
this->fOutputFile = TFile::Open(filename,"RECREATE");
if(!fOutputFile || !fOutputFile->IsOpen()){
TQLibrary::ERRORclass("unable to open file '%s'",filename.Data());
if(fOutputFile) delete fOutputFile;
return false;
}
this->setTagString("outputFileName",filename);
this->fMVA = new TMVA::Factory(this->GetName(),fOutputFile,options.Data());
return true;
}
void TQMVA::addVariable(const TString& name, const TString& title, const TString& expression, const TString& unit, char vtype, double min, double max, bool spectator){
if(!this->fMVA){
throw std::runtime_error("unable to book variable without instance of TMVA!");
}
#ifdef LEGACY_INTERFACE_PRE_607
if(spectator) this->fMVA->AddSpectator(TString::Format("%s := %s",name.Data(),name.Data()),title,unit,min,max);
else this->fMVA->AddVariable(TString::Format("%s := %s",name.Data(),name.Data()),title,unit,vtype,min,max);
this->fNames.push_back(name);
this->fExpressions.push_back(expression);
std::vector<TMVA::VariableInfo>& variables = this->fMVA->DefaultDataSetInfo().GetVariableInfos();
for(size_t i=0; i<this->fNames.size(); i++){
if(TQStringUtils::equal(variables[i].GetInternalName(),name)){
variables[i].SetLabel(expression);
}
}
#else
if(spectator) this->fDataLoader->AddSpectator(TString::Format("%s := %s",name.Data(),name.Data()),title,unit,min,max);
else this->fDataLoader->AddVariable(TString::Format("%s := %s",name.Data(),name.Data()),title,unit,vtype,min,max);
this->fNames.push_back(name);
this->fExpressions.push_back(expression);
std::vector<TMVA::VariableInfo>& variables = this->fDataLoader->DefaultDataSetInfo().GetVariableInfos();
for(size_t i=0; i<this->fNames.size(); i++){
if(TQStringUtils::equal(variables[i].GetInternalName(),name)){
variables[i].SetLabel(expression);
}
}
#endif
}
void TQMVA::printInternalVariables() const {
#ifdef LEGACY_INTERFACE_PRE_607
std::vector<TMVA::VariableInfo>& variables = this->fMVA->DefaultDataSetInfo().GetVariableInfos();
for(size_t i=0; i<variables.size(); i++){
std::cout << variables[i].GetInternalName() << ":" << variables[i].GetLabel() << ":" << variables[i].GetExpression() << " (" << variables[i].GetVarType() << "/" << variables[i].GetUnit() << ")" << std::endl;
}
#else
std::vector<TMVA::VariableInfo>& variables = this->fDataLoader->DefaultDataSetInfo().GetVariableInfos();
for(size_t i=0; i<variables.size(); i++){
std::cout << variables[i].GetInternalName() << ":" << variables[i].GetLabel() << ":" << variables[i].GetExpression() << " (" << variables[i].GetVarType() << "/" << variables[i].GetUnit() << ")" << std::endl;
}
#endif
}
void TQMVA::prepareTrees(const TString& options){
TCut cut;
#ifdef LEGACY_INTERFACE_PRE_607
if(this->fMVA) this->fMVA->PrepareTrainingAndTestTree(cut,options);
# else
if(this->fDataLoader) this->fDataLoader->PrepareTrainingAndTestTree(cut,options);
#endif
}
void TQMVA::bookVariable(const TString& name_, const TString& expression_, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,name,expression,"",'F',min,max,false);
}
void TQMVA::bookVariable(const TString& name_, const TString& expression_, const TString& title, const TString& unit, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,title,expression,unit,'F',min,max,false);
}
void TQMVA::bookVariable(const TString& name_, const TString& expression_, const TString& title, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,title,expression,"",'F',min,max,false);
}
void TQMVA::bookVariable(const char* name_, const char* expression_, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,name,expression,"",'F',min,max,false);
}
void TQMVA::bookVariable(const char* name_, const char* expression_, const char* title, const char* unit, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,title,expression,unit,'F',min,max,false);
}
void TQMVA::bookVariable(const char* name_, const char* expression_, const char* title, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression = (this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,title,expression,"",'F',min,max,false);
}
void TQMVA::bookSpectator(const TString& name_, const TString& expression_, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,name,expression,"",'F',min,max,true);
}
void TQMVA::bookSpectator(const TString& name_, const TString& expression_, const TString& title, const TString& unit, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,title,expression,unit,'F',min,max,true);
}
void TQMVA::bookSpectator(const TString& name_, const TString& expression_, const TString& title, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_) : expression_);
this->addVariable(name,title,expression,"",'F',min,max,true);
}
void TQMVA::bookSpectator(const char* name_, const char* expression_, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,name,expression,"",'F',min,max,true);
}
void TQMVA::bookSpectator(const char* name_, const char* expression_, const char* title, const char* unit, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression(this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,title,expression,unit,'F',min,max,true);
}
void TQMVA::bookSpectator(const char* name_, const char* expression_, const char* title, double min, double max){
TString name = TQFolder::makeValidIdentifier(name_);
TString expression = (this->fAliases ? this->fAliases->replaceInText(expression_).Data() : expression_);
this->addVariable(name,title,expression,"",'F',min,max,true);
}
void TQMVA::addSignal(const TString& path){
this->fSigPaths.push_back(path);
}
void TQMVA::addBackground(const TString& path){
this->fBkgPaths.push_back(path);
}
void TQMVA::clearSignal(){
this->fSigPaths.clear();
}
void TQMVA::clearBackground(){
this->fBkgPaths.clear();
}
int TQMVA::readSamples(){
return this->readSamples(TQMVA::EvenOddEventSelector());
}
int TQMVA::readSamples(const TQMVA::EventSelector& evtsel){
if(!this->fMVA){
ERRORclass("cannot initialze - no TMVA::Factory assigned!");
return 0;
}
if(!this->fBaseCut){
ERRORclass("cannot initialze - no TQCut assigned!");
return 0;
}
int retval = 0;
retval += this->readSamplesOfType(TQMVA::Signal,evtsel);
retval += this->readSamplesOfType(TQMVA::Background,evtsel);
return retval;
}
int TQMVA::readSamplesOfType(TQMVA::SampleType type){
return this->readSamplesOfType(type,TQMVA::EvenOddEventSelector());
}
int TQMVA::readSamplesOfType(TQMVA::SampleType type, const TQMVA::EventSelector& sel){
DEBUGclass("function called for '%s'",type == TQMVA::Signal ? "Signal" : "Background");
TQSampleIterator sitr(this->getListOfSamples(type),true);
int retval = 0;
while(sitr.hasNext()){
TQSample* s = sitr.readNext();
this->readSample(s, type,sel);
retval++;
}
return retval;
}
int TQMVA::readSample(TQSample* s, TQMVA::SampleType type, const TQMVA::EventSelector& sel){
if(fOutputFile) this->fOutputFile->cd();
if(!fActiveCut) fActiveCut = fBaseCut;
if(!s){
ERRORclass("sample is NULL");
return -1;
}
if(!fActiveCut){
ERRORclass("cannot read sample '%s' without active cut, please use TQMVA::setBaseCut(...) to set a base cut");
return -1;
}
DEBUGclass("reading sample '%s'",s->getPath().Data());
std::vector<double> vars(this->fExpressions.size());
TQToken* tok = s->getTreeToken();
if(!tok){
ERRORclass("unable to obtain tree token for sample '%s'",s->getPath().Data());
return -1;
}
TTree* t = (TTree*)(tok->getContent());
DEBUGclass("initializing cut '%s'",fBaseCut->GetName());
this->fBaseCut->initialize(s);
std::vector<TQObservable*> observables;
for(size_t i=0; i<this->fExpressions.size(); i++){
TQObservable* obs = TQObservable::getObservable(this->fExpressions[i],s);
if (!obs->initialize(s)) {
ERRORclass("Failed to initialize observable obtained from expression '%s' in TQMVA for sample '%s'",this->fExpressions[i].Data(),s->getPath().Data());
}
observables.push_back(obs);
}
t->SetBranchStatus("*", 0);
for(size_t i=0; i<this->fExpressions.size(); i++){
TQObservable* obs = observables.at(i);
DEBUGclass("activating branches for variable '%s'",this->fNames[i].Data());
TObjArray* bNames = obs->getBranchNames();
if (bNames) bNames->SetOwner(true);
#ifdef _DEBUG_
DEBUGclass("enabling the following branches:");
TQListUtils::printList(bNames);
#endif
TQListUtils::setBranchStatus(t,bNames,1);
delete bNames;
}
{
DEBUGclass("activating branches for cut '%s'",this->fBaseCut->GetName());
TObjArray* branchNames = this->fBaseCut->getListOfBranches();
branchNames->SetOwner(true);
#ifdef _DEBUG_
DEBUGclass("enabling the following branches:");
TQListUtils::printList(branchNames);
#endif
TQListUtils::setBranchStatus(t,branchNames,1);
}
int nTrainEvent = 0;
double sumWeightsTrain = 0;
int nTestEvent = 0;
double sumWeightsTest = 0;
int nEvent = 0;
DEBUGclass("entering event loop using cut '%s'",fActiveCut->GetName());
if(this->fVerbose) TQLibrary::msgStream.startProcessInfo(TQMessageStream::INFO,80,"r",TString::Format("%s: %s",(type == TQMVA::Signal ? "sig" : "bkg"),s->getPath().Data()));
TQCounter* cnt_train = NULL;
TQCounter* cnt_test = NULL;
TQCounter* cnt_total = NULL;
bool makeCounters = this->getTagBoolDefault("makeCounters",false);
if(makeCounters){
cnt_train = new TQCounter(TString::Format("TQMVA_%s_testing", this->GetName()),TString::Format("%s testing events", this->GetTitle()));
cnt_test = new TQCounter(TString::Format("TQMVA_%s_training",this->GetName()),TString::Format("%s training events",this->GetTitle()));
cnt_total = new TQCounter(TString::Format("TQMVA_%s_total", this->GetName()),TString::Format("%s testing+training events",this->GetTitle()));
}
#ifndef _DEBUG_
TQLibrary::redirect_stdout("/dev/null");
#endif
Long64_t nEntries = t->GetEntriesFast();
for(Long64_t iEvent = 0; iEvent < nEntries; ++iEvent){
t->GetEntry(iEvent);
#ifdef _DEBUG_
#endif
if(!this->fActiveCut->passedGlobally()){
continue;
}
nEvent++;
for(size_t i=0; i<observables.size(); ++i){
#ifdef _DEBUG_
try {
#endif
vars[i] = observables[i]->getValue();
#ifdef _DEBUG_
} catch(const std::exception& e){
BREAK("ERROR in TQMVA: observable '%s' with expression '%s' encountered error '%s'",observables[i]->GetName(),observables[i]->getActiveExpression().Data(),e.what());
}
#endif
}
double weight = this->fActiveCut->getGlobalWeight() * s->getNormalisation();
if(sel.selectEvent(iEvent)){
#ifdef LEGACY_INTERFACE_PRE_607
if(type==TQMVA::Signal) this->fMVA->AddSignalTrainingEvent(vars,weight);
else this->fMVA->AddBackgroundTrainingEvent(vars,weight);
#else
if(type==TQMVA::Signal) this->fDataLoader->AddSignalTrainingEvent(vars,weight);
else this->fDataLoader->AddBackgroundTrainingEvent(vars,weight);
#endif
if(makeCounters){
cnt_train->add(weight);
cnt_total->add(weight);
}
nTrainEvent++;
sumWeightsTrain += weight;
} else {
#ifdef LEGACY_INTERFACE_PRE_607
if(type==TQMVA::Signal) this->fMVA->AddSignalTestEvent(vars,weight);
else this->fMVA->AddBackgroundTestEvent(vars,weight);
#else
if(type==TQMVA::Signal) this->fDataLoader->AddSignalTestEvent(vars,weight);
else this->fDataLoader->AddBackgroundTestEvent(vars,weight);
#endif
if(makeCounters){
cnt_test->add(weight);
cnt_total->add(weight);
}
nTestEvent++;
sumWeightsTest += weight;
}
}
#ifndef _DEBUG_
TQLibrary::restore_stdout();
#endif
if(this->fVerbose){
if(nTrainEvent == 0 || nTestEvent == 0 || nEvent == 0){
TQLibrary::msgStream.endProcessInfo(TQMessageStream::WARN);
} else {
TQLibrary::msgStream.endProcessInfo(TQMessageStream::OK);
}
if(nEntries == 0){
WARNclass("this sample was empty (tree had no entries)");
} else if(nEvent == 0){
WARNclass("no events from this sample passed the cut '%s' (from a total of %lld events)!",this->fActiveCut->GetName(),nEntries);
#ifdef _DEBUG_
DEBUGclass("cut expression is as follows:");
this->fActiveCut->printActiveCutExpression();
#endif
} else {
if(nTrainEvent == 0){
WARNclass("event selector did not select any training events (from a total of %d selected events)!",nEvent);
}
if(nTestEvent == 0){
WARNclass("event selector did not select any testing events (from a total of %d selected events)!",nEvent);
}
}
if(this->fVerbose > 1){
INFO("number of read events: %d training (%.1f weighted), %d test (%.1f weighted)",nTrainEvent,sumWeightsTrain,nTestEvent,sumWeightsTest);
}
}
DEBUGclass("finalizing cut '%s'",fBaseCut->GetName());
this->fBaseCut->finalize();
for(size_t i=0; i<observables.size(); ++i){
DEBUGclass("finalizing variable '%s'",this->fNames[i].Data());
observables[i]->finalize();
}
if(makeCounters){
TQFolder* counters = s->getFolder(".cutflow+");
if(counters){
counters->addObject(cnt_train);
counters->addObject(cnt_test);
counters->addObject(cnt_total);
}
}
s->returnTreeToken(tok);
return nEvent;
}
void TQMVA::printListOfSamples(TQMVA::SampleType type){
std::vector<TString>* vec = (type == TQMVA::Signal ? &(this->fSigPaths) : &(this->fBkgPaths));
std::cout << TQStringUtils::makeBoldBlue(this->GetName()) << TQStringUtils::makeBoldWhite(": samples of type '") << TQStringUtils::makeBoldBlue(type == TQMVA::Signal ? "Signal" : "Background") << TQStringUtils::makeBoldWhite("'") << std::endl;
if(vec->size() < 1){
std::cout << TQStringUtils::makeBoldRed("<no paths listed>") << std::endl;
return;
}
for(size_t i=0; i<vec->size(); i++){
TString path(vec->at(i));
TQStringUtils::ensureTrailingText(path,"/*");
std::cout << "\t" << TQStringUtils::makeBoldWhite(path) << std::endl;
TList* l = this->fSampleFolder->getListOfSamples(path);
if(l && l->GetEntries() > 0){
TQSampleIterator itr(l,true);
while(itr.hasNext()){
std::cout << "\t\t";
TQSample* s = itr.readNext();
if(!s) std::cout << TQStringUtils::makeBoldRed("<NULL>");
else std::cout << s->getPath();
std::cout << std::endl;
}
} else {
if(l) delete l;
std::cout << "\t\t" << TQStringUtils::makeBoldRed("<no samples found under this path>" ) << std::endl;
}
}
}
TList* TQMVA::getListOfSamples(TQMVA::SampleType type){
std::vector<TString>* vec = (type == TQMVA::Signal ? &(this->fSigPaths) : &(this->fBkgPaths));
TList* retval = new TList();
for(size_t i=0; i<vec->size(); i++){
TString path(vec->at(i));
TQStringUtils::ensureTrailingText(path,"/*");
TQSampleIterator itr(this->fSampleFolder->getListOfSamples(path),true);
while(itr.hasNext()) {
TQSample* s = itr.readNext();
if (s && !s->hasSubSamples()) retval->Add(s);
}
}
return retval;
}
void TQMVA::setVerbose(int verbose){
this->fVerbose = verbose;
}
TQTaggable* TQMVA::getAliases(){
return this->fAliases;
}
void TQMVA::setAliases(TQTaggable* aliases){
this->fAliases = aliases;
}