Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions services/protocol/openai_v1_chat_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,26 @@ def stream_grok_chat_completion(body: dict[str, Any], spec, messages: list[dict[
return
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
response = grok.console_chat_completion(body, spec, messages)
if response.reasoning_content:
yield completion_chunk(model, {"role": "assistant", "reasoning_content": response.reasoning_content}, None, completion_id, created)
yield completion_chunk(model, {"content": response.content}, None, completion_id, created)
else:
yield completion_chunk(model, {"role": "assistant", "content": response.content}, None, completion_id, created)
sent_role = False
for event in grok.console_chat_completion_events(body, spec, messages):
delta = grok.extract_console_stream_delta(event)
if not delta.content and not delta.reasoning_content:
continue
if not sent_role:
sent_role = True
first_delta: dict[str, Any] = {"role": "assistant"}
if delta.reasoning_content:
first_delta["reasoning_content"] = delta.reasoning_content
else:
first_delta["content"] = delta.content
yield completion_chunk(model, first_delta, None, completion_id, created)
continue
if delta.reasoning_content:
yield completion_chunk(model, {"reasoning_content": delta.reasoning_content}, None, completion_id, created)
else:
yield completion_chunk(model, {"content": delta.content}, None, completion_id, created)
if not sent_role:
yield completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created)
yield completion_chunk(model, {}, "stop", completion_id, created)


Expand Down
226 changes: 218 additions & 8 deletions services/providers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class GrokConsoleCompletion:
raw_response: dict[str, Any] | None = None


@dataclass(frozen=True)
class GrokConsoleStreamDelta:
content: str = ""
reasoning_content: str = ""


_THINKING_SUMMARY_RE = re.compile(
r"^\s*(?:\*\*)?\s*(?:思考摘要|思考总结|thinking\s+summary|thought\s+summary|reasoning\s+summary|thinking|reasoning)\s*(?:\*\*\s*[::]|[::]\s*(?:\*\*)?)\s*(.*)$",
re.IGNORECASE,
Expand Down Expand Up @@ -209,6 +215,125 @@ def extract_console_completion(payload: dict[str, Any]) -> GrokConsoleCompletion
)


def _text_field(value: object) -> str:
if isinstance(value, str):
return value
if not isinstance(value, dict):
return ""
for key in ("text", "content", "output_text", "reasoning_content", "summary_text"):
text = value.get(key)
if isinstance(text, str) and text:
return text
return ""


def extract_console_stream_delta(event: dict[str, Any]) -> GrokConsoleStreamDelta:
event_type = str(event.get("type") or "").lower()
if event_type and "delta" not in event_type:
return GrokConsoleStreamDelta()
text = _text_field(event.get("delta"))
if not text:
text = _text_field(event)
if not text:
return GrokConsoleStreamDelta()
if "reasoning" in event_type or "thinking" in event_type:
return GrokConsoleStreamDelta(reasoning_content=text)
return GrokConsoleStreamDelta(content=text)


def _parse_console_stream_payload(payload: str, current_event: str) -> dict[str, Any] | None:
if not payload:
return None
try:
event = json.loads(payload)
except json.JSONDecodeError:
logger.warning({"event": "grok_console_stream_invalid_json"})
return None
if not isinstance(event, dict):
return None
if current_event and not event.get("type"):
event = {"type": current_event, **event}
return event


def _iter_console_stream_events(lines: Iterable[object]) -> Iterator[dict[str, Any]]:
current_event = ""
data_lines: list[str] = []

def flush_data() -> dict[str, Any] | None:
nonlocal current_event
if not data_lines:
current_event = ""
return None
payload = "\n".join(data_lines).strip()
data_lines.clear()
event = _parse_console_stream_payload(payload, current_event)
current_event = ""
return event

for raw_line in lines:
if raw_line is None:
continue
line = raw_line.decode("utf-8", errors="replace") if isinstance(raw_line, bytes) else str(raw_line)
line = line.rstrip("\r\n")
if not line.strip():
event = flush_data()
if event is not None:
yield event
continue
if line.startswith(":"):
continue
if line.startswith("event:"):
event = flush_data()
if event is not None:
yield event
current_event = line[6:]
if current_event.startswith(" "):
current_event = current_event[1:]
current_event = current_event.strip()
continue
if line.startswith("data:"):
payload = line[5:]
if payload.startswith(" "):
payload = payload[1:]
if payload.strip() == "[DONE]":
event = flush_data()
if event is not None:
yield event
break
data_lines.append(payload)
continue
line = line.strip()
if line.startswith("{"):
event = flush_data()
if event is not None:
yield event
event = _parse_console_stream_payload(line, current_event)
if event is not None:
yield event

event = flush_data()
if event is not None:
yield event


def _raise_for_console_stream_event(event: dict[str, Any]) -> None:
event_type = str(event.get("type") or "").lower()
if event_type not in {"error", "response.failed", "response.error", "response.incomplete", "response.cancelled"}:
return
error = event.get("error")
response = event.get("response")
if not error and isinstance(response, dict):
error = response.get("error") or response.get("incomplete_details")
if isinstance(error, dict):
message = str(error.get("message") or error.get("code") or error.get("reason") or event_type)
elif error:
message = str(error)
else:
message = event_type
raise GrokConsoleError(f"Grok upstream stream error: {message}", 502)


def _grok_console_profile():
return build_grok_console_profile(config.data)

Expand Down Expand Up @@ -260,6 +385,50 @@ def _feedback_status(upstream_status: int) -> str | None:
return None


def _console_upstream_error_detail(response: object | None) -> str:
if response is None:
return ""
json_data: object = None
json_method = getattr(response, "json", None)
if callable(json_method):
try:
json_data = json_method()
except Exception:
json_data = None
if isinstance(json_data, dict):
error = json_data.get("error")
if isinstance(error, dict):
for key in ("message", "code", "reason", "type"):
value = error.get(key)
if value:
return str(value)
elif error:
return str(error)
for key in ("message", "detail", "code", "reason"):
value = json_data.get(key)
if value:
return str(value)
text = getattr(response, "text", "") or ""
if not text:
content = getattr(response, "content", b"")
if isinstance(content, bytes):
text = content.decode("utf-8", errors="replace")
return str(text).strip()[:400]


def _raise_console_upstream_error(access_token: str, upstream_status: int, response: object | None = None) -> None:
feedback_status = _feedback_status(upstream_status)
if feedback_status:
from services.account_service import account_service

account_service.update_account(access_token, {"status": feedback_status})
message = f"Grok upstream error (HTTP {upstream_status})"
detail = _console_upstream_error_detail(response)
if detail:
message = f"{message}: {detail}"
raise GrokConsoleError(message, _openai_status(upstream_status), upstream_status)


class GrokConsoleClient:
def __init__(self, access_token: str) -> None:
self.access_token = access_token
Expand Down Expand Up @@ -310,19 +479,39 @@ def create_response(self, payload: dict[str, Any]) -> dict[str, Any]:
except requests.exceptions.RequestException as exc:
raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc
if response.status_code >= 400:
status = int(response.status_code)
feedback_status = _feedback_status(status)
if feedback_status:
from services.account_service import account_service

account_service.update_account(self.access_token, {"status": feedback_status})
message = f"Grok upstream error (HTTP {status})"
raise GrokConsoleError(message, _openai_status(status), status)
_raise_console_upstream_error(self.access_token, int(response.status_code), response)
data = response.json()
if not isinstance(data, dict):
raise GrokConsoleError("Grok upstream returned an invalid response", 502)
return data

def stream_response(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]:
stream_payload = dict(payload)
stream_payload["stream"] = True
try:
response = self._call_with_retry(
lambda: self.session.post(
CONSOLE_RESPONSES_URL,
headers=_headers(self.access_token),
json=stream_payload,
timeout=self.network_profile.timeout,
stream=True,
),
context="stream_response",
)
except requests.exceptions.RequestException as exc:
raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc
if response.status_code >= 400:
_raise_console_upstream_error(self.access_token, int(response.status_code), response)
try:
for event in _iter_console_stream_events(response.iter_lines()):
_raise_for_console_stream_event(event)
yield event
finally:
close = getattr(response, "close", None)
if callable(close):
close()


def _cookie_items(cookie_header: str) -> list[tuple[str, str]]:
items: list[tuple[str, str]] = []
Expand Down Expand Up @@ -830,5 +1019,26 @@ def console_chat_completion(body: dict[str, Any], spec: ModelSpec, messages: lis
return completion


def console_chat_completion_events(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> Iterator[dict[str, Any]]:
from services.account_service import account_service

access_token = account_service.get_text_access_token(provider=GROK_PROVIDER)
if not access_token:
raise HTTPException(status_code=503, detail={"error": "no available Grok account"})
payload = build_console_payload(spec, body, messages)
mark_used = False
try:
with GrokConsoleClient(access_token) as client:
for event in client.stream_response(payload):
mark_used = True
yield event
mark_used = True
except GrokConsoleError as exc:
raise HTTPException(status_code=exc.status_code, detail={"error": str(exc)}) from exc
finally:
if mark_used:
account_service.mark_text_used(access_token)


def chat_completion(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> str:
return console_chat_completion(body, spec, messages).content
Loading
Loading