Skip to content

Commit

Permalink
Merge pull request #57 from MacPaw/streaming
Browse files Browse the repository at this point in the history
Add streaming session and ability to use streaming
  • Loading branch information
Krivoblotsky authored May 15, 2023
2 parents eefb14b + 6bb1456 commit 12ba7f4
Show file tree
Hide file tree
Showing 21 changed files with 474 additions and 45 deletions.
66 changes: 42 additions & 24 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,71 @@ public final class ChatStore: ObservableObject {
}

@MainActor
func sendMessage(_ message: Message, conversationId: Conversation.ID) async {
func sendMessage(
_ message: Message,
conversationId: Conversation.ID,
model: Model
) async {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}
conversations[conversationIndex].messages.append(message)

await completeChat(conversationId: conversationId)
await completeChat(
conversationId: conversationId,
model: model
)
}

@MainActor
func completeChat(conversationId: Conversation.ID) async {
func completeChat(
conversationId: Conversation.ID,
model: Model
) async {
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
return
}

conversationErrors[conversationId] = nil

do {
let response = try await openAIClient.chats(
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
query: ChatQuery(
model: .gpt3_5Turbo,
model: model,
messages: conversation.messages.map { message in
Chat(role: message.role, content: message.content)
}
)
)

guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let existingMessages = conversations[conversationIndex].messages

for completionMessage in response.choices.map(\.message) {
let message = Message(
id: response.id,
role: completionMessage.role,
content: completionMessage.content,
createdAt: Date(timeIntervalSince1970: TimeInterval(response.created))
)

if existingMessages.contains(message) {
continue

for try await partialChatResult in chatsStream {
for choice in partialChatResult.choices {
let existingMessages = conversations[conversationIndex].messages
let message = Message(
id: partialChatResult.id,
role: choice.delta.role ?? .assistant,
content: choice.delta.content ?? "",
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
)
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
// Meld into previous message
let previousMessage = existingMessages[existingMessageIndex]
let combinedMessage = Message(
id: message.id, // id stays the same for different deltas
role: message.role,
content: previousMessage.content + message.content,
createdAt: message.createdAt
)
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
} else {
conversations[conversationIndex].messages.append(message)
}
}
conversations[conversationIndex].messages.append(message)
}

} catch {
conversationErrors[conversationId] = error
}
Expand Down
5 changes: 3 additions & 2 deletions Demo/DemoChat/Sources/UI/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public struct ChatView: View {
DetailView(
conversation: conversation,
error: store.conversationErrors[conversation.id],
sendMessage: { message in
sendMessage: { message, selectedModel in
Task {
await store.sendMessage(
Message(
Expand All @@ -55,7 +55,8 @@ public struct ChatView: View {
content: message,
createdAt: dateProvider()
),
conversationId: conversation.id
conversationId: conversation.id,
model: selectedModel
)
}
}
Expand Down
56 changes: 53 additions & 3 deletions Demo/DemoChat/Sources/UI/DetailView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ import UIKit
#elseif os(macOS)
import AppKit
#endif
import OpenAI
import SwiftUI

struct DetailView: View {
@State var inputText: String = ""
@FocusState private var isFocused: Bool
@State private var showsModelSelectionSheet = false
@State private var selectedChatModel: Model = .gpt3_5Turbo

private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4]

let conversation: Conversation
let error: Error?
let sendMessage: (String) -> Void
let sendMessage: (String, Model) -> Void

private var fillColor: Color {
#if os(iOS)
Expand Down Expand Up @@ -61,6 +66,51 @@ struct DetailView: View {
inputBar(scrollViewProxy: scrollViewProxy)
}
.navigationTitle("Chat")
.safeAreaInset(edge: .top) {
HStack {
Text(
"Model: \(selectedChatModel)"
)
.font(.caption)
.foregroundColor(.secondary)
Spacer()
}
.padding(.horizontal, 16)
.padding(.vertical, 8)
}
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button(action: {
showsModelSelectionSheet.toggle()
}) {
Image(systemName: "cpu")
}
}
}
.confirmationDialog(
"Select model",
isPresented: $showsModelSelectionSheet,
titleVisibility: .visible,
actions: {
ForEach(availableChatModels, id: \.self) { model in
Button {
selectedChatModel = model
} label: {
Text(model)
}
}

Button("Cancel", role: .cancel) {
showsModelSelectionSheet = false
}
},
message: {
Text(
"View https://platform.openai.com/docs/models/overview for details"
)
.font(.caption)
}
)
}
}
}
Expand Down Expand Up @@ -133,7 +183,7 @@ struct DetailView: View {
private func tapSendMessage(
scrollViewProxy: ScrollViewProxy
) {
sendMessage(inputText)
sendMessage(inputText, selectedChatModel)
inputText = ""

// if let lastMessage = conversation.messages.last {
Expand Down Expand Up @@ -206,7 +256,7 @@ struct DetailView_Previews: PreviewProvider {
]
),
error: nil,
sendMessage: { _ in }
sendMessage: { _, _ in }
)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Demo/DemoChat/Sources/UI/ModerationChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public struct ModerationChatView: View {
DetailView(
conversation: store.moderationConversation,
error: store.moderationConversationError,
sendMessage: { message in
sendMessage: { message, _ in
Task {
await store.sendModerationMessage(
Message(
Expand Down
78 changes: 76 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ This repository contains Swift community-maintained implementation over [OpenAI]
- [Usage](#usage)
- [Initialization](#initialization)
- [Completions](#completions)
- [Completions Streaming](#completions-streaming)
- [Chats](#chats)
- [Chats Streaming](#chats-streaming)
- [Images](#images)
- [Audio](#audio)
- [Audio Transcriptions](#audio-transcriptions)
Expand Down Expand Up @@ -146,6 +148,43 @@ let result = try await openAI.completions(query: query)
- index : 0
```

#### Completions Streaming

Completions streaming is available by using `completionsStream` function. Tokens will be sent one-by-one.

**Closures**
```swift
openAI.completionsStream(query: query) { partialResult in
switch partialResult {
case .success(let result):
print(result.choices)
case .failure(let error):
//Handle chunk error here
}
} completion: { error in
//Handle streaming error here
}
```

**Combine**

```swift
openAI
.completionsStream(query: query)
.sink { completion in
//Handle completion result here
} receiveValue: { result in
//Handle chunk here
}.store(in: &cancellables)
```

**Structured concurrency**
```swift
for try await result in openAI.completionsStream(query: query) {
//Handle result here
}
```

Review [Completions Documentation](https://platform.openai.com/docs/api-reference/completions) for more info.

### Chats
Expand Down Expand Up @@ -175,8 +214,6 @@ Using the OpenAI Chat API, you can build your own applications with `gpt-3.5-tur
public let topP: Double?
/// How many chat completion choices to generate for each input message.
public let n: Int?
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only `server-sent events` as they become available, with the stream terminated by a data: [DONE] message.
public let stream: Bool?
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
public let stop: [String]?
/// The maximum number of tokens to generate in the completion.
Expand Down Expand Up @@ -244,6 +281,43 @@ let result = try await openAI.chats(query: query)
- total_tokens : 49
```

#### Chats Streaming

Chats streaming is available by using `chatStream` function. Tokens will be sent one-by-one.

**Closures**
```swift
openAI.chatsStream(query: query) { partialResult in
switch partialResult {
case .success(let result):
print(result.choices)
case .failure(let error):
//Handle chunk error here
}
} completion: { error in
//Handle streaming error here
}
```

**Combine**

```swift
openAI
.chatsStream(query: query)
.sink { completion in
//Handle completion result here
} receiveValue: { result in
//Handle chunk here
}.store(in: &cancellables)
```

**Structured concurrency**
```swift
for try await result in openAI.chatsStream(query: query) {
//Handle result here
}
```

Review [Chat Documentation](https://platform.openai.com/docs/guides/chat) for more info.

### Images
Expand Down
31 changes: 30 additions & 1 deletion Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private var streamingSessions: [NSObject] = []

public let configuration: Configuration

Expand All @@ -59,6 +60,10 @@ final public class OpenAI: OpenAIProtocol {
performRequest(request: JSONRequest<CompletionsResult>(body: query, url: buildURL(path: .completions)), completion: completion)
}

public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result<CompletionsResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<CompletionsResult>(body: query.makeStreamable(), url: buildURL(path: .completions)), onResult: onResult, completion: completion)
}

public func images(query: ImagesQuery, completion: @escaping (Result<ImagesResult, Error>) -> Void) {
performRequest(request: JSONRequest<ImagesResult>(body: query, url: buildURL(path: .images)), completion: completion)
}
Expand All @@ -71,6 +76,10 @@ final public class OpenAI: OpenAIProtocol {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
}

public func chatsStream(query: ChatQuery, onResult: @escaping (Result<ChatStreamResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<ChatResult>(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, completion: completion)
}

public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
performRequest(request: JSONRequest<EditsResult>(body: query, url: buildURL(path: .edits)), completion: completion)
}
Expand Down Expand Up @@ -131,7 +140,27 @@ extension OpenAI {
task.resume()
} catch {
completion(.failure(error))
return
}
}

func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, onResult: @escaping (Result<ResultType, Error>) -> Void, completion: ((Error?) -> Void)?) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
session.onReceiveContent = {_, object in
onResult(.success(object))
}
session.onProcessingError = {_, error in
onResult(.failure(error))
}
session.onComplete = { [weak self] object, error in
self?.streamingSessions.removeAll(where: { $0 == object })
completion?(error)
}
session.perform()
streamingSessions.append(session)
} catch {
completion?(error)
}
}
}
Expand Down
Loading

0 comments on commit 12ba7f4

Please sign in to comment.