359 const SelectivityVector& rows,
360 std::vector<VectorPtr>& args,
362 exec::EvalCtx& context,
363 VectorPtr& output)
const override {
364 std::string promptSuffix =
"";
365 BaseVector::ensureWritable(rows, type, context.pool(), output);
367 exec::LocalDecodedVector decodedStringHolder1(context, *args[0], rows);
368 auto decodedStringInput1 = decodedStringHolder1.get();
370 exec::LocalDecodedVector decodedStringHolder2(context, *args[1], rows);
371 auto decodedStringInput2 = decodedStringHolder2.get();
373 int numInput = rows.size();
374 int numSelected = rows.countSelected();
375 LOG(INFO) <<
"[INFO ChatGPTRecommender:] countSelected: "
376 << rows.countSelected() <<
" numInput: " << numInput << std::endl;
378 if (args.size() == 3) {
379 exec::LocalDecodedVector decodedStringHolder3(context, *args[2], rows);
380 auto decodedStringInput3 = decodedStringHolder3.get();
381 StringView val = decodedStringInput3->valueAt<StringView>(0);
382 promptSuffix = std::string(val);
385 std::vector<std::string> results;
388 {
"Content-Type",
"application/json"},
389 {
"Authorization",
"Bearer " + apiKey_}};
392 std::vector<std::thread> threads;
393 std::vector<cpr::Response> responses(numInput);
394 std::vector<int> numFailureVector(numInput);
396 int numInputsPerThread = int(std::ceil(
float(numSelected) / numThreads_));
397 int processedInputCount = 0;
398 std::vector<std::string> payloadsBatchVector;
399 int processedIndex = 0;
401 for (
int i = 0; i < numInput; i++) {
403 if (!rows.isValid(i)) {
406 StringView val1 = decodedStringInput1->valueAt<StringView>(i);
407 StringView val2 = decodedStringInput2->valueAt<StringView>(i);
408 std::string valString =
409 "Summarized user statistics data (preference): " + std::string(val1) +
410 ". \n Summarized user movie metadata: " + std::string(val2) +
".\n" +
412 nlohmann::json messageArrays = nlohmann::json::array();
414 messageArrays.push_back({{
"role",
"user"}, {
"content", valString}});
416 nlohmann::json payload = {
417 {
"model", model_}, {
"messages", messageArrays}, {
"max_tokens", 500}};
419 payloadsBatchVector.push_back(payload.dump());
420 processedInputCount++;
422 if (processedInputCount == numInputsPerThread || i == numInput - 1) {
423 threads.emplace_back(
424 sendRequestViaCprBatch,
429 processedIndex - processedInputCount + 1,
430 std::ref(numFailureVector));
431 processedInputCount = 0;
432 payloadsBatchVector.clear();
437 for (
auto& thread : threads) {
441 for (
int i = 0; i < numSelected; i++) {
442 if (responses[i].status_code == 200) {
444 nlohmann::json response_json = nlohmann::json::parse(responses[i].text);
445 std::string generated_message =
446 response_json[
"choices"][0][
"message"][
"content"];
447 results.push_back(generated_message);
448 const_cast<uint64_t&
>(inputTokenNumber_) = inputTokenNumber_ +
449 response_json[
"usage"][
"prompt_tokens"].get<int>();
450 const_cast<uint64_t&
>(outputTokenNumber_) = outputTokenNumber_ +
451 response_json[
"usage"][
"completion_tokens"].get<int>();
452 const_cast<uint64_t&
>(numFailures_) =
453 numFailures_ + numFailureVector[i];
454 if (numFailureVector[i] > 0) {
456 <<
"[WARNING] Failed to send request to OpenAI API. Number of retries: "
457 << numFailureVector[i] <<
" numFailures_: " << numFailures_
460 LOG(INFO) << fmt::format(
461 "[INFO] i: {} / {}, results: {}, numFailures: {}",
468 LOG(ERROR) <<
"Error: " << responses[i].status_code <<
" - "
469 << responses[i].text << std::endl;
473 VectorMaker maker{context.pool()};
474 VectorPtr localResult = maker.flatVector<std::string>(results);
476 context.moveOrCopyResult(localResult, rows, output);