Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ impl ClientOptions {
client: TemporalServiceClient::new(svc),
options: Arc::new(self.clone()),
capabilities: None,
workers: Arc::new(ClientWorkerSet::new()),
workers: Arc::new(ClientWorkerSet::new(false)),
};
if !self.skip_get_system_info {
match client
Expand Down
9 changes: 9 additions & 0 deletions client/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,15 @@ proxier! {
r.extensions_mut().insert(labels);
}
);
(
describe_worker,
DescribeWorkerRequest,
DescribeWorkerResponse,
|r| {
let labels = namespaced_request!(r);
r.extensions_mut().insert(labels);
}
);
(
record_worker_heartbeat,
RecordWorkerHeartbeatRequest,
Expand Down
61 changes: 27 additions & 34 deletions client/src/worker_registry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ struct ClientWorkerSetImpl {
all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
/// Maps namespace to shared worker for worker heartbeating
shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
/// Disables erroring when multiple workers on the same namespace+task queue are registered.
/// This is used with testing, where multiple tests run in parallel on the same client
disable_dupe_check: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than this I was thinking something more like the test workers would be able to just not register themselves in unit tests by default?

}

impl ClientWorkerSetImpl {
/// Factory method.
fn new() -> Self {
fn new(disable_dupe_check: bool) -> Self {
Self {
slot_providers: Default::default(),
all_workers: Default::default(),
shared_worker: Default::default(),
disable_dupe_check,
}
}

Expand All @@ -81,7 +85,7 @@ impl ClientWorkerSetImpl {
worker.namespace().to_string(),
worker.task_queue().to_string(),
);
if self.slot_providers.contains_key(&slot_key) {
if self.slot_providers.contains_key(&slot_key) && !self.disable_dupe_check {
bail!(
"Registration of multiple workers on the same namespace and task queue for the same client not allowed: {slot_key:?}, worker_instance_key: {:?}.",
worker.worker_instance_key()
Expand Down Expand Up @@ -133,14 +137,8 @@ impl ClientWorkerSetImpl {

if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
if let Some(cb) = callback {
if is_empty {
self.shared_worker.remove(worker.namespace());
}

// To maintain single ownership of the callback, we must re-register the callback
// back to the ClientWorker
worker.register_callback(cb);
if callback.is_some() && is_empty {
self.shared_worker.remove(worker.namespace());
}
}

Expand Down Expand Up @@ -188,16 +186,16 @@ pub struct ClientWorkerSet {

impl Default for ClientWorkerSet {
fn default() -> Self {
Self::new()
Self::new(false)
}
}

impl ClientWorkerSet {
/// Factory method.
pub fn new() -> Self {
pub fn new(disable_dupe_check: bool) -> Self {
Self {
worker_grouping_key: Uuid::new_v4(),
worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
worker_manager: RwLock::new(ClientWorkerSetImpl::new(disable_dupe_check)),
}
}

Expand All @@ -212,14 +210,6 @@ impl ClientWorkerSet {
.try_reserve_wft_slot(namespace, task_queue)
}

/// Unregisters a local worker, typically when that worker starts shutdown.
pub fn unregister_worker(
&self,
worker_instance_key: Uuid,
) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
self.worker_manager.write().unregister(worker_instance_key)
}

/// Register a local worker that can provide WFT processing slots and potentially worker heartbeating.
pub fn register_worker(
&self,
Expand All @@ -228,6 +218,14 @@ impl ClientWorkerSet {
self.worker_manager.write().register(worker)
}

/// Unregisters a local worker, typically when that worker starts shutdown.
pub fn unregister_worker(
&self,
worker_instance_key: Uuid,
) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
self.worker_manager.write().unregister(worker_instance_key)
}

/// Returns the worker grouping key, which is unique for each worker.
pub fn worker_grouping_key(&self) -> Uuid {
self.worker_grouping_key
Expand Down Expand Up @@ -256,7 +254,7 @@ impl std::fmt::Debug for ClientWorkerSet {
}

/// Contains a worker heartbeat callback, wrapped for mocking
pub type HeartbeatCallback = Box<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;

/// Represents a complete worker that can handle both slot management
/// and worker heartbeat functionality.
Expand All @@ -276,7 +274,7 @@ pub trait ClientWorker: Send + Sync {
fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;

/// Unique identifier for this worker instance.
/// This must be stable across the worker's lifetime but unique per instance.
/// This must be stable across the worker's lifetime and unique per instance.
fn worker_instance_key(&self) -> Uuid;

/// Indicates if worker heartbeating is enabled for this client worker.
Expand All @@ -289,9 +287,6 @@ pub trait ClientWorker: Send + Sync {
fn new_shared_namespace_worker(
&self,
) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;

/// Registers a worker heartbeat callback, typically when a worker is unregistered from a client
fn register_callback(&self, callback: HeartbeatCallback);
}

#[cfg(test)]
Expand Down Expand Up @@ -340,7 +335,7 @@ mod tests {

#[test]
fn registry_keeps_one_provider_per_namespace() {
let manager = ClientWorkerSet::new();
let manager = ClientWorkerSet::new(false);
let mut worker_keys = vec![];
let mut successful_registrations = 0;

Expand Down Expand Up @@ -453,7 +448,7 @@ mod tests {
if heartbeat_enabled {
mock_provider
.expect_heartbeat_callback()
.returning(|| Some(Box::new(WorkerHeartbeat::default)));
.returning(|| Some(Arc::new(WorkerHeartbeat::default)));

let namespace_clone = namespace.clone();
mock_provider
Expand All @@ -463,16 +458,14 @@ mod tests {
namespace_clone.clone(),
)))
});

mock_provider.expect_register_callback().returning(|_| {});
}

mock_provider
}

#[test]
fn duplicate_namespace_task_queue_registration_fails() {
let manager = ClientWorkerSet::new();
let manager = ClientWorkerSet::new(false);

let worker1 = new_mock_provider_with_heartbeat(
"test_namespace".to_string(),
Expand Down Expand Up @@ -511,7 +504,7 @@ mod tests {

#[test]
fn multiple_workers_same_namespace_share_heartbeat_manager() {
let manager = ClientWorkerSet::new();
let manager = ClientWorkerSet::new(false);

let worker1 = new_mock_provider_with_heartbeat(
"shared_namespace".to_string(),
Expand Down Expand Up @@ -544,7 +537,7 @@ mod tests {

#[test]
fn different_namespaces_get_separate_heartbeat_managers() {
let manager = ClientWorkerSet::new();
let manager = ClientWorkerSet::new(false);
let worker1 = new_mock_provider_with_heartbeat(
"namespace1".to_string(),
"queue1".to_string(),
Expand Down Expand Up @@ -572,7 +565,7 @@ mod tests {

#[test]
fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
let manager = ClientWorkerSet::new();
let manager = ClientWorkerSet::new(false);

// Create two workers with same namespace but different task queues
let worker1 = new_mock_provider_with_heartbeat(
Expand Down
1 change: 1 addition & 0 deletions core-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ tonic = { workspace = true }
tracing = "0.1"
tracing-core = "0.1"
url = "2.5"
uuid = { version = "1.18.1", features = ["v4"] }

[dependencies.temporal-sdk-core-protos]
path = "../sdk-core-protos"
Expand Down
9 changes: 9 additions & 0 deletions core-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use temporal_sdk_core_protos::coresdk::{
workflow_activation::WorkflowActivation,
workflow_completion::WorkflowActivationCompletion,
};
use uuid::Uuid;

/// This trait is the primary way by which language specific SDKs interact with the core SDK.
/// It represents one worker, which has a (potentially shared) client for connecting to the service
Expand Down Expand Up @@ -138,6 +139,10 @@ pub trait Worker: Send + Sync {
/// This should be called only after [Worker::shutdown] has resolved and/or both polling
/// functions have returned `ShutDown` errors.
async fn finalize_shutdown(self);

/// Unique identifier for this worker instance.
/// This must be stable across the worker's lifetime and unique per instance.
fn worker_instance_key(&self) -> Uuid;
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -205,6 +210,10 @@ where
async fn finalize_shutdown(self) {
panic!("Can't finalize shutdown on Arc'd worker")
}

fn worker_instance_key(&self) -> Uuid {
(**self).worker_instance_key()
}
}

macro_rules! dbg_panic {
Expand Down
Loading
Loading