Skip to content

Commit 5efd4ee

Browse files
add support for pyarrow adls file io
1 parent 494f2fe commit 5efd4ee

File tree

4 files changed

+281
-46
lines changed

4 files changed

+281
-46
lines changed

pyiceberg/io/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@
8282
ADLS_CLIENT_ID = "adls.client-id"
8383
ADLS_CLIENT_SECRET = "adls.client-secret"
8484
ADLS_ACCOUNT_HOST = "adls.account-host"
85+
ADLS_BLOB_STORAGE_AUTHORITY = "adls.blob-storage-authority"
86+
ADLS_DFS_STORAGE_AUTHORITY = "adls.dfs-storage-authority"
87+
ADLS_BLOB_STORAGE_SCHEME = "adls.blob-storage-scheme",
88+
ADLS_DFS_STORAGE_SCHEME = "adls.dfs-storage-scheme",
8589
GCS_TOKEN = "gcs.oauth2.token"
8690
GCS_TOKEN_EXPIRES_AT_MS = "gcs.oauth2.token-expires-at"
8791
GCS_PROJECT_ID = "gcs.project-id"

pyiceberg/io/pyarrow.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@
8585
)
8686
from pyiceberg.expressions.visitors import visit as boolean_expression_visit
8787
from pyiceberg.io import (
88+
ADLS_ACCOUNT_NAME,
89+
ADLS_ACCOUNT_KEY,
90+
ADLS_BLOB_STORAGE_AUTHORITY,
91+
ADLS_DFS_STORAGE_AUTHORITY,
92+
ADLS_BLOB_STORAGE_SCHEME,
93+
ADLS_DFS_STORAGE_SCHEME,
94+
ADLS_SAS_TOKEN,
8895
AWS_ACCESS_KEY_ID,
8996
AWS_REGION,
9097
AWS_ROLE_ARN,
@@ -394,6 +401,9 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste
394401
elif scheme in {"gs", "gcs"}:
395402
return self._initialize_gcs_fs()
396403

404+
elif scheme in {"abfs", "abfss", "wasb", "wasbs"}:
405+
return self._initialize_azure_fs()
406+
397407
elif scheme in {"file"}:
398408
return self._initialize_local_fs()
399409

@@ -475,6 +485,34 @@ def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem:
475485

476486
return S3FileSystem(**client_kwargs)
477487

488+
def _initialize_azure_fs(self) -> FileSystem:
489+
from pyarrow.fs import AzureFileSystem
490+
491+
client_kwargs: Dict[str, str] = {}
492+
493+
if account_name := self.properties.get(ADLS_ACCOUNT_NAME):
494+
client_kwargs["account_name"] = account_name
495+
496+
if account_key := self.properties.get(ADLS_ACCOUNT_KEY):
497+
client_kwargs["account_key"] = account_key
498+
499+
if blob_storage_authority := self.properties.get(ADLS_BLOB_STORAGE_AUTHORITY):
500+
client_kwargs["blob_storage_authority"] = blob_storage_authority
501+
502+
if dfs_storage_authority := self.properties.get(ADLS_DFS_STORAGE_AUTHORITY):
503+
client_kwargs["dfs_storage_authority"] = dfs_storage_authority
504+
505+
if blob_storage_scheme := self.properties.get(ADLS_BLOB_STORAGE_SCHEME):
506+
client_kwargs["blob_storage_scheme"] = blob_storage_scheme
507+
508+
if dfs_storage_scheme := self.properties.get(ADLS_DFS_STORAGE_SCHEME):
509+
client_kwargs["dfs_storage_scheme"] = dfs_storage_scheme
510+
511+
if sas_token := self.properties.get(ADLS_SAS_TOKEN):
512+
client_kwargs["sas_token"] = sas_token
513+
514+
return AzureFileSystem(**client_kwargs)
515+
478516
def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem:
479517
from pyarrow.fs import HadoopFileSystem
480518

tests/conftest.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@
5656
GCS_SERVICE_HOST,
5757
GCS_TOKEN,
5858
GCS_TOKEN_EXPIRES_AT_MS,
59+
ADLS_ACCOUNT_NAME,
60+
ADLS_ACCOUNT_KEY,
61+
ADLS_BLOB_STORAGE_AUTHORITY,
62+
ADLS_DFS_STORAGE_SCHEME,
63+
ADLS_BLOB_STORAGE_SCHEME,
64+
ADLS_DFS_STORAGE_AUTHORITY,
5965
fsspec,
6066
load_file_io,
6167
)
@@ -348,6 +354,11 @@ def table_schema_with_all_types() -> Schema:
348354
)
349355

350356

357+
@pytest.fixture(params=["abfss", "wasbs"])
358+
def adls_scheme(request):
359+
return request.param
360+
361+
351362
@pytest.fixture(scope="session")
352363
def pyarrow_schema_simple_without_ids() -> "pa.Schema":
353364
import pyarrow as pa
@@ -2089,7 +2100,27 @@ def fsspec_fileio_gcs(request: pytest.FixtureRequest) -> FsspecFileIO:
20892100

20902101

20912102
@pytest.fixture
2092-
def pyarrow_fileio_gcs(request: pytest.FixtureRequest) -> "PyArrowFileIO":
2103+
def adls_fsspec_fileio(request: pytest.FixtureRequest) -> Generator[FsspecFileIO, None, None]:
2104+
from azure.storage.blob import BlobServiceClient
2105+
2106+
azurite_url = request.config.getoption("--adls.endpoint")
2107+
azurite_account_name = request.config.getoption("--adls.account-name")
2108+
azurite_account_key = request.config.getoption("--adls.account-key")
2109+
azurite_connection_string = f"DefaultEndpointsProtocol=http;AccountName={azurite_account_name};AccountKey={azurite_account_key};BlobEndpoint={azurite_url}/{azurite_account_name};"
2110+
properties = {
2111+
"adls.connection-string": azurite_connection_string,
2112+
"adls.account-name": azurite_account_name,
2113+
}
2114+
2115+
bbs = BlobServiceClient.from_connection_string(conn_str=azurite_connection_string)
2116+
bbs.create_container("tests")
2117+
yield fsspec.FsspecFileIO(properties=properties)
2118+
bbs.delete_container("tests")
2119+
bbs.close()
2120+
2121+
2122+
@pytest.fixture
2123+
def pyarrow_fileio_gcs(request: pytest.FixtureRequest) -> 'PyArrowFileIO':
20932124
from pyiceberg.io.pyarrow import PyArrowFileIO
20942125

20952126
properties = {
@@ -2101,6 +2132,33 @@ def pyarrow_fileio_gcs(request: pytest.FixtureRequest) -> "PyArrowFileIO":
21012132
return PyArrowFileIO(properties=properties)
21022133

21032134

2135+
@pytest.fixture
2136+
def pyarrow_fileio_adls(request: pytest.FixtureRequest) -> Generator[Any, None, None]:
2137+
from azure.storage.blob import BlobServiceClient
2138+
from pyiceberg.io.pyarrow import PyArrowFileIO
2139+
2140+
azurite_url = request.config.getoption("--adls.endpoint")
2141+
azurite_scheme, azurite_authority = azurite_url.split('://', 1)
2142+
2143+
azurite_account_name = request.config.getoption("--adls.account-name")
2144+
azurite_account_key = request.config.getoption("--adls.account-key")
2145+
azurite_connection_string = f"DefaultEndpointsProtocol=http;AccountName={azurite_account_name};AccountKey={azurite_account_key};BlobEndpoint={azurite_url}/{azurite_account_name};"
2146+
properties = {
2147+
ADLS_ACCOUNT_NAME: azurite_account_name,
2148+
ADLS_ACCOUNT_KEY: azurite_account_key,
2149+
ADLS_BLOB_STORAGE_AUTHORITY: azurite_authority,
2150+
ADLS_DFS_STORAGE_AUTHORITY: azurite_authority,
2151+
ADLS_BLOB_STORAGE_SCHEME: azurite_scheme,
2152+
ADLS_DFS_STORAGE_SCHEME: azurite_scheme,
2153+
}
2154+
2155+
bbs = BlobServiceClient.from_connection_string(conn_str=azurite_connection_string)
2156+
bbs.create_container("warehouse")
2157+
yield PyArrowFileIO(properties=properties)
2158+
bbs.delete_container("warehouse")
2159+
bbs.close()
2160+
2161+
21042162
def aws_credentials() -> None:
21052163
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
21062164
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
@@ -2162,26 +2220,6 @@ def fixture_dynamodb(_aws_credentials: None) -> Generator[boto3.client, None, No
21622220
yield boto3.client("dynamodb", region_name="us-east-1")
21632221

21642222

2165-
@pytest.fixture
2166-
def adls_fsspec_fileio(request: pytest.FixtureRequest) -> Generator[FsspecFileIO, None, None]:
2167-
from azure.storage.blob import BlobServiceClient
2168-
2169-
azurite_url = request.config.getoption("--adls.endpoint")
2170-
azurite_account_name = request.config.getoption("--adls.account-name")
2171-
azurite_account_key = request.config.getoption("--adls.account-key")
2172-
azurite_connection_string = f"DefaultEndpointsProtocol=http;AccountName={azurite_account_name};AccountKey={azurite_account_key};BlobEndpoint={azurite_url}/{azurite_account_name};"
2173-
properties = {
2174-
"adls.connection-string": azurite_connection_string,
2175-
"adls.account-name": azurite_account_name,
2176-
}
2177-
2178-
bbs = BlobServiceClient.from_connection_string(conn_str=azurite_connection_string)
2179-
bbs.create_container("tests")
2180-
yield fsspec.FsspecFileIO(properties=properties)
2181-
bbs.delete_container("tests")
2182-
bbs.close()
2183-
2184-
21852223
@pytest.fixture(scope="session")
21862224
def empty_home_dir_path(tmp_path_factory: pytest.TempPathFactory) -> str:
21872225
home_path = str(tmp_path_factory.mktemp("home"))

0 commit comments

Comments
 (0)