|
| 1 | +package swarm |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "path" |
| 6 | + |
| 7 | + "github.com/goccy/go-yaml" |
| 8 | + |
| 9 | + "github.com/pgEdge/control-plane/server/internal/database" |
| 10 | +) |
| 11 | + |
| 12 | +// ragYAMLConfig mirrors the pgedge-rag-server Config struct for YAML generation. |
| 13 | +// Only the fields the control plane needs to set are included. |
| 14 | +type ragYAMLConfig struct { |
| 15 | + Server ragServerYAML `yaml:"server"` |
| 16 | + Pipelines []ragPipelineYAML `yaml:"pipelines"` |
| 17 | + Defaults *ragDefaultsYAML `yaml:"defaults,omitempty"` |
| 18 | +} |
| 19 | + |
| 20 | +type ragServerYAML struct { |
| 21 | + ListenAddress string `yaml:"listen_address"` |
| 22 | + Port int `yaml:"port"` |
| 23 | +} |
| 24 | + |
| 25 | +type ragPipelineYAML struct { |
| 26 | + Name string `yaml:"name"` |
| 27 | + Description string `yaml:"description,omitempty"` |
| 28 | + Database ragDatabaseYAML `yaml:"database"` |
| 29 | + Tables []ragTableYAML `yaml:"tables"` |
| 30 | + EmbeddingLLM ragLLMYAML `yaml:"embedding_llm"` |
| 31 | + RAGLLM ragLLMYAML `yaml:"rag_llm"` |
| 32 | + APIKeys *ragAPIKeysYAML `yaml:"api_keys,omitempty"` |
| 33 | + TokenBudget *int `yaml:"token_budget,omitempty"` |
| 34 | + TopN *int `yaml:"top_n,omitempty"` |
| 35 | + SystemPrompt string `yaml:"system_prompt,omitempty"` |
| 36 | + Search *ragSearchYAML `yaml:"search,omitempty"` |
| 37 | +} |
| 38 | + |
| 39 | +type ragDatabaseYAML struct { |
| 40 | + Host string `yaml:"host"` |
| 41 | + Port int `yaml:"port"` |
| 42 | + Database string `yaml:"database"` |
| 43 | + Username string `yaml:"username"` |
| 44 | + Password string `yaml:"password"` |
| 45 | + SSLMode string `yaml:"ssl_mode"` |
| 46 | +} |
| 47 | + |
| 48 | +type ragTableYAML struct { |
| 49 | + Table string `yaml:"table"` |
| 50 | + TextColumn string `yaml:"text_column"` |
| 51 | + VectorColumn string `yaml:"vector_column"` |
| 52 | + IDColumn string `yaml:"id_column,omitempty"` |
| 53 | +} |
| 54 | + |
| 55 | +type ragLLMYAML struct { |
| 56 | + Provider string `yaml:"provider"` |
| 57 | + Model string `yaml:"model"` |
| 58 | + BaseURL string `yaml:"base_url,omitempty"` |
| 59 | +} |
| 60 | + |
| 61 | +// ragAPIKeysYAML holds container-side file paths for each provider's API key. |
| 62 | +type ragAPIKeysYAML struct { |
| 63 | + Anthropic string `yaml:"anthropic,omitempty"` |
| 64 | + OpenAI string `yaml:"openai,omitempty"` |
| 65 | + Voyage string `yaml:"voyage,omitempty"` |
| 66 | +} |
| 67 | + |
| 68 | +type ragSearchYAML struct { |
| 69 | + HybridEnabled *bool `yaml:"hybrid_enabled,omitempty"` |
| 70 | + VectorWeight *float64 `yaml:"vector_weight,omitempty"` |
| 71 | +} |
| 72 | + |
| 73 | +type ragDefaultsYAML struct { |
| 74 | + TokenBudget *int `yaml:"token_budget,omitempty"` |
| 75 | + TopN *int `yaml:"top_n,omitempty"` |
| 76 | +} |
| 77 | + |
| 78 | +// RAGConfigParams holds all inputs needed to generate pgedge-rag-server.yaml. |
| 79 | +type RAGConfigParams struct { |
| 80 | + Config *database.RAGServiceConfig |
| 81 | + DatabaseName string |
| 82 | + DatabaseHost string |
| 83 | + DatabasePort int |
| 84 | + Username string |
| 85 | + Password string |
| 86 | + // KeysDir is the container-side directory where API key files are mounted, |
| 87 | + // e.g. "/app/keys". Key filenames follow the {pipeline}_{embedding|rag}.key |
| 88 | + // convention produced by extractRAGAPIKeys. |
| 89 | + KeysDir string |
| 90 | +} |
| 91 | + |
| 92 | +// GenerateRAGConfig generates the pgedge-rag-server.yaml content from the |
| 93 | +// given parameters. API key paths in the generated YAML reference files under |
| 94 | +// KeysDir so the RAG server reads them from the bind-mounted keys directory. |
| 95 | +func GenerateRAGConfig(params *RAGConfigParams) ([]byte, error) { |
| 96 | + pipelines := make([]ragPipelineYAML, 0, len(params.Config.Pipelines)) |
| 97 | + for _, p := range params.Config.Pipelines { |
| 98 | + pl, err := buildRAGPipelineYAML(p, params) |
| 99 | + if err != nil { |
| 100 | + return nil, err |
| 101 | + } |
| 102 | + pipelines = append(pipelines, pl) |
| 103 | + } |
| 104 | + |
| 105 | + var defaults *ragDefaultsYAML |
| 106 | + if params.Config.Defaults != nil { |
| 107 | + src := params.Config.Defaults |
| 108 | + if src.TokenBudget != nil || src.TopN != nil { |
| 109 | + defaults = &ragDefaultsYAML{ |
| 110 | + TokenBudget: src.TokenBudget, |
| 111 | + TopN: src.TopN, |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + cfg := &ragYAMLConfig{ |
| 117 | + Server: ragServerYAML{ |
| 118 | + ListenAddress: "0.0.0.0", |
| 119 | + Port: 8080, |
| 120 | + }, |
| 121 | + Pipelines: pipelines, |
| 122 | + Defaults: defaults, |
| 123 | + } |
| 124 | + |
| 125 | + data, err := yaml.Marshal(cfg) |
| 126 | + if err != nil { |
| 127 | + return nil, err |
| 128 | + } |
| 129 | + return data, nil |
| 130 | +} |
| 131 | + |
| 132 | +func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) (ragPipelineYAML, error) { |
| 133 | + tables := make([]ragTableYAML, 0, len(p.Tables)) |
| 134 | + for _, t := range p.Tables { |
| 135 | + tbl := ragTableYAML{ |
| 136 | + Table: t.Table, |
| 137 | + TextColumn: t.TextColumn, |
| 138 | + VectorColumn: t.VectorColumn, |
| 139 | + } |
| 140 | + if t.IDColumn != nil { |
| 141 | + tbl.IDColumn = *t.IDColumn |
| 142 | + } |
| 143 | + tables = append(tables, tbl) |
| 144 | + } |
| 145 | + |
| 146 | + embLLM := ragLLMYAML{ |
| 147 | + Provider: p.EmbeddingLLM.Provider, |
| 148 | + Model: p.EmbeddingLLM.Model, |
| 149 | + } |
| 150 | + if p.EmbeddingLLM.BaseURL != nil { |
| 151 | + embLLM.BaseURL = *p.EmbeddingLLM.BaseURL |
| 152 | + } |
| 153 | + |
| 154 | + ragLLM := ragLLMYAML{ |
| 155 | + Provider: p.RAGLLM.Provider, |
| 156 | + Model: p.RAGLLM.Model, |
| 157 | + } |
| 158 | + if p.RAGLLM.BaseURL != nil { |
| 159 | + ragLLM.BaseURL = *p.RAGLLM.BaseURL |
| 160 | + } |
| 161 | + |
| 162 | + apiKeys, err := buildRAGAPIKeysYAML(p, params.KeysDir) |
| 163 | + if err != nil { |
| 164 | + return ragPipelineYAML{}, err |
| 165 | + } |
| 166 | + |
| 167 | + pipeline := ragPipelineYAML{ |
| 168 | + Name: p.Name, |
| 169 | + Database: ragDatabaseYAML{ |
| 170 | + Host: params.DatabaseHost, |
| 171 | + Port: params.DatabasePort, |
| 172 | + Database: params.DatabaseName, |
| 173 | + Username: params.Username, |
| 174 | + Password: params.Password, |
| 175 | + SSLMode: "prefer", |
| 176 | + }, |
| 177 | + Tables: tables, |
| 178 | + EmbeddingLLM: embLLM, |
| 179 | + RAGLLM: ragLLM, |
| 180 | + APIKeys: apiKeys, |
| 181 | + } |
| 182 | + |
| 183 | + if p.Description != nil { |
| 184 | + pipeline.Description = *p.Description |
| 185 | + } |
| 186 | + pipeline.TokenBudget = p.TokenBudget |
| 187 | + pipeline.TopN = p.TopN |
| 188 | + if p.SystemPrompt != nil { |
| 189 | + pipeline.SystemPrompt = *p.SystemPrompt |
| 190 | + } |
| 191 | + if p.Search != nil { |
| 192 | + pipeline.Search = &ragSearchYAML{ |
| 193 | + HybridEnabled: p.Search.HybridEnabled, |
| 194 | + VectorWeight: p.Search.VectorWeight, |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + return pipeline, nil |
| 199 | +} |
| 200 | + |
| 201 | +// buildRAGAPIKeysYAML maps each LLM provider that requires a key to the |
| 202 | +// corresponding bind-mounted key file path inside the container. |
| 203 | +// Embedding key: {keysDir}/{pipeline}_embedding.key |
| 204 | +// RAG key: {keysDir}/{pipeline}_rag.key |
| 205 | +// If embedding and RAG use the same provider, the RAG key path takes precedence |
| 206 | +// (both files contain the same value). Returns an error if both LLMs share a |
| 207 | +// provider but were configured with different API keys. |
| 208 | +func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) (*ragAPIKeysYAML, error) { |
| 209 | + // Reject mismatched keys for the same provider — the RAG server has a |
| 210 | + // single key slot per provider and cannot reconcile two different values. |
| 211 | + if p.EmbeddingLLM.Provider == p.RAGLLM.Provider && |
| 212 | + p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" && |
| 213 | + p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" && |
| 214 | + *p.EmbeddingLLM.APIKey != *p.RAGLLM.APIKey { |
| 215 | + return nil, fmt.Errorf("pipeline %q: embedding_llm and rag_llm share provider %q but have different API keys", |
| 216 | + p.Name, p.EmbeddingLLM.Provider) |
| 217 | + } |
| 218 | + |
| 219 | + keys := &ragAPIKeysYAML{} |
| 220 | + |
| 221 | + // Embedding provider key |
| 222 | + if p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" { |
| 223 | + keyPath := path.Join(keysDir, p.Name+"_embedding.key") |
| 224 | + switch p.EmbeddingLLM.Provider { |
| 225 | + case "anthropic": |
| 226 | + keys.Anthropic = keyPath |
| 227 | + case "openai": |
| 228 | + keys.OpenAI = keyPath |
| 229 | + case "voyage": |
| 230 | + keys.Voyage = keyPath |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + // RAG provider key (overwrites if same provider as embedding) |
| 235 | + if p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" { |
| 236 | + keyPath := path.Join(keysDir, p.Name+"_rag.key") |
| 237 | + switch p.RAGLLM.Provider { |
| 238 | + case "anthropic": |
| 239 | + keys.Anthropic = keyPath |
| 240 | + case "openai": |
| 241 | + keys.OpenAI = keyPath |
| 242 | + } |
| 243 | + } |
| 244 | + |
| 245 | + if keys.Anthropic == "" && keys.OpenAI == "" && keys.Voyage == "" { |
| 246 | + return nil, nil |
| 247 | + } |
| 248 | + return keys, nil |
| 249 | +} |
0 commit comments