Skip to content

Commit 8ded837

Browse files
committed
Consistently use ModelOptions
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent e39d007 commit 8ded837

File tree

10 files changed

+56
-32
lines changed

10 files changed

+56
-32
lines changed

core/backend/embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
1414
var inferenceModel interface{}
1515
var err error
1616

17-
opts := modelOpts(backendConfig, appConfig, []model.Option{})
17+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
1818

1919
if backendConfig.Backend == "" {
2020
inferenceModel, err = loader.GreedyLoader(opts...)

core/backend/image.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99

1010
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
1111

12-
opts := modelOpts(backendConfig, appConfig, []model.Option{})
12+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
1313

1414
inferenceModel, err := loader.BackendLoader(
1515
opts...,

core/backend/llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
3737
var inferenceModel grpc.Backend
3838
var err error
3939

40-
opts := modelOpts(c, o, []model.Option{})
40+
opts := ModelOptions(c, o, []model.Option{})
4141

4242
if c.Backend != "" {
4343
opts = append(opts, model.WithBackendString(c.Backend))

core/backend/options.go

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/rs/zerolog/log"
1212
)
1313

14-
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
14+
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
1515
name := c.Name
1616
if name == "" {
1717
name = c.Model
@@ -35,9 +35,10 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
3535
threads = so.Threads
3636
}
3737

38-
grpcOpts := GRPCModelOpts(c)
38+
c.Threads = &threads
39+
40+
grpcOpts := grpcModelOpts(c)
3941
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
40-
defOpts = append(defOpts, model.WithThreads(uint32(threads)))
4142

4243
if so.SingleBackend {
4344
defOpts = append(defOpts, model.WithSingleActiveBackend())
@@ -76,26 +77,62 @@ func getSeed(c config.BackendConfig) int32 {
7677
return seed
7778
}
7879

79-
func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
80+
func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
8081
b := 512
8182
if c.Batch != 0 {
8283
b = c.Batch
8384
}
85+
86+
f16 := false
87+
if c.F16 != nil {
88+
f16 = *c.F16
89+
}
90+
91+
embeddings := false
92+
if c.Embeddings != nil {
93+
embeddings = *c.Embeddings
94+
}
95+
96+
lowVRAM := false
97+
if c.LowVRAM != nil {
98+
lowVRAM = *c.LowVRAM
99+
}
100+
101+
mmap := false
102+
if c.MMap != nil {
103+
mmap = *c.MMap
104+
}
105+
106+
ctxSize := 1024
107+
if c.ContextSize != nil {
108+
ctxSize = *c.ContextSize
109+
}
110+
111+
mmlock := false
112+
if c.MMlock != nil {
113+
mmlock = *c.MMlock
114+
}
115+
116+
nGPULayers := 9999999
117+
if c.NGPULayers != nil {
118+
nGPULayers = *c.NGPULayers
119+
}
120+
84121
return &pb.ModelOptions{
85122
CUDA: c.CUDA || c.Diffusers.CUDA,
86123
SchedulerType: c.Diffusers.SchedulerType,
87124
PipelineType: c.Diffusers.PipelineType,
88125
CFGScale: c.Diffusers.CFGScale,
89126
LoraAdapter: c.LoraAdapter,
90127
LoraScale: c.LoraScale,
91-
F16Memory: *c.F16,
128+
F16Memory: f16,
92129
LoraBase: c.LoraBase,
93130
IMG2IMG: c.Diffusers.IMG2IMG,
94131
CLIPModel: c.Diffusers.ClipModel,
95132
CLIPSubfolder: c.Diffusers.ClipSubFolder,
96133
CLIPSkip: int32(c.Diffusers.ClipSkip),
97134
ControlNet: c.Diffusers.ControlNet,
98-
ContextSize: int32(*c.ContextSize),
135+
ContextSize: int32(ctxSize),
99136
Seed: getSeed(c),
100137
NBatch: int32(b),
101138
NoMulMatQ: c.NoMulMatQ,
@@ -117,16 +154,16 @@ func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
117154
YarnBetaSlow: c.YarnBetaSlow,
118155
NGQA: c.NGQA,
119156
RMSNormEps: c.RMSNormEps,
120-
MLock: *c.MMlock,
157+
MLock: mmlock,
121158
RopeFreqBase: c.RopeFreqBase,
122159
RopeScaling: c.RopeScaling,
123160
Type: c.ModelType,
124161
RopeFreqScale: c.RopeFreqScale,
125162
NUMA: c.NUMA,
126-
Embeddings: *c.Embeddings,
127-
LowVRAM: *c.LowVRAM,
128-
NGPULayers: int32(*c.NGPULayers),
129-
MMap: *c.MMap,
163+
Embeddings: embeddings,
164+
LowVRAM: lowVRAM,
165+
NGPULayers: int32(nGPULayers),
166+
MMap: mmap,
130167
MainGPU: c.MainGPU,
131168
Threads: int32(*c.Threads),
132169
TensorSplit: c.TensorSplit,

core/backend/rerank.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
1313

14-
opts := modelOpts(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
14+
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
1515
rerankModel, err := loader.BackendLoader(opts...)
1616
if err != nil {
1717
return nil, err

core/backend/soundgeneration.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func SoundGeneration(
2525
backendConfig config.BackendConfig,
2626
) (string, *proto.Result, error) {
2727

28-
opts := modelOpts(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
28+
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
2929

3030
soundGenModel, err := loader.BackendLoader(opts...)
3131
if err != nil {

core/backend/transcript.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
1818
backendConfig.Backend = model.WhisperBackend
1919
}
2020

21-
opts := modelOpts(backendConfig, appConfig, []model.Option{})
21+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
2222

2323
transcriptionModel, err := ml.BackendLoader(opts...)
2424
if err != nil {

core/backend/tts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func ModelTTS(
2828
bb = model.PiperBackend
2929
}
3030

31-
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
31+
opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
3232
model.WithBackendString(bb),
3333
model.WithModel(modelFile),
3434
})

core/startup/startup.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
160160

161161
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
162162

163-
grpcOpts := backend.GRPCModelOpts(*cfg)
164-
o := []model.Option{
165-
model.WithModel(cfg.Model),
166-
model.WithAssetDir(options.AssetsDestination),
167-
model.WithThreads(uint32(options.Threads)),
168-
model.WithLoadGRPCLoadModelOpts(grpcOpts),
169-
}
163+
o := backend.ModelOptions(*cfg, options, []model.Option{})
170164

171165
var backendErr error
172166
if cfg.Backend != "" {

pkg/model/options.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ type Options struct {
1010
backendString string
1111
model string
1212
modelID string
13-
threads uint32
1413
assetDir string
1514
context context.Context
1615

@@ -69,12 +68,6 @@ func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
6968
}
7069
}
7170

72-
func WithThreads(threads uint32) Option {
73-
return func(o *Options) {
74-
o.threads = threads
75-
}
76-
}
77-
7871
func WithAssetDir(assetDir string) Option {
7972
return func(o *Options) {
8073
o.assetDir = assetDir

0 commit comments

Comments
 (0)