|
1 | 1 | package coord
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - "sort" |
5 |
| - "sync" |
| 4 | + "context" |
| 5 | + "errors" |
6 | 6 |
|
| 7 | + "github.com/lemon-mint/coord/pconf" |
7 | 8 | "github.com/lemon-mint/coord/provider"
|
8 | 9 | )
|
9 | 10 |
|
10 |
| -var ( |
11 |
| - llmProvidersMu sync.RWMutex |
12 |
| - llmProviders = make(map[string]provider.LLMProvider) |
| 11 | +var ErrNoSuchProvider = errors.New("coord: no such provider") |
13 | 12 |
|
14 |
| - ttsProvidersMu sync.RWMutex |
15 |
| - ttsProviders = make(map[string]provider.TTSProvider) |
16 |
| - |
17 |
| - embeddingProvidersMu sync.RWMutex |
18 |
| - embeddingProviders = make(map[string]provider.EmbeddingProvider) |
19 |
| -) |
20 |
| - |
21 |
| -// ListLLMProviders returns the names of the registered llm providers. |
22 |
| -func ListLLMProviders() []string { |
| 13 | +func NewLLMClient(ctx context.Context, provider string, configs ...pconf.Config) (provider.LLMClient, error) { |
23 | 14 | llmProvidersMu.RLock()
|
24 | 15 | defer llmProvidersMu.RUnlock()
|
25 |
| - list := make([]string, 0, len(llmProviders)) |
26 |
| - for name := range llmProviders { |
27 |
| - list = append(list, name) |
| 16 | + |
| 17 | + driver, ok := llmProviders[provider] |
| 18 | + if !ok { |
| 19 | + return nil, ErrNoSuchProvider |
28 | 20 | }
|
29 |
| - sort.Strings(list) |
30 |
| - return list |
31 |
| -} |
32 | 21 |
|
33 |
| -// RegisterLLMProvider registers a llm provider. |
34 |
| -func RegisterLLMProvider(name string, p provider.LLMProvider) { |
35 |
| - llmProvidersMu.Lock() |
36 |
| - defer llmProvidersMu.Unlock() |
37 |
| - llmProviders[name] = p |
| 22 | + return driver.NewLLMClient(ctx, configs...) |
38 | 23 | }
|
39 | 24 |
|
40 |
| -// RemoveLLMProvider removes a llm provider. |
41 |
| -func RemoveLLMProvider(name string) { |
42 |
| - llmProvidersMu.Lock() |
43 |
| - defer llmProvidersMu.Unlock() |
44 |
| - delete(llmProviders, name) |
45 |
| -} |
| 25 | +func NewEmbeddingClient(ctx context.Context, provider string, configs ...pconf.Config) (provider.EmbeddingClient, error) { |
| 26 | + embeddingProvidersMu.RLock() |
| 27 | + defer embeddingProvidersMu.RUnlock() |
46 | 28 |
|
47 |
| -// ListTTSProviders returns the names of the registered tts providers. |
48 |
| -func ListTTSProviders() []string { |
49 |
| - ttsProvidersMu.RLock() |
50 |
| - defer ttsProvidersMu.RUnlock() |
51 |
| - list := make([]string, 0, len(ttsProviders)) |
52 |
| - for name := range ttsProviders { |
53 |
| - list = append(list, name) |
| 29 | + driver, ok := embeddingProviders[provider] |
| 30 | + if !ok { |
| 31 | + return nil, ErrNoSuchProvider |
54 | 32 | }
|
55 |
| - sort.Strings(list) |
56 |
| - return list |
57 |
| -} |
58 | 33 |
|
59 |
| -// RegisterTTSProvider registers a tts provider. |
60 |
| -func RegisterTTSProvider(name string, p provider.TTSProvider) { |
61 |
| - ttsProvidersMu.Lock() |
62 |
| - defer ttsProvidersMu.Unlock() |
63 |
| - ttsProviders[name] = p |
| 34 | + return driver.NewEmbeddingClient(ctx, configs...) |
64 | 35 | }
|
65 | 36 |
|
66 |
| -// RemoveTTSProvider removes a tts provider. |
67 |
| -func RemoveTTSProvider(name string) { |
68 |
| - ttsProvidersMu.Lock() |
69 |
| - defer ttsProvidersMu.Unlock() |
70 |
| - delete(ttsProviders, name) |
71 |
| -} |
| 37 | +func NewTTSClient(ctx context.Context, provider string, configs ...pconf.Config) (provider.TTSClient, error) { |
| 38 | + ttsProvidersMu.RLock() |
| 39 | + defer ttsProvidersMu.RUnlock() |
72 | 40 |
|
73 |
| -// ListEmbeddingProviders returns the names of the registered embedding providers. |
74 |
| -func ListEmbeddingProviders() []string { |
75 |
| - embeddingProvidersMu.RLock() |
76 |
| - defer embeddingProvidersMu.RUnlock() |
77 |
| - list := make([]string, 0, len(embeddingProviders)) |
78 |
| - for name := range embeddingProviders { |
79 |
| - list = append(list, name) |
| 41 | + driver, ok := ttsProviders[provider] |
| 42 | + if !ok { |
| 43 | + return nil, ErrNoSuchProvider |
80 | 44 | }
|
81 |
| - sort.Strings(list) |
82 |
| - return list |
83 |
| -} |
84 |
| - |
85 |
| -// RegisterEmbeddingProvider registers an embedding provider. |
86 |
| -func RegisterEmbeddingProvider(name string, p provider.EmbeddingProvider) { |
87 |
| - embeddingProvidersMu.Lock() |
88 |
| - defer embeddingProvidersMu.Unlock() |
89 |
| - embeddingProviders[name] = p |
90 |
| -} |
91 | 45 |
|
92 |
| -// RemoveEmbeddingProvider removes an embedding provider. |
93 |
| -func RemoveEmbeddingProvider(name string) { |
94 |
| - embeddingProvidersMu.Lock() |
95 |
| - defer embeddingProvidersMu.Unlock() |
96 |
| - delete(embeddingProviders, name) |
| 46 | + return driver.NewTTSClient(ctx, configs...) |
97 | 47 | }
|
0 commit comments