ML functions
 
Loading...
Searching...
No Matches
BatchNorm.h
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
24class BatchNorm1D : public MLFunction {
25public:
34 BatchNorm1D(float* weights, float* bias, int numDims, float eps = 1e-05) {
35 weights_ = new float[numDims];
36 bias_ = new float[numDims];
37 std::memcpy(weights_, weights, numDims * sizeof(float));
38 std::memcpy(bias_, bias, numDims * sizeof(float));
39 eps_ = eps;
40 dims.push_back(numDims);
41 }
42
54 void apply(
55 const SelectivityVector& rows,
56 std::vector<VectorPtr>& args,
57 const TypePtr& type,
58 exec::EvalCtx& context,
59 VectorPtr& output) const override {
60 BaseVector::ensureWritable(rows, type, context.pool(), output);
61 output->clearNulls(rows);
62 auto arrayOutput = output->as<ArrayVector>();
63 auto sizes = arrayOutput->mutableSizes(rows.end());
64 auto rawSizes = sizes->asMutable<int32_t>();
65 auto offsets = arrayOutput->mutableOffsets(rows.end());
66 auto rawOffsets = offsets->asMutable<int32_t>();
67
68 // Initialize sizes and offsets to zero.
69 std::fill(rawSizes, rawSizes + rows.end(), 0);
70 std::fill(rawOffsets, rawOffsets + rows.end(), 0);
71
72 auto elementsOutput = arrayOutput->elements();
73 auto elementsPool = context.pool();
74
75 exec::DecodedArgs decodedArgs(rows, args, context);
76 auto decodedInput = decodedArgs.at(0);
77 auto numRows = rows.size();
78
79 auto inputArray = decodedInput->base()->as<ArrayVector>();
80 auto inputElements = inputArray->elements();
81 float* inputValues = inputElements->values()->asMutable<float>();
82 auto inputOffsets = inputArray->rawOffsets();
83 auto inputSizes = inputArray->rawSizes();
84
85 std::map<vector_size_t, vector_size_t> rowMap;
86 std::unordered_set<vector_size_t> uniqueRawIndexeSet;
87 std::vector<vector_size_t> uniqueRawIndexeVector;
88 vector_size_t numUniqueRows = 0;
89 int numCols = dims[0];
90 rows.applyToSelected([&](vector_size_t row) {
91 auto mappedIndexInRowData = decodedInput->index(row);
92 if (uniqueRawIndexeSet.find(mappedIndexInRowData) ==
93 uniqueRawIndexeSet.end()) {
94 // add it
95 rowMap[row] = numUniqueRows;
96 uniqueRawIndexeSet.insert(mappedIndexInRowData);
97 uniqueRawIndexeVector.push_back(mappedIndexInRowData);
98 ++numUniqueRows;
99 } else {
100 // already added
101 rowMap[row] = rowMap[mappedIndexInRowData];
102 }
103 });
104
105 int numInputMatrixRows = numUniqueRows;
106 Eigen::MatrixXf inputMatrix(numInputMatrixRows, numCols);
107 int rowIndex = 0;
108 for (auto rawIndex : uniqueRawIndexeVector) {
109 Eigen::Map<const Eigen::VectorXf> rowVector(
110 inputValues + inputOffsets[rawIndex], numCols);
111 inputMatrix.row(rowIndex++) = rowVector;
112 }
113
114 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
115 resultMatrix(numInputMatrixRows, numCols);
116 for (int i = 0; i < numCols; i++) {
117 Eigen::VectorXf colData = inputMatrix.col(i);
118 float colMean = colData.mean();
119 float colVariance =
120 (colData.array() - colMean).square().sum() / (numInputMatrixRows - 1);
121
122 resultMatrix.col(i) =
123 (colData.array() - colMean) / sqrt(colVariance + eps_) * weights_[i] +
124 bias_[i];
125 }
126
127 auto baseOffset = elementsOutput->size();
128 elementsOutput->resize(baseOffset + rows.end() * numCols);
129 float* outputValues = elementsOutput->values()->asMutable<float>();
130 vector_size_t outputOffset = 0;
131 rows.applyToSelected([&](vector_size_t row) {
132 if (rowMap.find(row) == rowMap.end()) {
133 throw std::runtime_error(
134 "Mapped index not found for the result matrix.");
135 }
136 auto mappedIndexInResultMatrix = rowMap[row];
137 rawOffsets[row] = outputOffset;
138 rawSizes[row] = numCols;
139 std::memcpy(
140 outputValues + outputOffset,
141 resultMatrix.row(mappedIndexInResultMatrix).data(),
142 numCols * sizeof(float));
143 });
144
145 arrayOutput->setElements(elementsOutput);
146 }
147
153 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
154 return {exec::FunctionSignatureBuilder()
155 .argumentType("array(REAL)")
156 .returnType("array(REAL)")
157 .build()};
158 }
159
165 float* getTensor() const override {
166 return weights_;
167 }
168
174 float* getWeight() {
175 return weights_;
176 }
177
183 float* getBias() {
184 return bias_;
185 }
186
192 static std::string getName() {
193 return "batch_norm_1d";
194 }
195
201 std::string getWeightsFile() {
202 return weightsFile_;
203 }
204
210 void setWeights(float* weights) {
211 weights_ = weights;
212 }
213
220 CostEstimate getCost(std::vector<int> inputDims) {
221 return CostEstimate(0, inputDims[0], inputDims[1]);
222 }
223
224private:
225 float* weights_;
226 float* bias_;
227 float eps_;
228 std::string weightsFile_;
229 std::string biasFile_;
230};
CostEstimate getCost(std::vector< int > inputDims)
Estimates the computational cost of applying batch normalization.
Definition BatchNorm.h:220
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures supported by this class.
Definition BatchNorm.h:153
float * getBias()
Returns the biases of the batch normalization.
Definition BatchNorm.h:183
std::string getWeightsFile()
Returns the path to the weights file.
Definition BatchNorm.h:201
void setWeights(float *weights)
Sets the weights for the batch normalization.
Definition BatchNorm.h:210
float * getWeight()
Returns the weights of the batch normalization.
Definition BatchNorm.h:174
float * getTensor() const override
Returns the tensor associated with this function.
Definition BatchNorm.h:165
BatchNorm1D(float *weights, float *bias, int numDims, float eps=1e-05)
Constructor that initializes the batch normalization operation with weights, biases,...
Definition BatchNorm.h:34
static std::string getName()
Returns the name of the function.
Definition BatchNorm.h:192
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies 1D batch normalization to the input array.
Definition BatchNorm.h:54
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9
std::vector< int > dims
Dimensions of the function.
Definition BaseFunction.h:61