ML functions
 
Loading...
Searching...
No Matches
Embedding.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <fmt/format.h>
19#include <iostream>
20#include "BaseFunction.h"
21#include "velox/exec/tests/utils/AssertQueryBuilder.h"
22#include "velox/exec/tests/utils/PlanBuilder.h"
23#include "velox/exec/tests/utils/TempDirectoryPath.h"
24#include "velox/vector/tests/utils/VectorTestBase.h"
25
26using namespace facebook::velox;
27using namespace facebook::velox::test;
28using namespace facebook::velox::exec::test;
29using namespace facebook::velox::memory;
30
35class Embedding : public MLFunction {
36 public:
43 Embedding(float* weights, int numEmbeddings, int embeddingDims) {
44 // Create a deep copy of the weights
45 weights_ = new float[numEmbeddings * embeddingDims];
46 std::memcpy(
47 weights_, weights, numEmbeddings * embeddingDims * sizeof(float));
48 // weights_ = std::move(weights);
49 dims.push_back(numEmbeddings);
50 dims.push_back(embeddingDims);
51 }
52
59 Embedding(std::string weightsFile, int numEmbeddings, int embeddingDims) {
60 weightsFile_ = weightsFile;
61 dims.push_back(numEmbeddings);
62 dims.push_back(embeddingDims);
63 }
64
73 void apply(
74 const SelectivityVector& rows,
75 std::vector<VectorPtr>& args,
76 const TypePtr& type,
77 exec::EvalCtx& context,
78 VectorPtr& output) const override {
79 BaseVector::ensureWritable(rows, type, context.pool(), output);
80 output->clearNulls(rows);
81 auto arrayOutput = output->as<ArrayVector>();
82 auto sizes = arrayOutput->mutableSizes(rows.end());
83 auto rawSizes = sizes->asMutable<int32_t>();
84 auto offsets = arrayOutput->mutableOffsets(rows.end());
85 auto rawOffsets = offsets->asMutable<int32_t>();
86
87 // Initialize sizes and offsets to zero.
88 std::fill(rawSizes, rawSizes + rows.end(), 0);
89 std::fill(rawOffsets, rawOffsets + rows.end(), 0);
90
91 auto elementsOutput = arrayOutput->elements();
92 auto elementsPool = context.pool();
93
94 exec::DecodedArgs decodedArgs(rows, args, context);
95 auto input = decodedArgs.at(0);
96 auto arrayVector = input->base()->as<ArrayVector>();
97
98 auto indicesVector = arrayVector->elements();
99 int* indicesValues = indicesVector->values()->asMutable<int>();
100 int numInputs = rows.size();
101 // You can also use sizeof(*arrayVector->rawSizes()) to compute the size of
102 // a single entry in BufferPtr
103
104 int numEmbeddingToRetireve = 0;
105 rows.applyToSelected([&](vector_size_t row) {
106 int numSubIndices = arrayVector->sizeAt(row);
107 numEmbeddingToRetireve += numSubIndices;
108 });
109
110 auto baseOffset = elementsOutput->size();
111 // here, we need to resize it to the number of embeddings we need to
112 // retrieve
113 elementsOutput->resize(baseOffset + numEmbeddingToRetireve * dims[1]);
114 float* outputValues = elementsOutput->values()->asMutable<float>();
115
116 vector_size_t outputOffset = 0;
117 rows.applyToSelected([&](vector_size_t row) {
118 int numSubIndices = arrayVector->sizeAt(row);
119 int indicesOffset = arrayVector->offsetAt(row);
120 rawOffsets[row] = outputOffset;
121 rawSizes[row] = numSubIndices * dims[1];
122 for (int i = 0; i < numSubIndices; i++) {
123 // Support of variadic indexes
124 int embedIndex = indicesValues[indicesOffset + i];
125 if (embedIndex >= dims[0]) {
126 throw std::runtime_error(fmt::format(
127 "[Embedding] Index out of bounds: {} >= {}",
128 embedIndex,
129 dims[0]));
130 }
131 std::memcpy(
132 outputValues + outputOffset,
133 weights_ + embedIndex * dims[1],
134 dims[1] * sizeof(float));
135 outputOffset += dims[1];
136 }
137 });
138 arrayOutput->setElements(elementsOutput);
139 }
140
145 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
146 return {exec::FunctionSignatureBuilder()
147 .argumentType("array(INTEGER)")
148 .returnType("array(REAL)")
149 .build()};
150 }
151
156 float* getTensor() const override {
157 return weights_;
158 }
159
164 static std::string getName() {
165 return "embedding";
166 };
167
172 std::string getWeightsFile() {
173 return weightsFile_;
174 }
175
180 void setWeights(float* weights) {
181 weights_ = weights;
182 }
183
184 private:
185 float* weights_;
186 std::string weightsFile_;
187};
static std::string getName()
Returns the name of the function.
Definition Embedding.h:164
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition Embedding.h:145
void setWeights(float *weights)
Sets the embedding weights.
Definition Embedding.h:180
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the embedding function to the input data.
Definition Embedding.h:73
Embedding(float *weights, int numEmbeddings, int embeddingDims)
Constructor for Embedding.
Definition Embedding.h:43
float * getTensor() const override
Returns the tensor associated with the function.
Definition Embedding.h:156
std::string getWeightsFile()
Returns the path to the weights file.
Definition Embedding.h:172
Embedding(std::string weightsFile, int numEmbeddings, int embeddingDims)
Constructor for Embedding.
Definition Embedding.h:59
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61