Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ErrorDetails an enum with associated values BadRequest/ErrorInfo #106

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 155 additions & 99 deletions Sources/GoogleAI/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,144 @@

import Foundation

struct RPCError: Error {
let httpResponseCode: Int
let message: String
let status: RPCStatus
let details: [ErrorDetails]
struct ServerError: Error {
enum ErrorDetails {
case badRequest(BadRequest)
case errorInfo(ErrorInfo)
case unknown(String)

struct BadRequest {
static let type = "type.googleapis.com/google.rpc.BadRequest"

struct FieldViolation: Decodable {
let field: String?
let description: String?
}

let type: String
let fieldViolations: [FieldViolation]
}

struct ErrorInfo {
static let type = "type.googleapis.com/google.rpc.ErrorInfo"

let type: String
let reason: String?
let domain: String?
}
}

enum Status: String, Decodable {
// Not an error; returned on success.
case ok = "OK"

// The operation was cancelled, typically by the caller.
case cancelled = "CANCELLED"

// Unknown error.
case unknown = "UNKNOWN"

// The client specified an invalid argument.
case invalidArgument = "INVALID_ARGUMENT"

// The deadline expired before the operation could complete.
case deadlineExceeded = "DEADLINE_EXCEEDED"

// Some requested entity (e.g., file or directory) was not found.
case notFound = "NOT_FOUND"

// The entity that a client attempted to create (e.g., file or directory) already exists.
case alreadyExists = "ALREADY_EXISTS"

// The caller does not have permission to execute the specified operation.
case permissionDenied = "PERMISSION_DENIED"

// The request does not have valid authentication credentials for the operation.
case unauthenticated = "UNAUTHENTICATED"

// Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system
// is out of space.
case resourceExhausted = "RESOURCE_EXHAUSTED"

// The operation was rejected because the system is not in a state required for the operation's
// execution.
case failedPrecondition = "FAILED_PRECONDITION"

// The operation was aborted, typically due to a concurrency issue such as a sequencer check
// failure or transaction abort.
case aborted = "ABORTED"

// The operation was attempted past the valid range.
case outOfRange = "OUT_OF_RANGE"

// The operation is not implemented or is not supported/enabled in this service.
case unimplemented = "UNIMPLEMENTED"

// Internal errors.
case internalError = "INTERNAL"

private var errorInfo: ErrorDetails? {
return details.first { $0.isErrorInfo() }
// The service is currently unavailable.
case unavailable = "UNAVAILABLE"

// Unrecoverable data loss or corruption.
case dataLoss = "DATA_LOSS"
}

init(httpResponseCode: Int, message: String, status: RPCStatus, details: [ErrorDetails]) {
self.httpResponseCode = httpResponseCode
let code: Int
let message: String
let status: Status
let details: [ErrorDetails]

init(httpResponseCode: Int, message: String, status: Status, details: [ErrorDetails]) {
code = httpResponseCode
self.message = message
self.status = status
self.details = details
}

func isInvalidAPIKeyError() -> Bool {
return errorInfo?.reason == "API_KEY_INVALID"
return details.contains { errorDetails in
switch errorDetails {
case let .errorInfo(errorInfo):
return errorInfo.reason == "API_KEY_INVALID"
default:
return false
}
}
}

func isUnsupportedUserLocationError() -> Bool {
return message == RPCErrorMessage.unsupportedUserLocation.rawValue
}
}

extension RPCError: Decodable {
enum InvalidCandidateError: Error {
case emptyContent(underlyingError: Error)
case malformedContent(underlyingError: Error)
}

// MARK: - Decodable Conformance

extension ServerError: Decodable {
enum CodingKeys: CodingKey {
case error
}

struct ErrorStatus {
let code: Int?
let message: String?
let status: ServerError.Status?
let details: [ServerError.ErrorDetails]
}

init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let status = try container.decode(ErrorStatus.self, forKey: .error)

if let code = status.code {
httpResponseCode = code
self.code = code
} else {
httpResponseCode = -1
code = -1
}

if let message = status.message {
Expand All @@ -71,34 +170,7 @@ extension RPCError: Decodable {
}
}

struct ErrorStatus {
let code: Int?
let message: String?
let status: RPCStatus?
let details: [ErrorDetails]
}

struct ErrorDetails {
static let errorInfoType = "type.googleapis.com/google.rpc.ErrorInfo"

let type: String
let reason: String?
let domain: String?

func isErrorInfo() -> Bool {
return type == ErrorDetails.errorInfoType
}
}

extension ErrorDetails: Decodable, Equatable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case reason
case domain
}
}

extension ErrorStatus: Decodable {
extension ServerError.ErrorStatus: Decodable {
enum CodingKeys: CodingKey {
case code
case message
Expand All @@ -111,79 +183,63 @@ extension ErrorStatus: Decodable {
code = try container.decodeIfPresent(Int.self, forKey: .code)
message = try container.decodeIfPresent(String.self, forKey: .message)
do {
status = try container.decodeIfPresent(RPCStatus.self, forKey: .status)
status = try container.decodeIfPresent(ServerError.Status.self, forKey: .status)
} catch {
status = .unknown
}
if container.contains(.details) {
details = try container.decode([ErrorDetails].self, forKey: .details)
details = try container.decode([ServerError.ErrorDetails].self, forKey: .details)
} else {
details = []
}
}
}

enum RPCStatus: String, Decodable {
// Not an error; returned on success.
case ok = "OK"

// The operation was cancelled, typically by the caller.
case cancelled = "CANCELLED"

// Unknown error.
case unknown = "UNKNOWN"

// The client specified an invalid argument.
case invalidArgument = "INVALID_ARGUMENT"

// The deadline expired before the operation could complete.
case deadlineExceeded = "DEADLINE_EXCEEDED"

// Some requested entity (e.g., file or directory) was not found.
case notFound = "NOT_FOUND"

// The entity that a client attempted to create (e.g., file or directory) already exists.
case alreadyExists = "ALREADY_EXISTS"

// The caller does not have permission to execute the specified operation.
case permissionDenied = "PERMISSION_DENIED"

// The request does not have valid authentication credentials for the operation.
case unauthenticated = "UNAUTHENTICATED"

// Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system
// is out of space.
case resourceExhausted = "RESOURCE_EXHAUSTED"

// The operation was rejected because the system is not in a state required for the operation's
// execution.
case failedPrecondition = "FAILED_PRECONDITION"

// The operation was aborted, typically due to a concurrency issue such as a sequencer check
// failure or transaction abort.
case aborted = "ABORTED"

// The operation was attempted past the valid range.
case outOfRange = "OUT_OF_RANGE"

// The operation is not implemented or is not supported/enabled in this service.
case unimplemented = "UNIMPLEMENTED"
extension ServerError.ErrorDetails: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
}

// Internal errors.
case internalError = "INTERNAL"
init(from decoder: Decoder) throws {
let errorDetailsContainer = try decoder.container(keyedBy: CodingKeys.self)
let type = try errorDetailsContainer.decode(String.self, forKey: .type)
if type == BadRequest.type {
let badRequestContainer = try decoder.singleValueContainer()
let badRequest = try badRequestContainer.decode(BadRequest.self)
self = ServerError.ErrorDetails.badRequest(badRequest)
} else if type == ErrorInfo.type {
let errorInfoContainer = try decoder.singleValueContainer()
let errorInfo = try errorInfoContainer.decode(ErrorInfo.self)
self = ServerError.ErrorDetails.errorInfo(errorInfo)
} else {
self = ServerError.ErrorDetails.unknown(type)
}
}
}

// The service is currently unavailable.
case unavailable = "UNAVAILABLE"
extension ServerError.ErrorDetails.BadRequest: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case fieldViolations
}

// Unrecoverable data loss or corruption.
case dataLoss = "DATA_LOSS"
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
type = try container.decode(String.self, forKey: .type)
fieldViolations = try container.decode([FieldViolation].self, forKey: .fieldViolations)
}
}

enum RPCErrorMessage: String {
case unsupportedUserLocation = "User location is not supported for the API use."
extension ServerError.ErrorDetails.ErrorInfo: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case reason
case domain
}
}

enum InvalidCandidateError: Error {
case emptyContent(underlyingError: Error)
case malformedContent(underlyingError: Error)
// MARK: - Private

private enum RPCErrorMessage: String {
case unsupportedUserLocation = "User location is not supported for the API use."
}
2 changes: 1 addition & 1 deletion Sources/GoogleAI/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ struct GenerativeAIService {

private func parseError(responseData: Data) -> Error {
do {
return try JSONDecoder().decode(RPCError.self, from: responseData)
return try JSONDecoder().decode(ServerError.self, from: responseData)
} catch {
// TODO: Return an error about an unrecognized error payload with the response body
return error
Expand Down
4 changes: 2 additions & 2 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ public final class GenerativeModel {
private static func generateContentError(from error: Error) -> GenerateContentError {
if let error = error as? GenerateContentError {
return error
} else if let error = error as? RPCError, error.isInvalidAPIKeyError() {
} else if let error = error as? ServerError, error.isInvalidAPIKeyError() {
return GenerateContentError.invalidAPIKey
} else if let error = error as? RPCError, error.isUnsupportedUserLocationError() {
} else if let error = error as? ServerError, error.isUnsupportedUserLocationError() {
return GenerateContentError.unsupportedUserLocation
}
return GenerateContentError.internalError(underlying: error)
Expand Down
16 changes: 8 additions & 8 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
} catch let GenerateContentError.internalError(underlying: rpcError as ServerError) {
XCTAssertEqual(rpcError.status, .invalidArgument)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertEqual(rpcError.code, expectedStatusCode)
XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
} catch {
XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
Expand Down Expand Up @@ -333,9 +333,9 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
} catch let GenerateContentError.internalError(underlying: rpcError as ServerError) {
XCTAssertEqual(rpcError.status, .notFound)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertEqual(rpcError.code, expectedStatusCode)
XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
} catch {
XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
Expand Down Expand Up @@ -670,8 +670,8 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertNotNil(content.text)
responseCount += 1
}
} catch let GenerateContentError.internalError(rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 499)
} catch let GenerateContentError.internalError(rpcError as ServerError) {
XCTAssertEqual(rpcError.code, 499)
XCTAssertEqual(rpcError.status, .cancelled)

// Check the content count is correct.
Expand Down Expand Up @@ -814,8 +814,8 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.countTokens("Why is the sky blue?")
XCTFail("Request should not have succeeded.")
} catch let CountTokensError.internalError(rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 404)
} catch let CountTokensError.internalError(rpcError as ServerError) {
XCTAssertEqual(rpcError.code, 404)
XCTAssertEqual(rpcError.status, .notFound)
XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
return
Expand Down
Loading