Skip to content

Commit

Permalink
Introduce a DecoratedField to alleviate feature envy
Browse files Browse the repository at this point in the history
  • Loading branch information
davebenvenuti committed Jan 29, 2025
1 parent 9e25aa6 commit 99ad480
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 87 deletions.
135 changes: 52 additions & 83 deletions lib/protoboeuf/codegen.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
require "erb"
require "syntax_tree"
require_relative "codegen_type_helper"
require_relative "decorated_field"

module ProtoBoeuf
class CodeGen
Expand Down Expand Up @@ -93,7 +94,9 @@ def result(message, toplevel_enums, generate_types:, requires:, syntax:, options
def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, options:)
@message = message
@optional_field_bit_lut = []
@fields = @message.field
@fields = @message.field.map do |field|
DecoratedField.new(field:, message:, syntax:)
end
@enum_field_types = toplevel_enums.merge(message.enum_type.group_by(&:name))
@requires = requires
@generate_types = generate_types
Expand All @@ -109,8 +112,8 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt

optional_field_count = 0

message.field.each do |field|
if optional_field?(field)
@fields.each do |field|
if field.optional?
if field.type == :TYPE_ENUM
@enum_fields << field
else
Expand All @@ -133,15 +136,10 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt
# list only contains non-proto3_optional fields, so we're using that
# to only iterate over actual "oneof" fields.
@oneof_selection_fields = @oneof_fields.each_with_index.map do |item, i|
item && message.oneof_decl[i]
item && DecoratedField.new(message:, field: message.oneof_decl[i], syntax:)
end
end

def optional_field?(field)
proto3 = "proto3" == syntax
field.proto3_optional || (field.label == :LABEL_OPTIONAL && !proto3)
end

def result
"class #{message.name}\n" + class_body + "end\n"
end
Expand All @@ -163,7 +161,7 @@ def class_body

def conversion
fields = self.fields.reject do |field|
field.has_oneof_index? && !optional_field?(field)
field.has_oneof_index? && !field.optional?
end

oneofs = @oneof_selection_fields.map do |field|
Expand All @@ -181,12 +179,12 @@ def to_h
end

def convert_field(field)
if repeated?(field)
"result['#{field.name}'.to_sym] = #{iv_name(field)}"
if field.repeated?
"result['#{field.name}'.to_sym] = #{field.iv_name}"
elsif field.type == :TYPE_MESSAGE
"result['#{field.name}'.to_sym] = #{iv_name(field)}.to_h"
"result['#{field.name}'.to_sym] = #{field.iv_name}.to_h"
else
"result['#{field.name}'.to_sym] = #{iv_name(field)}"
"result['#{field.name}'.to_sym] = #{field.iv_name}"
end
end

Expand All @@ -198,9 +196,9 @@ def encode
"buff << @_unknown_fields if @_unknown_fields\nbuff\n end\n\n"
end

def encode_subtype(field, value_expr = iv_name(field), tagged = true)
def encode_subtype(field, value_expr = field.iv_name, tagged = true)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
encode_map(field, value_expr, tagged)
else
encode_repeated(field, value_expr, tagged)
Expand Down Expand Up @@ -271,16 +269,16 @@ def encode_bool(field, value_expr, tagged)
end

def encode_map(field, value_expr, tagged)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
map = #{value_expr}
if map.size > 0
old_buff = buff
map.each do |key, value|
buff = new_buffer = +''
#{encode_subtype(map_type.field[0], "key", true)}
#{encode_subtype(map_type.field[1], "value", true)}
#{encode_subtype(map_type.key, "key", true)}
#{encode_subtype(map_type.value, "value", true)}
buff = old_buff
#{encode_tag_and_length(field, true, "new_buffer.bytesize")}
old_buff.concat(new_buffer)
Expand All @@ -292,8 +290,8 @@ def encode_map(field, value_expr, tagged)
def encode_oneof(field, value_expr, tagged)
field.fields.map do |f|
<<~RUBY
if #{iv_name(field)} == :"#{f.name}"
#{encode_subtype(f, iv_name(f))}
if #{field.iv_name} == :"#{f.name}"
#{encode_subtype(f, f.iv_name)}
end
RUBY
end.join("\n")
Expand Down Expand Up @@ -646,7 +644,7 @@ def enum_readers

" # enum readers\n" +
fields.map { |field|
"def #{field.name}; #{enum_name(field)}.lookup(#{iv_name(field)}) || #{iv_name(field)}; end"
"def #{field.name}; #{enum_name(field)}.lookup(#{field.iv_name}) || #{field.iv_name}; end"
}.join("\n") + "\n"
end

Expand Down Expand Up @@ -722,7 +720,7 @@ def enum_writers

"# enum writers\n" +
fields.map do |field|
"def #{field.name}=(v); #{iv_name(field)} = #{enum_name(field)}.resolve(v) || v; end"
"def #{field.name}=(v); #{field.iv_name} = #{enum_name(field)}.resolve(v) || v; end"
end.join("\n") + "\n\n"
end

Expand All @@ -736,7 +734,7 @@ def required_writers
#{type_signature(params: { v: convert_field_type(field) })}
def #{field.name}=(v)
#{bounds_check(field, "v")}
#{iv_name(field)} = v
#{field.iv_name} = v
end
RUBY
end.join("\n") + "\n"
Expand All @@ -752,7 +750,7 @@ def optional_writers
def #{field.name}=(v)
#{bounds_check(field, "v")}
#{set_bitmask(field)}
#{iv_name(field)} = v
#{field.iv_name} = v
end
RUBY
}.join("\n") +
Expand All @@ -772,7 +770,7 @@ def oneof_writers
def #{field.name}=(v)
#{bounds_check(field, "v")}
@#{oneof.name} = :#{field.name}
#{iv_name(field)} = v
#{field.iv_name} = v
end
RUBY
end.join("\n")
Expand All @@ -786,7 +784,7 @@ def initialize_code
init_bitmask(message) +
initialize_oneofs +
fields.map { |field|
if field.has_oneof_index? && !optional_field?(field)
if field.has_oneof_index? && !field.optional?
initialize_oneof(field, message)
else
initialize_field(field)
Expand All @@ -796,26 +794,26 @@ def initialize_code

def initialize_oneofs
@oneof_selection_fields.map do |field|
"#{iv_name(field)} = nil # oneof field"
"#{field.iv_name} = nil # oneof field"
end.join("\n") + "\n"
end

def initialize_oneof(field, msg)
oneof = msg.oneof_decl[field.oneof_index]
oneof = DecoratedField.new(message: msg, field: msg.oneof_decl[field.oneof_index], syntax:)

<<~RUBY
if #{lvar_read(field)} == nil
#{iv_name(field)} = #{default_for(field)}
#{field.iv_name} = #{default_for(field)}
else
#{bounds_check(field, lvar_read(field))}
#{iv_name(oneof)} = :#{field.name}
#{iv_name(field)} = #{lvar_read(field)}
#{oneof.iv_name} = :#{field.name}
#{field.iv_name} = #{lvar_read(field)}
end
RUBY
end

def initialize_field(field)
if optional_field?(field)
if field.optional?
initialize_optional_field(field)
elsif field.type == :TYPE_ENUM
initialize_enum_field(field)
Expand All @@ -828,12 +826,12 @@ def initialize_optional_field(field)
set_field_to_var = if field.type == :TYPE_ENUM
initialize_enum_field(field)
else
"#{iv_name(field)} = #{lvar_read(field)}"
"#{field.iv_name} = #{lvar_read(field)}"
end

<<~RUBY
if #{lvar_read(field)} == nil
#{iv_name(field)} = #{default_for(field)}
#{field.iv_name} = #{default_for(field)}
else
#{bounds_check(field, lvar_read(field)).chomp}
#{set_bitmask(field)}
Expand Down Expand Up @@ -907,20 +905,15 @@ def lvar_name(field)
end
end

# Return an instance variable name for use in generated code
def iv_name(field)
"@#{field.name}"
end

def initialize_required_field(field)
<<~RUBY
#{bounds_check(field, lvar_read(field)).chomp}
#{iv_name(field)} = #{lvar_read(field)}
#{field.iv_name} = #{lvar_read(field)}
RUBY
end

def initialize_enum_field(field)
"#{iv_name(field)} = #{enum_name(field)}.resolve(#{field.name}) || #{lvar_read(field)}"
"#{field.iv_name} = #{enum_name(field)}.resolve(#{field.name}) || #{lvar_read(field)}"
end

def extra_api
Expand Down Expand Up @@ -1114,10 +1107,10 @@ def oneof_field_readers
def decode_from(buff, index, len)
<%= init_bitmask(message) %>
<%- for field in @oneof_selection_fields -%>
<%= iv_name(field) %> = nil # oneof field
<%= field.iv_name %> = nil # oneof field
<%- end -%>
<%- for field in fields -%>
<%= iv_name(field) %> = <%= default_for(field) %>
<%= field.iv_name %> = <%= default_for(field) %>
<%- end -%>
return self if index >= len
Expand Down Expand Up @@ -1172,11 +1165,11 @@ def decode_from(buff, index, len)
found = false
<%- fields.each do |field| -%>
<%- if !field.has_oneof_index? || optional_field?(field) -%>
<%- if !field.has_oneof_index? || field.optional? -%>
if tag == <%= tag_for_field(field, field.number) %>
found = true
<%= decode_code(field) %>
<%= set_bitmask(field) if optional_field?(field) %>
<%= set_bitmask(field) if field.optional? %>
return self if index >= len
<%- if !reads_next_tag?(field) -%>
<%= pull_tag %>
Expand All @@ -1201,7 +1194,7 @@ def decode_from(buff, index, len)
PACKED_REPEATED = ERB.new(<<~ERB)
<%= pull_uint64("value", "=") %>
goal = index + value
list = <%= iv_name(field) %>
list = <%= field.iv_name %>
while true
break if index >= goal
<%= decode_subtype(field, field.type, "list", "<<") %>
Expand All @@ -1222,7 +1215,7 @@ def pull_tag

def default_for(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
"{}"
else
"[]"
Expand All @@ -1247,25 +1240,9 @@ def default_for(field)
end
end

def map_field?(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.any? { |type| type.name == map_name && type.options&.map_entry }
end

def map_type(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.find do |type|
type.name == map_name && type.options&.map_entry
end || raise(ArgumentError, "Not a map field")
end

def initialize_signature
fields.flat_map do |f|
if f.has_oneof_index? || optional_field?(f)
if f.has_oneof_index? || f.optional?
"#{lvar_name(f)}: nil"
else
"#{lvar_name(f)}: #{default_for(f)}"
Expand Down Expand Up @@ -1330,18 +1307,18 @@ def pull_fixed_int32(dest, operator)
end

def decode_map(field)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
## PULL_MAP
map = #{iv_name(field)}
map = #{field.iv_name}
while tag == #{tag_for_field(field, field.number)}
#{pull_uint64("value", "=")}
index += 1 # skip the tag, assume it's the key
return self if index >= len
#{decode_subtype(map_type.field[0], map_type.field[0].type, "key", "=")}
#{decode_subtype(map_type.key, map_type.key.type, "key", "=")}
index += 1 # skip the tag, assume it's the value
#{decode_subtype(map_type.field[1], map_type.field[1].type, "map[key]", "=")}
#{decode_subtype(map_type.value, map_type.value.type, "map[key]", "=")}
return self if index >= len
#{pull_tag}
end
Expand All @@ -1351,7 +1328,7 @@ def decode_map(field)
def decode_repeated(field)
<<~RUBY
## DECODE REPEATED
list = #{iv_name(field)}
list = #{field.iv_name}
while true
#{decode_subtype(field, field.type, "list", "<<")}
return self if index >= len
Expand Down Expand Up @@ -1573,12 +1550,12 @@ def translate_well_known(type)
end

def pull_message(type, dest, operator)
type = translate_well_known(type)
translated_type = translate_well_known(type)

<<~RUBY
## PULL_MESSAGE
#{pull_uint64("msg_len", "=")}
#{dest} #{operator} #{class_name(type)}.allocate.decode_from(buff, index, index += msg_len)
#{dest} #{operator} #{class_name(translated_type)}.allocate.decode_from(buff, index, index += msg_len)
## END PULL_MESSAGE
RUBY
end
Expand Down Expand Up @@ -1646,22 +1623,18 @@ def pull_boolean(dest, operator)

def decode_code(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
decode_map(field)
elsif CodeGen.packed?(field)
PACKED_REPEATED.result(binding)
else
decode_repeated(field)
end
else
decode_subtype(field, field.type, iv_name(field), "=")
decode_subtype(field, field.type, field.iv_name, "=")
end
end

def required_fields(msg)
msg.fields.select(&:field?).reject(&:optional?)
end

def init_bitmask(msg)
optionals = optional_fields

Expand Down Expand Up @@ -1705,11 +1678,7 @@ def test_bitmask(field)
end

def reads_next_tag?(field)
map_field?(field) || (repeated?(field) && !CodeGen.packed?(field))
end

def repeated?(field)
field.label == :LABEL_REPEATED
field.map_field? || (field.repeated? && !CodeGen.packed?(field))
end
end

Expand Down
Loading

0 comments on commit 99ad480

Please sign in to comment.