Skip to content

Commit b6b6f3b

Browse files
authored
feat: generate RAG service YAML config resource (#313)
* feat: generate RAG service YAML config resource
1 parent acaaa66 commit b6b6f3b

8 files changed

Lines changed: 1072 additions & 11 deletions

File tree

server/internal/orchestrator/swarm/orchestrator.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,25 @@ func (o *Orchestrator) generateRAGInstanceResources(spec *database.ServiceInstan
704704
Keys: extractRAGAPIKeys(ragConfig),
705705
}
706706

707-
orchestratorResources = append(orchestratorResources, dataDir, keysResource)
707+
// RAG config resource — generates pgedge-rag-server.yaml in the data directory.
708+
var dbHost string
709+
var dbPort int
710+
if len(spec.DatabaseHosts) > 0 {
711+
dbHost = spec.DatabaseHosts[0].Host
712+
dbPort = spec.DatabaseHosts[0].Port
713+
}
714+
ragConfigRes := &RAGConfigResource{
715+
ServiceInstanceID: spec.ServiceInstanceID,
716+
ServiceID: spec.ServiceSpec.ServiceID,
717+
HostID: spec.HostID,
718+
DirResourceID: dataDirID,
719+
Config: ragConfig,
720+
DatabaseName: spec.DatabaseName,
721+
DatabaseHost: dbHost,
722+
DatabasePort: dbPort,
723+
}
724+
725+
orchestratorResources = append(orchestratorResources, dataDir, keysResource, ragConfigRes)
708726

709727
return o.buildServiceInstanceResources(spec, orchestratorResources)
710728
}
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
package swarm
2+
3+
import (
4+
"fmt"
5+
"path"
6+
7+
"github.com/goccy/go-yaml"
8+
9+
"github.com/pgEdge/control-plane/server/internal/database"
10+
)
11+
12+
// ragYAMLConfig mirrors the pgedge-rag-server Config struct for YAML generation.
13+
// Only the fields the control plane needs to set are included.
14+
type ragYAMLConfig struct {
15+
Server ragServerYAML `yaml:"server"`
16+
Pipelines []ragPipelineYAML `yaml:"pipelines"`
17+
Defaults *ragDefaultsYAML `yaml:"defaults,omitempty"`
18+
}
19+
20+
type ragServerYAML struct {
21+
ListenAddress string `yaml:"listen_address"`
22+
Port int `yaml:"port"`
23+
}
24+
25+
type ragPipelineYAML struct {
26+
Name string `yaml:"name"`
27+
Description string `yaml:"description,omitempty"`
28+
Database ragDatabaseYAML `yaml:"database"`
29+
Tables []ragTableYAML `yaml:"tables"`
30+
EmbeddingLLM ragLLMYAML `yaml:"embedding_llm"`
31+
RAGLLM ragLLMYAML `yaml:"rag_llm"`
32+
APIKeys *ragAPIKeysYAML `yaml:"api_keys,omitempty"`
33+
TokenBudget *int `yaml:"token_budget,omitempty"`
34+
TopN *int `yaml:"top_n,omitempty"`
35+
SystemPrompt string `yaml:"system_prompt,omitempty"`
36+
Search *ragSearchYAML `yaml:"search,omitempty"`
37+
}
38+
39+
type ragDatabaseYAML struct {
40+
Host string `yaml:"host"`
41+
Port int `yaml:"port"`
42+
Database string `yaml:"database"`
43+
Username string `yaml:"username"`
44+
Password string `yaml:"password"`
45+
SSLMode string `yaml:"ssl_mode"`
46+
}
47+
48+
type ragTableYAML struct {
49+
Table string `yaml:"table"`
50+
TextColumn string `yaml:"text_column"`
51+
VectorColumn string `yaml:"vector_column"`
52+
IDColumn string `yaml:"id_column,omitempty"`
53+
}
54+
55+
type ragLLMYAML struct {
56+
Provider string `yaml:"provider"`
57+
Model string `yaml:"model"`
58+
BaseURL string `yaml:"base_url,omitempty"`
59+
}
60+
61+
// ragAPIKeysYAML holds container-side file paths for each provider's API key.
62+
type ragAPIKeysYAML struct {
63+
Anthropic string `yaml:"anthropic,omitempty"`
64+
OpenAI string `yaml:"openai,omitempty"`
65+
Voyage string `yaml:"voyage,omitempty"`
66+
}
67+
68+
type ragSearchYAML struct {
69+
HybridEnabled *bool `yaml:"hybrid_enabled,omitempty"`
70+
VectorWeight *float64 `yaml:"vector_weight,omitempty"`
71+
}
72+
73+
type ragDefaultsYAML struct {
74+
TokenBudget *int `yaml:"token_budget,omitempty"`
75+
TopN *int `yaml:"top_n,omitempty"`
76+
}
77+
78+
// RAGConfigParams holds all inputs needed to generate pgedge-rag-server.yaml.
79+
type RAGConfigParams struct {
80+
Config *database.RAGServiceConfig
81+
DatabaseName string
82+
DatabaseHost string
83+
DatabasePort int
84+
Username string
85+
Password string
86+
// KeysDir is the container-side directory where API key files are mounted,
87+
// e.g. "/app/keys". Key filenames follow the {pipeline}_{embedding|rag}.key
88+
// convention produced by extractRAGAPIKeys.
89+
KeysDir string
90+
}
91+
92+
// GenerateRAGConfig generates the pgedge-rag-server.yaml content from the
93+
// given parameters. API key paths in the generated YAML reference files under
94+
// KeysDir so the RAG server reads them from the bind-mounted keys directory.
95+
func GenerateRAGConfig(params *RAGConfigParams) ([]byte, error) {
96+
pipelines := make([]ragPipelineYAML, 0, len(params.Config.Pipelines))
97+
for _, p := range params.Config.Pipelines {
98+
pl, err := buildRAGPipelineYAML(p, params)
99+
if err != nil {
100+
return nil, err
101+
}
102+
pipelines = append(pipelines, pl)
103+
}
104+
105+
var defaults *ragDefaultsYAML
106+
if params.Config.Defaults != nil {
107+
src := params.Config.Defaults
108+
if src.TokenBudget != nil || src.TopN != nil {
109+
defaults = &ragDefaultsYAML{
110+
TokenBudget: src.TokenBudget,
111+
TopN: src.TopN,
112+
}
113+
}
114+
}
115+
116+
cfg := &ragYAMLConfig{
117+
Server: ragServerYAML{
118+
ListenAddress: "0.0.0.0",
119+
Port: 8080,
120+
},
121+
Pipelines: pipelines,
122+
Defaults: defaults,
123+
}
124+
125+
data, err := yaml.Marshal(cfg)
126+
if err != nil {
127+
return nil, err
128+
}
129+
return data, nil
130+
}
131+
132+
func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) (ragPipelineYAML, error) {
133+
tables := make([]ragTableYAML, 0, len(p.Tables))
134+
for _, t := range p.Tables {
135+
tbl := ragTableYAML{
136+
Table: t.Table,
137+
TextColumn: t.TextColumn,
138+
VectorColumn: t.VectorColumn,
139+
}
140+
if t.IDColumn != nil {
141+
tbl.IDColumn = *t.IDColumn
142+
}
143+
tables = append(tables, tbl)
144+
}
145+
146+
embLLM := ragLLMYAML{
147+
Provider: p.EmbeddingLLM.Provider,
148+
Model: p.EmbeddingLLM.Model,
149+
}
150+
if p.EmbeddingLLM.BaseURL != nil {
151+
embLLM.BaseURL = *p.EmbeddingLLM.BaseURL
152+
}
153+
154+
ragLLM := ragLLMYAML{
155+
Provider: p.RAGLLM.Provider,
156+
Model: p.RAGLLM.Model,
157+
}
158+
if p.RAGLLM.BaseURL != nil {
159+
ragLLM.BaseURL = *p.RAGLLM.BaseURL
160+
}
161+
162+
apiKeys, err := buildRAGAPIKeysYAML(p, params.KeysDir)
163+
if err != nil {
164+
return ragPipelineYAML{}, err
165+
}
166+
167+
pipeline := ragPipelineYAML{
168+
Name: p.Name,
169+
Database: ragDatabaseYAML{
170+
Host: params.DatabaseHost,
171+
Port: params.DatabasePort,
172+
Database: params.DatabaseName,
173+
Username: params.Username,
174+
Password: params.Password,
175+
SSLMode: "prefer",
176+
},
177+
Tables: tables,
178+
EmbeddingLLM: embLLM,
179+
RAGLLM: ragLLM,
180+
APIKeys: apiKeys,
181+
}
182+
183+
if p.Description != nil {
184+
pipeline.Description = *p.Description
185+
}
186+
pipeline.TokenBudget = p.TokenBudget
187+
pipeline.TopN = p.TopN
188+
if p.SystemPrompt != nil {
189+
pipeline.SystemPrompt = *p.SystemPrompt
190+
}
191+
if p.Search != nil {
192+
pipeline.Search = &ragSearchYAML{
193+
HybridEnabled: p.Search.HybridEnabled,
194+
VectorWeight: p.Search.VectorWeight,
195+
}
196+
}
197+
198+
return pipeline, nil
199+
}
200+
201+
// buildRAGAPIKeysYAML maps each LLM provider that requires a key to the
202+
// corresponding bind-mounted key file path inside the container.
203+
// Embedding key: {keysDir}/{pipeline}_embedding.key
204+
// RAG key: {keysDir}/{pipeline}_rag.key
205+
// If embedding and RAG use the same provider, the RAG key path takes precedence
206+
// (both files contain the same value). Returns an error if both LLMs share a
207+
// provider but were configured with different API keys.
208+
func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) (*ragAPIKeysYAML, error) {
209+
// Reject mismatched keys for the same provider — the RAG server has a
210+
// single key slot per provider and cannot reconcile two different values.
211+
if p.EmbeddingLLM.Provider == p.RAGLLM.Provider &&
212+
p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" &&
213+
p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" &&
214+
*p.EmbeddingLLM.APIKey != *p.RAGLLM.APIKey {
215+
return nil, fmt.Errorf("pipeline %q: embedding_llm and rag_llm share provider %q but have different API keys",
216+
p.Name, p.EmbeddingLLM.Provider)
217+
}
218+
219+
keys := &ragAPIKeysYAML{}
220+
221+
// Embedding provider key
222+
if p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" {
223+
keyPath := path.Join(keysDir, p.Name+"_embedding.key")
224+
switch p.EmbeddingLLM.Provider {
225+
case "anthropic":
226+
keys.Anthropic = keyPath
227+
case "openai":
228+
keys.OpenAI = keyPath
229+
case "voyage":
230+
keys.Voyage = keyPath
231+
}
232+
}
233+
234+
// RAG provider key (overwrites if same provider as embedding)
235+
if p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" {
236+
keyPath := path.Join(keysDir, p.Name+"_rag.key")
237+
switch p.RAGLLM.Provider {
238+
case "anthropic":
239+
keys.Anthropic = keyPath
240+
case "openai":
241+
keys.OpenAI = keyPath
242+
}
243+
}
244+
245+
if keys.Anthropic == "" && keys.OpenAI == "" && keys.Voyage == "" {
246+
return nil, nil
247+
}
248+
return keys, nil
249+
}

0 commit comments

Comments
 (0)