22#include <xgboost/c_api.h>
27#include "BaseFunction.h"
28#include "velox/exec/tests/utils/AssertQueryBuilder.h"
29#include "velox/exec/tests/utils/PlanBuilder.h"
30#include "velox/exec/tests/utils/TempDirectoryPath.h"
31#include "velox/ml_functions/DecisionTree.h"
32#include "velox/vector/tests/utils/VectorTestBase.h"
34using namespace facebook::velox;
35using namespace facebook::velox::test;
36using namespace facebook::velox::exec::test;
37using namespace facebook::velox::memory;
56 XGBoosterCreate(NULL, 0, &
booster);
57 XGBoosterSetParam(
booster,
"seed",
"0");
58 XGBoosterLoadModel(
booster, pathToJSON.c_str());
70 std::vector<float>& resultVector,
73 auto inputFeatures = input->as<ArrayVector>()->elements();
74 float* inputValues = inputFeatures->values()->asMutable<
float>();
76 XGDMatrixCreateFromMat(inputValues, numInputs, numFeatures, NAN, &dtest);
77 unsigned long numOutputs;
78 float const* outData = NULL;
79 XGBoosterPredictFromDMatrix(
booster, dtest, 0, 0, &numOutputs, &outData);
80 assert(numOutputs == numInputs);
81 memcpy(resultVector.data(), outData, numOutputs *
sizeof(
float));
101 this->forest = std::make_shared<XGBoost>(forestPath);
102 this->numFeatures = numFeatures;
103 this->forestPath = forestPath;
115 const SelectivityVector& rows,
116 std::vector<VectorPtr>& args,
118 exec::EvalCtx& context,
119 VectorPtr& output)
const override {
120 BaseVector::ensureWritable(rows, type, context.pool(), output);
122 int numInputs = rows.size();
124 std::vector<float> resultVector(numInputs);
126 this->forest->predict(args[0], resultVector, numInputs, this->numFeatures);
128 VectorMaker maker{context.pool()};
130 output = maker.flatVector<
float>(resultVector, REAL());
137 static std::vector<std::shared_ptr<exec::FunctionSignature>>
signatures() {
138 return {exec::FunctionSignatureBuilder()
139 .argumentType(
"array(REAL)")
158 return "xgboost_predict";
174 return this->forestPath;
182 std::string forestPath;
std::shared_ptr< XGBoost > XGBoostPtr
Forward declaration of the XGBoost class.
Definition XGBoost.h:43
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition XGBoost.h:137
XGBoostPrediction(std::string forestPath, int numFeatures)
Construct a new XGBoostPrediction object.
Definition XGBoost.h:100
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to make predictions.
Definition XGBoost.h:114
std::string & getForestPath()
Get the path to the XGBoost model file.
Definition XGBoost.h:173
float * getTensor() const override
Get the tensor data.
Definition XGBoost.h:149
int getNumFeatures()
Get the number of features.
Definition XGBoost.h:165
static std::string getName()
Get the name of the function.
Definition XGBoost.h:157
A class for managing XGBoost models and making predictions.
Definition XGBoost.h:49
void predict(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Make predictions using the XGBoost model.
Definition XGBoost.h:68
XGBoost(std::string pathToJSON)
Construct a new XGBoost object.
Definition XGBoost.h:55
BoosterHandle booster
Handle to the XGBoost booster.
Definition XGBoost.h:86