ML functions
 
Loading...
Searching...
No Matches
CosineSimilarity.h
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
24public:
31 dims.push_back(dim);
32 }
33
46 void apply(
47 const SelectivityVector& rows,
48 std::vector<VectorPtr>& args,
49 const TypePtr& type,
50 exec::EvalCtx& context,
51 VectorPtr& output) const override {
52 BaseVector::ensureWritable(rows, type, context.pool(), output);
53
54 BaseVector* left = args[0].get();
55 BaseVector* right = args[1].get();
56
57 exec::LocalDecodedVector leftHolder(context, *left, rows);
58 auto decodedLeftArray = leftHolder.get();
59 auto leftInputOffset =
60 decodedLeftArray->base()->as<ArrayVector>()->rawOffsets();
61 auto baseLeftArray =
62 decodedLeftArray->base()->as<ArrayVector>()->elements();
63
64 exec::LocalDecodedVector rightHolder(context, *right, rows);
65 auto decodedRightArray = rightHolder.get();
66 auto rightInputOffset =
67 decodedRightArray->base()->as<ArrayVector>()->rawOffsets();
68 auto baseRightArray =
69 decodedRightArray->base()->as<ArrayVector>()->elements();
70 float* input1Values = baseLeftArray->values()->asMutable<float>();
71 float* input2Values = baseRightArray->values()->asMutable<float>();
72
73 int numInput = rows.size();
74
75 std::vector<float> resultVector(numInput);
76
77 rows.applyToSelected([&](vector_size_t i) {
78 // Map the input values into Eigen vectors
79 auto leftIndexInRaw = decodedLeftArray->index(i);
80 auto rightIndexInRaw = decodedRightArray->index(i);
81 Eigen::Map<Eigen::VectorXf> vec1(
82 input1Values + leftInputOffset[leftIndexInRaw], dims[0]);
83 Eigen::Map<Eigen::VectorXf> vec2(
84 input2Values + rightInputOffset[rightIndexInRaw], dims[0]);
85
86 // Compute cosine similarity
87 float dotProduct = vec1.dot(vec2);
88 float norm1 = vec1.norm();
89 float norm2 = vec2.norm();
90 float cosineSim = dotProduct / (norm1 * norm2 + 1e-8);
91
92 // Store the result
93 resultVector[i] = cosineSim;
94 });
95
96 VectorMaker maker{context.pool()};
97 output = maker.flatVector<float>(resultVector, REAL());
98 }
99
105 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
106 return {exec::FunctionSignatureBuilder()
107 .argumentType("array(REAL)")
108 .argumentType("array(REAL)")
109 .returnType("REAL")
110 .build()};
111 }
112
118 static std::string getName() {
119 return "cosine_similarity";
120 }
121
127 float* getTensor() const override {
128 return nullptr;
129 }
130
137 CostEstimate getCost(std::vector<int> inputDims) {
138 return CostEstimate(0, inputDims[0], inputDims[1]);
139 }
140
141private:
142 std::vector<int> dims;
143};
static std::string getName()
Returns the name of the function.
Definition CosineSimilarity.h:118
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures supported by this class.
Definition CosineSimilarity.h:105
CostEstimate getCost(std::vector< int > inputDims)
Estimates the computational cost of applying the cosine similarity computation.
Definition CosineSimilarity.h:137
CosineSimilarity(int dim)
Constructor that initializes the cosine similarity computation with the dimension of the input arrays...
Definition CosineSimilarity.h:30
float * getTensor() const override
Returns the tensor associated with this function.
Definition CosineSimilarity.h:127
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the cosine similarity computation to the input arrays.
Definition CosineSimilarity.h:46
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9