ML functions
 
Loading...
Searching...
No Matches
DecisionForest.h
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
23class Forest {
24public:
25 Node forest[MAX_NUM_TREES][MAX_NUM_NODES_PER_TREE];
28
32 Forest() {}
33
40 Forest(std::string pathToFolder, bool isClassification)
42 this->constructForestFromFolder(pathToFolder);
43 }
44
52 std::string pathToFolder,
53 std::vector<std::string>& pathVector) {
54 if (pathToFolder[pathToFolder.length() - 1] != '/') {
55 pathToFolder = pathToFolder + std::string("/");
56 }
57
58 DIR* dr = opendir(pathToFolder.c_str());
59
60 struct dirent* file = NULL;
61
62 while ((file = readdir(dr)) != NULL) {
63 if ((strcmp(file->d_name, ".") != 0) && (strcmp(file->d_name, ".."))) {
64 std::string path = pathToFolder + std::string(file->d_name);
65
66 pathVector.push_back(path);
67 }
68 }
69
70 closedir(dr);
71 }
72
78 void constructForestFromFolder(std::string pathToFolder) {
79 std::vector<std::string> treePaths;
80
81 vectorizeForestFolder(pathToFolder, treePaths);
82
83 constructForestFromPaths(treePaths);
84 }
85
91 void constructForestFromPaths(std::vector<std::string>& treesPathIn) {
92 this->numTrees = treesPathIn.size();
93
94 for (int n = 0; n < numTrees; ++n) {
95 Tree::constructTreeFromPath(treesPathIn[n], &(forest[n][0]));
96 }
97
98 // STATS ABOUT THE FOREST
99 LOG(INFO)
100 << "[Forest-constructForestFromPaths] Number of trees in the forest: "
101 << numTrees << std::endl;
102 }
103
112 inline void predict(
113 VectorPtr& input,
114 std::vector<float>& resultVector,
115 int numInputs,
116 int numFeatures) {
117 auto inputFeatures = input->as<ArrayVector>()->elements();
118
119 float* inputValues = inputFeatures->values()->asMutable<float>();
120
121 float* outData = resultVector.data();
122
123 for (int rowIndex = 0; rowIndex < numInputs; rowIndex++) {
124 int curBase = rowIndex * numFeatures;
125
126 float accumulatedResult = 0.0;
127
128 for (int treeIndex = 0; treeIndex < numTrees; treeIndex++) {
129 int curIndex = 0;
130
131 Node* tree = forest[treeIndex];
132
133 while (!tree[curIndex].isLeaf) {
134 const float featureValue =
135 inputValues[curBase + tree[curIndex].indexID];
136
137 curIndex = featureValue < tree[curIndex].threshold
138 ? tree[curIndex].leftChild
139 : tree[curIndex].rightChild;
140 }
141
142 accumulatedResult += (float)(tree[curIndex].leafValue);
143 }
144
145 accumulatedResult /= numTrees;
146
147 if (isClassification) {
148 accumulatedResult = (accumulatedResult > 0.0) ? 1.0 : 0.0;
149 }
150
151 outData[rowIndex] = accumulatedResult;
152 }
153 }
154};
155
161class ForestPrediction : public MLFunction {
162public:
171 std::string forestPath,
172 int numFeatures,
173 bool isClassification) {
174 if (!std::filesystem::exists(forestPath)) {
175 throw std::runtime_error(
176 "[ForestPrediction] Path not exists: " + forestPath);
177 }
178
179 this->forest = std::make_shared<Forest>(forestPath, isClassification);
180
181 this->numFeatures = numFeatures;
182
183 this->forestPath = forestPath;
184
185 this->isClassification = isClassification;
186 }
187
200 void apply(
201 const SelectivityVector& rows,
202 std::vector<VectorPtr>& args,
203 const TypePtr& type,
204 exec::EvalCtx& context,
205 VectorPtr& output) const override {
206 BaseVector::ensureWritable(rows, type, context.pool(), output);
207
208 int numInputs = rows.size();
209
210 std::vector<float> resultVector(numInputs);
211
212 this->forest->predict(args[0], resultVector, numInputs, this->numFeatures);
213
214 VectorMaker maker{context.pool()};
215
216 output = maker.flatVector<float>(resultVector, REAL());
217 }
218
224 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
225 return {exec::FunctionSignatureBuilder()
226 .argumentType("array(REAL)")
227 .returnType("REAL")
228 .build()};
229 }
230
236 float* getTensor() const override {
237 return new float[0];
238 }
239
245 static std::string getName() {
246 return "tree_predict";
247 }
248
254 int getNumFeatures() {
255 return numFeatures;
256 }
257
263 std::string& getForestPath() {
264 return this->forestPath;
265 }
266
272 bool getClassification() {
273 return this->isClassification;
274 }
275
276private:
277 ForestPtr forest;
278 int numFeatures;
279 std::string forestPath;
280 bool isClassification;
A class that implements a random forest prediction function, inheriting from MLFunction.
void predict(VectorPtr &input, std::vector< float > &resultVector, int numInputs, int numFeatures)
Makes predictions using the forest.
Definition DecisionForest.h:112
bool isClassification
Flag indicating whether the forest is used for classification.
Definition DecisionForest.h:27
static void vectorizeForestFolder(std::string pathToFolder, std::vector< std::string > &pathVector)
Scans a folder and collects paths to tree files.
Definition DecisionForest.h:51
void constructForestFromPaths(std::vector< std::string > &treesPathIn)
Constructs the forest from a list of tree file paths.
Definition DecisionForest.h:91
void constructForestFromFolder(std::string pathToFolder)
Constructs the forest from a folder of tree files.
Definition DecisionForest.h:78
Forest(std::string pathToFolder, bool isClassification)
Constructor that initializes the forest from a folder of tree files.
Definition DecisionForest.h:40
int numTrees
Number of trees in the forest.
Definition DecisionForest.h:26
Forest()
Default constructor.
Definition DecisionForest.h:32
Node forest[MAX_NUM_TREES][MAX_NUM_NODES_PER_TREE]
Array of trees in the forest.
Definition DecisionForest.h:25
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9