Skip to content

Commit

Permalink
support connection string
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Jan 6, 2025
1 parent 6102ba9 commit 24ff523
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 251 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ Alternatively, you can use the following environment variables when starting pos

Supported S3 uri formats are shown below:
- s3:// \<bucket\> / \<path\>
- s3a:// \<bucket\> / \<path\>
- https:// \<bucket\>.s3.amazonaws.com / \<path\>
- https:// s3.amazonaws.com / \<bucket\> / \<path\>

Expand All @@ -209,6 +208,7 @@ key = Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/
Alternatively, you can use the following environment variables when starting postgres to configure the Azure Blob Storage client:
- `AZURE_STORAGE_ACCOUNT`: the storage account name of the Azure Blob
- `AZURE_STORAGE_KEY`: the storage key of the Azure Blob
- `AZURE_STORAGE_CONNECTION_STRING`: the connection string for the Azure Blob (this can be set instead of specifying account name and key)
- `AZURE_STORAGE_SAS_TOKEN`: the storage SAS token for the Azure Blob
- `AZURE_STORAGE_ENDPOINT`: the endpoint **(only via environment variables)**
- `AZURE_CONFIG_FILE`: an alternative location for the config file **(only via environment variables)**
Expand Down
3 changes: 2 additions & 1 deletion src/arrow_parquet/parquet_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
},
pgrx_utils::{collect_attributes_for, CollectAttributesFor},
type_compat::{geometry::reset_postgis_context, map::reset_map_context},
PG_BACKEND_TOKIO_RUNTIME,
};

use super::{
Expand All @@ -33,7 +34,7 @@ use super::{
schema_parser::{
ensure_file_schema_match_tupledesc_schema, parse_arrow_schema_from_attributes,
},
uri_utils::{parquet_reader_from_uri, PG_BACKEND_TOKIO_RUNTIME},
uri_utils::parquet_reader_from_uri,
};

pub(crate) struct ParquetReaderContext {
Expand Down
3 changes: 2 additions & 1 deletion src/arrow_parquet/parquet_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ use crate::{
schema_parser::{
parquet_schema_string_from_attributes, parse_arrow_schema_from_attributes,
},
uri_utils::{parquet_writer_from_uri, PG_BACKEND_TOKIO_RUNTIME},
uri_utils::parquet_writer_from_uri,
},
pgrx_utils::{collect_attributes_for, CollectAttributesFor},
type_compat::{geometry::reset_postgis_context, map::reset_map_context},
PG_BACKEND_TOKIO_RUNTIME,
};

use super::pg_to_arrow::{
Expand Down
243 changes: 8 additions & 235 deletions src/arrow_parquet/uri_utils.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
use std::{
panic,
sync::{Arc, LazyLock},
};
use std::{panic, sync::Arc};

use arrow::datatypes::SchemaRef;
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use home::home_dir;
use ini::Ini;
use object_store::{
aws::{AmazonS3, AmazonS3Builder},
azure::{AzureConfigKey, MicrosoftAzure, MicrosoftAzureBuilder},
local::LocalFileSystem,
path::Path,
ObjectStore, ObjectStoreScheme,
};
use parquet::{
arrow::{
arrow_to_parquet_schema,
Expand All @@ -29,229 +15,16 @@ use pgrx::{
ereport,
pg_sys::{get_role_oid, has_privs_of_role, superuser, AsPgCStr, GetUserId},
};
use tokio::runtime::Runtime;
use url::Url;

use crate::arrow_parquet::parquet_writer::DEFAULT_ROW_GROUP_SIZE;
use crate::{
arrow_parquet::parquet_writer::DEFAULT_ROW_GROUP_SIZE, object_store::create_object_store,
PG_BACKEND_TOKIO_RUNTIME,
};

const PARQUET_OBJECT_STORE_READ_ROLE: &str = "parquet_object_store_read";
const PARQUET_OBJECT_STORE_WRITE_ROLE: &str = "parquet_object_store_write";

// PG_BACKEND_TOKIO_RUNTIME creates a tokio runtime that uses the current thread
// to run the tokio reactor. This uses the same thread that is running the Postgres backend.
pub(crate) static PG_BACKEND_TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap_or_else(|e| panic!("failed to create tokio runtime: {}", e))
});

fn parse_azure_blob_container(uri: &Url) -> Option<String> {
let host = uri.host_str()?;

// az(ure)://{container}/key
if uri.scheme() == "az" || uri.scheme() == "azure" {
return Some(host.to_string());
}
// https://{account}.blob.core.windows.net/{container}
else if host.ends_with(".blob.core.windows.net") {
let path_segments: Vec<&str> = uri.path_segments()?.collect();

// Container name is the first part of the path
return Some(
path_segments
.first()
.expect("unexpected error during parsing azure blob uri")
.to_string(),
);
}

None
}

fn parse_s3_bucket(uri: &Url) -> Option<String> {
let host = uri.host_str()?;

// s3(a)://{bucket}/key
if uri.scheme() == "s3" || uri.scheme() == "s3a" {
return Some(host.to_string());
}
// https://s3.amazonaws.com/{bucket}/key
else if host == "s3.amazonaws.com" {
let path_segments: Vec<&str> = uri.path_segments()?.collect();

// Bucket name is the first part of the path
return Some(
path_segments
.first()
.expect("unexpected error during parsing s3 uri")
.to_string(),
);
}
// https://{bucket}.s3.amazonaws.com/key
else if host.ends_with(".s3.amazonaws.com") {
let bucket_name = host.split('.').next()?;
return Some(bucket_name.to_string());
}

None
}

fn object_store_with_location(uri: &Url, copy_from: bool) -> (Arc<dyn ObjectStore>, Path) {
let (scheme, path) =
ObjectStoreScheme::parse(uri).unwrap_or_else(|_| panic!("unrecognized uri {}", uri));

// object_store crate can recognize a bunch of different schemes and paths, but we only support
// local, azure, and s3 schemes with a subset of all supported paths.
match scheme {
ObjectStoreScheme::AmazonS3 => {
let bucket_name = parse_s3_bucket(uri).unwrap_or_else(|| {
panic!("unsupported s3 uri: {}", uri);
});

let storage_container = PG_BACKEND_TOKIO_RUNTIME
.block_on(async { Arc::new(get_s3_object_store(&bucket_name).await) });

(storage_container, path)
}
ObjectStoreScheme::MicrosoftAzure => {
let container_name = parse_azure_blob_container(uri).unwrap_or_else(|| {
panic!("unsupported azure blob storage uri: {}", uri);
});

let storage_container = PG_BACKEND_TOKIO_RUNTIME
.block_on(async { Arc::new(get_azure_object_store(&container_name).await) });

(storage_container, path)
}
ObjectStoreScheme::Local => {
let uri = uri_as_string(uri);

if !copy_from {
// create or overwrite the local file
std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.create(true)
.open(&uri)
.unwrap_or_else(|e| panic!("{}", e));
}

let storage_container = Arc::new(LocalFileSystem::new());

let path = Path::from_filesystem_path(&uri).unwrap_or_else(|e| panic!("{}", e));

(storage_container, path)
}
_ => {
panic!("unsupported scheme {} in uri {}", uri.scheme(), uri);
}
}
}

// get_s3_object_store creates an AmazonS3 object store with the given bucket name.
// It is configured by environment variables and aws config files as fallback method.
// We need to read the config files to make the fallback method work since object_store
// does not provide a way to read them. Currently, we only support to extract
// "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "AWS_ENDPOINT_URL",
// and "AWS_REGION" from the config files.
async fn get_s3_object_store(bucket_name: &str) -> AmazonS3 {
let mut aws_s3_builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name);

// first tries environment variables and then the config files
let sdk_config = aws_config::defaults(BehaviorVersion::v2024_03_28())
.load()
.await;

if let Some(credential_provider) = sdk_config.credentials_provider() {
if let Ok(credentials) = credential_provider.provide_credentials().await {
// AWS_ACCESS_KEY_ID
aws_s3_builder = aws_s3_builder.with_access_key_id(credentials.access_key_id());

// AWS_SECRET_ACCESS_KEY
aws_s3_builder = aws_s3_builder.with_secret_access_key(credentials.secret_access_key());

if let Some(token) = credentials.session_token() {
// AWS_SESSION_TOKEN
aws_s3_builder = aws_s3_builder.with_token(token);
}
}
}

// AWS_ENDPOINT_URL
if let Some(aws_endpoint_url) = sdk_config.endpoint_url() {
aws_s3_builder = aws_s3_builder.with_endpoint(aws_endpoint_url);
}

// AWS_REGION
if let Some(aws_region) = sdk_config.region() {
aws_s3_builder = aws_s3_builder.with_region(aws_region.as_ref());
}

aws_s3_builder.build().unwrap_or_else(|e| panic!("{}", e))
}

async fn get_azure_object_store(container_name: &str) -> MicrosoftAzure {
let mut azure_builder = MicrosoftAzureBuilder::from_env().with_container_name(container_name);

// ~/.azure/config
let azure_config_file_path = std::env::var("AZURE_CONFIG_FILE").unwrap_or(
home_dir()
.expect("failed to get home directory")
.join(".azure")
.join("config")
.to_str()
.expect("failed to convert path to string")
.to_string(),
);

let azure_config_content = Ini::load_from_file(&azure_config_file_path).ok();

// storage account
let azure_blob_account = match std::env::var("AZURE_STORAGE_ACCOUNT") {
Ok(account) => Some(account),
Err(_) => azure_config_content
.as_ref()
.and_then(|ini| ini.section(Some("storage")))
.and_then(|section| section.get("account"))
.map(|account| account.to_string()),
};

if let Some(azure_blob_account) = azure_blob_account {
azure_builder = azure_builder.with_account(azure_blob_account);
}

// storage key
let azure_blob_key = match std::env::var("AZURE_STORAGE_KEY") {
Ok(key) => Some(key),
Err(_) => azure_config_content
.as_ref()
.and_then(|ini| ini.section(Some("storage")))
.and_then(|section| section.get("key"))
.map(|key| key.to_string()),
};

if let Some(azure_blob_key) = azure_blob_key {
azure_builder = azure_builder.with_access_key(azure_blob_key);
}

// sas token
let azure_blob_sas_token = match std::env::var("AZURE_STORAGE_SAS_TOKEN") {
Ok(token) => Some(token),
Err(_) => azure_config_content
.as_ref()
.and_then(|ini| ini.section(Some("storage")))
.and_then(|section| section.get("sas_token"))
.map(|token| token.to_string()),
};

if let Some(azure_blob_sas_token) = azure_blob_sas_token {
azure_builder = azure_builder.with_config(AzureConfigKey::SasKey, azure_blob_sas_token);
}

azure_builder.build().unwrap_or_else(|e| panic!("{}", e))
}

pub(crate) fn parse_uri(uri: &str) -> Url {
if !uri.contains("://") {
// local file
Expand Down Expand Up @@ -285,7 +58,7 @@ pub(crate) fn parquet_schema_from_uri(uri: &Url) -> SchemaDescriptor {

pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc<ParquetMetaData> {
let copy_from = true;
let (parquet_object_store, location) = object_store_with_location(uri, copy_from);
let (parquet_object_store, location) = create_object_store(uri, copy_from);

PG_BACKEND_TOKIO_RUNTIME.block_on(async {
let object_store_meta = parquet_object_store
Expand All @@ -308,7 +81,7 @@ pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc<ParquetMetaData> {

pub(crate) fn parquet_reader_from_uri(uri: &Url) -> ParquetRecordBatchStream<ParquetObjectReader> {
let copy_from = true;
let (parquet_object_store, location) = object_store_with_location(uri, copy_from);
let (parquet_object_store, location) = create_object_store(uri, copy_from);

PG_BACKEND_TOKIO_RUNTIME.block_on(async {
let object_store_meta = parquet_object_store
Expand Down Expand Up @@ -340,7 +113,7 @@ pub(crate) fn parquet_writer_from_uri(
writer_props: WriterProperties,
) -> AsyncArrowWriter<ParquetObjectWriter> {
let copy_from = false;
let (parquet_object_store, location) = object_store_with_location(uri, copy_from);
let (parquet_object_store, location) = create_object_store(uri, copy_from);

let parquet_object_writer = ParquetObjectWriter::new(parquet_object_store, location);

Expand Down
13 changes: 13 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::sync::LazyLock;

use parquet_copy_hook::hook::{init_parquet_copy_hook, ENABLE_PARQUET_COPY_HOOK};
use parquet_copy_hook::pg_compat::MarkGUCPrefixReserved;
use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry};
use tokio::runtime::Runtime;

mod arrow_parquet;
mod object_store;
mod parquet_copy_hook;
mod parquet_udfs;
#[cfg(any(test, feature = "pg_test"))]
Expand All @@ -20,6 +24,15 @@ pgrx::pg_module_magic!();

extension_sql_file!("../sql/bootstrap.sql", name = "role_setup", bootstrap);

// PG_BACKEND_TOKIO_RUNTIME creates a tokio runtime that uses the current thread
// to run the tokio reactor. This uses the same thread that is running the Postgres backend.
pub(crate) static PG_BACKEND_TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap_or_else(|e| panic!("failed to create tokio runtime: {}", e))
});

#[pg_guard]
pub extern "C" fn _PG_init() {
GucRegistry::define_bool_guc(
Expand Down
Loading

0 comments on commit 24ff523

Please sign in to comment.