Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xando committed Jun 12, 2024
1 parent 98a710c commit 7daa72a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies = [
"google-cloud-bigquery>=3,<5",
"google-cloud-bigquery-storage>=2,<3",
"pyarrow>=16,<17",
"protobuf>=4,<5"
]

[project.license]
Expand Down
24 changes: 19 additions & 5 deletions src/pyarrow/bigquery/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from google.cloud import bigquery_storage
from google.cloud import bigquery
from google.cloud.exceptions import NotFound

import pyarrow as pa
import pyarrow.feather as fa
Expand All @@ -23,8 +24,18 @@
# NOTE: This is required for the multiprocessing to correcly serialize the worker arguments
multiprocessing.set_start_method("fork")

def _bq_table_exists(project: str, location: str):
client = bigquery.Client(project=project)

try:
client.get_table(location)
logger.debug(f"Table {location} already exists")
except NotFound as e:
logger.debug("Table {location} is not found")
raise e

def _generate_streams(

def _bq_read_create_strems(
read_client: bigquery_storage.BigQueryReadClient,
parent: str,
location: str,
Expand Down Expand Up @@ -55,14 +66,15 @@ def _generate_streams(
return read_session.streams, schema


def _read_streams(read_client, read_streams, table_schema, batch_size, queue_results, temp_dir):
def _stream_worker(read_client, read_streams, table_schema, batch_size, queue_results, temp_dir):
batches = []

for stream in read_streams:
t = time.time()

for message in read_client.read_rows(stream.name):
record_batch = pa.ipc.read_record_batch(message.arrow_record_batch.serialized_record_batch, table_schema)

batches.append(record_batch)

if sum(b.num_rows for b in batches) >= batch_size:
Expand Down Expand Up @@ -104,7 +116,9 @@ def reader(
queue_results = multiprocessing.Queue()
read_client = bigquery_storage.BigQueryReadClient()

streams, streams_schema = _generate_streams(
_bq_table_exists(project, source)

streams, streams_schema = _bq_read_create_strems(
read_client=read_client,
parent=project,
location=source,
Expand All @@ -114,7 +128,7 @@ def reader(
)
workers_done = 0

assert streams, "No streams to read"
assert streams, "No streams to read, Table might be empty"

logger.debug(f"Number of workers: {worker_count}, number of streams: {len(streams)}")

Expand All @@ -127,7 +141,7 @@ def reader(
try:
for streams in some_itertools.to_split(streams, actual_worker_count):
e = worker_type(
target=_read_streams,
target=_stream_worker,
args=(
read_client,
streams,
Expand Down
4 changes: 2 additions & 2 deletions src/pyarrow/bigquery/write/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _bq_create_table(*, project, location, schema, expire, overwrite):
logger.debug(f"Created BigQuery table '{location}'")


def _bq_storage_create_stream(write_client: bigquery_storage_v1.BigQueryWriteClient, parent, protobuf_definition):
def _bq_write_create_stream(write_client: bigquery_storage_v1.BigQueryWriteClient, parent, protobuf_definition):
write_stream = write_client.create_write_stream(
parent=parent,
write_stream=bigquery_storage_v1.types.WriteStream(type=bigquery_storage_v1.types.WriteStream.Type.PENDING),
Expand Down Expand Up @@ -99,7 +99,7 @@ def _stream_worker(
schema_protobuf,
queue_results,
):
stream = _bq_storage_create_stream(write_client, parent, schema_protobuf)
stream = _bq_write_create_stream(write_client, parent, schema_protobuf)

offset = 0

Expand Down
7 changes: 6 additions & 1 deletion src/pyarrow/bigquery/write/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from google.cloud.bigquery_storage_v1 import types
from google.api_core.exceptions import Unknown
from google.api_core import retry

from . import pa_to_pb

Expand All @@ -16,6 +18,9 @@ def upload_data(stream, pa_table, protobuf_definition, offset):
request.offset = offset + local_offset
request.proto_rows = proto_data

stream.append_rows_stream.send(request).result()
job = stream.append_rows_stream.send(request)
job.result(
retry=retry.Retry(predicate=retry.if_exception_type(Unknown))
)

local_offset += len(serialized_rows)

0 comments on commit 7daa72a

Please sign in to comment.