Skip to content

Support vendor params in protocol tests #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion gems/smithy-client/lib/smithy-client/dynamic_errors.rb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def error_class(error_code)
# @return [Symbol] Returns a symbolized constant name for the given `error_code`.
def error_class_constant(error_code)
constant = error_code.to_s
constant = constant.gsub(/https?:.*$/, '')
constant = constant.gsub(/[^a-zA-Z0-9]/, '')
constant = "Error#{constant}" unless constant.match(/^[a-z]/i)
constant = constant[0].upcase + constant[1..]
Expand Down
5 changes: 0 additions & 5 deletions gems/smithy-client/spec/smithy-client/dynamic_errors_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ module Client
expect(mod.error_class('My.Error')).to be(mod::MyError)
end

it 'removes http schemas from the error code' do
expect(mod.error_class('ErrorClass:http://example.com')).to be(mod::ErrorClass)
expect(mod.error_class('ErrorClass:https://example.com')).to be(mod::ErrorClass)
end

it 'ensures the error class name starts with a letter' do
expect(mod.error_class('123Code')).to be(mod::Error123Code)
end
Expand Down
23 changes: 13 additions & 10 deletions gems/smithy/lib/smithy/templates/client/protocol_spec.erb
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ module <%= module_name %>
<% unless operation_tests.request_tests.empty? -%>
describe 'requests' do
<% operation_tests.request_tests.each do |test| -%>
<% test.comments.each do |line| -%>
<% test.docstrings.each do |line| -%>
# <%= line %>
<% end -%>
<% if test.skip? -%>
it '<%= test.id %>', skip: '<%= test.skip_reason %>' do
it '<%= test['id'] %>', skip: '<%= test.skip_reason %>' do
<% else -%>
it '<%= test.id %>' do
it '<%= test['id'] %>' do
<% end -%>
<% if test['host'] -%>
client = Client.new(client_options.merge(endpoint: '<%= test.endpoint %>'))
Expand Down Expand Up @@ -76,13 +76,13 @@ module <%= module_name %>
<% unless operation_tests.response_tests.empty? -%>
describe 'responses' do
<% operation_tests.response_tests.each do |test| -%>
<% test.comments.each do |line| -%>
<% test.docstrings.each do |line| -%>
# <%= line %>
<% end -%>
<% if test.skip? -%>
it '<%= test.id %>', skip: '<%= test.skip_reason %>' do
it '<%= test['id'] %>', skip: '<%= test.skip_reason %>' do
<% else -%>
it '<%= test.id %>' do
it '<%= test['id'] %>' do
<% end -%>
response = { status_code: <%= test['code'] %> }
<% if test['headers'] -%>
Expand Down Expand Up @@ -110,13 +110,13 @@ module <%= module_name %>
<% unless operation_tests.error_tests.empty? -%>
describe 'response errors' do
<% operation_tests.error_tests.each do |test| -%>
<% test.comments.each do |line| -%>
<% test.docstrings.each do |line| -%>
# <%= line %>
<% end -%>
<% if test.skip? -%>
it '<%= test.id %>: <%= test.error_name %>', skip: '<%= test.skip_reason %>' do
it '<%= test['id'] %>: <%= test.error_name %>', skip: '<%= test.skip_reason %>' do
<% else -%>
it '<%= test.id %>: <%= test.error_name %>' do
it '<%= test['id'] %>: <%= test.error_name %>' do
<% end -%>
response = { status_code: <%= test['code'] %> }
<% if test['headers'] -%>
Expand All @@ -132,7 +132,10 @@ module <%= module_name %>
client.stub_responses(:<%= operation_tests.name %>, response)
expect { client.<%= operation_tests.name %> }.to raise_error do |e|
expect(e).to be_a(Errors::<%= test.error_name %>)
<%= test.data_expect %>
expect(e.data.to_h).to match_data(<%= test.params %>)
<% test.vendor_code(:error_expect_code).each do |line| -%>
<%= line %>
<% end -%>
end
end
<% end -%>
Expand Down
87 changes: 49 additions & 38 deletions gems/smithy/lib/smithy/views/client/protocol_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def initialize(plan)
Model::ServiceIndex
.new(@model)
.operations_for(@plan.service)
.map { |id, o| OperationTests.new(@model, id, o) }
.map { |id, o| OperationTests.new(@model, id, o, vendor_code) }
.reject(&:empty?)
super()
end
Expand All @@ -29,12 +29,20 @@ def additional_requires
Set.new(@all_operation_tests.map(&:additional_requires).flatten)
end

def vendor_code
@plan
.welds
.map(&:protocol_test_vendor_code)
.reduce({}, :merge)
end

# @api private
class OperationTests
def initialize(model, id, operation)
def initialize(model, id, operation, vendor_code)
@model = model
@id = id
@operation = operation
@vendor_code = vendor_code
@request_tests = build_request_tests
@response_tests = build_response_tests
@error_tests = build_error_tests
Expand All @@ -61,15 +69,15 @@ def build_response_tests
.fetch('traits', {})
.fetch('smithy.test#httpResponseTests', [])
.select { |t| t.fetch('appliesTo', 'client') == 'client' }
.map { |t| ResponseTest.new(@model, @operation, t) }
.map { |t| ResponseTest.new(@model, @operation, t, @vendor_code) }
end

def build_request_tests
@operation
.fetch('traits', {})
.fetch('smithy.test#httpRequestTests', [])
.select { |t| t.fetch('appliesTo', 'client') == 'client' }
.map { |t| RequestTest.new(@model, @operation, t) }
.map { |t| RequestTest.new(@model, @operation, t, @vendor_code) }
end

def build_error_tests
Expand All @@ -85,39 +93,45 @@ def tests_for_error(error)
error_shape
.fetch('traits', {})
.fetch('smithy.test#httpResponseTests', [])
.map { |t| ErrorTest.new(@model, @operation, error['target'], t) }
.map { |t| ErrorTest.new(@model, @operation, error['target'], t, @vendor_code) }
end
end

# @api private
class TestCase
def initialize(model, operation, test_case)
def initialize(model, operation, test_case, vendor_code)
@model = model
@operation = operation
@test_case = test_case
@vendor_code = vendor_code
@input_shape = Model.shape(@model, @operation['input']['target'])
@output_shape = Model.shape(@model, @operation['output']['target'])
end

attr_reader :test_case
def vendor_code(method)
return [] unless @test_case['vendorParamsShape'] && @test_case['vendorParams']

def [](key)
test_case[key]
vendor_code = @vendor_code[@test_case['vendorParamsShape']]
unless vendor_code.respond_to?(method)
raise "Unhandled protocol test vendor code for shape: '#{@test_case['vendorParamsShape']}. '" \
"Please implement a class that responds to :#{method} and register it with a weld."
end
vendor_code.send(method, @test_case['vendorParams']).split("\n")
end

def comments
test_case.fetch('documentation', '').split("\n")
def [](key)
@test_case[key]
end

def id
test_case['id']
def docstrings
@test_case.fetch('documentation', '').split("\n")
end

def additional_requires
requires = []
if test_case['bodyMediaType']
if @test_case['bodyMediaType']
requires +=
case test_case['bodyMediaType']
case @test_case['bodyMediaType']
when 'application/cbor'
%w[base64]
when 'application/json'
Expand All @@ -133,49 +147,50 @@ def skip?
@operation
.fetch('traits', {})
.fetch('smithy.ruby#skipTests', [])
.any? { |skip| skip['id'] == id }
.any? { |skip| skip['id'] == @test_case['id'] }
end

def skip_reason
@operation
.fetch('traits', {})
.fetch('smithy.ruby#skipTests', [])
.find { |skip| skip['id'] == id }
.find { |skip| skip['id'] == @test_case['id'] }
&.fetch('reason', 'skipped')
end
end

# @api private
class RequestTest < TestCase
def params
ShapeToHash.transform_value(@model, test_case.fetch('params', {}), @input_shape)
ShapeToHash.transform_value(@model, @test_case.fetch('params', {}), @input_shape)
end

def endpoint
"https://#{test_case.fetch('host', '127.0.0.1')}"
"https://#{@test_case.fetch('host', '127.0.0.1')}"
end

def body_expect
return nil unless test_case['body']
return nil unless @test_case['body']

case test_case['bodyMediaType']
case @test_case['bodyMediaType']
when 'application/cbor'
'expect(Smithy::CBOR.decode(request.body.read)).' \
"to match_data(Smithy::CBOR.decode(::Base64.decode64('#{test_case['body']}')))"
"to match_data(Smithy::CBOR.decode(::Base64.decode64('#{@test_case['body']}')))"
when 'application/json'
"expect(JSON.parse(request.body.read)).to eq(JSON.parse('#{test_case['body']}'))"
"expect(JSON.parse(request.body.read)).to eq(JSON.parse('#{@test_case['body']}'))"
else
"expect(request.body.read).to eq('#{test_case['body']}')"
"expect(request.body.read).to eq('#{@test_case['body']}')"
end
end

def query_expect?
test_case['queryParams'] || test_case['forbidQueryParams'] || test_case['requireQueryParams']
@test_case['queryParams'] || @test_case['forbidQueryParams'] || @test_case['requireQueryParams']
end

def idempotency_token_trait?
@input_shape.fetch('members', {})
.any? { |_name, shape| shape.fetch('traits', {}).key?('smithy.api#idempotencyToken') }
@input_shape
.fetch('members', {})
.any? { |_name, shape| shape.fetch('traits', {}).key?('smithy.api#idempotencyToken') }
end
end

Expand Down Expand Up @@ -207,16 +222,16 @@ def required?(traits)
end

def stub_body
case test_case['bodyMediaType']
case @test_case['bodyMediaType']
when 'application/cbor'
"::Base64.decode64('#{test_case['body']}')"
"::Base64.decode64('#{@test_case['body']}')"
else
"'#{test_case['body']}'"
"'#{@test_case['body']}'"
end
end

def data_expect
output = ShapeToHash.transform_value(@model, test_case.fetch('params', {}), @output_shape)
output = ShapeToHash.transform_value(@model, @test_case.fetch('params', {}), @output_shape)
"expect(response.data.to_h).to match_data(#{output})"
end

Expand All @@ -230,8 +245,8 @@ def streaming_member

# @api private
class ErrorTest < ResponseTest
def initialize(model, operation, error_id, test_case)
super(model, operation, test_case)
def initialize(model, operation, error_id, test_case, vendor_code)
super(model, operation, test_case, vendor_code)
@error_id = error_id
@error_shape = Model.shape(@model, error_id)
end
Expand All @@ -241,11 +256,7 @@ def error_name
end

def params
ShapeToHash.transform_value(@model, test_case.fetch('params', {}), @error_shape)
end

def data_expect
"expect(e.data.to_h).to match_data(#{params})"
ShapeToHash.transform_value(@model, @test_case.fetch('params', {}), @error_shape)
end
end
end
Expand Down
8 changes: 8 additions & 0 deletions gems/smithy/lib/smithy/weld.rb
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,13 @@ def add_auth_schemes
def remove_auth_schemes
[]
end

# Called when generating protocol tests. The key should be the same as the vendor params shape
# in a protocol test, and the value should be a class that responds to one of the following methods:
# * error_expect_code(params) - returns a string that is rendered inside a rescue block (with error rescued as `e`).
# Protocol tests are run with RSpec and expectations should be used.
def protocol_test_vendor_code
{}
end
end
end