30#include "velox/exec/tests/utils/AssertQueryBuilder.h"
31#include "velox/exec/tests/utils/PlanBuilder.h"
32#include "velox/exec/tests/utils/TempDirectoryPath.h"
33#include "velox/ml_functions/BaseFunction.h"
34#include "velox/vector/tests/utils/VectorTestBase.h"
36using namespace facebook::velox;
37using namespace facebook::velox::test;
38using namespace facebook::velox::exec::test;
39using namespace facebook::velox::memory;
43#define MAX_NUM_NODES_PER_TREE 512
46typedef std::shared_ptr<Tree> TreePtr;
93 std::vector<std::string> relationships;
94 std::vector<std::string> innerNodes;
95 std::vector<std::string> leafNodes;
97 treePathIn, relationships, innerNodes, leafNodes);
111 std::string treePathIn,
112 std::vector<std::string>& relationships,
113 std::vector<std::string>& innerNodes,
114 std::vector<std::string>& leafNodes) {
115 std::ifstream inputFile;
116 inputFile.open(treePathIn.data());
117 assert(inputFile.is_open());
120 while (getline(inputFile, line)) {
121 if ((line.size() == 0) || (line.find(
"graph") != std::string::npos) ||
122 (line.find(
"}") != std::string::npos)) {
124 if (line.find(
"->") != std::string::npos) {
125 relationships.push_back(line);
126 }
else if (line.find(
"leaf") != std::string::npos) {
127 leafNodes.push_back(line);
128 }
else if (line.find(
"label") != std::string::npos) {
129 innerNodes.push_back(line);
146 std::vector<std::string>& innerNodes,
148 int findStartPosition;
152 for (
int i = 0; i < innerNodes.size(); ++i) {
153 const std::string& currentLine = innerNodes[i];
158 if ((findEndPosition = currentLine.find(
"[ label")) !=
160 nodeID = std::stoi(currentLine.substr(4, findEndPosition - 1 - 4));
162 LOG(ERROR) <<
"[ERROR] Error in extracting inner node nodeID\n";
166 if ((findStartPosition = currentLine.find(
"f")) != std::string::npos &&
167 (findEndPosition = currentLine.find(
"<")) != std::string::npos) {
168 indexID = std::stoi(currentLine.substr(
169 findStartPosition + 1, findEndPosition - findStartPosition - 1));
171 LOG(ERROR) <<
"[Error] Error in extracting inner node indexID\n";
175 if ((findStartPosition = currentLine.find(
"<")) != std::string::npos &&
176 (findEndPosition = currentLine.find(
"\" ]")) != std::string::npos) {
177 threshold = std::stod(currentLine.substr(
178 findStartPosition + 1, findEndPosition - findStartPosition - 1));
180 LOG(ERROR) <<
"[ERROR] Error in extracting inner node threshold\n";
183 tree[nodeID].isMissTrackLeft =
false;
185 tree[nodeID].indexID = indexID;
186 tree[nodeID].isLeaf =
false;
187 tree[nodeID].leftChild = -1;
188 tree[nodeID].rightChild = -1;
189 tree[nodeID].threshold = threshold;
199 std::vector<std::string>& leafNodes,
201 int findStartPosition;
205 for (
int i = 0; i < leafNodes.size(); ++i) {
206 const std::string& currentLine = leafNodes[i];
208 float leafValue = -1.0f;
210 if ((findEndPosition = currentLine.find(
"[")) != std::string::npos) {
211 nodeID = std::stoi(currentLine.substr(4, findEndPosition - 1 - 4));
213 LOG(ERROR) <<
"[ERROR] Error in extracting leaf node nodeID\n";
217 if ((findStartPosition = currentLine.find(
"leaf=")) !=
219 (findEndPosition = currentLine.find(
"\" ]")) != std::string::npos) {
220 leafValue = std::stod(currentLine.substr(
221 findStartPosition + 5,
222 findEndPosition - 3 - findStartPosition - 5));
224 std::cout <<
"Error in extracting leaf node leafValue\n";
228 tree[nodeID].indexID = -1;
229 tree[nodeID].isLeaf =
true;
230 tree[nodeID].leftChild = -1;
231 tree[nodeID].rightChild = -1;
232 tree[nodeID].leafValue = leafValue;
233 tree[nodeID].isMissTrackLeft =
true;
243 std::vector<std::string>& relationships,
245 int findStartPosition;
249 for (
int i = 0; i < relationships.size(); ++i) {
250 const std::string& currentLine = relationships[i];
254 if ((findMidPosition = currentLine.find(
"->")) != std::string::npos) {
256 std::stoi(currentLine.substr(4, findMidPosition - 1 - 4));
258 std::cout <<
"Error in extracting parentNodeID\n";
262 if ((findEndPosition = currentLine.find(
"[")) != std::string::npos) {
263 childNodeID = std::stoi(currentLine.substr(
264 findMidPosition + 3, findEndPosition - 1 - findMidPosition - 3));
266 std::cout <<
"Error in extracting childNodeID\n";
270 if (currentLine.find(
"yes, missing") != std::string::npos) {
271 tree[parentNodeID].isMissTrackLeft =
275 if (
tree[parentNodeID].leftChild == -1) {
276 tree[parentNodeID].leftChild = childNodeID;
277 }
else if (
tree[parentNodeID].rightChild == -1) {
278 tree[parentNodeID].rightChild = childNodeID;
281 <<
"Error in parsing trees: children nodes were updated again: "
282 << parentNodeID <<
"->" << childNodeID << std::endl;
295 while (!
tree[curIndex].isLeaf) {
296 const float featureValue = input[curBase +
tree[curIndex].indexID];
297 curIndex = featureValue <
tree[curIndex].threshold
298 ?
tree[curIndex].leftChild
299 :
tree[curIndex].rightChild;
301 float result = (float)(
tree[curIndex].leafValue);
314 std::vector<float>& resultVector,
317 auto inputFeatures = input->as<ArrayVector>()->elements();
318 float* inputValues = inputFeatures->values()->asMutable<
float>();
319 float* outData = resultVector.data();
321 for (
int rowIndex = 0; rowIndex < numInputs; rowIndex++) {
323 int curBase = rowIndex * numFeatures;
324 while (!
tree[curIndex].isLeaf) {
325 const float featureValue =
326 inputValues[curBase +
tree[curIndex].indexID];
327 curIndex = featureValue <
tree[curIndex].threshold
328 ?
tree[curIndex].leftChild
329 :
tree[curIndex].rightChild;
331 outData[rowIndex] = (float)(
tree[curIndex].leafValue);
344 std::vector<float>& resultVector,
347 auto inputFeatures = input->as<ArrayVector>()->elements();
348 float* inputValues = inputFeatures->values()->asMutable<
float>();
349 float* outData = resultVector.data();
351 for (
int rowIndex = 0; rowIndex < numInputs; rowIndex++) {
353 int curBase = rowIndex * numFeatures;
354 while (!
tree[curIndex].isLeaf) {
355 const float featureValue =
356 inputValues[curBase +
tree[curIndex].indexID];
357 if (std::isnan(featureValue)) {
358 curIndex =
tree[curIndex].isMissTrackLeft ?
tree[curIndex].leftChild
359 :
tree[curIndex].rightChild;
362 curIndex = featureValue <
tree[curIndex].threshold
363 ?
tree[curIndex].leftChild
364 :
tree[curIndex].rightChild;
367 outData[rowIndex] = (float)(
tree[curIndex].leafValue);
387 std::string treePath,
390 this->tree = std::make_shared<Tree>(treeId, treePath);
391 this->numFeatures = numFeatures;
392 this->hasMissing = hasMissing;
404 const SelectivityVector& rows,
405 std::vector<VectorPtr>& args,
407 exec::EvalCtx& context,
408 VectorPtr& output)
const override {
409 BaseVector::ensureWritable(rows, type, context.pool(), output);
411 int numInputs = rows.size();
412 std::vector<float> resultVector(numInputs);
415 this->tree->predictMissing(
416 args[0], resultVector, numInputs, this->numFeatures);
418 this->tree->predict(args[0], resultVector, numInputs, this->numFeatures);
421 VectorMaker maker{context.pool()};
422 output = maker.flatVector<
float>(resultVector, REAL());
429 static std::vector<std::shared_ptr<exec::FunctionSignature>>
signatures() {
430 return {exec::FunctionSignatureBuilder()
431 .argumentType(
"array(REAL)")
449 return "tree_predict";
465 CostEstimate
getCost(std::vector<int> inputDims) {
467 return CostEstimate(1, inputDims[0],
dims[1]);
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition DecisionTree.h:429
static std::string getName()
Returns the name of the function.
Definition DecisionTree.h:448
CostEstimate getCost(std::vector< int > inputDims)
Estimates the cost of the function.
Definition DecisionTree.h:465
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the tree prediction function to the input data.
Definition DecisionTree.h:403
TreePrediction(int treeId, std::string treePath, int numFeatures, bool hasMissing)
Constructor for TreePrediction.
Definition DecisionTree.h:385
std::string getFuncName()
Returns the function name.
Definition DecisionTree.h:456
float * getTensor() const override
Returns the tensor associated with the function.
Definition DecisionTree.h:440
Represents a decision tree used for predictions.
Definition DecisionTree.h:68
Node tree[MAX_NUM_NODES_PER_TREE]
Array of tree nodes.
Definition DecisionTree.h:70
void predictMissing(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Predicts the output for multiple inputs, handling missing values.
Definition DecisionTree.h:342
void predict(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Predicts the output for multiple inputs.
Definition DecisionTree.h:312
static void constructTreeFromPathHelper(std::string treePathIn, std::vector< std::string > &relationships, std::vector< std::string > &innerNodes, std::vector< std::string > &leafNodes)
Parses the file and categorizes the lines into relationships, inner nodes, and leaf nodes.
Definition DecisionTree.h:110
static void processRelationships(std::vector< std::string > &relationships, Node *tree)
Parses the lines corresponding to tree relationships.
Definition DecisionTree.h:242
static void processInnerNodes(std::vector< std::string > &innerNodes, Node *tree)
Parses the lines corresponding to tree inner nodes.
Definition DecisionTree.h:145
static void constructTreeFromPath(std::string treePathIn, Node *tree)
Constructs a tree from a file dumped from an xgboost model.
Definition DecisionTree.h:92
float predictSingle(float *input, int curBase)
Predicts the output for a single input.
Definition DecisionTree.h:293
static void processLeafNodes(std::vector< std::string > &leafNodes, Node *tree)
Parses the lines corresponding to tree leaf nodes.
Definition DecisionTree.h:198
Tree(int id, std::string treePath)
Constructor that initializes the tree from an xgboost model dump.
Definition DecisionTree.h:83
int treeId
ID of the tree in the forest.
Definition DecisionTree.h:71
Tree()
Default constructor.
Definition DecisionTree.h:76
Represents a node in a decision tree.
Definition DecisionTree.h:52
int indexID
Index of the feature to compare.
Definition DecisionTree.h:57
bool isLeaf
Whether the node is a leaf.
Definition DecisionTree.h:60
int rightChild
Index of the right child node.
Definition DecisionTree.h:59
float leafValue
Value for leaf nodes.
Definition DecisionTree.h:55
bool isMissTrackLeft
Whether to track left if feature value is missing.
Definition DecisionTree.h:61
float threshold
Threshold value for non-leaf nodes.
Definition DecisionTree.h:54
int leftChild
Index of the left child node.
Definition DecisionTree.h:58