25#include "BaseFunction.h"
26#include "velox/common/base/VeloxException.h"
27#include "velox/common/base/tests/GTestUtils.h"
28#include "velox/exec/tests/utils/AssertQueryBuilder.h"
29#include "velox/exec/tests/utils/PlanBuilder.h"
30#include "velox/exec/tests/utils/TempDirectoryPath.h"
31#include "velox/expression/VectorFunction.h"
32#include "velox/functions/Macros.h"
33#include "velox/functions/Registerer.h"
34#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
35#include "velox/ml_functions/DecisionTree.h"
36#include "velox/type/OpaqueCustomTypes.h"
37#include "velox/vector/tests/utils/VectorTestBase.h"
39using namespace facebook::velox;
40using namespace facebook::velox::test;
41using namespace facebook::velox::exec::test;
42using namespace facebook::velox::memory;
52class TreeType :
public OpaqueType {
53 TreeType() : OpaqueType(std::type_index(
typeid(
ml::Tree))) {}
60 static const std::shared_ptr<const TreeType>&
get() {
61 static const std::shared_ptr<const TreeType> instance{
79 const char*
name()
const override {
89 using type = std::shared_ptr<Tree>;
91 static constexpr const char*
typeName =
"tree_type";
147class VeloxTreeConstruction :
public exec::VectorFunction {
149 VeloxTreeConstruction() {}
160 const SelectivityVector& rows,
161 std::vector<VectorPtr>& args,
163 exec::EvalCtx& context,
164 VectorPtr& output)
const override {
165 auto flatInput = args[0]->as<SimpleVector<StringView>>();
167 BaseVector::ensureWritable(rows, type, context.pool(), output);
169 auto flatResult = output->asFlatVector<std::shared_ptr<void>>();
171 rows.applyToSelected([&](
auto row) {
173 row, std::make_shared<Tree>(row, flatInput->valueAt(row)));
181 static std::vector<std::shared_ptr<exec::FunctionSignature>>
signatures() {
182 return {exec::FunctionSignatureBuilder()
183 .argumentType(
"VARCHAR")
184 .returnType(
"tree_type")
193 return "velox_tree_construct";
230 const SelectivityVector& rows,
231 std::vector<VectorPtr>& args,
233 exec::EvalCtx& context,
234 VectorPtr& output)
const override {
235 BaseVector::ensureWritable(rows, type, context.pool(), output);
237 BaseVector* left = args[0].get();
239 exec::LocalDecodedVector leftHolder(context, *left, rows);
241 auto decodedLeftArray = leftHolder.get();
244 decodedLeftArray->base()->as<ArrayVector>()->elements();
246 float* input1Values = baseLeftArray->values()->asMutable<
float>();
248 auto flatInput = args[1]->as<SimpleVector<std::shared_ptr<void>>>();
250 auto flatResult = output->asFlatVector<
float>();
252 rows.applyToSelected([&](
auto row) {
255 std::static_pointer_cast<Tree>(flatInput->valueAt(row))
265 static std::vector<std::shared_ptr<exec::FunctionSignature>>
signatures() {
266 return {exec::FunctionSignatureBuilder()
267 .argumentType(
"array(REAL)")
268 .argumentType(
"tree_type")
278 return "velox_tree_predict";
294 CostEstimate
getCost(std::vector<int> inputDims) {
296 return CostEstimate(1, inputDims[0],
dims[1]);
309 VELOX_DEFINE_FUNCTION_TYPES(T);
318 out_type<float>& result,
319 const arg_type<Array<float>>& a,
320 const arg_type<TheTree>& b) {
CustomType< TreeT > TheTree
Alias for the custom tree type.
Definition VeloxDecisionTree.h:94
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61
A factory class that always fails to create instances.
Definition VeloxDecisionTree.h:124
exec::CastOperatorPtr getCastOperator() const override
Get the cast operator for the type.
Definition VeloxDecisionTree.h:138
TypePtr getType() const override
Get the type instance.
Definition VeloxDecisionTree.h:130
Factory class for creating instances of TreeType.
Definition VeloxDecisionTree.h:100
exec::CastOperatorPtr getCastOperator() const override
Get the cast operator for the type.
Definition VeloxDecisionTree.h:115
TypePtr getType() const override
Get the TreeType instance.
Definition VeloxDecisionTree.h:106
static const std::shared_ptr< const TreeType > & get()
Get a shared instance of TreeType.
Definition VeloxDecisionTree.h:60
std::string toString() const override
Convert the type to a string representation.
Definition VeloxDecisionTree.h:71
const char * name() const override
Get the name of the type.
Definition VeloxDecisionTree.h:79
Represents a decision tree used for predictions.
Definition DecisionTree.h:68
static std::string getName()
Get the name of the function.
Definition VeloxDecisionTree.h:192
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition VeloxDecisionTree.h:181
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to construct trees.
Definition VeloxDecisionTree.h:159
static std::string getName()
Get the name of the function.
Definition VeloxDecisionTree.h:277
VeloxTreePrediction(int numFeatures)
Construct a new VeloxTreePrediction object.
Definition VeloxDecisionTree.h:207
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Get the function signatures.
Definition VeloxDecisionTree.h:265
CostEstimate getCost(std::vector< int > inputDims)
Get the cost estimate for the function.
Definition VeloxDecisionTree.h:294
int numFeatures
The number of features in the input data.
Definition VeloxDecisionTree.h:299
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Apply the function to make predictions.
Definition VeloxDecisionTree.h:229
std::string getFuncName()
Get the function name.
Definition VeloxDecisionTree.h:285
float * getTensor() const override
Get the tensor data.
Definition VeloxDecisionTree.h:217
A struct representing the custom type for trees.
Definition VeloxDecisionTree.h:88
std::shared_ptr< Tree > type
The underlying type for the custom type.
Definition VeloxDecisionTree.h:89
static constexpr const char * typeName
The name of the custom type.
Definition VeloxDecisionTree.h:91
A simple function for making predictions using tree models.
Definition VeloxDecisionTree.h:308
void call(out_type< float > &result, const arg_type< Array< float > > &a, const arg_type< TheTree > &b)
Call the function to make predictions.
Definition VeloxDecisionTree.h:317