Pull to refresh

Машинное обучение с помощью TMVA (ROOT)

Reading time4 min
Views4.2K

В последние пару лет только и слышно о том, что Python и scikit-learn являются неким золотым стандартом в data science.
Я же хочу рассказать Вам о возможности альтернативного развития в области machine learning, библиотеке написанной на С++.
TMVA (Toolkit for Multivariate Data Analysis with ROOT) — open-source библиотека алгоритмов машинного обучения, которая идёт в дополнение к пакету анализа больших данных ROOT, соответственно устанавливается вместе с ним. Про установку подробно написано в мануале, поэтому мы не будем рассматривать этот момент.
Основным сайтом проекта до недавнего времени считался TMVA, но, как мы видим, на нём уже давненько не было никаких обновлений. Это не повод для скепсиса и паники, т.к. теперь его развитием занимается новая команда церновских разработчиков.
CERN (Европейская организация по ядерным исследованиям) была первопроходцем в создании ПО для анализа больших объёмов данных. Именно там была разработана объектно-ориентированная библиотека ROOT, которая нашла применение не только в мире физики.
В ROOT'е данные хранятся в очень экономичном формате *.root, но можно работать и с любым текстовым форматом. Для простоты используем при работе с TMVA обычный тектовый формат csv/txt.
К сожалению, на данный момент, в TMVA используются только алгоритмы обучения с учителем.


Примеры графиков в TMVA

Так выглядят корреляционные матрицы в TMVA:
image
Соотношения фичей:
image
Ro-curve выглядит нестандартно:
image


Итак, представим, что у нас уже установлен ROOT и есть 2 текстовых файла: с "хорошими" и теми, кого нужно классифицировать (либо построить регрессию для прогнозирования). Для того, чтобы подать как инпут эти 2 файла, необходимо привести заголовок файла к необходимому формату:
id/F:Param1/I:Param2/I:Param3/F


Типичный пример входного формата текстового файла

id/F:Param1/I:Param2/I:Param3/F
2,59,1,0
3,85,0,44
4,39,0,78
...


В TMVA 2 типа данных: Float и Integer (в Reader'e только float)
В качестве разделителя переменных по умолчанию идёт знак запятой.
Ознакомиться со списком алгоритмов можно в User Guide


Давайте перейдём к коду.


#include "TMVA/Types.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
using std::cout;
//For Reader
std::string outputListFileName;

void Model_BDT()
{
std::cout << std::endl;

std::cout << "===> Start TMVAClassification" << std::endl;
//Создаём выходное ROOT-дерево, в котором у нас будет содержаться информация о модели: корреляционные матрицы, корреляции параметров, RO-кривые)
TFile* outFputFile = new TFile("Model.root", "RECREATE"); 
//Создаём для записи файл построения модели (генерируемый методом MakeClass, он будет расположен в директории weights, вместе с xml файлом
TMVA::Factory *factory = new TMVA::Factory("TMVAClassification_Model",outFputFile,"V:!Silent:Color:Transformations=I:DrawProgressBar:AnalysisType=Classification");

//читаем файлы
TString sigFile="Signal.csv";
TString bkgFile ="Background.csv";

cout << ">>>> Adding variables phase\n";
factory->AddVariable("Param1",'I');
factory->AddVariable("Param2",'I');
factory->AddVariable("Param3",'F');
//Id в моём случае будет просто проверочной переменной
factory->AddSpectator("id", 'F');
Double_t sigWeight = 1.0; // overall weight for all signal events
Double_t bkgWeight = 1.0; // overall weight for all background events
factory->SetInputTrees( sigFile, bkgFile, sigWeight, bkgWeight );

cout << ">>>> Cutting\n";
//Отбираем значения для параметра Param1 и Param3;может пригодиться если данные с каким-то шумом
TCut preselectionCut("Param1 > 0. && Param3<350.0");
TCut mycutS = "";
//Можем взять каждое n-ое событие в Background,если данных очень много, а ноутбук не тянет
TCut mycutB = "id%100==0";
//Задаём объём тренировочного и тестового дерева
factory->PrepareTrainingAndTestTree(mycutS, mycutB, "nTrain_Signal=16000:nTest_Signal=1451:nTrain_Background=800000:nTest_Background=118416:VerboseLevel=Debug");

//Выбираем модель Boosted Decision and Regression Trees, вводим параметры
factory->BookMethod(TMVA::Types::kBDT, "BDT", "MaxDepth=5:NTrees=2000:MinNodeSize=9%:PruneStrength=10:SeparationType=GiniIndex");
//Выводим help для метода
factory->PrintHelpMessage("BDT");
//тренируем,тестируем и оцениваем модель
cout << ">>>> doing TrainAllMethods\n";
factory->TrainAllMethods();
cout << ">>>> doing TestAllMethods\n";
factory->TestAllMethods();
cout << ">>>> doing EvaluateAllMethods\n";
factory->EvaluateAllMethods();

 // Save the output
   outFputFile->Close();

   std::cout << "===> Wrote root file: " << outFputFile->GetName() << std::endl;
   std::cout << "===> TMVAClassification is done!" << std::endl;

   delete factory;
}

Запустить макрос можно командой из терминала "root Model_BDT.C".
После того, как всё досчитается, в консоли можно открыть ROOT-браузер, командой "TBrowser b;" и полюбоваться множеством симпатичных графиков.
В следующей статье я хочу рассказать про то, как написать Reader модели, который позволяет применять полученную модель на любых других данных и выгрузить отскоренный массив с определённой отсечкой скор-балла.

Tags:
Hubs:
+1
Comments5

Articles