Skip to content

Commit d0f3ad9

Browse files
markchuaMark Chua
and
Mark Chua
authored
Add header propagation and workflow middleware. (#226)
Co-authored-by: Mark Chua <[email protected]>
1 parent 5e9dd9e commit d0f3ad9

20 files changed

+251
-25
lines changed

examples/bin/trigger

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
require_relative '../init'
33

44
Dir[File.expand_path('../workflows/*.rb', __dir__)].each { |f| require f }
5+
Dir[File.expand_path('../middleware/*.rb', __dir__)].each { |f| require f }
56

67
workflow_class_name, *args = ARGV
78
workflow_class = Object.const_get(workflow_class_name)
@@ -10,5 +11,9 @@ workflow_id = SecureRandom.uuid
1011
# Convert integer strings to integers
1112
input = args.map { |arg| Integer(arg) rescue arg }
1213

14+
Temporal.configure do |config|
15+
config.add_header_propagator(SamplePropagator)
16+
end
17+
1318
run_id = Temporal.start_workflow(workflow_class, *input, options: { workflow_id: workflow_id })
1419
Temporal.logger.info "Started workflow", { workflow_id: workflow_id, run_id: run_id }

examples/bin/worker

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ if !ENV['USE_ERROR_SERIALIZATION_V2'].nil?
2424
end
2525
end
2626

27+
Temporal.configure do |config|
28+
config.add_header_propagator(SamplePropagator)
29+
end
30+
2731
worker = Temporal::Worker.new(binary_checksum: `git show HEAD -s --format=%H`.strip)
2832

2933
worker.register_workflow(AsyncActivityWorkflow)
@@ -94,5 +98,7 @@ worker.register_dynamic_activity(DelegatorActivity)
9498

9599
worker.add_workflow_task_middleware(LoggingMiddleware, 'EXAMPLE')
96100
worker.add_activity_middleware(LoggingMiddleware, 'EXAMPLE')
101+
worker.add_activity_middleware(SamplePropagator)
102+
worker.add_workflow_middleware(SamplePropagator)
97103

98104
worker.start
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class SamplePropagator
2+
def inject!(headers)
3+
headers['test-header'] = 'test'
4+
end
5+
6+
def call(metadata)
7+
Temporal.logger.info("Got headers!", headers: metadata.headers.to_h)
8+
yield
9+
end
10+
end

lib/temporal/client.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def start_workflow(workflow, *input, options: {}, **args)
6161
run_timeout: compute_run_timeout(execution_options),
6262
task_timeout: execution_options.timeouts[:task],
6363
workflow_id_reuse_policy: options[:workflow_id_reuse_policy],
64-
headers: execution_options.headers,
64+
headers: config.header_propagator_chain.inject(execution_options.headers),
6565
memo: execution_options.memo,
6666
search_attributes: Workflow::Context::Helpers.process_search_attributes(execution_options.search_attributes),
6767
)
@@ -78,7 +78,7 @@ def start_workflow(workflow, *input, options: {}, **args)
7878
run_timeout: compute_run_timeout(execution_options),
7979
task_timeout: execution_options.timeouts[:task],
8080
workflow_id_reuse_policy: options[:workflow_id_reuse_policy],
81-
headers: execution_options.headers,
81+
headers: config.header_propagator_chain.inject(execution_options.headers),
8282
memo: execution_options.memo,
8383
search_attributes: Workflow::Context::Helpers.process_search_attributes(execution_options.search_attributes),
8484
signal_name: signal_name,
@@ -127,7 +127,7 @@ def schedule_workflow(workflow, cron_schedule, *input, options: {}, **args)
127127
run_timeout: compute_run_timeout(execution_options),
128128
task_timeout: execution_options.timeouts[:task],
129129
workflow_id_reuse_policy: options[:workflow_id_reuse_policy],
130-
headers: execution_options.headers,
130+
headers: config.header_propagator_chain.inject(execution_options.headers),
131131
cron_schedule: cron_schedule,
132132
memo: execution_options.memo,
133133
search_attributes: Workflow::Context::Helpers.process_search_attributes(execution_options.search_attributes),

lib/temporal/configuration.rb

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
require 'temporal/logger'
22
require 'temporal/metrics_adapters/null'
3+
require 'temporal/middleware/header_propagator_chain'
4+
require 'temporal/middleware/entry'
35
require 'temporal/connection/converter/payload/nil'
46
require 'temporal/connection/converter/payload/bytes'
57
require 'temporal/connection/converter/payload/json'
@@ -12,7 +14,7 @@ class Configuration
1214
Execution = Struct.new(:namespace, :task_queue, :timeouts, :headers, :search_attributes, keyword_init: true)
1315

1416
attr_reader :timeouts, :error_handlers
15-
attr_accessor :connection_type, :converter, :use_error_serialization_v2, :host, :port, :credentials, :identity, :logger, :metrics_adapter, :namespace, :task_queue, :headers, :search_attributes
17+
attr_accessor :connection_type, :converter, :use_error_serialization_v2, :host, :port, :credentials, :identity, :logger, :metrics_adapter, :namespace, :task_queue, :headers, :search_attributes, :header_propagators
1618

1719
# See https://docs.temporal.io/blog/activity-timeouts/ for general docs.
1820
# We want an infinite execution timeout for cron schedules and other perpetual workflows.
@@ -58,6 +60,7 @@ def initialize
5860
@credentials = :this_channel_is_insecure
5961
@identity = nil
6062
@search_attributes = {}
63+
@header_propagators = []
6164
end
6265

6366
def on_error(&block)
@@ -96,6 +99,15 @@ def default_execution_options
9699
).freeze
97100
end
98101

102+
def add_header_propagator(propagator_class, *args)
103+
raise 'header propagator must implement `def inject!(headers)`' unless propagator_class.method_defined? :inject!
104+
@header_propagators << Middleware::Entry.new(propagator_class, args)
105+
end
106+
107+
def header_propagator_chain
108+
Middleware::HeaderPropagatorChain.new(header_propagators)
109+
end
110+
99111
private
100112

101113
def default_identity

lib/temporal/connection/serializer/schedule_activity.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def to_proto
3232
def serialize_headers(headers)
3333
return unless headers
3434

35-
Temporalio::Api::Common::V1::Header.new(fields: object.headers)
35+
Temporalio::Api::Common::V1::Header.new(fields: to_payload_map(headers))
3636
end
3737
end
3838
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module Temporal
2+
module Middleware
3+
class HeaderPropagatorChain
4+
def initialize(entries = [])
5+
@propagators = entries.map(&:init_middleware)
6+
end
7+
8+
def inject(headers)
9+
return headers if propagators.empty?
10+
h = headers.dup
11+
for propagator in propagators
12+
propagator.inject!(h)
13+
end
14+
h
15+
end
16+
17+
private
18+
19+
attr_reader :propagators
20+
end
21+
end
22+
end

lib/temporal/worker.rb

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def initialize(
3131
@activities = Hash.new { |hash, key| hash[key] = ExecutableLookup.new }
3232
@pollers = []
3333
@workflow_task_middleware = []
34+
@workflow_middleware = []
3435
@activity_middleware = []
3536
@shutting_down = false
3637
@activity_poller_options = {
@@ -90,6 +91,10 @@ def add_workflow_task_middleware(middleware_class, *args)
9091
@workflow_task_middleware << Middleware::Entry.new(middleware_class, args)
9192
end
9293

94+
def add_workflow_middleware(middleware_class, *args)
95+
@workflow_middleware << Middleware::Entry.new(middleware_class, args)
96+
end
97+
9398
def add_activity_middleware(middleware_class, *args)
9499
@activity_middleware << Middleware::Entry.new(middleware_class, args)
95100
end
@@ -128,14 +133,14 @@ def stop
128133

129134
attr_reader :config, :activity_poller_options, :workflow_poller_options,
130135
:activities, :workflows, :pollers,
131-
:workflow_task_middleware, :activity_middleware
136+
:workflow_task_middleware, :workflow_middleware, :activity_middleware
132137

133138
def shutting_down?
134139
@shutting_down
135140
end
136141

137142
def workflow_poller_for(namespace, task_queue, lookup)
138-
Workflow::Poller.new(namespace, task_queue, lookup.freeze, config, workflow_task_middleware, workflow_poller_options)
143+
Workflow::Poller.new(namespace, task_queue, lookup.freeze, config, workflow_task_middleware, workflow_middleware, workflow_poller_options)
139144
end
140145

141146
def activity_poller_for(namespace, task_queue, lookup)

lib/temporal/workflow/context.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def execute_activity(activity_class, *input, **args)
7979
task_queue: execution_options.task_queue,
8080
retry_policy: execution_options.retry_policy,
8181
timeouts: execution_options.timeouts,
82-
headers: execution_options.headers
82+
headers: config.header_propagator_chain.inject(execution_options.headers)
8383
)
8484

8585
target, cancelation_id = schedule_command(command)
@@ -136,7 +136,7 @@ def execute_workflow(workflow_class, *input, **args)
136136
retry_policy: execution_options.retry_policy,
137137
parent_close_policy: parent_close_policy,
138138
timeouts: execution_options.timeouts,
139-
headers: execution_options.headers,
139+
headers: config.header_propagator_chain.inject(execution_options.headers),
140140
cron_schedule: cron_schedule,
141141
memo: execution_options.memo,
142142
workflow_id_reuse_policy: workflow_id_reuse_policy,
@@ -261,7 +261,7 @@ def continue_as_new(*input, **args)
261261
input: input,
262262
timeouts: execution_options.timeouts,
263263
retry_policy: execution_options.retry_policy,
264-
headers: execution_options.headers,
264+
headers: config.header_propagator_chain.inject(execution_options.headers),
265265
memo: execution_options.memo,
266266
search_attributes: Helpers.process_search_attributes(execution_options.search_attributes)
267267
)

lib/temporal/workflow/executor.rb

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Executor
1616
# @param task_metadata [Metadata::WorkflowTask]
1717
# @param config [Configuration]
1818
# @param track_stack_trace [Boolean]
19-
def initialize(workflow_class, history, task_metadata, config, track_stack_trace)
19+
def initialize(workflow_class, history, task_metadata, config, track_stack_trace, middleware_chain)
2020
@workflow_class = workflow_class
2121
@dispatcher = Dispatcher.new
2222
@query_registry = QueryRegistry.new
@@ -25,6 +25,7 @@ def initialize(workflow_class, history, task_metadata, config, track_stack_trace
2525
@task_metadata = task_metadata
2626
@config = config
2727
@track_stack_trace = track_stack_trace
28+
@middleware_chain = middleware_chain
2829
end
2930

3031
def run
@@ -55,7 +56,8 @@ def process_queries(queries)
5556

5657
private
5758

58-
attr_reader :workflow_class, :dispatcher, :query_registry, :state_manager, :task_metadata, :history, :config, :track_stack_trace
59+
attr_reader :workflow_class, :dispatcher, :query_registry, :state_manager,
60+
:task_metadata, :history, :config, :track_stack_trace, :middleware_chain
5961

6062
def process_query(query)
6163
result = query_registry.handle(query.query_type, query.query_args)
@@ -70,7 +72,9 @@ def execute_workflow(input, workflow_started_event)
7072
context = Workflow::Context.new(state_manager, dispatcher, workflow_class, metadata, config, query_registry, track_stack_trace)
7173

7274
Fiber.new do
73-
workflow_class.execute_in_context(context, input)
75+
middleware_chain.invoke(metadata) do
76+
workflow_class.execute_in_context(context, input)
77+
end
7478
end.resume
7579
end
7680
end

lib/temporal/workflow/poller.rb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ class Poller
1414
binary_checksum: nil
1515
}.freeze
1616

17-
def initialize(namespace, task_queue, workflow_lookup, config, middleware = [], options = {})
17+
def initialize(namespace, task_queue, workflow_lookup, config, middleware = [], workflow_middleware = [], options = {})
1818
@namespace = namespace
1919
@task_queue = task_queue
2020
@workflow_lookup = workflow_lookup
2121
@config = config
2222
@middleware = middleware
23+
@workflow_middleware = workflow_middleware
2324
@shutting_down = false
2425
@options = DEFAULT_OPTIONS.merge(options)
2526
end
@@ -45,7 +46,7 @@ def wait
4546

4647
private
4748

48-
attr_reader :namespace, :task_queue, :connection, :workflow_lookup, :config, :middleware, :options, :thread
49+
attr_reader :namespace, :task_queue, :connection, :workflow_lookup, :config, :middleware, :workflow_middleware, :options, :thread
4950

5051
def connection
5152
@connection ||= Temporal::Connection.generate(config.for_connection)
@@ -96,8 +97,9 @@ def poll_for_task
9697

9798
def process(task)
9899
middleware_chain = Middleware::Chain.new(middleware)
100+
workflow_middleware_chain = Middleware::Chain.new(workflow_middleware)
99101

100-
TaskProcessor.new(task, namespace, workflow_lookup, middleware_chain, config, binary_checksum).process
102+
TaskProcessor.new(task, namespace, workflow_lookup, middleware_chain, workflow_middleware_chain, config, binary_checksum).process
101103
end
102104

103105
def thread_pool

lib/temporal/workflow/task_processor.rb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@ def query_args
2424
MAX_FAILED_ATTEMPTS = 1
2525
LEGACY_QUERY_KEY = :legacy_query
2626

27-
def initialize(task, namespace, workflow_lookup, middleware_chain, config, binary_checksum)
27+
def initialize(task, namespace, workflow_lookup, middleware_chain, workflow_middleware_chain, config, binary_checksum)
2828
@task = task
2929
@namespace = namespace
3030
@metadata = Metadata.generate_workflow_task_metadata(task, namespace)
3131
@task_token = task.task_token
3232
@workflow_name = task.workflow_type.name
3333
@workflow_class = workflow_lookup.find(workflow_name)
3434
@middleware_chain = middleware_chain
35+
@workflow_middleware_chain = workflow_middleware_chain
3536
@config = config
3637
@binary_checksum = binary_checksum
3738
end
@@ -53,7 +54,7 @@ def process
5354
track_stack_trace = queries.values.map(&:query_type).include?(StackTraceTracker::STACK_TRACE_QUERY_NAME)
5455

5556
# TODO: For sticky workflows we need to cache the Executor instance
56-
executor = Workflow::Executor.new(workflow_class, history, metadata, config, track_stack_trace)
57+
executor = Workflow::Executor.new(workflow_class, history, metadata, config, track_stack_trace, workflow_middleware_chain)
5758

5859
commands = middleware_chain.invoke(metadata) do
5960
executor.run
@@ -79,7 +80,7 @@ def process
7980
private
8081

8182
attr_reader :task, :namespace, :task_token, :workflow_name, :workflow_class,
82-
:middleware_chain, :metadata, :config, :binary_checksum
83+
:middleware_chain, :workflow_middleware_chain, :metadata, :config, :binary_checksum
8384

8485
def connection
8586
@connection ||= Temporal::Connection.generate(config.for_connection)

spec/unit/lib/temporal/client_spec.rb

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,35 @@ class TestStartWorkflow < Temporal::Workflow
3939

4040
before { allow(connection).to receive(:start_workflow_execution).and_return(temporal_response) }
4141

42+
context 'with header propagator' do
43+
class TestHeaderPropagator
44+
def inject!(header)
45+
header['test'] = 'asdf'
46+
end
47+
end
48+
49+
it 'updates the header' do
50+
config.add_header_propagator(TestHeaderPropagator)
51+
subject.start_workflow(TestStartWorkflow, 42)
52+
expect(connection)
53+
.to have_received(:start_workflow_execution)
54+
.with(
55+
namespace: 'default-test-namespace',
56+
workflow_id: an_instance_of(String),
57+
workflow_name: 'TestStartWorkflow',
58+
task_queue: 'default-test-task-queue',
59+
input: [42],
60+
task_timeout: Temporal.configuration.timeouts[:task],
61+
run_timeout: Temporal.configuration.timeouts[:run],
62+
execution_timeout: Temporal.configuration.timeouts[:execution],
63+
workflow_id_reuse_policy: nil,
64+
headers: { 'test' => 'asdf' },
65+
memo: {},
66+
search_attributes: {},
67+
)
68+
end
69+
end
70+
4271
context 'using a workflow class' do
4372
it 'returns run_id' do
4473
result = subject.start_workflow(TestStartWorkflow, 42)

spec/unit/lib/temporal/configuration_spec.rb

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
require 'temporal/configuration'
22

33
describe Temporal::Configuration do
4+
class TestHeaderPropagator
5+
def inject!(_); end
6+
end
7+
48
describe '#initialize' do
59
it 'initializes proper default workflow timeouts' do
610
timeouts = subject.timeouts
@@ -27,6 +31,25 @@
2731
end
2832
end
2933

34+
describe '#add_header_propagator' do
35+
let(:header_propagators) { subject.send(:header_propagators) }
36+
37+
it 'adds middleware entry to the list of middlewares' do
38+
subject.add_header_propagator(TestHeaderPropagator)
39+
subject.add_header_propagator(TestHeaderPropagator, 'arg1', 'arg2')
40+
41+
expect(header_propagators.size).to eq(2)
42+
43+
expect(header_propagators[0]).to be_an_instance_of(Temporal::Middleware::Entry)
44+
expect(header_propagators[0].klass).to eq(TestHeaderPropagator)
45+
expect(header_propagators[0].args).to eq([])
46+
47+
expect(header_propagators[1]).to be_an_instance_of(Temporal::Middleware::Entry)
48+
expect(header_propagators[1].klass).to eq(TestHeaderPropagator)
49+
expect(header_propagators[1].args).to eq(['arg1', 'arg2'])
50+
end
51+
end
52+
3053
describe '#for_connection' do
3154
let (:new_identity) { 'new_identity' }
3255

0 commit comments

Comments
 (0)