Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: serverless context #822

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ Create a Stateful Serverless Function to get the IP and Latency of a domain:

```golang
func Handler(ctx serverless.Context) {
fc, _ := ai.ParseFunctionCallContext(ctx)

var msg Parameter
fc.UnmarshalArguments(&msg)
ctx.ReadLLMArguments(&msg)

// get ip of the domain
ips, _ := net.LookupIP(msg.Domain)
Expand All @@ -120,7 +118,7 @@ func Handler(ctx serverless.Context) {
stats := pinger.Statistics()

val := fmt.Sprintf("domain %s has ip %s with average latency %s", msg.Domain, ips[0], stats.AvgRtt)
fc.Write(val)
ctx.WriteLLMResult(val)
}

```
Expand Down
48 changes: 0 additions & 48 deletions ai/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ai

import (
"encoding/json"
"fmt"

"github.com/yomorun/yomo/serverless"
)
Expand Down Expand Up @@ -58,50 +57,3 @@ func (fco *FunctionCall) FromBytes(b []byte) error {
fco.IsOK = obj.IsOK
return nil
}

// Write writes the result to zipper
func (fco *FunctionCall) Write(result string) error {
fco.Result = result
fco.IsOK = true
buf, err := fco.Bytes()
if err != nil {
return err
}
return fco.ctx.Write(ReducerTag, buf)
}

// WriteErrors writes the error to reducer
func (fco *FunctionCall) WriteErrors(err error) error {
fco.IsOK = false
fco.Error = err.Error()
return fco.Write("")
}

// UnmarshalArguments deserialize Arguments to the parameter object
func (fco *FunctionCall) UnmarshalArguments(v any) error {
return json.Unmarshal([]byte(fco.Arguments), v)
}

// JSONString returns the JSON string of FunctionCallObject
func (fco *FunctionCall) JSONString() string {
b, _ := json.Marshal(fco)
return string(b)
}

// ParseFunctionCallContext creates a new unctionCallObject from the given context
func ParseFunctionCallContext(ctx serverless.Context) (*FunctionCall, error) {
if ctx == nil {
return nil, fmt.Errorf("ai: ctx is nil")
}

if ctx.Data() == nil {
return nil, fmt.Errorf("ai: ctx.Data() is nil")
}

fco := &FunctionCall{
IsOK: true,
}
fco.ctx = ctx
err := fco.FromBytes(ctx.Data())
return fco, err
}
57 changes: 17 additions & 40 deletions ai/function_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/serverless/mock"
)

var jsonStr = "{\"req_id\":\"yYdzyl\",\"arguments\":\"{\\n \\\"sourceTimezone\\\": \\\"America/Los_Angeles\\\",\\n \\\"targetTimezone\\\": \\\"Asia/Singapore\\\",\\n \\\"timeString\\\": \\\"2024-03-25 07:00:00\\\"\\n}\",\"tool_call_id\":\"call_aZrtm5xcLs1qtP0SWo4CZi75\",\"function_name\":\"fn-timezone-converter\",\"is_ok\":false}"
Expand Down Expand Up @@ -41,68 +40,46 @@ func TestFunctionCallBytes(t *testing.T) {
assert.Equal(t, string(bytes), jsonStr, "Original and bytes should be equal")
}

func TestFunctionCallJSONString(t *testing.T) {
// Call JSONString
target := original.JSONString()
assert.Equal(t, jsonStr, target, "Original and target JSON strings should be equal")
}

func TestFunctionCallParseCallContext(t *testing.T) {
t.Run("ctx is nil", func(t *testing.T) {
_, err := ParseFunctionCallContext(nil)
assert.Error(t, err)
})

func TestReadFunctionCall(t *testing.T) {
t.Run("ctx.Data is nil", func(t *testing.T) {
ctx := mock.NewMockContext(nil, 0)
_, err := ParseFunctionCallContext(ctx)
ctx := NewMockContext(nil, 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(fnCall)
assert.Error(t, err)
})

t.Run("ctx.Data is invalid", func(t *testing.T) {
ctx := mock.NewMockContext([]byte(errJSONStr), 0)
_, err := ParseFunctionCallContext(ctx)
ctx := NewMockContext([]byte(errJSONStr), 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(&fnCall)
assert.Error(t, err)
})
}

func TestFunctionCallUnmarshalArguments(t *testing.T) {
// Unmarshal the arguments into a map
func TestReadLLMArguments(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)
target := make(map[string]string)
err := original.UnmarshalArguments(&target)
err := ctx.ReadLLMArguments(&target)

assert.NoError(t, err)
assert.Equal(t, "America/Los_Angeles", target["sourceTimezone"])
assert.Equal(t, "Asia/Singapore", target["targetTimezone"])
assert.Equal(t, "2024-03-25 07:00:00", target["timeString"])
}

func TestFunctionCallWrite(t *testing.T) {
ctx := mock.NewMockContext([]byte(jsonStr), 0x10)
func TestWriteLLMResult(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)

fco, err := ParseFunctionCallContext(ctx)
// read
target := make(map[string]string)
err := ctx.ReadLLMArguments(&target)
assert.NoError(t, err)

// Call Write
err = fco.Write("test result")
// write
err = ctx.WriteLLMResult("test result")
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithResult("test result"), string(res[0].Data))
}

func TestFunctionCallWriteErrors(t *testing.T) {
ctx := mock.NewMockContext([]byte(jsonStr), 0x10)

fco, err := ParseFunctionCallContext(ctx)
assert.NoError(t, err)

// Call WriteErrors
err = fco.WriteErrors(fmt.Errorf("test error"))
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithError("test error"), string(res[0].Data))
}
143 changes: 143 additions & 0 deletions ai/mock_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package ai

import (
"encoding/json"
"errors"
"sync"

"github.com/yomorun/yomo/serverless"
"github.com/yomorun/yomo/serverless/guest"
)

var _ serverless.Context = (*MockContext)(nil)

// WriteRecord composes the data, tag and target.
type WriteRecord struct {
Data []byte
Tag uint32
Target string
}

// MockContext mock context.
type MockContext struct {
data []byte
tag uint32
fnCall *FunctionCall

mu sync.Mutex
wrSlice []WriteRecord
}

// NewMockContext returns the mock context.
// the data is that returned by ctx.Data(), the tag is that returned by ctx.Tag().
func NewMockContext(data []byte, tag uint32) *MockContext {
return &MockContext{
data: data,
tag: tag,
}
}

// Data incoming data.
func (c *MockContext) Data() []byte {
return c.data

Check warning on line 42 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L41-L42

Added lines #L41 - L42 were not covered by tests
}

// Tag incoming tag.
func (c *MockContext) Tag() uint32 {
return c.tag

Check warning on line 47 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L46-L47

Added lines #L46 - L47 were not covered by tests
}

// Metadata returns the metadata by the given key.
func (c *MockContext) Metadata(_ string) (string, bool) {
panic("not implemented")

Check warning on line 52 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L51-L52

Added lines #L51 - L52 were not covered by tests
}

// HTTP returns the HTTP interface.H
func (m *MockContext) HTTP() serverless.HTTP {
return &guest.GuestHTTP{}

Check warning on line 57 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L56-L57

Added lines #L56 - L57 were not covered by tests
}

// Write writes the data with the given tag.
func (c *MockContext) Write(tag uint32, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()

Check warning on line 63 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L61-L63

Added lines #L61 - L63 were not covered by tests

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: data,
Tag: tag,
})

Check warning on line 68 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L65-L68

Added lines #L65 - L68 were not covered by tests

return nil

Check warning on line 70 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L70

Added line #L70 was not covered by tests
}

// WriteWithTarget writes the data with the given tag and target.
func (c *MockContext) WriteWithTarget(tag uint32, data []byte, target string) error {
c.mu.Lock()
defer c.mu.Unlock()

Check warning on line 76 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L74-L76

Added lines #L74 - L76 were not covered by tests

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: data,
Tag: tag,
Target: target,
})

Check warning on line 82 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L78-L82

Added lines #L78 - L82 were not covered by tests

return nil

Check warning on line 84 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L84

Added line #L84 was not covered by tests
}

// ReadLLMArguments reads LLM function arguments.
func (c *MockContext) ReadLLMArguments(args any) error {
fnCall := &FunctionCall{}
err := fnCall.FromBytes(c.data)
if err != nil {
return err

Check warning on line 92 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L92

Added line #L92 was not covered by tests
}
// if success, assign the object to the given object
c.fnCall = fnCall
if len(fnCall.Arguments) == 0 && args != nil {
return errors.New("function arguments is empty, can't read to the given object")

Check warning on line 97 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L97

Added line #L97 was not covered by tests
}
return json.Unmarshal([]byte(fnCall.Arguments), args)
}

// WriteLLMResult writes LLM function result.
func (c *MockContext) WriteLLMResult(result string) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.fnCall == nil {
return errors.New("no function call, can't write result")

Check warning on line 108 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L108

Added line #L108 was not covered by tests
}
// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
buf, err := c.fnCall.Bytes()
if err != nil {
return err

Check warning on line 115 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L115

Added line #L115 was not covered by tests
}

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: buf,
Tag: ReducerTag,
})
return nil
}

// ReadLLMFunctionCall reads LLM function call.
func (c *MockContext) ReadLLMFunctionCall(fnCall any) error {
if c.data == nil {
return errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)

Check warning on line 134 in ai/mock_context.go

View check run for this annotation

Codecov / codecov/patch

ai/mock_context.go#L134

Added line #L134 was not covered by tests
}

// RecordsWritten returns the data records be written with `ctx.Write`.
func (c *MockContext) RecordsWritten() []WriteRecord {
c.mu.Lock()
defer c.mu.Unlock()

return c.wrSlice
}
1 change: 1 addition & 0 deletions cli/serverless/golang/templates/main.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
Run: func(cmd *cobra.Command, args []string) {
run(cmd, args)
},
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
}
)

Expand Down
2 changes: 2 additions & 0 deletions core/serverless/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package serverless

import (
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/metadata"
)
Expand All @@ -12,6 +13,7 @@ type Context struct {
tag uint32
md metadata.M
data []byte
fnCall *ai.FunctionCall
}

// NewContext creates a new serverless Context
Expand Down
50 changes: 50 additions & 0 deletions core/serverless/context_llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package serverless

import (
"encoding/json"
"errors"

"github.com/yomorun/yomo/ai"
)

// ReadLLMArguments reads LLM function arguments
func (c *Context) ReadLLMArguments(args any) error {
fnCall := &ai.FunctionCall{}
err := fnCall.FromBytes(c.data)
if err != nil {
return err
}
// if success, assign the object to the given object
c.fnCall = fnCall
if len(fnCall.Arguments) == 0 && args != nil {
return errors.New("function arguments is empty, can't read to the given object")
}
return json.Unmarshal([]byte(fnCall.Arguments), args)
}

// WriteLLMResult writes LLM function result
func (c *Context) WriteLLMResult(result string) error {
if c.fnCall == nil {
return errors.New("no function call, can't write result")
}
// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
buf, err := c.fnCall.Bytes()
if err != nil {
return err
}
return c.Write(ai.ReducerTag, buf)
}

// ReadLLMFunctionCall reads LLM function call
func (c *Context) ReadLLMFunctionCall(fnCall any) error {
if c.data == nil {
return errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*ai.FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)
}
Loading