ML functions
 
Loading...
Searching...
No Matches
VeloxDecisionTree.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <fcntl.h>
19#include <stdlib.h>
20#include <unistd.h>
21#include <cmath>
22#include <iostream>
23#include <memory>
24#include <string>
25#include "BaseFunction.h"
26#include "velox/common/base/VeloxException.h"
27#include "velox/common/base/tests/GTestUtils.h"
28#include "velox/exec/tests/utils/AssertQueryBuilder.h"
29#include "velox/exec/tests/utils/PlanBuilder.h"
30#include "velox/exec/tests/utils/TempDirectoryPath.h"
31#include "velox/expression/VectorFunction.h"
32#include "velox/functions/Macros.h"
33#include "velox/functions/Registerer.h"
34#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
35#include "velox/ml_functions/DecisionTree.h"
36#include "velox/type/OpaqueCustomTypes.h"
37#include "velox/vector/tests/utils/VectorTestBase.h"
38
39using namespace facebook::velox;
40using namespace facebook::velox::test;
41using namespace facebook::velox::exec::test;
42using namespace facebook::velox::memory;
43
44namespace ml {
45
52class TreeType : public OpaqueType {
53 TreeType() : OpaqueType(std::type_index(typeid(ml::Tree))) {}
54
55 public:
60 static const std::shared_ptr<const TreeType>& get() {
61 static const std::shared_ptr<const TreeType> instance{
62 new TreeType()
63 };
64 return instance;
65 }
66
71 std::string toString() const override {
72 return name();
73 }
74
79 const char* name() const override {
80 return "tree_type";
81 }
82};
83
88struct TreeT {
89 using type = std::shared_ptr<Tree>;
90
91 static constexpr const char* typeName = "tree_type";
92};
93
94using TheTree = CustomType<TreeT>;
95
100class TreeTypeFactories : public CustomTypeFactories {
101 public:
106 TypePtr getType() const override {
107 return TreeType::get();
108 }
109
115 exec::CastOperatorPtr getCastOperator() const override {
116 VELOX_UNSUPPORTED();
117 }
118};
119
124class AlwaysFailingTypeFactories : public CustomTypeFactories {
125 public:
130 TypePtr getType() const override {
131 VELOX_UNSUPPORTED();
132 }
133
138 exec::CastOperatorPtr getCastOperator() const override {
139 VELOX_UNSUPPORTED();
140 }
141};
142
147class VeloxTreeConstruction : public exec::VectorFunction {
148 public:
149 VeloxTreeConstruction() {}
150
159 void apply(
160 const SelectivityVector& rows,
161 std::vector<VectorPtr>& args,
162 const TypePtr& type,
163 exec::EvalCtx& context,
164 VectorPtr& output) const override {
165 auto flatInput = args[0]->as<SimpleVector<StringView>>();
166
167 BaseVector::ensureWritable(rows, type, context.pool(), output);
168
169 auto flatResult = output->asFlatVector<std::shared_ptr<void>>();
170
171 rows.applyToSelected([&](auto row) {
172 flatResult->set(
173 row, std::make_shared<Tree>(row, flatInput->valueAt(row)));
174 });
175 }
176
181 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
182 return {exec::FunctionSignatureBuilder()
183 .argumentType("VARCHAR")
184 .returnType("tree_type")
185 .build()};
186 }
187
192 static std::string getName() {
193 return "velox_tree_construct";
194 }
195};
196
202 public:
208 this->numFeatures = numFeatures;
209 dims.push_back(numFeatures);
210 }
211
217 float* getTensor() const override {
218 return new float[0]; // will this lead to memory leak?
219 }
220
229 void apply(
230 const SelectivityVector& rows,
231 std::vector<VectorPtr>& args,
232 const TypePtr& type,
233 exec::EvalCtx& context,
234 VectorPtr& output) const override {
235 BaseVector::ensureWritable(rows, type, context.pool(), output);
236
237 BaseVector* left = args[0].get();
238
239 exec::LocalDecodedVector leftHolder(context, *left, rows);
240
241 auto decodedLeftArray = leftHolder.get();
242
243 auto baseLeftArray =
244 decodedLeftArray->base()->as<ArrayVector>()->elements();
245
246 float* input1Values = baseLeftArray->values()->asMutable<float>();
247
248 auto flatInput = args[1]->as<SimpleVector<std::shared_ptr<void>>>();
249
250 auto flatResult = output->asFlatVector<float>();
251
252 rows.applyToSelected([&](auto row) {
253 flatResult->set(
254 row,
255 std::static_pointer_cast<Tree>(flatInput->valueAt(row))
256 ->predictSingle(input1Values, row * numFeatures)
257 );
258 });
259 }
260
265 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
266 return {exec::FunctionSignatureBuilder()
267 .argumentType("array(REAL)")
268 .argumentType("tree_type")
269 .returnType("REAL")
270 .build()};
271 }
272
277 static std::string getName() {
278 return "velox_tree_predict";
279 }
280
285 std::string getFuncName() {
286 return getName();
287 };
288
294 CostEstimate getCost(std::vector<int> inputDims) {
295 // TODO
296 return CostEstimate(1, inputDims[0], dims[1]);
297 }
298
300};
301
307template <typename T>
309 VELOX_DEFINE_FUNCTION_TYPES(T);
310
317 void call(
318 out_type<float>& result,
319 const arg_type<Array<float>>& a,
320 const arg_type<TheTree>& b) {
321 result = 0.0;
322 }
323};
324
325} // namespace ml
CustomType< TreeT > TheTree
Alias for the custom tree type.
Definition VeloxDecisionTree.h:94
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61
A factory class that always fails to create instances.
Definition VeloxDecisionTree.h:124
exec::CastOperatorPtr getCastOperator() const override
Get the cast operator for the type.
Definition VeloxDecisionTree.h:138
TypePtr getType() const override
Get the type instance.
Definition VeloxDecisionTree.h:130
Factory class for creating instances of TreeType.
Definition VeloxDecisionTree.h:100
exec::CastOperatorPtr getCastOperator() const override
Get the cast operator for the type.
Definition VeloxDecisionTree.h:115
TypePtr getType() const override
Get the TreeType instance.
Definition VeloxDecisionTree.h:106
static const std::shared_ptr< const TreeType > & get()
Get a shared instance of TreeType.
Definition VeloxDecisionTree.h:60
std::string toString() const override
Convert the type to a string representation.
Definition VeloxDecisionTree.h:71
const char * name() const override
Get the name of the type.
Definition VeloxDecisionTree.h:79
Represents a decision tree used for predictions.
Definition DecisionTree.h:68
static std::string getName()
Get the name of the function.
Definition VeloxDecisionTree.h:192
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition VeloxDecisionTree.h:181
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to construct trees.
Definition VeloxDecisionTree.h:159
static std::string getName()
Get the name of the function.
Definition VeloxDecisionTree.h:277
VeloxTreePrediction(int numFeatures)
Construct a new VeloxTreePrediction object.
Definition VeloxDecisionTree.h:207
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition VeloxDecisionTree.h:265
CostEstimate getCost(std::vector< int > inputDims)
Get the cost estimate for the function.
Definition VeloxDecisionTree.h:294
int numFeatures
The number of features in the input data.
Definition VeloxDecisionTree.h:299
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to make predictions.
Definition VeloxDecisionTree.h:229
std::string getFuncName()
Get the function name.
Definition VeloxDecisionTree.h:285
float * getTensor() const override
Get the tensor data.
Definition VeloxDecisionTree.h:217
A struct representing the custom type for trees.
Definition VeloxDecisionTree.h:88
std::shared_ptr< Tree > type
The underlying type for the custom type.
Definition VeloxDecisionTree.h:89
static constexpr const char * typeName
The name of the custom type.
Definition VeloxDecisionTree.h:91
A simple function for making predictions using tree models.
Definition VeloxDecisionTree.h:308
void call(out_type< float > &result, const arg_type< Array< float > > &a, const arg_type< TheTree > &b)
Call the function to make predictions.
Definition VeloxDecisionTree.h:317