ML functions
 
Loading...
Searching...
No Matches
Concat.h
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
23class Concat : public MLFunction {
24public:
31 Concat(int input1Dims, int input2Dims) {
32 input1Dims_ = input1Dims;
33 input2Dims_ = input2Dims;
34 LOG(ERROR) << "[ERROR UDF-CONCAT] Bug exists in the apply function when decoding the right input arrays of filtered rows. Use built-in concat instead!";
35 }
36
49 void apply(
50 const SelectivityVector& rows,
51 std::vector<VectorPtr>& args,
52 const TypePtr& type,
53 exec::EvalCtx& context,
54 VectorPtr& output) const override {
55 BaseVector::ensureWritable(rows, type, context.pool(), output);
56
57 // Decoder is required to handle address error, reference code:
58 // ArrayIntersectExcept.cpp
59 BaseVector* left = args[0].get();
60 BaseVector* right = args[1].get();
61
62 exec::LocalDecodedVector leftHolder(context, *left, rows);
63 auto decodedLeftArray = leftHolder.get();
64 auto baseLeftArray =
65 decodedLeftArray->base()->as<ArrayVector>()->elements();
66
67 exec::LocalDecodedVector rightHolder(context, *right, rows);
68 auto decodedRightArray = rightHolder.get();
69 auto baseRightArray =
70 decodedRightArray->base()->as<ArrayVector>()->elements();
71
72 float* input1Values = baseLeftArray->values()->asMutable<float>();
73 float* input2Values = baseRightArray->values()->asMutable<float>();
74 // std::cout << "[DEBUG]: rows.size(): " << rows.size()
75 // << " # selected: " << rows.countSelected() << std::endl;
76 // std::cout << "[DEBUG]: size of Elements: " << baseLeftArray->size() << ",
77 // "
78 // << baseRightArray->size() << std::endl;
79 // std::cout << "[DEBUG] input1Dims_: " << input1Dims_
80 // << ", input2Dims_: " << input2Dims_ << std::endl;
81
82 std::vector<std::vector<float>> results;
83
84 for (int i = 0; i < rows.size(); i++) {
85 std::vector<float> concatenatedVector(input1Dims_ + input2Dims_);
86 // copy the 1st array
87 std::memcpy(
88 concatenatedVector.data(),
89 input1Values + i * input1Dims_,
90 input1Dims_ * sizeof(float));
91 // copy the 2nd array
92 std::memcpy(
93 concatenatedVector.data() + input1Dims_,
94 input2Values + i * input2Dims_,
95 input2Dims_ * sizeof(float));
96 results.push_back(concatenatedVector);
97 }
98
99 VectorMaker maker{context.pool()};
100 output = maker.arrayVector<float>(results, REAL());
101 }
102
108 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
109 return {exec::FunctionSignatureBuilder()
110 .argumentType("array(REAL)")
111 .argumentType("array(REAL)")
112 .returnType("array(REAL)")
113 .build()};
114 }
115
121 static std::string getName() {
122 return "concat";
123 }
124
130 float* getTensor() const override {
131 return nullptr;
132 }
133
140 CostEstimate getCost(std::vector<int> inputDims) {
141 return CostEstimate(0, inputDims[0], inputDims[1]);
142 }
143
144private:
145 int input1Dims_;
146 int input2Dims_;
147};
CostEstimate getCost(std::vector< int > inputDims)
Estimates the computational cost of applying the concatenation operation.
Definition Concat.h:140
Concat(int input1Dims, int input2Dims)
Constructor that initializes the concatenation operation with input dimensions.
Definition Concat.h:31
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures supported by this class.
Definition Concat.h:108
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the concatenation operation to the input arrays.
Definition Concat.h:49
static std::string getName()
Returns the name of the function.
Definition Concat.h:121
float * getTensor() const override
Returns the tensor associated with this function.
Definition Concat.h:130
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9