Skip to content

Commit 5b9113e

Browse files
committed
Initial implementation for audio support
1 parent 2f9a7a2 commit 5b9113e

9 files changed

Lines changed: 122 additions & 32 deletions

File tree

pkg/bbr/handlers/request.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics"
3636
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
3737
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
38+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
3839
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3940
)
4041

@@ -46,11 +47,21 @@ const (
4647
)
4748

4849
// HandleRequestBody handles request bodies.
49-
func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte, contentType string) ([]*eppb.ProcessingResponse, error) {
50+
func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte, contentType, path string) ([]*eppb.ProcessingResponse, error) {
5051
logger := log.FromContext(ctx)
5152
var ret []*eppb.ProcessingResponse
5253

5354
if strings.Contains(strings.ToLower(contentType), "multipart/form-data") {
55+
if !metadata.PathAllowedForMultipartModelExtraction(path) {
56+
if s.streaming {
57+
ret = append(ret, &eppb.ProcessingResponse{
58+
Response: &eppb.ProcessingResponse_RequestHeaders{RequestHeaders: &eppb.HeadersResponse{}},
59+
})
60+
ret = addStreamedBodyResponse(ret, requestBodyBytes)
61+
return ret, nil
62+
}
63+
return []*eppb.ProcessingResponse{{Response: &eppb.ProcessingResponse_RequestBody{RequestBody: &eppb.BodyResponse{}}}}, nil
64+
}
5465
model, parseErr := parseModelFromMultipart(requestBodyBytes, contentType)
5566
if parseErr != nil || model == "" {
5667
metrics.RecordModelNotInBodyCounter()
@@ -137,9 +148,7 @@ func parseModelFromMultipart(body []byte, contentType string) (string, error) {
137148
}
138149
if p.FormName() == "model" {
139150
var buf bytes.Buffer
140-
if _, err := buf.ReadFrom(p); err != nil {
141-
return "", err
142-
}
151+
buf.ReadFrom(p)
143152
return strings.TrimSpace(buf.String()), nil
144153
}
145154
}

pkg/bbr/handlers/request_test.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ func TestHandleRequestBody(t *testing.T) {
373373
t.Run(test.name, func(t *testing.T) {
374374
server := NewServer(test.streaming, &fakeDatastore{}, []framework.PayloadProcessor{})
375375
bodyBytes, _ := json.Marshal(test.body)
376-
resp, err := server.HandleRequestBody(ctx, bodyBytes, "application/json")
376+
resp, err := server.HandleRequestBody(ctx, bodyBytes, "application/json", "")
377377
if err != nil {
378378
if !test.wantErr {
379379
t.Fatalf("HandleRequestBody returned unexpected error: %v, want %v", err, test.wantErr)
@@ -415,7 +415,7 @@ func TestHandleRequestBodyWithPluginMetrics(t *testing.T) {
415415
"model": "bar",
416416
"prompt": "test",
417417
})
418-
_, err := server.HandleRequestBody(ctx, bodyBytes, "application/json")
418+
_, err := server.HandleRequestBody(ctx, bodyBytes, "application/json", "")
419419
if err != nil {
420420
t.Fatalf("HandleRequestBody returned unexpected error: %v", err)
421421
}
@@ -466,7 +466,7 @@ func TestHandleRequestBody_Multipart(t *testing.T) {
466466
contentType := "multipart/form-data; boundary=" + boundary
467467

468468
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{})
469-
resp, err := server.HandleRequestBody(ctx, body, contentType)
469+
resp, err := server.HandleRequestBody(ctx, body, contentType, "/v1/audio/transcriptions")
470470
if err != nil {
471471
t.Fatalf("HandleRequestBody: %v", err)
472472
}
@@ -489,6 +489,26 @@ func TestHandleRequestBody_Multipart(t *testing.T) {
489489
}
490490
}
491491

492+
func TestHandleRequestBody_MultipartWrongPath(t *testing.T) {
493+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
494+
boundary := "----boundary"
495+
body := buildMultipartBody(t, boundary, "whisper-1", "audio.mp3", []byte("fake audio"))
496+
contentType := "multipart/form-data; boundary=" + boundary
497+
498+
server := NewServer(false, &fakeDatastore{}, []framework.PayloadProcessor{})
499+
resp, err := server.HandleRequestBody(ctx, body, contentType, "/v1/video/something")
500+
if err != nil {
501+
t.Fatalf("HandleRequestBody: %v", err)
502+
}
503+
if len(resp) != 1 {
504+
t.Fatalf("expected 1 response, got %d", len(resp))
505+
}
506+
br := resp[0].GetRequestBody()
507+
if br != nil && br.Response != nil && br.Response.HeaderMutation != nil && len(br.Response.HeaderMutation.SetHeaders) > 0 {
508+
t.Error("multipart on non-transcriptions path should not set model headers")
509+
}
510+
}
511+
492512
func buildMultipartBody(t *testing.T, boundary, model, filename string, fileContent []byte) []byte {
493513
var buf bytes.Buffer
494514
w := multipart.NewWriter(&buf)

pkg/bbr/handlers/server.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
117117
}
118118
if v.RequestHeaders != nil && v.RequestHeaders.Headers != nil {
119119
for _, h := range v.RequestHeaders.Headers.Headers {
120-
if strings.EqualFold(h.Key, "content-type") {
120+
switch {
121+
case strings.EqualFold(h.Key, "content-type"):
121122
streamedBody.contentType = requtil.GetHeaderValue(h)
122-
break
123+
case strings.EqualFold(h.Key, ":path"):
124+
streamedBody.path = requtil.GetHeaderValue(h)
123125
}
124126
}
125127
}
@@ -167,6 +169,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
167169
type streamedBody struct {
168170
body []byte
169171
contentType string
172+
path string
170173
}
171174

172175
func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, streamedBody *streamedBody, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) {
@@ -186,5 +189,5 @@ func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBod
186189
requestBodyBytes = body.GetBody()
187190
}
188191

189-
return s.HandleRequestBody(ctx, requestBodyBytes, streamedBody.contentType)
192+
return s.HandleRequestBody(ctx, requestBodyBytes, streamedBody.contentType, streamedBody.path)
190193
}

pkg/epp/metadata/consts.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616

1717
package metadata
1818

19+
import "strings"
20+
1921
const (
2022
// SubsetFilterNamespace is the key for the outer namespace struct in the metadata field of the extproc request that is used to wrap the subset filter.
2123
SubsetFilterNamespace = "envoy.lb.subset_hint"
@@ -36,4 +38,24 @@ const (
3638
ModelNameRewriteKey = "x-gateway-model-name-rewrite"
3739
// ModelNameKey is the header BBR sets from the request (JSON body or multipart form) for routing.
3840
ModelNameKey = "x-gateway-model-name"
41+
// AudioTranscriptionsPathPrefix is the path prefix for OpenAI-style audio transcriptions (multipart/form-data).
42+
AudioTranscriptionsPathPrefix = "/v1/audio/transcriptions"
3943
)
44+
45+
// MultipartModelExtractionPathPrefixes lists path prefixes for which multipart/form-data requests
46+
// get model extraction (same logic as transcriptions: parse form for "model", set headers, pass body through).
47+
// Add a prefix here to enable multipart model extraction for additional APIs (e.g. video).
48+
var MultipartModelExtractionPathPrefixes = []string{
49+
AudioTranscriptionsPathPrefix,
50+
}
51+
52+
// PathAllowedForMultipartModelExtraction reports whether the request path is allowed for
53+
// multipart model extraction (BBR and Director use this to gate the same multipart handling).
54+
func PathAllowedForMultipartModelExtraction(path string) bool {
55+
for _, prefix := range MultipartModelExtractionPathPrefixes {
56+
if strings.HasPrefix(path, prefix) {
57+
return true
58+
}
59+
}
60+
return false
61+
}

pkg/epp/metadata/path_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package metadata
18+
19+
import "testing"
20+
21+
func TestPathAllowedForMultipartModelExtraction(t *testing.T) {
22+
tests := []struct {
23+
path string
24+
allowed bool
25+
}{
26+
{"/v1/audio/transcriptions", true},
27+
{"/v1/completions", false},
28+
}
29+
for _, tt := range tests {
30+
got := PathAllowedForMultipartModelExtraction(tt.path)
31+
if got != tt.allowed {
32+
t.Errorf("PathAllowedForMultipartModelExtraction(%q) = %v, want %v", tt.path, got, tt.allowed)
33+
}
34+
}
35+
}

pkg/epp/requestcontrol/director.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,23 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
189189
return reqCtx, nil
190190
}
191191

192-
func getHeader(headers map[string]string, key string) string {
193-
for k, v := range headers {
194-
if strings.EqualFold(k, key) {
195-
return v
192+
func (d *Director) processRequestBody(ctx context.Context, reqCtx *handlers.RequestContext) (*fwksched.LLMRequestBody, error) {
193+
var ct, path, model string
194+
for k, v := range reqCtx.Request.Headers {
195+
switch {
196+
case strings.EqualFold(k, "content-type"):
197+
ct = v
198+
case strings.EqualFold(k, ":path"):
199+
path = v
196200
}
197201
}
198-
return ""
199-
}
200-
201-
func (d *Director) processRequestBody(ctx context.Context, reqCtx *handlers.RequestContext) (*fwksched.LLMRequestBody, error) {
202-
ct := getHeader(reqCtx.Request.Headers, "content-type")
203-
if strings.Contains(strings.ToLower(ct), "multipart/form-data") {
204-
model := getHeader(reqCtx.Request.Headers, metadata.ModelNameKey)
202+
if strings.Contains(strings.ToLower(ct), "multipart/form-data") && metadata.PathAllowedForMultipartModelExtraction(path) {
203+
for k, v := range reqCtx.Request.Headers {
204+
if strings.EqualFold(k, metadata.ModelNameKey) {
205+
model = v
206+
break
207+
}
208+
}
205209
if model == "" {
206210
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "multipart request missing x-gateway-model-name header"}
207211
}

pkg/epp/requestcontrol/director_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,8 @@ func TestDirector_HandleRequest_Multipart(t *testing.T) {
759759
Headers: map[string]string{
760760
requtil.RequestIdHeaderKey: "req-1",
761761
"content-type": "multipart/form-data; boundary=----boundary",
762-
metadata.ModelNameKey: "whisper-1",
762+
":path": metadata.AudioTranscriptionsPathPrefix,
763+
metadata.ModelNameKey: "whisper-1",
763764
},
764765
RawBody: []byte("raw multipart body"),
765766
Metadata: map[string]any{},

test/integration/bbr/hermetic_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,12 @@ func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) {
117117
{
118118
name: "audio transcriptions: multipart form sets model header and passes body through",
119119
reqs: func() []*extProcPb.ProcessingRequest {
120-
headers, body := integration.BuildMultipartTranscriptionsRequest("whisper-1", "audio.mp3", []byte("fake audio"))
121-
return integration.ReqRaw(headers, string(body))
120+
h, b := integration.BuildMultipartTranscriptionsRequest("whisper-1", "audio.mp3", []byte("fake audio"))
121+
return integration.ReqRaw(h, string(b))
122122
}(),
123123
wantResponses: func() []*extProcPb.ProcessingResponse {
124-
_, body := integration.BuildMultipartTranscriptionsRequest("whisper-1", "audio.mp3", []byte("fake audio"))
125-
return []*extProcPb.ProcessingResponse{
126-
ExpectBBRHeader("whisper-1"),
127-
ExpectBBRBodyPassThroughRaw(body),
128-
}
124+
_, b := integration.BuildMultipartTranscriptionsRequest("whisper-1", "audio.mp3", []byte("fake audio"))
125+
return []*extProcPb.ProcessingResponse{ExpectBBRHeader("whisper-1"), ExpectBBRBodyPassThroughRaw(b)}
129126
}(),
130127
},
131128
}

test/integration/util.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ func ReqLLMUnary(logger logr.Logger, prompt, model string) *extProcPb.Processing
6969
return GenerateRequest(logger, prompt, model, nil)
7070
}
7171

72-
// BuildMultipartTranscriptionsRequest builds headers and body for an OpenAI-style
73-
// /v1/audio/transcriptions request (multipart/form-data with "model" and "file").
74-
// Returns headers (including content-type with boundary) and the raw body for use with ReqRaw.
72+
// BuildMultipartTranscriptionsRequest builds multipart/form-data (model + file) for use with ReqRaw.
7573
func BuildMultipartTranscriptionsRequest(model, filename string, fileContent []byte) (headers map[string]string, body []byte) {
7674
boundary := "----boundary"
7775
var buf bytes.Buffer
@@ -87,6 +85,7 @@ func BuildMultipartTranscriptionsRequest(model, filename string, fileContent []b
8785
_ = w.Close()
8886
return map[string]string{
8987
"content-type": "multipart/form-data; boundary=" + boundary,
88+
":path": "/v1/audio/transcriptions",
9089
}, buf.Bytes()
9190
}
9291

0 commit comments

Comments
 (0)