ML functions
 
Loading...
Searching...
No Matches
DotProduct.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <Eigen/Dense>
19#include <cmath>
20#include <iostream>
21#include "BaseFunction.h"
22#include "velox/exec/tests/utils/AssertQueryBuilder.h"
23#include "velox/exec/tests/utils/PlanBuilder.h"
24#include "velox/exec/tests/utils/TempDirectoryPath.h"
25#include "velox/vector/tests/utils/VectorTestBase.h"
26
27using namespace facebook::velox;
28using namespace facebook::velox::test;
29using namespace facebook::velox::exec::test;
30using namespace facebook::velox::memory;
31
36class DotProduct : public MLFunction {
37 public:
42 DotProduct(int inputDims) {
43 inputDims_ = inputDims;
44 }
45
54 void apply(
55 const SelectivityVector& rows,
56 std::vector<VectorPtr>& args,
57 const TypePtr& type,
58 exec::EvalCtx& context,
59 VectorPtr& output) const override {
60 BaseVector::ensureWritable(rows, type, context.pool(), output);
61
62 // Decoder is required to handle address error, reference code:
63 // ArrayIntersectExcept.cpp
64 BaseVector* left = args[0].get();
65 BaseVector* right = args[1].get();
66
67 exec::LocalDecodedVector leftHolder(context, *left, rows);
68 auto decodedLeftArray = leftHolder.get();
69 auto baseLeftArray =
70 decodedLeftArray->base()->as<ArrayVector>()->elements();
71
72 exec::LocalDecodedVector rightHolder(context, *right, rows);
73 auto decodedRightArray = rightHolder.get();
74 auto baseRightArray = rightHolder->base()->as<ArrayVector>()->elements();
75
76 float* input1Values = baseLeftArray->values()->asMutable<float>();
77 float* input2Values = baseRightArray->values()->asMutable<float>();
78
79 auto numInput = rows.size();
80
81 Eigen::Map<
82 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
83 input1Matrix(input1Values, numInput, inputDims_);
84 Eigen::Map<
85 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
86 input2Matrix(input2Values, numInput, inputDims_);
87
88 std::vector<std::vector<float>> results;
89
90 for (int i = 0; i < rows.size(); i++) {
91 std::vector<float> r;
92 float dotProduct = input1Matrix.row(i).dot(input2Matrix.row(i));
93 r.push_back(dotProduct);
94 results.push_back(r);
95 }
96
97 VectorMaker maker{context.pool()};
98 output = maker.arrayVector<float>(results, REAL());
99 }
100
105 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
106 return {exec::FunctionSignatureBuilder()
107 .argumentType("array(REAL)")
108 .argumentType("array(REAL)")
109 .returnType("array(REAL)")
110 .build()};
111 }
112
117 static std::string getName() {
118 return "dot_product";
119 }
120
125 float* getTensor() const override {
126 // TODO: need to implement
127 return nullptr;
128 }
129
135 CostEstimate getCost(std::vector<int> inputDims) {
136 // TODO: need to implement
137 return CostEstimate(0, inputDims[0], inputDims[1]);
138 }
139
140 private:
141 int inputDims_;
142};
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the dot product function to the input data.
Definition DotProduct.h:54
CostEstimate getCost(std::vector< int > inputDims)
Estimates the cost of the function.
Definition DotProduct.h:135
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition DotProduct.h:105
float * getTensor() const override
Returns the tensor associated with the function.
Definition DotProduct.h:125
static std::string getName()
Returns the name of the function.
Definition DotProduct.h:117
DotProduct(int inputDims)
Constructor for DotProduct.
Definition DotProduct.h:42
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9