-
Notifications
You must be signed in to change notification settings - Fork 11
/
gpt2.go
60 lines (49 loc) · 1.5 KB
/
gpt2.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package gpt2
// #cgo CFLAGS: -I./ggml.cpp/include/ggml/ -I./ggml.cpp/examples/ -I./ggml.cpp/src/
// #cgo CXXFLAGS: -I./ggml.cpp/include/ggml/ -I./ggml.cpp/examples/ -I./ggml.cpp/src/
// #cgo darwin LDFLAGS: -framework Accelerate
// #cgo darwin CXXFLAGS: -std=c++17
// #cgo LDFLAGS: -ltransformers -lm -lstdc++
// #include <gpt2.h>
import "C"
import (
"fmt"
"strings"
"unsafe"
)
type GPT2 struct {
state unsafe.Pointer
}
func New(model string) (*GPT2, error) {
state := C.gpt2_allocate_state()
modelPath := C.CString(model)
result := C.gpt2_bootstrap(modelPath, state)
if result != 0 {
return nil, fmt.Errorf("failed loading model")
}
return &GPT2{state: state}, nil
}
func (l *GPT2) Predict(text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)
input := C.CString(text)
if po.Tokens == 0 {
po.Tokens = 99999999
}
out := make([]byte, po.Tokens)
params := C.gpt2_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch))
ret := C.gpt2_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
if ret != 0 {
return "", fmt.Errorf("inference failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, text)
res = strings.TrimPrefix(res, "\n")
res = strings.TrimSuffix(res, "<|endoftext|>")
C.gpt2_free_params(params)
return res, nil
}
func (l *GPT2) Free() {
C.gpt2_free_model(l.state)
}