Skip to content

Commit b79ea7b

Browse files
authored
fix: add allowlist validation for RAG pipeline names (#337)
* fix: add allowlist validation for RAG pipeline names
1 parent 8b87c6b commit b79ea7b

2 files changed

Lines changed: 59 additions & 1 deletion

File tree

server/internal/database/rag_service_config.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,23 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7+
"regexp"
78
"slices"
89
"sort"
910
"strings"
1011
)
1112

13+
// ragPipelineNamePatternText is the allowlist pattern for RAG pipeline names.
14+
// It is kept as a const so that the compiled regexp and the error message both
15+
// reference the same literal and cannot drift apart.
16+
const ragPipelineNamePatternText = `^[a-z0-9_][a-z0-9_-]*$`
17+
18+
// ragPipelineNamePattern restricts pipeline names to lowercase alphanumeric
19+
// characters, hyphens, and underscores. The first character must not be a
20+
// hyphen so that names are safe as filename components and cannot be
21+
// misinterpreted as CLI flags if ever passed to a command.
22+
var ragPipelineNamePattern = regexp.MustCompile(ragPipelineNamePatternText)
23+
1224
// RAGPipelineLLMConfig represents LLM configuration for an embedding or RAG step.
1325
type RAGPipelineLLMConfig struct {
1426
Provider string `json:"provider"`
@@ -126,9 +138,11 @@ func validateRAGPipeline(p RAGPipeline, i int, seenNames map[string]bool) []erro
126138
var errs []error
127139
prefix := fmt.Sprintf("pipelines[%d]", i)
128140

129-
// name (required, unique)
141+
// name (required, allowlist, unique)
130142
if p.Name == "" {
131143
errs = append(errs, fmt.Errorf("%s.name is required", prefix))
144+
} else if !ragPipelineNamePattern.MatchString(p.Name) {
145+
errs = append(errs, fmt.Errorf("%s.name %q is invalid: must match %s", prefix, p.Name, ragPipelineNamePatternText))
132146
} else if seenNames[p.Name] {
133147
errs = append(errs, fmt.Errorf("pipelines contains duplicate name %q", p.Name))
134148
} else {

server/internal/database/rag_service_config_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,50 @@ func TestParseRAGServiceConfig_MissingRAGLLM(t *testing.T) {
384384
assert.Contains(t, errs[0].Error(), "rag_llm.provider")
385385
}
386386

387+
func TestParseRAGServiceConfig_PipelineNameAllowlist(t *testing.T) {
388+
validNames := []string{
389+
"default",
390+
"my-pipeline",
391+
"my_pipeline",
392+
"pipeline-1",
393+
"a",
394+
"abc123",
395+
"a-b_c-1",
396+
}
397+
for _, name := range validNames {
398+
t.Run("valid/"+name, func(t *testing.T) {
399+
config := minimalRAGConfig()
400+
config["pipelines"].([]any)[0].(map[string]any)["name"] = name
401+
_, errs := database.ParseRAGServiceConfig(config, false)
402+
assert.Empty(t, errs, "name %q should be valid", name)
403+
})
404+
}
405+
406+
invalidNames := []string{
407+
"My Pipeline", // uppercase + space
408+
"pipeline name", // space
409+
"pipeline/name", // slash
410+
"../etc/passwd", // path traversal
411+
"UPPER", // uppercase
412+
"pipe🔥line", // unicode emoji
413+
"pipeline.name", // dot
414+
"-pipeline", // leading hyphen (could be misread as a CLI flag)
415+
"", // empty (covered separately, but included for completeness)
416+
}
417+
for _, name := range invalidNames {
418+
if name == "" {
419+
continue // empty name is a separate "required" error
420+
}
421+
t.Run("invalid/"+name, func(t *testing.T) {
422+
config := minimalRAGConfig()
423+
config["pipelines"].([]any)[0].(map[string]any)["name"] = name
424+
_, errs := database.ParseRAGServiceConfig(config, false)
425+
require.NotEmpty(t, errs, "name %q should be invalid", name)
426+
assert.Contains(t, errs[0].Error(), "must match ^[a-z0-9_][a-z0-9_-]*$")
427+
})
428+
}
429+
}
430+
387431
func TestParseRAGServiceConfig_MultiplePipelines(t *testing.T) {
388432
config := map[string]any{
389433
"pipelines": []any{

0 commit comments

Comments
 (0)