Skip to content

Commit

Permalink
refactor: serverless context (#822)
Browse files Browse the repository at this point in the history
# Description

the serverless context supports LLM read and write operations,
simplifying LLM function calls.
  • Loading branch information
venjiang committed May 22, 2024
1 parent 5c9d6ee commit 8960899
Show file tree
Hide file tree
Showing 22 changed files with 637 additions and 225 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ Create a Stateful Serverless Function to get the IP and Latency of a domain:

```golang
func Handler(ctx serverless.Context) {
fc, _ := ai.ParseFunctionCallContext(ctx)

var msg Parameter
fc.UnmarshalArguments(&msg)
ctx.ReadLLMArguments(&msg)

// get ip of the domain
ips, _ := net.LookupIP(msg.Domain)
Expand All @@ -120,7 +118,7 @@ func Handler(ctx serverless.Context) {
stats := pinger.Statistics()

val := fmt.Sprintf("domain %s has ip %s with average latency %s", msg.Domain, ips[0], stats.AvgRtt)
fc.Write(val)
ctx.WriteLLMResult(val)
}

```
Expand Down
48 changes: 0 additions & 48 deletions ai/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ai

import (
"encoding/json"
"fmt"

"github.com/yomorun/yomo/serverless"
)
Expand Down Expand Up @@ -58,50 +57,3 @@ func (fco *FunctionCall) FromBytes(b []byte) error {
fco.IsOK = obj.IsOK
return nil
}

// Write writes the result to zipper
func (fco *FunctionCall) Write(result string) error {
fco.Result = result
fco.IsOK = true
buf, err := fco.Bytes()
if err != nil {
return err
}
return fco.ctx.Write(ReducerTag, buf)
}

// WriteErrors writes the error to reducer
func (fco *FunctionCall) WriteErrors(err error) error {
fco.IsOK = false
fco.Error = err.Error()
return fco.Write("")
}

// UnmarshalArguments deserialize Arguments to the parameter object
func (fco *FunctionCall) UnmarshalArguments(v any) error {
return json.Unmarshal([]byte(fco.Arguments), v)
}

// JSONString returns the JSON string of FunctionCallObject
func (fco *FunctionCall) JSONString() string {
b, _ := json.Marshal(fco)
return string(b)
}

// ParseFunctionCallContext creates a new unctionCallObject from the given context
func ParseFunctionCallContext(ctx serverless.Context) (*FunctionCall, error) {
if ctx == nil {
return nil, fmt.Errorf("ai: ctx is nil")
}

if ctx.Data() == nil {
return nil, fmt.Errorf("ai: ctx.Data() is nil")
}

fco := &FunctionCall{
IsOK: true,
}
fco.ctx = ctx
err := fco.FromBytes(ctx.Data())
return fco, err
}
57 changes: 17 additions & 40 deletions ai/function_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/serverless/mock"
)

var jsonStr = "{\"req_id\":\"yYdzyl\",\"arguments\":\"{\\n \\\"sourceTimezone\\\": \\\"America/Los_Angeles\\\",\\n \\\"targetTimezone\\\": \\\"Asia/Singapore\\\",\\n \\\"timeString\\\": \\\"2024-03-25 07:00:00\\\"\\n}\",\"tool_call_id\":\"call_aZrtm5xcLs1qtP0SWo4CZi75\",\"function_name\":\"fn-timezone-converter\",\"is_ok\":false}"
Expand Down Expand Up @@ -41,68 +40,46 @@ func TestFunctionCallBytes(t *testing.T) {
assert.Equal(t, string(bytes), jsonStr, "Original and bytes should be equal")
}

func TestFunctionCallJSONString(t *testing.T) {
// Call JSONString
target := original.JSONString()
assert.Equal(t, jsonStr, target, "Original and target JSON strings should be equal")
}

func TestFunctionCallParseCallContext(t *testing.T) {
t.Run("ctx is nil", func(t *testing.T) {
_, err := ParseFunctionCallContext(nil)
assert.Error(t, err)
})

func TestReadFunctionCall(t *testing.T) {
t.Run("ctx.Data is nil", func(t *testing.T) {
ctx := mock.NewMockContext(nil, 0)
_, err := ParseFunctionCallContext(ctx)
ctx := NewMockContext(nil, 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(fnCall)
assert.Error(t, err)
})

t.Run("ctx.Data is invalid", func(t *testing.T) {
ctx := mock.NewMockContext([]byte(errJSONStr), 0)
_, err := ParseFunctionCallContext(ctx)
ctx := NewMockContext([]byte(errJSONStr), 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(&fnCall)
assert.Error(t, err)
})
}

func TestFunctionCallUnmarshalArguments(t *testing.T) {
// Unmarshal the arguments into a map
func TestReadLLMArguments(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)
target := make(map[string]string)
err := original.UnmarshalArguments(&target)
err := ctx.ReadLLMArguments(&target)

assert.NoError(t, err)
assert.Equal(t, "America/Los_Angeles", target["sourceTimezone"])
assert.Equal(t, "Asia/Singapore", target["targetTimezone"])
assert.Equal(t, "2024-03-25 07:00:00", target["timeString"])
}

func TestFunctionCallWrite(t *testing.T) {
ctx := mock.NewMockContext([]byte(jsonStr), 0x10)
func TestWriteLLMResult(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)

fco, err := ParseFunctionCallContext(ctx)
// read
target := make(map[string]string)
err := ctx.ReadLLMArguments(&target)
assert.NoError(t, err)

// Call Write
err = fco.Write("test result")
// write
err = ctx.WriteLLMResult("test result")
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithResult("test result"), string(res[0].Data))
}

func TestFunctionCallWriteErrors(t *testing.T) {
ctx := mock.NewMockContext([]byte(jsonStr), 0x10)

fco, err := ParseFunctionCallContext(ctx)
assert.NoError(t, err)

// Call WriteErrors
err = fco.WriteErrors(fmt.Errorf("test error"))
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithError("test error"), string(res[0].Data))
}
143 changes: 143 additions & 0 deletions ai/mock_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package ai

import (
"encoding/json"
"errors"
"sync"

"github.com/yomorun/yomo/serverless"
"github.com/yomorun/yomo/serverless/guest"
)

var _ serverless.Context = (*MockContext)(nil)

// WriteRecord composes the data, tag and target.
type WriteRecord struct {
Data []byte
Tag uint32
Target string
}

// MockContext mock context.
type MockContext struct {
data []byte
tag uint32
fnCall *FunctionCall

mu sync.Mutex
wrSlice []WriteRecord
}

// NewMockContext returns the mock context.
// the data is that returned by ctx.Data(), the tag is that returned by ctx.Tag().
func NewMockContext(data []byte, tag uint32) *MockContext {
return &MockContext{
data: data,
tag: tag,
}
}

// Data incoming data.
func (c *MockContext) Data() []byte {
return c.data
}

// Tag incoming tag.
func (c *MockContext) Tag() uint32 {
return c.tag
}

// Metadata returns the metadata by the given key.
func (c *MockContext) Metadata(_ string) (string, bool) {
panic("not implemented")
}

// HTTP returns the HTTP interface.H
func (m *MockContext) HTTP() serverless.HTTP {
return &guest.GuestHTTP{}
}

// Write writes the data with the given tag.
func (c *MockContext) Write(tag uint32, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: data,
Tag: tag,
})

return nil
}

// WriteWithTarget writes the data with the given tag and target.
func (c *MockContext) WriteWithTarget(tag uint32, data []byte, target string) error {
c.mu.Lock()
defer c.mu.Unlock()

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: data,
Tag: tag,
Target: target,
})

return nil
}

// ReadLLMArguments reads LLM function arguments.
func (c *MockContext) ReadLLMArguments(args any) error {
fnCall := &FunctionCall{}
err := fnCall.FromBytes(c.data)
if err != nil {
return err
}
// if success, assign the object to the given object
c.fnCall = fnCall
if len(fnCall.Arguments) == 0 && args != nil {
return errors.New("function arguments is empty, can't read to the given object")
}
return json.Unmarshal([]byte(fnCall.Arguments), args)
}

// WriteLLMResult writes LLM function result.
func (c *MockContext) WriteLLMResult(result string) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.fnCall == nil {
return errors.New("no function call, can't write result")
}
// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
buf, err := c.fnCall.Bytes()
if err != nil {
return err
}

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: buf,
Tag: ReducerTag,
})
return nil
}

// ReadLLMFunctionCall reads LLM function call.
func (c *MockContext) ReadLLMFunctionCall(fnCall any) error {
if c.data == nil {
return errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)
}

// RecordsWritten returns the data records be written with `ctx.Write`.
func (c *MockContext) RecordsWritten() []WriteRecord {
c.mu.Lock()
defer c.mu.Unlock()

return c.wrSlice
}
1 change: 1 addition & 0 deletions cli/serverless/golang/templates/main.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
Run: func(cmd *cobra.Command, args []string) {
run(cmd, args)
},
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
}
)

Expand Down
2 changes: 2 additions & 0 deletions core/serverless/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package serverless

import (
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/metadata"
)
Expand All @@ -12,6 +13,7 @@ type Context struct {
tag uint32
md metadata.M
data []byte
fnCall *ai.FunctionCall
}

// NewContext creates a new serverless Context
Expand Down
50 changes: 50 additions & 0 deletions core/serverless/context_llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package serverless

import (
"encoding/json"
"errors"

"github.com/yomorun/yomo/ai"
)

// ReadLLMArguments reads LLM function arguments
func (c *Context) ReadLLMArguments(args any) error {
fnCall := &ai.FunctionCall{}
err := fnCall.FromBytes(c.data)
if err != nil {
return err
}
// if success, assign the object to the given object
c.fnCall = fnCall
if len(fnCall.Arguments) == 0 && args != nil {
return errors.New("function arguments is empty, can't read to the given object")
}
return json.Unmarshal([]byte(fnCall.Arguments), args)
}

// WriteLLMResult writes LLM function result
func (c *Context) WriteLLMResult(result string) error {
if c.fnCall == nil {
return errors.New("no function call, can't write result")
}
// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
buf, err := c.fnCall.Bytes()
if err != nil {
return err
}
return c.Write(ai.ReducerTag, buf)
}

// ReadLLMFunctionCall reads LLM function call
func (c *Context) ReadLLMFunctionCall(fnCall any) error {
if c.data == nil {
return errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*ai.FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)
}
Loading

0 comments on commit 8960899

Please sign in to comment.