ML functions
 
Loading...
Searching...
No Matches
ChatGPT.h
1/*
2 * Copyright (c) 2025 ASU Cactus Lab.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
23class ChatGPT : public MLFunction {
24public:
32 apiKey_ = getEnvVar("OPENAI_API_KEY");
33 if (apiKey_ == "") {
34 throw std::runtime_error("[ERROR] OpenAI API key is not set, please set OPENAI_API_KEY");
35 }
36 numThreads_ = getEnvVar("NUM_THREADS") == "" ? 8 : std::stoi(getEnvVar("NUM_THREADS"));
37 url_ = "https://api.openai.com/v1/chat/completions";
38 model_ = "gpt-3.5-turbo";
39 inputTokenNumber_ = 0;
40 outputTokenNumber_ = 0;
41 numFailures_ = 0;
42 }
43
53 ChatGPT(std::string url, std::string model) {
54 apiKey_ = getEnvVar("OPENAI_API_KEY");
55 if (apiKey_ == "") {
56 throw std::runtime_error("[ERROR] OpenAI API key is not set, please set OPENAI_API_KEY");
57 }
58 numThreads_ = getEnvVar("NUM_THREADS") == "" ? 8 : std::stoi(getEnvVar("NUM_THREADS"));
59 url_ = url;
60 model_ = model;
61 inputTokenNumber_ = 0;
62 outputTokenNumber_ = 0;
63 numFailures_ = 0;
64 }
65
72 std::string filename = "chatgpt.log";
73 std::ofstream file(filename, std::ios::app);
74 if (!file) {
75 std::cerr << "Unable to open file: " << filename << std::endl;
76 return;
77 }
78 auto now = std::chrono::system_clock::now();
79 std::time_t now_c = std::chrono::system_clock::to_time_t(now);
80 file << std::put_time(std::localtime(&now_c), "%Y-%m-%d %H:%M:%S") << " ";
81 file << "[ChatGPT] # Input:" << inputTokenNumber_
82 << " # Output: " << outputTokenNumber_
83 << " # NumFailure: " << numFailures_ << std::endl;
84 file.close();
85 }
86
99 void apply(
100 const SelectivityVector& rows,
101 std::vector<VectorPtr>& args,
102 const TypePtr& type,
103 exec::EvalCtx& context,
104 VectorPtr& output) const override {
105 std::string promptPrefix = "";
106 BaseVector::ensureWritable(rows, type, context.pool(), output);
107
108 exec::LocalDecodedVector decodedStringHolder(context, *args[0], rows);
109 auto decodedStringInput = decodedStringHolder.get();
110
111 int numInput = rows.size();
112 int numSelected = rows.countSelected();
113 LOG(INFO) << "[INFO ChatGPT:] countSelected: " << rows.countSelected()
114 << " numInput: " << numInput << std::endl;
115
116 if (args.size() == 2) {
117 exec::LocalDecodedVector decodedStringHolder2(context, *args[1], rows);
118 auto decodedStringInput2 = decodedStringHolder2.get();
119 StringView val = decodedStringInput2->valueAt<StringView>(0);
120 promptPrefix = std::string(val);
121 }
122
123 std::vector<std::string> results;
124
125 cpr::Header headers{
126 {"Content-Type", "application/json"},
127 {"Authorization", "Bearer " + apiKey_}};
128
129 // Thread vector
130 std::vector<std::thread> threads;
131 std::vector<cpr::Response> responses(numInput);
132 std::vector<int> numFailureVector(numInput);
133
134 int numInputsPerThread = int(std::ceil(float(numSelected) / numThreads_));
135 int processedInputCount = 0;
136 std::vector<std::string> payloadsBatchVector;
137 int processedIndex = 0;
138
139 // Version 1
140 // This approach is more efficient by sending requests in batches and
141 // leveraging multiple threads to send requests concurrently, it requires
142 // additional isValid check to skip the rows that are not selected. Note: at
143 // the end of this approach, it is required to invoke
144 // context.moveOrCopyResult to copy the results back to the output vector
145 // since we only compute the results for selected ones
146 for (int i = 0; i < numInput; i++) {
147 // if the row is not selected, skip
148 if (!rows.isValid(i)) {
149 continue;
150 }
151 StringView val = decodedStringInput->valueAt<StringView>(i);
152 std::string valString = promptPrefix + std::string(val);
153 nlohmann::json messageArrays = nlohmann::json::array();
154 // Add message
155 messageArrays.push_back({{"role", "user"}, {"content", valString}});
156
157 nlohmann::json payload = {
158 {"model", model_}, {"messages", messageArrays}, {"max_tokens", 150}};
159
160 payloadsBatchVector.push_back(payload.dump());
161 processedInputCount++;
162
163 if (processedInputCount == numInputsPerThread || i == numInput - 1) {
164 threads.emplace_back(
165 sendRequestViaCprBatch,
166 url_,
167 headers,
168 payloadsBatchVector,
169 std::ref(responses),
170 processedIndex - processedInputCount + 1,
171 std::ref(numFailureVector));
172 processedInputCount = 0;
173 payloadsBatchVector.clear();
174 }
175 processedIndex++;
176 }
177
178 for (auto& thread : threads) {
179 thread.join();
180 }
181
182 for (int i = 0; i < numSelected; i++) {
183 if (responses[i].status_code == 200) {
184 // parse the returned value
185 nlohmann::json response_json = nlohmann::json::parse(responses[i].text);
186 std::string generated_message =
187 response_json["choices"][0]["message"]["content"];
188 results.push_back(generated_message);
189 const_cast<uint64_t&>(inputTokenNumber_) = inputTokenNumber_ +
190 response_json["usage"]["prompt_tokens"].get<int>();
191 const_cast<uint64_t&>(outputTokenNumber_) = outputTokenNumber_ +
192 response_json["usage"]["completion_tokens"].get<int>();
193 const_cast<uint64_t&>(numFailures_) =
194 numFailures_ + numFailureVector[i];
195 if (numFailureVector[i] > 0) {
196 LOG(WARNING)
197 << "[WARNING] Failed to send request to OpenAI API. Number of retries: "
198 << numFailureVector[i] << " numFailures_: " << numFailures_
199 << std::endl;
200 }
201 LOG(INFO) << fmt::format(
202 "[INFO] i: {} / {}, results: {}, numFailures: {}",
203 i + 1,
204 numSelected,
205 generated_message,
206 numFailureVector[i])
207 << std::endl;
208 } else {
209 LOG(ERROR) << "Error: " << responses[i].status_code << " - "
210 << responses[i].text << std::endl;
211 }
212 }
213
214 VectorMaker maker{context.pool()};
215 VectorPtr localResult = maker.flatVector<std::string>(results);
216
217 context.moveOrCopyResult(localResult, rows, output);
218 }
219
225 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
226 return {
227 exec::FunctionSignatureBuilder()
228 .argumentType("VARCHAR")
229 .returnType("VARCHAR")
230 .build(),
231 exec::FunctionSignatureBuilder()
232 .argumentType("VARCHAR")
233 .argumentType("VARCHAR")
234 .returnType("VARCHAR")
235 .build()};
236 }
237
243 float* getTensor() const override {
244 return nullptr;
245 }
246
252 static std::string getName() {
253 return "chatgpt";
254 }
255
262 CostEstimate getCost(std::vector<int> inputDims) {
263 return CostEstimate(0, inputDims[0], inputDims[1]);
264 }
265
266private:
267 std::string apiKey_;
268 std::string url_;
269 std::string model_;
270 uint64_t inputTokenNumber_;
271 uint64_t outputTokenNumber_;
272 uint64_t numFailures_;
273 int numThreads_;
274};
275
283public:
291 apiKey_ = getEnvVar("OPENAI_API_KEY");
292 if (apiKey_ == "") {
293 throw std::runtime_error("[ERROR] OpenAI API key is not set, please set OPENAI_API_KEY");
294 }
295 numThreads_ = getEnvVar("NUM_THREADS") == "" ? 8 : std::stoi(getEnvVar("NUM_THREADS"));
296 url_ = "https://api.openai.com/v1/chat/completions";
297 model_ = "gpt-3.5-turbo";
298 inputTokenNumber_ = 0;
299 outputTokenNumber_ = 0;
300 numFailures_ = 0;
301 }
302
312 ChatGPTRecommender(std::string url, std::string model) {
313 apiKey_ = getEnvVar("OPENAI_API_KEY");
314 if (apiKey_ == "") {
315 throw std::runtime_error("[ERROR] OpenAI API key is not set, please set OPENAI_API_KEY");
316 }
317 numThreads_ = getEnvVar("NUM_THREADS") == "" ? 8 : std::stoi(getEnvVar("NUM_THREADS"));
318 url_ = url;
319 model_ = model;
320 inputTokenNumber_ = 0;
321 outputTokenNumber_ = 0;
322 numFailures_ = 0;
323 }
324
331 std::string filename = "chatgpt.log";
332 std::ofstream file(filename, std::ios::app);
333 if (!file) {
334 std::cerr << "Unable to open file: " << filename << std::endl;
335 return;
336 }
337 auto now = std::chrono::system_clock::now();
338 std::time_t now_c = std::chrono::system_clock::to_time_t(now);
339 file << std::put_time(std::localtime(&now_c), "%Y-%m-%d %H:%M:%S") << " ";
340 file << "[ChatGPT Recommender] # Input:" << inputTokenNumber_
341 << " # Output: " << outputTokenNumber_
342 << " # NumFailure: " << numFailures_ << std::endl;
343 file.close();
344 }
345
358 void apply(
359 const SelectivityVector& rows,
360 std::vector<VectorPtr>& args,
361 const TypePtr& type,
362 exec::EvalCtx& context,
363 VectorPtr& output) const override {
364 std::string promptSuffix = "";
365 BaseVector::ensureWritable(rows, type, context.pool(), output);
366
367 exec::LocalDecodedVector decodedStringHolder1(context, *args[0], rows);
368 auto decodedStringInput1 = decodedStringHolder1.get();
369
370 exec::LocalDecodedVector decodedStringHolder2(context, *args[1], rows);
371 auto decodedStringInput2 = decodedStringHolder2.get();
372
373 int numInput = rows.size();
374 int numSelected = rows.countSelected();
375 LOG(INFO) << "[INFO ChatGPTRecommender:] countSelected: "
376 << rows.countSelected() << " numInput: " << numInput << std::endl;
377
378 if (args.size() == 3) {
379 exec::LocalDecodedVector decodedStringHolder3(context, *args[2], rows);
380 auto decodedStringInput3 = decodedStringHolder3.get();
381 StringView val = decodedStringInput3->valueAt<StringView>(0);
382 promptSuffix = std::string(val);
383 }
384
385 std::vector<std::string> results;
386
387 cpr::Header headers{
388 {"Content-Type", "application/json"},
389 {"Authorization", "Bearer " + apiKey_}};
390
391 // Thread vector
392 std::vector<std::thread> threads;
393 std::vector<cpr::Response> responses(numInput);
394 std::vector<int> numFailureVector(numInput);
395
396 int numInputsPerThread = int(std::ceil(float(numSelected) / numThreads_));
397 int processedInputCount = 0;
398 std::vector<std::string> payloadsBatchVector;
399 int processedIndex = 0;
400
401 for (int i = 0; i < numInput; i++) {
402 // if the row is not selected, skip
403 if (!rows.isValid(i)) {
404 continue;
405 }
406 StringView val1 = decodedStringInput1->valueAt<StringView>(i);
407 StringView val2 = decodedStringInput2->valueAt<StringView>(i);
408 std::string valString =
409 "Summarized user statistics data (preference): " + std::string(val1) +
410 ". \n Summarized user movie metadata: " + std::string(val2) + ".\n" +
411 promptSuffix;
412 nlohmann::json messageArrays = nlohmann::json::array();
413 // Add message
414 messageArrays.push_back({{"role", "user"}, {"content", valString}});
415
416 nlohmann::json payload = {
417 {"model", model_}, {"messages", messageArrays}, {"max_tokens", 500}};
418
419 payloadsBatchVector.push_back(payload.dump());
420 processedInputCount++;
421
422 if (processedInputCount == numInputsPerThread || i == numInput - 1) {
423 threads.emplace_back(
424 sendRequestViaCprBatch,
425 url_,
426 headers,
427 payloadsBatchVector,
428 std::ref(responses),
429 processedIndex - processedInputCount + 1,
430 std::ref(numFailureVector));
431 processedInputCount = 0;
432 payloadsBatchVector.clear();
433 }
434 processedIndex++;
435 }
436
437 for (auto& thread : threads) {
438 thread.join();
439 }
440
441 for (int i = 0; i < numSelected; i++) {
442 if (responses[i].status_code == 200) {
443 // parse the returned value
444 nlohmann::json response_json = nlohmann::json::parse(responses[i].text);
445 std::string generated_message =
446 response_json["choices"][0]["message"]["content"];
447 results.push_back(generated_message);
448 const_cast<uint64_t&>(inputTokenNumber_) = inputTokenNumber_ +
449 response_json["usage"]["prompt_tokens"].get<int>();
450 const_cast<uint64_t&>(outputTokenNumber_) = outputTokenNumber_ +
451 response_json["usage"]["completion_tokens"].get<int>();
452 const_cast<uint64_t&>(numFailures_) =
453 numFailures_ + numFailureVector[i];
454 if (numFailureVector[i] > 0) {
455 LOG(WARNING)
456 << "[WARNING] Failed to send request to OpenAI API. Number of retries: "
457 << numFailureVector[i] << " numFailures_: " << numFailures_
458 << std::endl;
459 }
460 LOG(INFO) << fmt::format(
461 "[INFO] i: {} / {}, results: {}, numFailures: {}",
462 i + 1,
463 numSelected,
464 generated_message,
465 numFailureVector[i])
466 << std::endl;
467 } else {
468 LOG(ERROR) << "Error: " << responses[i].status_code << " - "
469 << responses[i].text << std::endl;
470 }
471 }
472
473 VectorMaker maker{context.pool()};
474 VectorPtr localResult = maker.flatVector<std::string>(results);
475
476 context.moveOrCopyResult(localResult, rows, output);
477 }
478
484 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
485 return {
486 exec::FunctionSignatureBuilder()
487 .argumentType("VARCHAR")
488 .argumentType("VARCHAR")
489 .returnType("VARCHAR")
490 .build(),
491 exec::FunctionSignatureBuilder()
492 .argumentType("VARCHAR")
493 .argumentType("VARCHAR")
494 .argumentType("VARCHAR")
495 .returnType("VARCHAR")
496 .build()};
497 }
498
504 float* getTensor() const override {
505 return nullptr;
506 }
507
513 static std::string getName() {
514 return "chatgpt_recommender";
515 }
516
523 CostEstimate getCost(std::vector<int> inputDims) {
524 return CostEstimate(0, inputDims[0], inputDims[1]);
525 }
526
527private:
528 std::string apiKey_;
529 std::string url_;
530 std::string model_;
531 uint64_t inputTokenNumber_;
532 uint64_t outputTokenNumber_;
533 uint64_t numFailures_;
534 int numThreads_;
535};
std::string getEnvVar(std::string const &key)
Retrieves the value of an environment variable.
Definition UtilFunction.h:483
float * getTensor() const override
Returns the tensor associated with this function.
Definition ChatGPT.h:504
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures supported by this class.
Definition ChatGPT.h:484
static std::string getName()
Returns the name of the function.
Definition ChatGPT.h:513
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the ChatGPTRecommender function to the input array.
Definition ChatGPT.h:358
ChatGPTRecommender()
Default constructor.
Definition ChatGPT.h:290
~ChatGPTRecommender()
Destructor.
Definition ChatGPT.h:330
ChatGPTRecommender(std::string url, std::string model)
Constructor with custom URL and model.
Definition ChatGPT.h:312
CostEstimate getCost(std::vector< int > inputDims)
Estimates the computational cost of applying the ChatGPTRecommender function.
Definition ChatGPT.h:523
float * getTensor() const override
Returns the tensor associated with this function.
Definition ChatGPT.h:243
CostEstimate getCost(std::vector< int > inputDims)
Estimates the computational cost of applying the ChatGPT function.
Definition ChatGPT.h:262
ChatGPT()
Default constructor.
Definition ChatGPT.h:31
static std::vector< std::shared_ptr< exec::FunctionSignature > > signatures()
Returns the function signatures supported by this class.
Definition ChatGPT.h:225
static std::string getName()
Returns the name of the function.
Definition ChatGPT.h:252
ChatGPT(std::string url, std::string model)
Constructor with custom URL and model.
Definition ChatGPT.h:53
void apply(const SelectivityVector &rows, std::vector< VectorPtr > &args, const TypePtr &type, exec::EvalCtx &context, VectorPtr &output) const override
Applies the ChatGPT function to the input array.
Definition ChatGPT.h:99
~ChatGPT()
Destructor.
Definition ChatGPT.h:71
A base class for machine learning functions, inheriting from Velox's VectorFunction.
Definition BaseFunction.h:9