ML functions
 
Loading...
Searching...
No Matches
HuggingFaceServerless.h
Go to the documentation of this file.
1
15
16#pragma once
17
18#include <cpr/cpr.h>
19#include <json/json.h>
20#include <nlohmann/json.hpp>
21#include <Eigen/Dense>
22#include <cmath>
23#include <iostream>
24#include "BaseFunction.h"
25#include "velox/exec/tests/utils/AssertQueryBuilder.h"
26#include "velox/exec/tests/utils/PlanBuilder.h"
27#include "velox/exec/tests/utils/TempDirectoryPath.h"
28#include "velox/ml_functions/UtilFunction.h"
29#include "velox/vector/tests/utils/VectorTestBase.h"
30
31using namespace facebook::velox;
32using namespace facebook::velox::test;
33using namespace facebook::velox::exec::test;
34using namespace facebook::velox::memory;
35
46
52 public:
58 HuggingFaceServerless(std::string apiURL, HuggingFaceTaskType taskType) {
59 apiURL_ = apiURL;
60 taskType_ = taskType;
61 apiToken_ = getEnvVar("HF_TOKEN");
62 if (apiToken_ == "") {
63 throw std::runtime_error(fmt::format(
64 "[ERROR] HuggingFace token is not set, please set HF_TOKEN"));
65 }
66 inputTokenNumber_ = 0;
67 outputTokenNumber_ = 0;
68 numFailures_ = 0;
69 }
70
76 std::string filename = "huggingfaceServerless.log";
77 std::ofstream file(filename, std::ios::app);
78 if (!file) {
79 std::cerr << "Unable to open file: " << filename << std::endl;
80 return;
81 }
82 // Get the current time
83 auto now = std::chrono::system_clock::now();
84 std::time_t now_c = std::chrono::system_clock::to_time_t(now);
85
86 // Write the timestamp to the file
87 file << std::put_time(std::localtime(&now_c), "%Y-%m-%d %H:%M:%S") << " ";
88
89 // Write the uint64_t values to the file
90 file << "[HuggingFaceServerless] # Input:" << inputTokenNumber_
91 << " # Output: " << outputTokenNumber_
92 << " # NumFailure: " << numFailures_ << std::endl;
93 file.close();
94 }
95
104 void apply(
105 const SelectivityVector& rows,
106 std::vector<VectorPtr>& args,
107 const TypePtr& type,
108 exec::EvalCtx& context,
109 VectorPtr& output) const override {
110 BaseVector::ensureWritable(rows, type, context.pool(), output);
111 // Read string input
112 exec::LocalDecodedVector decodedStringHolder(context, *args[0], rows);
113 auto decodedStringInput = decodedStringHolder.get();
114 int numInputs = rows.size();
115
116 cpr::Header headers{{"Authorization", fmt::format("Bearer {}", apiToken_)}};
117
118 std::vector<std::vector<float>> result(numInputs);
119
120 // Limit of number of inputs can be sent to serverless API at once
121 // HuggingFace itself suggests a maximum number of 10K, but the API
122 // may be busy and fail. Try reducing the limit or deploy a dedicated endpoint.
123 const int HF_SERVERLESS_INPUT_LIMIT = 5000;
124
125 // HuggingFace inputs are formatted as follows:
126 // "inputs": ["Sentence 1", "Sentence 2"],
127 std::string strInputs = "[";
128 int accuInputCount = 0;
129 int insertedDataIdx = 0;
130 for (int i = 0; i < numInputs; i++) {
131 std::string valString =
132 std::string(decodedStringInput->valueAt<StringView>(i));
133 const_cast<uint64_t&>(inputTokenNumber_) =
134 inputTokenNumber_ + countWords(valString);
135
136 strInputs += "\"" + valString + "\"";
137 accuInputCount += 1;
138
139 if (i != (numInputs - 1) && accuInputCount != HF_SERVERLESS_INPUT_LIMIT) {
140 strInputs += ",";
141 } else {
142 // Need to post the inputs to Hugging Face serverless API and get results
143 strInputs += "]";
144 auto huggingFaceInputs = "{\"inputs\": " + strInputs + "}";
145
146 cpr::Response response = cpr::Post(
147 cpr::Url{apiURL_},
148 cpr::Body{huggingFaceInputs},
149 cpr::Header{headers});
150
151 if (response.status_code == 200) {
152 // The response text can be parsed as JSON objects,
153 // it should be a list of results for each input
154 auto jsonObj = nlohmann::json::parse(response.text);
155 int processedEmbeddingCount = 0;
156 for (const auto& innerVector : jsonObj) {
157 // Iterate response for each sample
159 std::vector<float> floatVector(3);
160 for (const auto& value : innerVector) {
161 int dataIdx = 0;
162 // TODO: Different models come with different return value names
163 // Need more work here to handle such cases
164 if (value["label"] == "positive") {
165 dataIdx = 0;
166 } else if (value["label"] == "neutral") {
167 dataIdx = 1;
168 } else if (value["label"] == "negative") {
169 dataIdx = 2;
170 }
171 floatVector[dataIdx] = value["score"];
172 }
173 result[insertedDataIdx++] = floatVector;
174 } else if (
176 if (processedEmbeddingCount == accuInputCount) {
177 break;
178 }
179 // Need case-by-case handling for different models
180 if (apiURL_.find("all-MiniLM") != std::string::npos) {
181 auto returnedEmbedding = innerVector;
182 std::vector<float> embeddingVector;
183 for (const auto& val : returnedEmbedding) {
184 embeddingVector.push_back(val);
185 }
186 const_cast<uint64_t&>(outputTokenNumber_) =
187 outputTokenNumber_ + embeddingVector.size();
188 processedEmbeddingCount += 1;
189 result[insertedDataIdx++] = embeddingVector;
190 if (processedEmbeddingCount == accuInputCount) {
191 break;
192 }
193 } else {
194 auto returnedEmbedding = innerVector[0];
195 // FIXME: Sometimes it returns an unfixed number of embeddings,
196 // need further investigation
197 for (const auto& value : returnedEmbedding) {
198 std::vector<float> embeddingVector;
199 for (const auto& val : value) {
200 embeddingVector.push_back(val);
201 }
202 processedEmbeddingCount += 1;
203 result[insertedDataIdx++] = embeddingVector;
204 if (processedEmbeddingCount == accuInputCount) {
205 break;
206 }
207 }
208 }
209 } else {
210 throw std::runtime_error(fmt::format(
211 "Current HuggingFace Task Type {} is not supported",
212 taskType_));
213 }
214 }
215 } else {
216 // Handle error cases
217 std::cerr << "Error in fetching the results: "
218 << response.error.message << std::endl;
219 for (int l = 0; i < accuInputCount; i++) {
220 result[insertedDataIdx++] = {0.0};
221 }
222 }
223
224 // Reset
225 strInputs = "[";
226 accuInputCount = 0;
227 }
228 }
229
230 VectorMaker maker{context.pool()};
231 output = maker.arrayVector<float>(result, REAL());
232 }
233
238 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
239 return {exec::FunctionSignatureBuilder()
240 .argumentType("VARCHAR")
241 .returnType("array(REAL)")
242 .build()};
243 }
244
249 static std::string getName() {
250 return "huggingface";
251 }
252
257 float* getTensor() const override {
258 // TODO: need to implement
259 return nullptr;
260 }
261
267 CostEstimate getCost(std::vector<int> inputDims) {
268 // TODO: need to implement
269 return CostEstimate(0, inputDims[0], inputDims[1]);
270 }
271
272 private:
273 std::string apiURL_;
274 std::string apiToken_;
275 HuggingFaceTaskType taskType_;
276 uint64_t inputTokenNumber_;
277 uint64_t outputTokenNumber_;
278 uint64_t numFailures_;
279};
HuggingFaceTaskType
Enumeration of supported Hugging Face task types.
Definition HuggingFaceServerless.h:40
@ TEXT_FEATURE_EXTRACTION
Text feature extraction task.
Definition HuggingFaceServerless.h:44
@ IMAGE_CLASSIFICATION
Image classification task.
Definition HuggingFaceServerless.h:42
@ TEXT_CLASSIFICATION
Text classification task.
Definition HuggingFaceServerless.h:41
@ REGRESSION
Regression task.
Definition HuggingFaceServerless.h:43
std::string getEnvVar(std::string const &key)
Retrieves the value of an environment variable.
Definition UtilFunction.h:483
int countWords(const std::string &input)
Counts the number of words in a string.
Definition UtilFunction.h:555
~HuggingFaceServerless()
Destructor for HuggingFaceServerless. Logs input/output token counts and failure statistics to a file...
Definition HuggingFaceServerless.h:75
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures.
Definition HuggingFaceServerless.h:238
static std::string getName()
Returns the name of the function.
Definition HuggingFaceServerless.h:249
float * getTensor() const override
Returns the tensor associated with the function.
Definition HuggingFaceServerless.h:257
CostEstimate getCost(std::vector< int > inputDims)
Estimates the cost of the function.
Definition HuggingFaceServerless.h:267
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the Hugging Face serverless API function to the input data.
Definition HuggingFaceServerless.h:104
HuggingFaceServerless(std::string apiURL, HuggingFaceTaskType taskType)
Constructor for HuggingFaceServerless.
Definition HuggingFaceServerless.h:58
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9