Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/MainDistributionPipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ jobs:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main
with:
duckdb_version: v1.4.0
duckdb_version: v1.4.2
extension_name: flock
ci_tools_version: main
exclude_archs: 'wasm_mvp;wasm_threads;wasm_eh'

duckdb-stable-build:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.4.0
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.4.2
with:
duckdb_version: v1.4.0
ci_tools_version: v1.4.0
duckdb_version: v1.4.2
ci_tools_version: v1.4.2
extension_name: flock
exclude_archs: 'wasm_mvp;wasm_threads;wasm_eh'
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 974 files
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_subdirectory(model_manager)
add_subdirectory(prompt_manager)
add_subdirectory(custom_parser)
add_subdirectory(secret_manager)
add_subdirectory(metrics)

set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/flock_extension.cpp ${EXTENSION_SOURCES}
Expand Down
3 changes: 0 additions & 3 deletions src/functions/aggregate/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

namespace flock {

nlohmann::json AggregateFunctionBase::model_details;
std::string AggregateFunctionBase::user_query;

void AggregateFunctionBase::ValidateArguments(duckdb::Vector inputs[], idx_t input_count) {
if (input_count != 3) {
throw std::runtime_error("Expected exactly 3 arguments for aggregate function, got " + std::to_string(input_count));
Expand Down
6 changes: 6 additions & 0 deletions src/functions/aggregate/aggregate_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ void AggregateFunctionState::Combine(const AggregateFunctionState& source) {
Initialize();
}

// Copy model_details and user_query from source if not already set
if (model_details.empty() && !source.model_details.empty()) {
model_details = source.model_details;
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition only checks if model_details is empty, but doesn't verify user_query. If model_details is non-empty but user_query is empty, the user_query from the source won't be copied. Consider checking both conditions separately or ensuring they're always set together.

Suggested change
model_details = source.model_details;
model_details = source.model_details;
}
if (user_query.empty() && !source.user_query.empty()) {

Copilot uses AI. Check for mistakes.
user_query = source.user_query;
}

if (source.value) {
auto idx = 0u;
for (auto& column: *source.value) {
Expand Down
47 changes: 46 additions & 1 deletion src/functions/aggregate/llm_first_or_last/implementation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "flock/core/config.hpp"
#include "flock/functions/aggregate/llm_first_or_last.hpp"
#include "flock/metrics/manager.hpp"

#include <chrono>
#include <vector>

namespace flock {

Expand Down Expand Up @@ -77,11 +82,40 @@ void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateIn
AggregateFunctionType function_type) {
const auto states_vector = reinterpret_cast<AggregateFunctionState**>(duckdb::FlatVector::GetData<duckdb::data_ptr_t>(states));

// Map AggregateFunctionType to FunctionType
FunctionType metrics_function_type = (function_type == AggregateFunctionType::FIRST) ? FunctionType::LLM_FIRST : FunctionType::LLM_LAST;

auto db = Config::db;
std::vector<const void*> processed_state_ids;
std::string merged_model_name;
std::string merged_provider;

// Process each state individually
for (idx_t i = 0; i < count; i++) {
auto idx = i + offset;
auto* state = states_vector[idx];

if (state && !state->value->empty()) {
// Use model_details and user_query from the state (not static variables)
Model model(state->model_details);
auto model_details_obj = model.GetModelDetails();

// Get state ID for metrics
const void* state_id = static_cast<const void*>(state);
processed_state_ids.push_back(state_id);

// Start metrics tracking
MetricsManager::StartInvocation(db, state_id, metrics_function_type);
MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name);

// Store model info for merged metrics (use first non-empty)
if (merged_model_name.empty() && !model_details_obj.model_name.empty()) {
merged_model_name = model_details_obj.model_name;
merged_provider = model_details_obj.provider_name;
}

auto exec_start = std::chrono::high_resolution_clock::now();

auto tuples_with_ids = *state->value;
tuples_with_ids.push_back(nlohmann::json::object());
for (auto j = 0; j < static_cast<int>((*state->value)[0]["data"].size()); j++) {
Expand All @@ -93,12 +127,23 @@ void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateIn
}
LlmFirstOrLast function_instance;
function_instance.function_type = function_type;
function_instance.user_query = state->user_query;
function_instance.model_details = state->model_details;
auto response = function_instance.Evaluate(tuples_with_ids);

auto exec_end = std::chrono::high_resolution_clock::now();
double exec_duration_ms = std::chrono::duration<double, std::milli>(exec_end - exec_start).count();
MetricsManager::AddExecutionTime(exec_duration_ms);

result.SetValue(idx, response.dump());
} else {
result.SetValue(idx, nullptr);// Empty JSON object for null/empty states
result.SetValue(idx, nullptr);
}
}

// Merge all metrics from processed states into a single metrics entry
MetricsManager::MergeAggregateMetrics(db, processed_state_ids, metrics_function_type,
merged_model_name, merged_provider);
}

}// namespace flock
47 changes: 44 additions & 3 deletions src/functions/aggregate/llm_reduce/implementation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "flock/core/config.hpp"
#include "flock/functions/aggregate/llm_reduce.hpp"
#include "flock/metrics/manager.hpp"

#include <chrono>
#include <vector>

namespace flock {

Expand Down Expand Up @@ -66,23 +71,59 @@ void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputDa
const AggregateFunctionType function_type) {
const auto states_vector = reinterpret_cast<AggregateFunctionState**>(duckdb::FlatVector::GetData<duckdb::data_ptr_t>(states));

auto db = Config::db;
std::vector<const void*> processed_state_ids;
std::string merged_model_name;
std::string merged_provider;

// Process each state individually
for (idx_t i = 0; i < count; i++) {
auto idx = i + offset;
auto* state = states_vector[idx];

if (state && !state->value->empty()) {
if (state && state->value && !state->value->empty()) {
// Use model_details and user_query from the state
Model model(state->model_details);
auto model_details_obj = model.GetModelDetails();

// Get state ID for metrics
const void* state_id = static_cast<const void*>(state);
processed_state_ids.push_back(state_id);

// Start metrics tracking for this state
MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_REDUCE);
MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name);

// Store model info for merged metrics (use first non-empty)
if (merged_model_name.empty() && !model_details_obj.model_name.empty()) {
merged_model_name = model_details_obj.model_name;
merged_provider = model_details_obj.provider_name;
}

auto exec_start = std::chrono::high_resolution_clock::now();

LlmReduce reduce_instance;
reduce_instance.model = Model(model_details);
reduce_instance.model = Model(state->model_details);
reduce_instance.user_query = state->user_query;
auto response = reduce_instance.ReduceLoop(*state->value, function_type);

auto exec_end = std::chrono::high_resolution_clock::now();
double exec_duration_ms = std::chrono::duration<double, std::milli>(exec_end - exec_start).count();
MetricsManager::AddExecutionTime(exec_duration_ms);

if (response.is_string()) {
result.SetValue(idx, response.get<std::string>());
} else {
result.SetValue(idx, response.dump());
}
} else {
result.SetValue(idx, nullptr);// Empty result for null/empty states
result.SetValue(idx, nullptr);
}
}

// Merge all metrics from processed states into a single metrics entry
MetricsManager::MergeAggregateMetrics(db, processed_state_ids, FunctionType::LLM_REDUCE,
merged_model_name, merged_provider);
}

}// namespace flock
44 changes: 43 additions & 1 deletion src/functions/aggregate/llm_rerank/implementation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "flock/core/config.hpp"
#include "flock/functions/aggregate/llm_rerank.hpp"
#include "flock/metrics/manager.hpp"

#include <chrono>
#include <vector>

namespace flock {

Expand Down Expand Up @@ -116,22 +121,59 @@ void LlmRerank::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& agg
idx_t count, idx_t offset) {
const auto states_vector = reinterpret_cast<AggregateFunctionState**>(duckdb::FlatVector::GetData<duckdb::data_ptr_t>(states));

auto db = Config::db;
std::vector<const void*> processed_state_ids;
std::string merged_model_name;
std::string merged_provider;

// Process each state individually
for (idx_t i = 0; i < count; i++) {
auto idx = i + offset;
auto* state = states_vector[idx];

if (state && !state->value->empty()) {
// Use model_details and user_query from the state (not static variables)
Model model(state->model_details);
auto model_details_obj = model.GetModelDetails();

// Get state ID for metrics
const void* state_id = static_cast<const void*>(state);
processed_state_ids.push_back(state_id);

// Start metrics tracking
MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_RERANK);
MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name);

// Store model info for merged metrics (use first non-empty)
if (merged_model_name.empty() && !model_details_obj.model_name.empty()) {
merged_model_name = model_details_obj.model_name;
merged_provider = model_details_obj.provider_name;
}

auto exec_start = std::chrono::high_resolution_clock::now();

auto tuples_with_ids = nlohmann::json::array();
for (auto j = 0; j < static_cast<int>(state->value->size()); j++) {
tuples_with_ids.push_back((*state->value)[j]);
}
LlmRerank function_instance;
function_instance.user_query = state->user_query;
function_instance.model_details = state->model_details;
auto reranked_tuples = function_instance.SlidingWindow(tuples_with_ids);

auto exec_end = std::chrono::high_resolution_clock::now();
double exec_duration_ms = std::chrono::duration<double, std::milli>(exec_end - exec_start).count();
MetricsManager::AddExecutionTime(exec_duration_ms);

result.SetValue(idx, reranked_tuples.dump());
} else {
result.SetValue(idx, nullptr);// Empty result for null/empty states
result.SetValue(idx, nullptr);
}
}

// Merge all metrics from processed states into a single metrics entry
MetricsManager::MergeAggregateMetrics(db, processed_state_ids, FunctionType::LLM_RERANK,
merged_model_name, merged_provider);
}

}// namespace flock
23 changes: 23 additions & 0 deletions src/functions/scalar/llm_complete/implementation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include "flock/functions/scalar/llm_complete.hpp"
#include "flock/metrics/manager.hpp"

#include <chrono>

namespace flock {

Expand All @@ -25,6 +28,11 @@ std::vector<std::string> LlmComplete::Operation(duckdb::DataChunk& args) {
// LlmComplete::ValidateArguments(args);
auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1);
Model model(model_details_json);

// Set model name and provider in metrics (context is already set in Execute)
auto model_details = model.GetModelDetails();
MetricsManager::SetModelInfo(model_details.model_name, model_details.provider_name);

auto prompt_context_json = CastVectorOfStructsToJson(args.data[1], args.size());
auto context_columns = nlohmann::json::array();
if (prompt_context_json.contains("context_columns")) {
Expand Down Expand Up @@ -63,6 +71,16 @@ std::vector<std::string> LlmComplete::Operation(duckdb::DataChunk& args) {
}

void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) {
// Get database instance and state ID for metrics
auto& context = state.GetContext();
auto* db = context.db.get();
const void* state_id = static_cast<const void*>(&state);

// Start metrics tracking
MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE);

// Start execution timing
auto exec_start = std::chrono::high_resolution_clock::now();

if (const auto results = LlmComplete::Operation(args); static_cast<int>(results.size()) == 1) {
auto empty_vec = duckdb::Vector(std::string());
Expand All @@ -75,6 +93,11 @@ void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& stat
result.SetValue(index++, duckdb::Value(res));
}
}

// End execution timing and update metrics
auto exec_end = std::chrono::high_resolution_clock::now();
double exec_duration_ms = std::chrono::duration<double, std::milli>(exec_end - exec_start).count();
MetricsManager::AddExecutionTime(exec_duration_ms);
}

}// namespace flock
22 changes: 22 additions & 0 deletions src/functions/scalar/llm_embedding/implementation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "flock/core/config.hpp"
#include "flock/functions/scalar/llm_embedding.hpp"
#include "flock/metrics/manager.hpp"

#include <chrono>

namespace flock {

Expand Down Expand Up @@ -32,6 +36,10 @@ std::vector<duckdb::vector<duckdb::Value>> LlmEmbedding::Operation(duckdb::DataC
auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1);
Model model(model_details_json);

// Set model name and provider in metrics (context is already set in Execute)
auto model_details = model.GetModelDetails();
MetricsManager::SetModelInfo(model_details.model_name, model_details.provider_name);

std::vector<std::string> prepared_inputs;
auto num_rows = inputs["context_columns"][0]["data"].size();
for (size_t row_idx = 0; row_idx < num_rows; row_idx++) {
Expand Down Expand Up @@ -71,12 +79,26 @@ std::vector<duckdb::vector<duckdb::Value>> LlmEmbedding::Operation(duckdb::DataC
}

void LlmEmbedding::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) {
// Get database instance and state ID for metrics
auto& context = state.GetContext();
auto* db = context.db.get();
const void* state_id = static_cast<const void*>(&state);

// Start metrics tracking
MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_EMBEDDING);

auto exec_start = std::chrono::high_resolution_clock::now();

auto results = LlmEmbedding::Operation(args);

auto index = 0;
for (const auto& res: results) {
result.SetValue(index++, duckdb::Value::LIST(res));
}

auto exec_end = std::chrono::high_resolution_clock::now();
double exec_duration_ms = std::chrono::duration<double, std::milli>(exec_end - exec_start).count();
MetricsManager::AddExecutionTime(exec_duration_ms);
}

}// namespace flock
Loading
Loading