Skip to content

Commit

Permalink
Merge branch 'main' of github.com:argilla-io/distilabel
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarobartt committed Apr 15, 2024
2 parents 6cbcced + 8063572 commit e50b559
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 44 deletions.
15 changes: 11 additions & 4 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,20 @@ class _BatchManagerStep(_Serializable):
seq_no: int = 0
last_batch_received: List[str] = field(default_factory=list)

def add_batch(self, batch: _Batch) -> None:
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
data and keep track of the last batch received from the predecessors.
Args:
batch: The output batch of an step to be processed by the step.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
"""
from_step = batch.step_name
self.data[from_step].extend(batch.data[0])
if prepend:
self.data[from_step] = batch.data[0] + self.data[from_step]
else:
self.data[from_step].extend(batch.data[0])
if batch.last_batch:
self.last_batch_received.append(from_step)

Expand Down Expand Up @@ -676,12 +681,14 @@ def register_batch(self, batch: _Batch) -> None:
def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
return self._last_batch_received.get(step_name)

def add_batch(self, to_step: str, batch: _Batch) -> None:
def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
"""Add an output batch from `batch.step_name` to `to_step`.
Args:
to_step: The name of the step that will process the batch.
batch: The output batch of an step to be processed by `to_step`.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
Raises:
ValueError: If `to_step` is not found in the batch manager.
Expand All @@ -690,7 +697,7 @@ def add_batch(self, to_step: str, batch: _Batch) -> None:
raise ValueError(f"Step '{to_step}' not found in the batch manager.")

step = self._steps[to_step]
step.add_batch(batch)
step.add_batch(batch, prepend)

def get_batch(self, step_name: str) -> Union[_Batch, None]:
"""Get the next batch to be processed by the step.
Expand Down
38 changes: 21 additions & 17 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
_STEPS_FINISHED = set()
_STEPS_FINISHED_LOCK = threading.Lock()

_STOP_LOOP = False


def _init_worker(queue: "Queue[Any]") -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
Expand Down Expand Up @@ -167,7 +165,7 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
Args:
write_buffer: The write buffer to write the data from the leaf steps to disk.
"""
while self._batch_manager.can_generate() and not _STOP_LOOP: # type: ignore
while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore
self._logger.debug("Waiting for output batch from step...")
if (batch := self.output_queue.get()) is None:
self._logger.debug("Received `None` from output queue. Breaking loop.")
Expand All @@ -176,13 +174,12 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
if batch.step_name in self.dag.leaf_steps:
write_buffer.add_batch(batch)

# If `_STOP_LOOP` was set to `True` while waiting for the output queue, then
# If `_STOP_CALLED` was set to `True` while waiting for the output queue, then
# we need to handle the stop of the pipeline and break the loop to avoid
# propagating the batches through the pipeline and making the stop process
# slower.
if _STOP_LOOP:
if _STOP_CALLED:
self._handle_batch_on_stop(batch)
self._handle_stop(write_buffer)
break

self._logger.debug(
Expand All @@ -192,7 +189,7 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:

self._manage_batch_flow(batch)

if _STOP_LOOP:
if _STOP_CALLED:
self._handle_stop(write_buffer)

def _manage_batch_flow(self, batch: "_Batch") -> None:
Expand Down Expand Up @@ -266,15 +263,16 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None:
# Send `None` to the input queues of all the steps to notify them to stop
# processing batches.
for step_name in self.dag:
if input_queue := self._wait_step_input_queue_empty(step_name):
if self._check_step_not_loaded_or_finished(step_name):
if input_queue := self.dag.get_step(step_name).get("input_queue"):
while not input_queue.empty():
batch = input_queue.get()
self._batch_manager.add_batch( # type: ignore
to_step=step_name, batch=batch, prepend=True
)
self._logger.debug(
f"Step '{step_name}' not loaded or already finished. Skipping sending"
" sentinel `None`"
f"Adding batch back to the batch manager: {batch}"
)
continue
input_queue.put(None)
self._logger.debug(f"Send `None` to step '{step_name}' input queue.")

# Wait for the input queue to be empty, which means that all the steps finished
# processing the batches that were sent before the stop flag.
Expand Down Expand Up @@ -352,7 +350,7 @@ def _update_all_steps_loaded(steps_loaded: List[str]) -> None:

self._logger.info("⏳ Waiting for all the steps to load...")
previous_message = None
while True:
while not _STOP_CALLED:
with self.shared_info[_STEPS_LOADED_LOCK_KEY]:
steps_loaded = self.shared_info[_STEPS_LOADED_KEY]
num_steps_loaded = (
Expand All @@ -379,19 +377,27 @@ def _update_all_steps_loaded(steps_loaded: List[str]) -> None:

time.sleep(2.5)

return not _STOP_CALLED

def _request_initial_batches(self) -> None:
"""Requests the initial batches to the generator steps."""
assert self._batch_manager, "Batch manager is not set"

for step in self._batch_manager._steps.values():
if batch := step.get_batch():
self._logger.debug(
f"Sending initial batch to '{step.step_name}' step: {batch}"
)
self._send_batch_to_step(batch)

for step_name in self.dag.root_steps:
seq_no = 0
if last_batch := self._batch_manager.get_last_batch(step_name):
seq_no = last_batch.seq_no + 1
batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=False)
self._logger.debug(
f"Requesting initial batch to '{step_name}' generator step: {batch}"
)
self._send_batch_to_step(batch)

def _send_batch_to_step(self, batch: "_Batch") -> None:
Expand Down Expand Up @@ -520,9 +526,7 @@ def _stop(self) -> None:
finished processing the batches that were sent before the stop flag. Then it will
send `None` to the output queue to notify the pipeline to stop."""

global _STOP_LOOP, _STOP_CALLED

_STOP_LOOP = True
global _STOP_CALLED

with _STOP_CALLED_LOCK:
if _STOP_CALLED:
Expand Down
22 changes: 17 additions & 5 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,37 @@ class Argilla(Step, ABC):
This class is not intended to be instanced directly, but via subclass.
Attributes:
dataset_name: The name of the dataset in Argilla.
dataset_name: The name of the dataset in Argilla where the records will be added.
dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
None, which means it will be created in the default workspace.
`None`, which means it will be created in the default workspace.
api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
the `ARGILLA_API_URL` environment variable.
api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
be read from the `ARGILLA_API_KEY` environment variable.
Runtime parameters:
- `dataset_name`: The name of the dataset in Argilla where the records will be
added.
- `dataset_workspace`: The workspace where the dataset will be created in Argilla.
Defaults to `None`, which means it will be created in the default workspace.
- `api_url`: The base URL to use for the Argilla API requests.
- `api_key`: The API key to authenticate the requests to the Argilla API.
Input columns:
- dynamic, based on the `inputs` value provided
"""

dataset_name: str
dataset_workspace: Optional[str] = None
dataset_name: RuntimeParameter[str] = Field(
default=None, description="The name of the dataset in Argilla."
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults"
"to `None` which means it will be created in the default workspace.",
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv("ARGILLA_BASE_URL"),
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
Expand Down Expand Up @@ -122,6 +132,8 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

@property
@abstractmethod
def inputs(self) -> List[str]:
Expand Down
2 changes: 0 additions & 2 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

# Both `instruction` and `generations` will be used as the fields of the dataset
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generations = self.input_mappings.get("generations", "generations")
Expand Down
2 changes: 0 additions & 2 deletions src/distilabel/steps/argilla/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

self._instruction = self.input_mappings.get("instruction", "instruction")
self._generation = self.input_mappings.get("generation", "generation")

Expand Down
91 changes: 89 additions & 2 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,46 @@ def test_add_batch(self) -> None:
assert batch_manager_step.data["step1"] == [{"a": 1}, {"a": 2}, {"a": 3}]
assert batch_manager_step.last_batch_received == []

def test_add_batch_with_prepend(self) -> None:
batch_manager_step = _BatchManagerStep(
step_name="step2",
accumulate=False,
input_batch_size=10,
data={
"step1": [
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
]
},
)

batch_manager_step.add_batch(
_Batch(
seq_no=0,
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
),
prepend=True,
)

assert batch_manager_step.data["step1"] == [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
]
assert batch_manager_step.last_batch_received == []

def test_add_batch_last_batch(self) -> None:
batch_manager_step = _BatchManagerStep(
step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
Expand Down Expand Up @@ -784,9 +824,56 @@ def test_add_batch(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
)
batch = batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)
batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)

assert batch is None
assert batch_manager._steps["step3"].data == {
"step1": [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
],
"step2": [],
}

def test_add_batch_with_prepend(self) -> None:
batch_manager = _BatchManager(
steps={
"step3": _BatchManagerStep(
step_name="step3",
accumulate=False,
input_batch_size=5,
data={
"step1": [{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}],
"step2": [],
},
)
},
last_batch_received={"step3": None},
)
batch_from_step_1 = _Batch(
seq_no=0,
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
)
batch_manager.add_batch(to_step="step3", batch=batch_from_step_1, prepend=True)
assert batch_manager._steps["step3"].data == {
"step1": [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
],
"step2": [],
}

def test_add_batch_enough_data(self) -> None:
batch_manager = _BatchManager(
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/steps/argilla/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.argilla.base import Argilla
from distilabel.steps.base import StepInput
from pydantic import ValidationError

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput
Expand Down Expand Up @@ -82,17 +81,6 @@ def test_with_errors(self) -> None:
dataset_workspace="argilla",
)

with pytest.raises(
ValidationError, match="dataset_name\n Field required \\[type=missing"
):
CustomArgilla(
name="step",
api_url="https://example.com",
api_key="api.key", # type: ignore
dataset_workspace="argilla",
pipeline=Pipeline(name="unit-test-pipeline"),
)

with pytest.raises(
TypeError,
match="Can't instantiate abstract class Argilla with abstract methods inputs, process",
Expand Down Expand Up @@ -136,6 +124,18 @@ def test_serialization(self) -> None:
"name": "input_batch_size",
"optional": True,
},
{
"description": "The name of the dataset in Argilla.",
"name": "dataset_name",
"optional": False,
},
{
"description": "The workspace where the dataset will be created in Argilla. "
"Defaultsto `None` which means it will be created in the default "
"workspace.",
"name": "dataset_workspace",
"optional": True,
},
{
"name": "api_url",
"optional": True,
Expand Down
Loading

0 comments on commit e50b559

Please sign in to comment.