Skip to content

Commit

Permalink
Introduce a DecoratedField to alleviate feature envy
Browse files Browse the repository at this point in the history
  • Loading branch information
davebenvenuti committed Jan 29, 2025
1 parent 9e25aa6 commit a962cde
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 55 deletions.
76 changes: 25 additions & 51 deletions lib/protoboeuf/codegen.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
require "erb"
require "syntax_tree"
require_relative "codegen_type_helper"
require_relative "decorated_field"

module ProtoBoeuf
class CodeGen
Expand Down Expand Up @@ -93,7 +94,9 @@ def result(message, toplevel_enums, generate_types:, requires:, syntax:, options
def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, options:)
@message = message
@optional_field_bit_lut = []
@fields = @message.field
@fields = @message.field.map do |field|
DecoratedField.new(field:, message:, syntax:)
end
@enum_field_types = toplevel_enums.merge(message.enum_type.group_by(&:name))
@requires = requires
@generate_types = generate_types
Expand All @@ -109,8 +112,8 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt

optional_field_count = 0

message.field.each do |field|
if optional_field?(field)
@fields.each do |field|
if field.optional?
if field.type == :TYPE_ENUM
@enum_fields << field
else
Expand All @@ -137,11 +140,6 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt
end
end

def optional_field?(field)
proto3 = "proto3" == syntax
field.proto3_optional || (field.label == :LABEL_OPTIONAL && !proto3)
end

def result
"class #{message.name}\n" + class_body + "end\n"
end
Expand All @@ -163,7 +161,7 @@ def class_body

def conversion
fields = self.fields.reject do |field|
field.has_oneof_index? && !optional_field?(field)
field.has_oneof_index? && !field.optional?
end

oneofs = @oneof_selection_fields.map do |field|
Expand All @@ -181,7 +179,7 @@ def to_h
end

def convert_field(field)
if repeated?(field)
if field.repeated?
"result['#{field.name}'.to_sym] = #{iv_name(field)}"
elsif field.type == :TYPE_MESSAGE
"result['#{field.name}'.to_sym] = #{iv_name(field)}.to_h"
Expand All @@ -200,7 +198,7 @@ def encode

def encode_subtype(field, value_expr = iv_name(field), tagged = true)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
encode_map(field, value_expr, tagged)
else
encode_repeated(field, value_expr, tagged)
Expand Down Expand Up @@ -271,16 +269,16 @@ def encode_bool(field, value_expr, tagged)
end

def encode_map(field, value_expr, tagged)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
map = #{value_expr}
if map.size > 0
old_buff = buff
map.each do |key, value|
buff = new_buffer = +''
#{encode_subtype(map_type.field[0], "key", true)}
#{encode_subtype(map_type.field[1], "value", true)}
#{encode_subtype(map_type.key, "key", true)}
#{encode_subtype(map_type.value, "value", true)}
buff = old_buff
#{encode_tag_and_length(field, true, "new_buffer.bytesize")}
old_buff.concat(new_buffer)
Expand Down Expand Up @@ -786,7 +784,7 @@ def initialize_code
init_bitmask(message) +
initialize_oneofs +
fields.map { |field|
if field.has_oneof_index? && !optional_field?(field)
if field.has_oneof_index? && !field.optional?
initialize_oneof(field, message)
else
initialize_field(field)
Expand Down Expand Up @@ -815,7 +813,7 @@ def initialize_oneof(field, msg)
end

def initialize_field(field)
if optional_field?(field)
if field.optional?
initialize_optional_field(field)
elsif field.type == :TYPE_ENUM
initialize_enum_field(field)
Expand Down Expand Up @@ -1172,11 +1170,11 @@ def decode_from(buff, index, len)
found = false
<%- fields.each do |field| -%>
<%- if !field.has_oneof_index? || optional_field?(field) -%>
<%- if !field.has_oneof_index? || field.optional? -%>
if tag == <%= tag_for_field(field, field.number) %>
found = true
<%= decode_code(field) %>
<%= set_bitmask(field) if optional_field?(field) %>
<%= set_bitmask(field) if field.optional? %>
return self if index >= len
<%- if !reads_next_tag?(field) -%>
<%= pull_tag %>
Expand Down Expand Up @@ -1222,7 +1220,7 @@ def pull_tag

def default_for(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
"{}"
else
"[]"
Expand All @@ -1247,25 +1245,9 @@ def default_for(field)
end
end

def map_field?(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.any? { |type| type.name == map_name && type.options&.map_entry }
end

def map_type(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.find do |type|
type.name == map_name && type.options&.map_entry
end || raise(ArgumentError, "Not a map field")
end

def initialize_signature
fields.flat_map do |f|
if f.has_oneof_index? || optional_field?(f)
if f.has_oneof_index? || f.optional?
"#{lvar_name(f)}: nil"
else
"#{lvar_name(f)}: #{default_for(f)}"
Expand Down Expand Up @@ -1330,7 +1312,7 @@ def pull_fixed_int32(dest, operator)
end

def decode_map(field)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
## PULL_MAP
Expand All @@ -1339,9 +1321,9 @@ def decode_map(field)
#{pull_uint64("value", "=")}
index += 1 # skip the tag, assume it's the key
return self if index >= len
#{decode_subtype(map_type.field[0], map_type.field[0].type, "key", "=")}
#{decode_subtype(map_type.key, map_type.key.type, "key", "=")}
index += 1 # skip the tag, assume it's the value
#{decode_subtype(map_type.field[1], map_type.field[1].type, "map[key]", "=")}
#{decode_subtype(map_type.value, map_type.value.type, "map[key]", "=")}
return self if index >= len
#{pull_tag}
end
Expand Down Expand Up @@ -1573,12 +1555,12 @@ def translate_well_known(type)
end

def pull_message(type, dest, operator)
type = translate_well_known(type)
translated_type = translate_well_known(type)

<<~RUBY
## PULL_MESSAGE
#{pull_uint64("msg_len", "=")}
#{dest} #{operator} #{class_name(type)}.allocate.decode_from(buff, index, index += msg_len)
#{dest} #{operator} #{class_name(translated_type)}.allocate.decode_from(buff, index, index += msg_len)
## END PULL_MESSAGE
RUBY
end
Expand Down Expand Up @@ -1646,7 +1628,7 @@ def pull_boolean(dest, operator)

def decode_code(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
decode_map(field)
elsif CodeGen.packed?(field)
PACKED_REPEATED.result(binding)
Expand All @@ -1658,10 +1640,6 @@ def decode_code(field)
end
end

def required_fields(msg)
msg.fields.select(&:field?).reject(&:optional?)
end

def init_bitmask(msg)
optionals = optional_fields

Expand Down Expand Up @@ -1705,11 +1683,7 @@ def test_bitmask(field)
end

def reads_next_tag?(field)
map_field?(field) || (repeated?(field) && !CodeGen.packed?(field))
end

def repeated?(field)
field.label == :LABEL_REPEATED
field.map_field? || (field.repeated? && !CodeGen.packed?(field))
end
end

Expand Down
8 changes: 4 additions & 4 deletions lib/protoboeuf/codegen_type_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def convert_type(converted_type, optional: false, array: false)
end

def convert_field_type(field)
converted_type = if map_field?(field)
map_type = self.map_type(field)
"T::Hash[#{convert_field_type(map_type.field[0])}, #{convert_field_type(map_type.field[1])}]"
converted_type = if field.map_field?
map_type = field.map_type
"T::Hash[#{convert_field_type(map_type.key)}, #{convert_field_type(map_type.value)}]"
else
case field.type
when :TYPE_BOOL
Expand All @@ -87,7 +87,7 @@ def convert_field_type(field)
convert_type(
converted_type,
optional: field.label == :TYPE_OPTIONAL,
array: field.label == :LABEL_REPEATED && !map_field?(field),
array: field.label == :LABEL_REPEATED && !field.map_field?,
)
end

Expand Down
53 changes: 53 additions & 0 deletions lib/protoboeuf/decorated_field.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# frozen_string_literal: true

require "forwardable"

# Adds some convenience methods to Google::Protobuf::FieldDescriptorProto
class DecoratedField
attr_reader :original_field, :message, :syntax

extend Forwardable

def_delegators :@original_field, :name, :label, :type_name, :type, :number, :options, :oneof_index, :has_oneof_index?

def initialize(field:, message:, syntax:)
@original_field = field
@message = message
@syntax = syntax
end

def optional?
original_field.proto3_optional || (label == :LABEL_OPTIONAL && !proto3?)
end

def repeated?
label == :LABEL_REPEATED
end

def map_field?
return false unless repeated?

map_name = type_name.split(".").last
message.nested_type.any? { |type| type.name == map_name && type.options&.map_entry }
end

class MapType < Struct.new(:key, :value); end

def map_type
return false unless repeated?

map_name = type_name.split(".").last
message.nested_type.find { |type| type.name == map_name && type.options&.map_entry }.tap do |descriptor|
raise ArgumentError, "Not a map field" if descriptor.nil?

return MapType.new(
key: DecoratedField.new(field: descriptor.field[0], message:, syntax:),
value: DecoratedField.new(field: descriptor.field[1], message:, syntax:),
)
end
end

def proto3?
"proto3" == syntax
end
end
Loading

0 comments on commit a962cde

Please sign in to comment.