ML functions
 
Loading...
Searching...
No Matches
XGBoost.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <dirent.h>
19#include <fcntl.h>
20#include <stdlib.h>
21#include <unistd.h>
22#include <xgboost/c_api.h>
23#include <cmath>
24#include <iostream>
25#include <memory>
26#include <string>
27#include "BaseFunction.h"
28#include "velox/exec/tests/utils/AssertQueryBuilder.h"
29#include "velox/exec/tests/utils/PlanBuilder.h"
30#include "velox/exec/tests/utils/TempDirectoryPath.h"
31#include "velox/ml_functions/DecisionTree.h"
32#include "velox/vector/tests/utils/VectorTestBase.h"
33
34using namespace facebook::velox;
35using namespace facebook::velox::test;
36using namespace facebook::velox::exec::test;
37using namespace facebook::velox::memory;
38
39namespace ml {
40
41class XGBoost;
42
43typedef std::shared_ptr<XGBoost> XGBoostPtr;
44
49class XGBoost {
50 public:
55 XGBoost(std::string pathToJSON) {
56 XGBoosterCreate(NULL, 0, &booster);
57 XGBoosterSetParam(booster, "seed", "0");
58 XGBoosterLoadModel(booster, pathToJSON.c_str());
59 }
60
68 inline void predict(
69 VectorPtr& input,
70 std::vector<float>& resultVector,
71 int numInputs,
72 int numFeatures) {
73 auto inputFeatures = input->as<ArrayVector>()->elements();
74 float* inputValues = inputFeatures->values()->asMutable<float>();
75 DMatrixHandle dtest;
76 XGDMatrixCreateFromMat(inputValues, numInputs, numFeatures, NAN, &dtest);
77 unsigned long numOutputs;
78 float const* outData = NULL;
79 XGBoosterPredictFromDMatrix(booster, dtest, 0, 0, &numOutputs, &outData);
80 assert(numOutputs == numInputs);
81 memcpy(resultVector.data(), outData, numOutputs * sizeof(float));
82 XGDMatrixFree(dtest);
83 XGBoosterFree(booster);
84 }
85
86 BoosterHandle booster;
87};
88
94 public:
100 XGBoostPrediction(std::string forestPath, int numFeatures) {
101 this->forest = std::make_shared<XGBoost>(forestPath);
102 this->numFeatures = numFeatures;
103 this->forestPath = forestPath;
104 }
105
114 void apply(
115 const SelectivityVector& rows,
116 std::vector<VectorPtr>& args,
117 const TypePtr& type,
118 exec::EvalCtx& context,
119 VectorPtr& output) const override {
120 BaseVector::ensureWritable(rows, type, context.pool(), output);
121
122 int numInputs = rows.size();
123
124 std::vector<float> resultVector(numInputs);
125
126 this->forest->predict(args[0], resultVector, numInputs, this->numFeatures);
127
128 VectorMaker maker{context.pool()};
129
130 output = maker.flatVector<float>(resultVector, REAL());
131 }
132
137 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
138 return {exec::FunctionSignatureBuilder()
139 .argumentType("array(REAL)")
140 .returnType("REAL")
141 .build()};
142 }
143
149 float* getTensor() const override {
150 return new float[0]; // will this lead to memory leak?
151 }
152
157 static std::string getName() {
158 return "xgboost_predict";
159 }
160
166 return numFeatures;
167 }
168
173 std::string& getForestPath() {
174 return this->forestPath;
175 }
176
177 private:
178 XGBoostPtr forest;
179
180 int numFeatures;
181
182 std::string forestPath;
183};
184
185} // namespace ml
std::shared_ptr< XGBoost > XGBoostPtr
Forward declaration of the XGBoost class.
Definition XGBoost.h:43
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition XGBoost.h:137
XGBoostPrediction(std::string forestPath, int numFeatures)
Construct a new XGBoostPrediction object.
Definition XGBoost.h:100
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to make predictions.
Definition XGBoost.h:114
std::string & getForestPath()
Get the path to the XGBoost model file.
Definition XGBoost.h:173
float * getTensor() const override
Get the tensor data.
Definition XGBoost.h:149
int getNumFeatures()
Get the number of features.
Definition XGBoost.h:165
static std::string getName()
Get the name of the function.
Definition XGBoost.h:157
A class for managing XGBoost models and making predictions.
Definition XGBoost.h:49
void predict(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Make predictions using the XGBoost model.
Definition XGBoost.h:68
XGBoost(std::string pathToJSON)
Construct a new XGBoost object.
Definition XGBoost.h:55
BoosterHandle booster
Handle to the XGBoost booster.
Definition XGBoost.h:86