diff --git a/src/mysql/connection.cr b/src/mysql/connection.cr index bc648d2..9cdf932 100644 --- a/src/mysql/connection.cr +++ b/src/mysql/connection.cr @@ -1,6 +1,29 @@ require "socket" class MySql::Connection < DB::Connection + class ConnectionError < Exception; end + + class UnexpectedPacketError < ConnectionError + getter status : UInt8 + + def initialize(status) + @status = status + super("unexpected packet #{status}") + end + end + + class UnexpectedPacketValueError < ConnectionError + getter attribute : String + getter value : UInt64 + + def initialize(attribute, value) + @attribute = attribute + @value = value + + super("unexpected value for #{attribute}: #{value}") + end + end + record Options, host : String, port : Int32, @@ -9,7 +32,7 @@ class MySql::Connection < DB::Connection initial_catalog : String?, charset : String do def self.from_uri(uri : URI) : Options - host = uri.hostname || raise "no host provided" + host = uri.hostname || raise ConnectionError.new("no host provided") port = uri.port || 3306 username = uri.user password = uri.password @@ -44,7 +67,7 @@ class MySql::Connection < DB::Connection end read_ok_or_err do |packet, status| - raise "packet #{status} not implemented" + raise NotImplementedError.new("packet #{status} not implemented") end rescue IO::Error raise DB::ConnectionRefused.new @@ -117,13 +140,13 @@ class MySql::Connection < DB::Connection # :nodoc: def handle_err_packet(packet) 8.times { packet.read_byte! } - raise packet.read_string + raise ConnectionError.new(packet.read_string) end # :nodoc: def raise_if_err_packet(packet) raise_if_err_packet(packet) do |status| - raise "unexpected packet #{status}" + raise UnexpectedPacketError.new(status) end end @@ -152,14 +175,14 @@ class MySql::Connection < DB::Connection name = packet.read_lenenc_string org_name = packet.read_lenenc_string next_length = packet.read_lenenc_int # length of fixed-length fields, always 0x0c - raise "Unexpected next_length value: #{next_length}." unless next_length == 0x0c + raise UnexpectedPacketValueError.new("next_length", next_length) unless next_length == 0x0c character_set = packet.read_fixed_int(2).to_u16! column_length = packet.read_fixed_int(4).to_u32! column_type = packet.read_fixed_int(1).to_u8! flags = packet.read_fixed_int(2).to_u16! decimal = packet.read_fixed_int(1).to_u8! filler = packet.read_fixed_int(2).to_u16! # filler [00] [00] - raise "Unexpected filler value #{filler}" unless filler == 0x0000 + raise UnexpectedPacketValueError.new("filler", filler) unless filler == 0x0000 target << ColumnSpec.new(catalog, schema, table, org_table, name, org_name, character_set, column_length, column_type, flags, decimal) end diff --git a/src/mysql/read_packet.cr b/src/mysql/read_packet.cr index 56745f4..4151860 100644 --- a/src/mysql/read_packet.cr +++ b/src/mysql/read_packet.cr @@ -1,4 +1,8 @@ class MySql::ReadPacket < IO + class EOFError < IO::EOFError; end + + class UnexpectedIntLengthError < Exception; end + @length : Int32 = 0 @remaining : Int32 = 0 @seq : UInt8 = 0u8 @@ -32,16 +36,16 @@ class MySql::ReadPacket < IO {% if compare_versions(Crystal::VERSION, "0.35.0") == 0 %} def write(slice) : Int64 - raise "not implemented" + raise NotImplementedError.new("not implemented") end {% else %} def write(slice) : Nil - raise "not implemented" + raise NotImplementedError.new("not implemented") end {% end %} def read_byte! - read_byte || raise "Unexpected EOF" + read_byte || raise EOFError.new("Unexpected EOF") end def read_string @@ -88,7 +92,7 @@ class MySql::ReadPacket < IO elsif h == 0xfe read_bytes(UInt64, IO::ByteFormat::LittleEndian) else - raise "Unexpected int length" + raise UnexpectedIntLengthError.new("Unexpected int length") end res.to_u64 diff --git a/src/mysql/types.cr b/src/mysql/types.cr index e90548a..c7aa194 100644 --- a/src/mysql/types.cr +++ b/src/mysql/types.cr @@ -67,29 +67,29 @@ abstract struct MySql::Type end def self.type_for(t) - raise "MySql::Type does not support #{t} values" + raise NotImplementedError.new("MySql::Type does not support #{t} values") end def self.db_any_type - raise "not implemented" + raise NotImplementedError.new("not implemented") end # Writes in packet the value in ProtocolBinary format. # Used when sending query params. def self.write(packet, v) - raise "not supported write" + raise NotImplementedError.new("not supported write") end # Reads from packet a value in ProtocolBinary format of the type # specified by self. def self.read(packet) - raise "not supported read" + raise NotImplementedError.new("not supported read") end # Parse from str a value in TextProtocol format of the type # specified by self. def self.parse(str : ::String) - raise "not supported" + raise NotImplementedError.new("not supported") end # :nodoc: diff --git a/src/mysql/unprepared_statement.cr b/src/mysql/unprepared_statement.cr index 5e87a26..587dc23 100644 --- a/src/mysql/unprepared_statement.cr +++ b/src/mysql/unprepared_statement.cr @@ -16,7 +16,7 @@ class MySql::UnpreparedStatement < DB::Statement end private def perform_exec_or_query(args : Enumerable) - raise "exec/query with args is not supported" if args.size > 0 + raise NotImplementedError.new("exec/query with args is not supported") if args.size > 0 conn = self.conn conn.write_packet do |packet| diff --git a/src/mysql/write_packet.cr b/src/mysql/write_packet.cr index e858452..0fa964b 100644 --- a/src/mysql/write_packet.cr +++ b/src/mysql/write_packet.cr @@ -3,7 +3,7 @@ class MySql::WritePacket < IO end def read(slice) - raise "not implemented" + raise NotImplementedError.new("not implemented") end {% if compare_versions(Crystal::VERSION, "0.35.0") == 0 %}