ML functions
 
Loading...
Searching...
No Matches
DecisionTree.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
20
21#pragma once
22
23#include <fcntl.h>
24#include <stdlib.h>
25#include <unistd.h>
26#include <cmath>
27#include <iostream>
28#include <memory>
29#include <string>
30#include "velox/exec/tests/utils/AssertQueryBuilder.h"
31#include "velox/exec/tests/utils/PlanBuilder.h"
32#include "velox/exec/tests/utils/TempDirectoryPath.h"
33#include "velox/ml_functions/BaseFunction.h"
34#include "velox/vector/tests/utils/VectorTestBase.h"
35
36using namespace facebook::velox;
37using namespace facebook::velox::test;
38using namespace facebook::velox::exec::test;
39using namespace facebook::velox::memory;
40
41namespace ml {
42
43#define MAX_NUM_NODES_PER_TREE 512
44
45class Tree;
46typedef std::shared_ptr<Tree> TreePtr;
47
52typedef struct {
53 union {
54 float threshold;
55 float leafValue;
56 };
57 int indexID;
60 bool isLeaf;
62} Node;
63
68class Tree {
69 public:
70 Node tree[MAX_NUM_NODES_PER_TREE];
71 int treeId;
72
76 Tree() {}
77
83 Tree(int id, std::string treePath) : treeId{id} {
84 this->constructTreeFromPath(treePath, this->tree);
85 }
86
92 static void constructTreeFromPath(std::string treePathIn, Node* tree) {
93 std::vector<std::string> relationships;
94 std::vector<std::string> innerNodes;
95 std::vector<std::string> leafNodes;
97 treePathIn, relationships, innerNodes, leafNodes);
98 processInnerNodes(innerNodes, tree);
99 processLeafNodes(leafNodes, tree);
100 processRelationships(relationships, tree);
101 }
102
111 std::string treePathIn,
112 std::vector<std::string>& relationships,
113 std::vector<std::string>& innerNodes,
114 std::vector<std::string>& leafNodes) {
115 std::ifstream inputFile;
116 inputFile.open(treePathIn.data());
117 assert(inputFile.is_open());
118
119 std::string line;
120 while (getline(inputFile, line)) {
121 if ((line.size() == 0) || (line.find("graph") != std::string::npos) ||
122 (line.find("}") != std::string::npos)) {
123 } else {
124 if (line.find("->") != std::string::npos) {
125 relationships.push_back(line);
126 } else if (line.find("leaf") != std::string::npos) {
127 leafNodes.push_back(line);
128 } else if (line.find("label") != std::string::npos) {
129 innerNodes.push_back(line);
130 } else {
131 // skip the case of empty line, somehow it won't be captured by the
132 // first condition
133 }
134 }
135 }
136
137 inputFile.close();
138 }
139
145 static void processInnerNodes(
146 std::vector<std::string>& innerNodes,
147 Node* tree) {
148 int findStartPosition;
149 int findMidPosition;
150 int findEndPosition;
151
152 for (int i = 0; i < innerNodes.size(); ++i) {
153 const std::string& currentLine = innerNodes[i];
154 int nodeID;
155 int indexID;
156 float threshold;
157
158 if ((findEndPosition = currentLine.find("[ label")) !=
159 std::string::npos) {
160 nodeID = std::stoi(currentLine.substr(4, findEndPosition - 1 - 4));
161 } else {
162 LOG(ERROR) << "[ERROR] Error in extracting inner node nodeID\n";
163 exit(1);
164 }
165
166 if ((findStartPosition = currentLine.find("f")) != std::string::npos &&
167 (findEndPosition = currentLine.find("<")) != std::string::npos) {
168 indexID = std::stoi(currentLine.substr(
169 findStartPosition + 1, findEndPosition - findStartPosition - 1));
170 } else {
171 LOG(ERROR) << "[Error] Error in extracting inner node indexID\n";
172 exit(1);
173 }
174
175 if ((findStartPosition = currentLine.find("<")) != std::string::npos &&
176 (findEndPosition = currentLine.find("\" ]")) != std::string::npos) {
177 threshold = std::stod(currentLine.substr(
178 findStartPosition + 1, findEndPosition - findStartPosition - 1));
179 } else {
180 LOG(ERROR) << "[ERROR] Error in extracting inner node threshold\n";
181 exit(1);
182 }
183 tree[nodeID].isMissTrackLeft = false; // XGBoost default is noMissing/right
184
185 tree[nodeID].indexID = indexID;
186 tree[nodeID].isLeaf = false;
187 tree[nodeID].leftChild = -1;
188 tree[nodeID].rightChild = -1;
189 tree[nodeID].threshold = threshold;
190 }
191 }
192
198 static void processLeafNodes(
199 std::vector<std::string>& leafNodes,
200 Node* tree) {
201 int findStartPosition;
202 int findMidPosition;
203 int findEndPosition;
204
205 for (int i = 0; i < leafNodes.size(); ++i) {
206 const std::string& currentLine = leafNodes[i];
207 int nodeID;
208 float leafValue = -1.0f;
209
210 if ((findEndPosition = currentLine.find("[")) != std::string::npos) {
211 nodeID = std::stoi(currentLine.substr(4, findEndPosition - 1 - 4));
212 } else {
213 LOG(ERROR) << "[ERROR] Error in extracting leaf node nodeID\n";
214 exit(1);
215 }
216
217 if ((findStartPosition = currentLine.find("leaf=")) !=
218 std::string::npos &&
219 (findEndPosition = currentLine.find("\" ]")) != std::string::npos) {
220 leafValue = std::stod(currentLine.substr(
221 findStartPosition + 5,
222 findEndPosition - 3 - findStartPosition - 5));
223 } else {
224 std::cout << "Error in extracting leaf node leafValue\n";
225 exit(1);
226 }
227
228 tree[nodeID].indexID = -1;
229 tree[nodeID].isLeaf = true;
230 tree[nodeID].leftChild = -1;
231 tree[nodeID].rightChild = -1;
232 tree[nodeID].leafValue = leafValue;
233 tree[nodeID].isMissTrackLeft = true; // Doesn't matter to leave nodes
234 }
235 }
236
243 std::vector<std::string>& relationships,
244 Node* tree) {
245 int findStartPosition;
246 int findMidPosition;
247 int findEndPosition;
248
249 for (int i = 0; i < relationships.size(); ++i) {
250 const std::string& currentLine = relationships[i];
251 int parentNodeID;
252 int childNodeID;
253
254 if ((findMidPosition = currentLine.find("->")) != std::string::npos) {
255 parentNodeID =
256 std::stoi(currentLine.substr(4, findMidPosition - 1 - 4));
257 } else {
258 std::cout << "Error in extracting parentNodeID\n";
259 exit(1);
260 }
261
262 if ((findEndPosition = currentLine.find("[")) != std::string::npos) {
263 childNodeID = std::stoi(currentLine.substr(
264 findMidPosition + 3, findEndPosition - 1 - findMidPosition - 3));
265 } else {
266 std::cout << "Error in extracting childNodeID\n";
267 exit(1);
268 }
269
270 if (currentLine.find("yes, missing") != std::string::npos) {
271 tree[parentNodeID].isMissTrackLeft =
272 true; // in processInnerNodes(), default value is set to no/right
273 }
274
275 if (tree[parentNodeID].leftChild == -1) {
276 tree[parentNodeID].leftChild = childNodeID;
277 } else if (tree[parentNodeID].rightChild == -1) {
278 tree[parentNodeID].rightChild = childNodeID;
279 } else {
280 std::cout
281 << "Error in parsing trees: children nodes were updated again: "
282 << parentNodeID << "->" << childNodeID << std::endl;
283 }
284 }
285 }
286
293 inline float predictSingle(float* input, int curBase) {
294 int curIndex = 0;
295 while (!tree[curIndex].isLeaf) {
296 const float featureValue = input[curBase + tree[curIndex].indexID];
297 curIndex = featureValue < tree[curIndex].threshold
298 ? tree[curIndex].leftChild
299 : tree[curIndex].rightChild;
300 }
301 float result = (float)(tree[curIndex].leafValue);
302 return result;
303 }
304
312 inline void predict(
313 VectorPtr& input,
314 std::vector<float>& resultVector,
315 int numInputs,
316 int numFeatures) {
317 auto inputFeatures = input->as<ArrayVector>()->elements();
318 float* inputValues = inputFeatures->values()->asMutable<float>();
319 float* outData = resultVector.data();
320
321 for (int rowIndex = 0; rowIndex < numInputs; rowIndex++) {
322 int curIndex = 0;
323 int curBase = rowIndex * numFeatures;
324 while (!tree[curIndex].isLeaf) {
325 const float featureValue =
326 inputValues[curBase + tree[curIndex].indexID];
327 curIndex = featureValue < tree[curIndex].threshold
328 ? tree[curIndex].leftChild
329 : tree[curIndex].rightChild;
330 }
331 outData[rowIndex] = (float)(tree[curIndex].leafValue);
332 }
333 }
334
342 inline void predictMissing(
343 VectorPtr& input,
344 std::vector<float>& resultVector,
345 int numInputs,
346 int numFeatures) {
347 auto inputFeatures = input->as<ArrayVector>()->elements();
348 float* inputValues = inputFeatures->values()->asMutable<float>();
349 float* outData = resultVector.data();
350
351 for (int rowIndex = 0; rowIndex < numInputs; rowIndex++) {
352 int curIndex = 0;
353 int curBase = rowIndex * numFeatures;
354 while (!tree[curIndex].isLeaf) {
355 const float featureValue =
356 inputValues[curBase + tree[curIndex].indexID];
357 if (std::isnan(featureValue)) {
358 curIndex = tree[curIndex].isMissTrackLeft ? tree[curIndex].leftChild
359 : tree[curIndex].rightChild;
360
361 } else {
362 curIndex = featureValue < tree[curIndex].threshold
363 ? tree[curIndex].leftChild
364 : tree[curIndex].rightChild;
365 }
366 }
367 outData[rowIndex] = (float)(tree[curIndex].leafValue);
368 }
369 }
370};
371
377 public:
386 int treeId,
387 std::string treePath,
388 int numFeatures,
389 bool hasMissing) {
390 this->tree = std::make_shared<Tree>(treeId, treePath);
391 this->numFeatures = numFeatures;
392 this->hasMissing = hasMissing;
393 }
394
403 void apply(
404 const SelectivityVector& rows,
405 std::vector<VectorPtr>& args,
406 const TypePtr& type,
407 exec::EvalCtx& context,
408 VectorPtr& output) const override {
409 BaseVector::ensureWritable(rows, type, context.pool(), output);
410
411 int numInputs = rows.size();
412 std::vector<float> resultVector(numInputs);
413
414 if (hasMissing) {
415 this->tree->predictMissing(
416 args[0], resultVector, numInputs, this->numFeatures);
417 } else {
418 this->tree->predict(args[0], resultVector, numInputs, this->numFeatures);
419 }
420
421 VectorMaker maker{context.pool()};
422 output = maker.flatVector<float>(resultVector, REAL());
423 }
424
429 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
430 return {exec::FunctionSignatureBuilder()
431 .argumentType("array(REAL)")
432 .returnType("REAL")
433 .build()};
434 }
435
440 float* getTensor() const override {
441 return new float[0]; // will this lead to memory leak?
442 }
443
448 static std::string getName() {
449 return "tree_predict";
450 }
451
456 std::string getFuncName() {
457 return getName();
458 };
459
465 CostEstimate getCost(std::vector<int> inputDims) {
466 // TODO
467 return CostEstimate(1, inputDims[0], dims[1]);
468 }
469
470 private:
471 TreePtr tree;
472 int numFeatures;
473 bool hasMissing;
474};
475
476} // namespace ml
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition DecisionTree.h:429
static std::string getName()
Returns the name of the function.
Definition DecisionTree.h:448
CostEstimate getCost(std::vector< int > inputDims)
Estimates the cost of the function.
Definition DecisionTree.h:465
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the tree prediction function to the input data.
Definition DecisionTree.h:403
TreePrediction(int treeId, std::string treePath, int numFeatures, bool hasMissing)
Constructor for TreePrediction.
Definition DecisionTree.h:385
std::string getFuncName()
Returns the function name.
Definition DecisionTree.h:456
float * getTensor() const override
Returns the tensor associated with the function.
Definition DecisionTree.h:440
Represents a decision tree used for predictions.
Definition DecisionTree.h:68
Node tree[MAX_NUM_NODES_PER_TREE]
Array of tree nodes.
Definition DecisionTree.h:70
void predictMissing(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Predicts the output for multiple inputs, handling missing values.
Definition DecisionTree.h:342
void predict(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Predicts the output for multiple inputs.
Definition DecisionTree.h:312
static void constructTreeFromPathHelper(std::string treePathIn, std::vector< std::string > &relationships, std::vector< std::string > &innerNodes, std::vector< std::string > &leafNodes)
Parses the file and categorizes the lines into relationships, inner nodes, and leaf nodes.
Definition DecisionTree.h:110
static void processRelationships(std::vector< std::string > &relationships, Node *tree)
Parses the lines corresponding to tree relationships.
Definition DecisionTree.h:242
static void processInnerNodes(std::vector< std::string > &innerNodes, Node *tree)
Parses the lines corresponding to tree inner nodes.
Definition DecisionTree.h:145
static void constructTreeFromPath(std::string treePathIn, Node *tree)
Constructs a tree from a file dumped from an xgboost model.
Definition DecisionTree.h:92
float predictSingle(float *input, int curBase)
Predicts the output for a single input.
Definition DecisionTree.h:293
static void processLeafNodes(std::vector< std::string > &leafNodes, Node *tree)
Parses the lines corresponding to tree leaf nodes.
Definition DecisionTree.h:198
Tree(int id, std::string treePath)
Constructor that initializes the tree from an xgboost model dump.
Definition DecisionTree.h:83
int treeId
ID of the tree in the forest.
Definition DecisionTree.h:71
Tree()
Default constructor.
Definition DecisionTree.h:76
Represents a node in a decision tree.
Definition DecisionTree.h:52
int indexID
Index of the feature to compare.
Definition DecisionTree.h:57
bool isLeaf
Whether the node is a leaf.
Definition DecisionTree.h:60
int rightChild
Index of the right child node.
Definition DecisionTree.h:59
float leafValue
Value for leaf nodes.
Definition DecisionTree.h:55
bool isMissTrackLeft
Whether to track left if feature value is missing.
Definition DecisionTree.h:61
float threshold
Threshold value for non-leaf nodes.
Definition DecisionTree.h:54
int leftChild
Index of the left child node.
Definition DecisionTree.h:58