ML functions
 
Loading...
Searching...
No Matches
Dropout.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <cmath>
19#include <iostream>
20#include "BaseFunction.h"
21#include "velox/exec/tests/utils/AssertQueryBuilder.h"
22#include "velox/exec/tests/utils/PlanBuilder.h"
23#include "velox/exec/tests/utils/TempDirectoryPath.h"
24#include "velox/vector/tests/utils/VectorTestBase.h"
25
26using namespace facebook::velox;
27using namespace facebook::velox::test;
28using namespace facebook::velox::exec::test;
29using namespace facebook::velox::memory;
30
35class Dropout : public MLFunction {
36 public:
41 Dropout(float p) {
42 p_ = p;
43 // std::random_device device;
44 // std::mt19937 gen(device());
45 // std::bernoulli_distribution coin_flip(0.5);
46 // bool outcome = coin_flip(gen);
47 }
48
57 void apply(
58 const SelectivityVector& rows,
59 std::vector<VectorPtr>& args,
60 const TypePtr& type,
61 exec::EvalCtx& context,
62 VectorPtr& output) const override {
63 std::random_device rd;
64 std::mt19937 gen(rd());
65 std::bernoulli_distribution bernoulli(p_);
66
67 BaseVector::ensureWritable(rows, type, context.pool(), output);
68
69 auto inputFeatures = args[0]->as<ArrayVector>()->elements();
70 float* inputValues = inputFeatures->values()->asMutable<float>();
71
72 int inputSize = inputFeatures->size();
73 int numInput = args[0]->size();
74 int numFeatures = inputSize / numInput;
75
76 // float* result[inputSize];
77 std::vector<std::vector<float>> result(
78 numInput, std::vector<float>(numFeatures));
79
80 for (int i = 0; i < numInput; i++) {
81 for (int j = 0; j < numFeatures; j++) {
82 bool outcome = bernoulli(gen);
83 if (outcome) {
84 result[i][j] = 0;
85 } else {
86 result[i][j] = inputValues[i * numFeatures + j];
87 }
88 }
89 }
90
91 VectorMaker maker{context.pool()};
92 output = maker.arrayVector<float>(result, REAL());
93 }
94
99 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
100 return {exec::FunctionSignatureBuilder()
101 .argumentType("array(REAL)")
102 .returnType("array(REAL)")
103 .build()};
104 }
105
110 static std::string getName() {
111 return "dropout";
112 };
113
118 float* getTensor() const override {
119 // FIXME
120 return nullptr;
121 }
122
127 void setWeight(float p) {
128 p_ = p;
129 }
130
131 private:
132 float p_;
133 std::mt19937 gen_;
134 std::bernoulli_distribution bernoulli_;
135};
float * getTensor() const override
Returns the tensor associated with the function.
Definition Dropout.h:118
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition Dropout.h:99
Dropout(float p)
Constructor for Dropout.
Definition Dropout.h:41
void setWeight(float p)
Sets the dropout probability.
Definition Dropout.h:127
static std::string getName()
Returns the name of the function.
Definition Dropout.h:110
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the dropout function to the input data.
Definition Dropout.h:57
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9