Skip to content

Commit 8225532

Browse files
authored
Feat/sort dependencies nullable and reasoning backfilling (#26)
## 🚪 Why? Zod schemas of type nullable were returned after creating a workflow as Any or Undefined for Sorbet, which wasn't the intended behavior, we expected nilable. Also, the order was not being considered, so dependencies and dependent structs were not working properly. Last, we missed to add the backfilling for the reasoning_details that was covered by Mastra chat, with this we aim to solve this too. ## 🔑 What? Refactored mastra client to handle reasoning_details, and the schema_struct_string to handle the sorting of dependencies, and having nullables translated the right way. Updated version to the 0.4.2.
1 parent 02e83fd commit 8225532

File tree

5 files changed

+354
-34
lines changed

5 files changed

+354
-34
lines changed

Gemfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PATH
22
remote: .
33
specs:
4-
ai (0.4.1)
4+
ai (0.4.2)
55
actionpack (>= 7.1.3)
66
activesupport (>= 7.1.3)
77
json_schemer (~> 2.4.0)

lib/ai/clients/mastra.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def generate(agent_name, messages:, options: {})
7474
parsed_response['response']['body'] = parsed_response['response']['messages']
7575
end
7676

77+
if parsed_response['reasoning']
78+
parsed_response['reasoning_details'] = parsed_response['reasoning']
79+
end
80+
7781
parsed_response
7882
end
7983

lib/ai/schema_to_struct_string.rb

Lines changed: 147 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# typed: strict
2+
# rubocop:disable Sorbet/ForbidTUntyped
23

34
require 'json'
45
require 'active_support/inflector'
@@ -9,6 +10,10 @@ module Ai
910
#
1011
# The resulting definition is returned as a *string* so that it can be
1112
# injected into ERB templates when auto-generating files.
13+
#
14+
# Note: This class uses T.untyped for JSON schema structures as they are
15+
# inherently dynamic and come from external sources. Type safety is maintained
16+
# through runtime checks and the generated output is fully typed.
1217
class SchemaToStructString
1318
extend T::Sig
1419

@@ -23,24 +28,27 @@ def initialize(schema, class_name: 'Input')
2328
@root_class_name = class_name
2429
@generated_classes = T.let(Set.new, T::Set[String])
2530
@nested_definitions = T.let([], T::Array[String])
26-
@schema_definitions = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) # rubocop:disable Sorbet/ForbidTUntyped
27-
@resolved_refs = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) # rubocop:disable Sorbet/ForbidTUntyped
31+
@schema_definitions = T.let({}, T::Hash[String, T::Hash[String, T.untyped]])
32+
@resolved_refs = T.let({}, T::Hash[String, T::Hash[String, T.untyped]])
33+
@dependencies = T.let({}, T::Hash[String, T::Set[String]])
34+
@current_class = T.let(nil, T.nilable(String))
2835
end
2936

3037
sig { returns(String) }
3138
def convert
3239
main_definition = generate_struct(parsed_schema, @root_class_name)
33-
(@nested_definitions + [main_definition]).join("\n\n")
40+
sorted_definitions = topological_sort(@nested_definitions)
41+
(sorted_definitions + [main_definition]).join("\n\n")
3442
end
3543

36-
sig { returns(T::Hash[String, T.untyped]) } # rubocop:disable Sorbet/ForbidTUntyped
44+
sig { returns(T::Hash[String, T.untyped]) }
3745
def parsed_schema
3846
return @parsed_schema if @parsed_schema
3947

40-
full_schema = T.let(JSON.parse(@schema), T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped
48+
full_schema = T.let(JSON.parse(@schema), T::Hash[String, T.untyped])
4149

4250
if full_schema.key?('json')
43-
@parsed_schema = T.let(full_schema['json'], T.nilable(T::Hash[String, T.untyped])) # rubocop:disable Sorbet/ForbidTUntyped
51+
@parsed_schema = T.let(full_schema['json'], T.nilable(T::Hash[String, T.untyped]))
4452
elsif full_schema.key?('$defs') || full_schema.key?('definitions')
4553
@schema_definitions = full_schema['$defs'] || full_schema['definitions'] || {}
4654
@parsed_schema = full_schema
@@ -53,13 +61,11 @@ def parsed_schema
5361
raise ArgumentError, "Invalid JSON schema provided: #{e.message}"
5462
end
5563

56-
# rubocop:disable Sorbet/ForbidTUntyped
5764
sig do
5865
params(schema_hash: T::Hash[T.any(Symbol, String), T.untyped]).returns(
5966
T::Hash[T.any(Symbol, String), T.untyped]
6067
)
6168
end
62-
# rubocop:enable Sorbet/ForbidTUntyped
6369
def resolve_ref(schema_hash)
6470
ref = schema_hash['$ref']
6571
return schema_hash unless ref
@@ -84,18 +90,17 @@ def resolve_ref(schema_hash)
8490

8591
return schema_hash unless resolved
8692

87-
@resolved_refs[ref] = T.cast(resolved, T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped
93+
@resolved_refs[ref] = T.cast(resolved, T::Hash[String, T.untyped])
8894
resolved
8995
end
9096

9197
sig do
92-
params(
93-
schema: T.untyped, # rubocop:disable Sorbet/ForbidTUntyped
94-
parts: T::Array[String]
95-
).returns(T.nilable(T::Hash[T.any(Symbol, String), T.untyped])) # rubocop:disable Sorbet/ForbidTUntyped
98+
params(schema: T.untyped, parts: T::Array[String]).returns(
99+
T.nilable(T::Hash[T.any(Symbol, String), T.untyped])
100+
)
96101
end
97102
def navigate_schema_path(schema, parts)
98-
current = T.let(schema, T.untyped) # rubocop:disable Sorbet/ForbidTUntyped
103+
current = T.let(schema, T.untyped)
99104

100105
parts.each_with_index do |part, _index|
101106
return nil if current.nil?
@@ -129,53 +134,103 @@ def navigate_schema_path(schema, parts)
129134

130135
sig do
131136
params(
132-
schema_hash: T::Hash[T.any(Symbol, String), T.untyped], # rubocop:disable Sorbet/ForbidTUntyped
137+
schema_hash: T::Hash[T.any(Symbol, String), T.untyped],
133138
class_name: String,
134139
depth: Integer
135140
).returns(String)
136141
end
137142
def generate_struct(schema_hash, class_name, depth = 0)
138-
properties = T.let(schema_hash.fetch('properties', {}), T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped
143+
properties = T.let(schema_hash.fetch('properties', {}), T::Hash[String, T.untyped])
139144
required = T.let(schema_hash.fetch('required', []), T::Array[String])
140145

146+
previous_class = @current_class
147+
@current_class = class_name
148+
@dependencies[class_name] ||= Set.new
149+
141150
lines = []
142151
lines << "class #{class_name} < T::Struct"
143152

144153
properties.each do |prop_name, prop_schema|
145154
prop_type = sorbet_type(prop_name, prop_schema, depth)
146-
prop_type = "T.nilable(#{prop_type})" unless required.include?(prop_name) ||
147-
prop_type == 'T.untyped'
155+
156+
extract_class_dependencies(prop_type).each { |dep| add_dependency(dep) }
157+
158+
unless required.include?(prop_name) || prop_type == 'T.untyped' ||
159+
prop_type.start_with?('T.nilable(')
160+
prop_type = "T.nilable(#{prop_type})"
161+
end
148162

149163
comment = build_comment(prop_schema)
150164
lines << " #{comment}" if comment
151165
lines << " const :#{prop_name}, #{prop_type}"
152166
end
153167

154168
lines << 'end'
169+
170+
@current_class = previous_class
171+
155172
lines.join("\n")
156173
end
157174

158175
sig do
159176
params(
160177
prop_name: T.any(Symbol, String),
161-
prop_schema: T::Hash[T.any(Symbol, String), T.untyped], # rubocop:disable Sorbet/ForbidTUntyped
178+
prop_schema: T::Hash[T.any(Symbol, String), T.untyped],
162179
depth: Integer
163180
).returns(String)
164181
end
165182
def sorbet_type(prop_name, prop_schema, depth) # rubocop:disable Metrics/CyclomaticComplexity
166183
resolved_schema = resolve_ref(prop_schema)
167-
type = T.unsafe(resolved_schema['type'] || resolved_schema[:type]) # rubocop:disable Sorbet/ForbidTUnsafe
168-
169-
if type.is_a?(Array)
170-
non_null = type.reject { |t| t == 'null' }
171-
ruby_types =
172-
non_null
173-
.map { |t| sorbet_type(prop_name, resolved_schema.merge('type' => t), depth) }
174-
.uniq
175-
return "T.any(#{ruby_types.join(', ')})"
184+
185+
# Handle anyOf pattern for nullable types (e.g., from Zod's .nullable())
186+
any_of_value = resolved_schema['anyOf']
187+
if any_of_value.is_a?(Array)
188+
non_null_schemas = any_of_value.select { |s| s.is_a?(Hash) && s['type'] != 'null' }
189+
190+
if non_null_schemas.length == 1 && non_null_schemas.length < any_of_value.length
191+
# It's a nullable type: anyOf with exactly one non-null type
192+
first_schema = T.cast(non_null_schemas.first, T::Hash[T.any(Symbol, String), T.untyped])
193+
inner_type = sorbet_type(prop_name, first_schema, depth)
194+
return "T.nilable(#{inner_type})"
195+
elsif non_null_schemas.length > 1
196+
# Multiple non-null types in union
197+
ruby_types =
198+
non_null_schemas
199+
.map do |schema|
200+
sorbet_type(
201+
prop_name,
202+
T.cast(schema, T::Hash[T.any(Symbol, String), T.untyped]),
203+
depth
204+
)
205+
end
206+
.uniq
207+
base_type = "T.any(#{ruby_types.join(', ')})"
208+
has_null = any_of_value.any? { |s| s.is_a?(Hash) && s['type'] == 'null' }
209+
return has_null ? "T.nilable(#{base_type})" : base_type
210+
end
211+
end
212+
213+
# Get the type field, which can be a string or array
214+
type_value = resolved_schema['type'] || resolved_schema[:type]
215+
216+
if type_value.is_a?(Array)
217+
non_null = type_value.reject { |t| t == 'null' }
218+
219+
if non_null.length == 1 && non_null.length < type_value.length
220+
inner_type =
221+
sorbet_type(prop_name, resolved_schema.merge('type' => non_null.first), depth)
222+
return "T.nilable(#{inner_type})"
223+
elsif non_null.length > 1
224+
ruby_types =
225+
non_null
226+
.map { |t| sorbet_type(prop_name, resolved_schema.merge('type' => t), depth) }
227+
.uniq
228+
base_type = "T.any(#{ruby_types.join(', ')})"
229+
return non_null.length < type_value.length ? "T.nilable(#{base_type})" : base_type
230+
end
176231
end
177232

178-
case type
233+
case type_value
179234
when 'string'
180235
return 'Time' if resolved_schema['format'] == 'date-time'
181236
return 'String' unless resolved_schema.key?('enum')
@@ -204,7 +259,7 @@ def sorbet_type(prop_name, prop_schema, depth) # rubocop:disable Metrics/Cycloma
204259
end
205260
"T::Array[T.any(#{tuple_types.join(', ')})]"
206261
else
207-
items = T.cast(raw_items, T::Hash[T.any(Symbol, String), T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped
262+
items = T.cast(raw_items, T::Hash[T.any(Symbol, String), T.untyped])
208263
"T::Array[#{sorbet_type(prop_name.to_s.singularize, items, depth + 1)}]"
209264
end
210265
when 'object'
@@ -238,7 +293,7 @@ def generate_enum(class_name, values)
238293
lines.join("\n")
239294
end
240295

241-
sig { params(prop_schema: T::Hash[String, T.untyped]).returns(T.nilable(String)) } # rubocop:disable Sorbet/ForbidTUntyped
296+
sig { params(prop_schema: T::Hash[String, T.untyped]).returns(T.nilable(String)) }
242297
def build_comment(prop_schema)
243298
keys_in_order = %w[
244299
minLength
@@ -269,5 +324,66 @@ def build_comment(prop_schema)
269324

270325
"# #{entries.join(', ')}"
271326
end
327+
328+
sig { params(type_string: String).returns(T::Set[String]) }
329+
def extract_class_dependencies(type_string)
330+
dependencies = Set.new
331+
332+
type_string.scan(/(?<![T.])\b([A-Z][a-zA-Z0-9_]*(?:Enum)?)\b/) do |match|
333+
class_name = match[0]
334+
unless %w[String Integer Float Time Boolean NilClass Array Hash].include?(class_name)
335+
dependencies.add(class_name)
336+
end
337+
end
338+
339+
dependencies
340+
end
341+
342+
sig { params(dependency_class: String).void }
343+
def add_dependency(dependency_class)
344+
return unless @current_class
345+
346+
@dependencies[@current_class] ||= Set.new
347+
T.must(@dependencies[@current_class]).add(dependency_class)
348+
end
349+
350+
sig { params(definitions: T::Array[String]).returns(T::Array[String]) }
351+
def topological_sort(definitions)
352+
class_to_def = T.let({}, T::Hash[String, String])
353+
definitions.each do |defn|
354+
match = defn.match(/class\s+([A-Z][a-zA-Z0-9_]*)/)
355+
next unless match
356+
357+
class_name = T.must(match[1])
358+
class_to_def[class_name] = defn
359+
end
360+
361+
sorted = T.let([], T::Array[String])
362+
visited = T.let(Set.new, T::Set[String])
363+
visiting = T.let(Set.new, T::Set[String])
364+
365+
visit = T.let(nil, T.nilable(T.proc.params(class_name: String).void))
366+
visit =
367+
lambda do |class_name|
368+
next if visited.include?(class_name)
369+
370+
next if visiting.include?(class_name)
371+
372+
visiting.add(class_name)
373+
374+
deps = @dependencies[class_name] || Set.new
375+
deps.each { |dep| T.must(visit).call(dep) if class_to_def.key?(dep) }
376+
377+
visiting.delete(class_name)
378+
visited.add(class_name)
379+
defn = class_to_def[class_name]
380+
sorted << defn if defn
381+
end
382+
383+
class_to_def.keys.each { |class_name| visit.call(class_name) }
384+
385+
sorted
386+
end
272387
end
273388
end
389+
# rubocop:enable Sorbet/ForbidTUntyped

lib/ai/version.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# frozen_string_literal: true
33

44
module Ai
5-
VERSION = '0.4.1'
5+
VERSION = '0.4.2'
66
end

0 commit comments

Comments
 (0)