ML functions
 
Loading...
Searching...
No Matches
SVD.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <Eigen/Dense>
19#include <cmath>
20#include <iostream>
21
22using namespace facebook::velox;
23using namespace facebook::velox::test;
24using namespace facebook::velox::exec::test;
25using namespace facebook::velox::memory;
26
31class SVD : public MLFunction {
32 public:
43 SVD(float* bu,
44 float* bi,
45 float* pu,
46 float* qi,
47 int numUser,
48 int numItem,
49 int latentDims) {
50 // Create a deep copy of the weights
51 bu_ = new float[numUser];
52 bi_ = new float[numItem];
53 pu_ = new float[numUser * latentDims];
54 qi_ = new float[numItem * latentDims];
55 std::memcpy(bu_, bu, numUser * sizeof(float));
56 std::memcpy(bi_, bi, numItem * sizeof(float));
57 std::memcpy(pu_, pu, numUser * latentDims * sizeof(float));
58 std::memcpy(qi_, qi, numItem * latentDims * sizeof(float));
59 dims.push_back(numUser);
60 dims.push_back(numItem);
61 dims.push_back(latentDims);
62 }
63
72 void apply(
73 const SelectivityVector& rows,
74 std::vector<VectorPtr>& args,
75 const TypePtr& outputType,
76 exec::EvalCtx& context,
77 VectorPtr& output) const override {
78 BaseVector::ensureWritable(rows, outputType, context.pool(), output);
79
80 exec::DecodedArgs decodedArgs(rows, args, context);
81 auto decodedUser = decodedArgs.at(0);
82 auto decodedItem = decodedArgs.at(1);
83
84 auto arrayOutput = output->asFlatVector<float>();
85 float* outputValues = arrayOutput->mutableRawValues<float>();
86
87 rows.applyToSelected([&](vector_size_t i) {
88 auto userId = decodedUser->valueAt<int>(i);
89 auto itemId = decodedItem->valueAt<int>(i);
90
91 if (userId > dims[0]) {
92 LOG(WARNING) << "User id out of bound: " << userId << " / " << dims[0];
93 userId = 0;
94 } else if (itemId > dims[1]) {
95 itemId = 0;
96 LOG(WARNING) << "Item id out of bound: " << itemId << " / " << dims[1];
97 }
98
99 Eigen::Map<Eigen::VectorXf> qiVec(qi_ + itemId * dims[2], dims[2]);
100 Eigen::Map<Eigen::VectorXf> puVec(pu_ + userId * dims[2], dims[2]);
101
102 float prediction = bu_[userId] + bi_[itemId] + puVec.dot(qiVec);
103 outputValues[i] = prediction;
104 });
105 }
106
111 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
112 return {exec::FunctionSignatureBuilder()
113 .returnType("REAL")
114 .argumentType("INTEGER")
115 .argumentType("INTEGER")
116 .build()};
117 }
118
123 float* getTensor() const override {
124 return weights_;
125 }
126
131 std::string getFuncName() {
132 return getName();
133 };
134
139 static std::string getName() {
140 return "svd";
141 };
142
147 std::string getWeightsFile() {
148 return weightsFile_;
149 }
150
155 void setWeights(float* weights) {
156 weights_ = weights;
157 }
158
164 CostEstimate getCost(std::vector<int> inputDims) {
165 std::vector<double> coefficientVector = getCoefficientVector(getName());
166 float cost = coefficientVector[0] * inputDims[0] * inputDims[1];
167
168 return CostEstimate(cost, inputDims[0], inputDims[1]);
169 }
170
171 private:
172 float* weights_;
173 float* bu_;
174 float* bi_;
175 float* pu_;
176 float* qi_;
177 std::string weightsFile_;
178};
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< double > getCoefficientVector(std::string name)
Retrieves the cost coefficients for the function.
Definition BaseFunction.h:83
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61
static std::string getName()
Returns the name of the function.
Definition SVD.h:139
float * getTensor() const override
Returns the tensor associated with the function.
Definition SVD.h:123
void setWeights(float *weights)
Sets the weights for the function.
Definition SVD.h:155
CostEstimate getCost(std::vector< int > inputDims)
Estimates the cost of the function.
Definition SVD.h:164
std::string getFuncName()
Returns the name of the function.
Definition SVD.h:131
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition SVD.h:111
SVD(float *bu, float *bi, float *pu, float *qi, int numUser, int numItem, int latentDims)
Constructor for SVD.
Definition SVD.h:43
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &outputType, exec::EvalCtx &context, VectorPtr &output) const override
Applies the SVD function to the input data.
Definition SVD.h:72
std::string getWeightsFile()
Returns the path to the weights file.
Definition SVD.h:147