273 lines
8.7 KiB
Python
273 lines
8.7 KiB
Python
"""LLM client module — all network egress is concentrated here.
|
|
|
|
Uses ``httpx`` (already in requirements) to call OpenAI-compatible endpoints.
|
|
No ``openai`` SDK dependency. Sync functions are fine: FastAPI runs sync
|
|
handlers in a threadpool.
|
|
|
|
Public API:
|
|
- ``is_configured(cfg)`` — returns True when the client can make calls.
|
|
- ``test_connection(cfg)`` — minimal request to verify credentials.
|
|
- ``expand_query(cfg, query)`` — query-term expansion (step 3 consumer).
|
|
Returns ``ExpansionResult`` with ``terms`` and optional ``error``.
|
|
- ``analyze_image(...)`` — **reserved stub, not implemented**.
|
|
|
|
All calls go through ``_call_chat_completion()`` so tests can mock a single
|
|
boundary.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from app.settings_store import LLMConfig
|
|
|
|
# Sensible defaults
|
|
_TIMEOUT_SECONDS = 30
|
|
|
|
# ── Prompt for query expansion (Step 3) ──────────────────────────────────
|
|
_EXPAND_QUERY_SYSTEM_PROMPT = (
|
|
"你是搬家物品搜索助手。用户在搜索自己打包的箱子与物品(家居/搬家场景)。"
|
|
"给定一个搜索词,列出用户可能用来命名同一类物品的相关词:"
|
|
"近义词、常见别称、上位类别、具体品类。"
|
|
"规则:用与查询相同的语言;"
|
|
"只给与该物品紧密相关、有助于在清单里找到它的词;"
|
|
"不要解释、不要造无关词;最多 8 个;"
|
|
"只输出一个 JSON 字符串数组,例如 "
|
|
'`["炒锅","平底锅","汤锅","厨具"]`。'
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LLMResult:
|
|
"""Uniform result wrapper for LLM calls."""
|
|
|
|
success: bool
|
|
message: str
|
|
data: Any = None
|
|
|
|
|
|
@dataclass
|
|
class ExpansionResult:
|
|
"""Structured result from ``expand_query``.
|
|
|
|
``terms`` is always a list (may be empty).
|
|
``error`` is ``None`` on success (including legitimate empty results);
|
|
on failure (timeout, network error, HTTP error) it contains a
|
|
human-friendly error message.
|
|
"""
|
|
|
|
terms: list[str]
|
|
error: str | None = None
|
|
|
|
|
|
def is_configured(cfg: LLMConfig) -> bool:
|
|
"""Return True only when the LLM is enabled AND has required fields."""
|
|
return bool(cfg.enabled and cfg.model and cfg.api_key)
|
|
|
|
|
|
def test_connection(cfg: LLMConfig) -> LLMResult:
|
|
"""Send a minimal chat-completion request to verify the config.
|
|
|
|
Uses a tiny prompt to minimise cost. Returns an ``LLMResult`` indicating
|
|
success or failure with a human-readable message.
|
|
"""
|
|
if not is_configured(cfg):
|
|
return LLMResult(
|
|
success=False,
|
|
message="LLM 未配置或未启用(缺少 model 或 api_key)。",
|
|
)
|
|
|
|
try:
|
|
response = _call_chat_completion(
|
|
cfg,
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
max_tokens=1,
|
|
)
|
|
return LLMResult(
|
|
success=True,
|
|
message=f"连接成功(模型:{cfg.model})。",
|
|
data=response,
|
|
)
|
|
except httpx.HTTPStatusError as exc:
|
|
status = exc.response.status_code
|
|
return LLMResult(
|
|
success=False,
|
|
message=f"连接失败(HTTP {status})。请检查 base_url、model 和 api_key。",
|
|
)
|
|
except httpx.ConnectError:
|
|
return LLMResult(
|
|
success=False,
|
|
message="无法连接到服务器。请检查 base_url 是否正确。",
|
|
)
|
|
except httpx.TimeoutException:
|
|
return LLMResult(
|
|
success=False,
|
|
message="连接超时。请检查网络或 base_url 是否可达。",
|
|
)
|
|
except Exception as exc: # noqa: BLE001 — graceful degradation
|
|
return LLMResult(
|
|
success=False,
|
|
message=f"未知错误:{exc}",
|
|
)
|
|
|
|
|
|
def expand_query(
|
|
cfg: LLMConfig,
|
|
query: str,
|
|
extra_hints: str = "",
|
|
) -> ExpansionResult:
|
|
"""Expand a search query into multiple synonymous terms via LLM.
|
|
|
|
Returns an ``ExpansionResult``. On success ``terms`` contains the expanded
|
|
terms (possibly empty) and ``error`` is ``None``. On failure (network
|
|
error, timeout, HTTP error) ``terms`` is ``[]`` and ``error`` contains a
|
|
human-friendly message.
|
|
"""
|
|
if not is_configured(cfg):
|
|
return ExpansionResult(terms=[])
|
|
|
|
system_prompt = _EXPAND_QUERY_SYSTEM_PROMPT
|
|
if extra_hints and extra_hints.strip():
|
|
system_prompt += "\n" + extra_hints.strip()
|
|
|
|
try:
|
|
response = _call_chat_completion(
|
|
cfg,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": query},
|
|
],
|
|
max_tokens=200,
|
|
temperature=0,
|
|
)
|
|
except httpx.TimeoutException:
|
|
return ExpansionResult(
|
|
terms=[],
|
|
error="AI 搜索请求超时,请稍后再试。",
|
|
)
|
|
except httpx.ConnectError:
|
|
return ExpansionResult(
|
|
terms=[],
|
|
error="无法连接到 AI 服务,请检查网络或设置。",
|
|
)
|
|
except httpx.HTTPStatusError:
|
|
return ExpansionResult(
|
|
terms=[],
|
|
error="AI 服务返回错误,请检查配置。",
|
|
)
|
|
except Exception: # noqa: BLE001 — graceful degradation
|
|
return ExpansionResult(
|
|
terms=[],
|
|
error="AI 搜索暂时不可用,请稍后再试。",
|
|
)
|
|
|
|
choices = response.get("choices", [])
|
|
if not choices:
|
|
return ExpansionResult(terms=[])
|
|
content = choices[0].get("message", {}).get("content", "")
|
|
return ExpansionResult(terms=_parse_json_string_array(content))
|
|
|
|
|
|
# ── Constants for output contract enforcement ────────────────────────────
|
|
_MAX_EXPANSION_TERMS = 8
|
|
_MAX_TERM_LENGTH = 30
|
|
|
|
|
|
def _parse_json_string_array(content: str) -> list[str]:
|
|
"""Parse LLM output into a list of strings.
|
|
|
|
Strict contract enforcement:
|
|
1. Strip markdown code fences;
|
|
2. Try ``json.loads`` — only accept a JSON **array of strings**;
|
|
3. Anything else (prose, JSON objects, bad JSON) → return ``[]``.
|
|
|
|
This ensures the output contract is enforced by code: no matter what
|
|
the model returns or what ``ai_search_extra_hints`` contains, only a
|
|
valid JSON string array is accepted.
|
|
"""
|
|
text = content.strip()
|
|
if not text:
|
|
return []
|
|
|
|
# Strip markdown code fences
|
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
|
text = re.sub(r"\s*```$", "", text)
|
|
text = text.strip()
|
|
|
|
# Attempt JSON parse — strictly require a list
|
|
try:
|
|
parsed = json.loads(text)
|
|
except (json.JSONDecodeError, ValueError):
|
|
return []
|
|
|
|
if not isinstance(parsed, list):
|
|
return []
|
|
|
|
# Validate every element is a string; reject non-string items
|
|
terms: list[str] = []
|
|
for item in parsed:
|
|
if not isinstance(item, str):
|
|
return []
|
|
cleaned = item.strip()
|
|
if cleaned and len(cleaned) <= _MAX_TERM_LENGTH:
|
|
terms.append(cleaned)
|
|
|
|
# Cap total count
|
|
return terms[:_MAX_EXPANSION_TERMS]
|
|
|
|
|
|
def analyze_image(cfg: LLMConfig, image_data: bytes, prompt: str) -> LLMResult:
|
|
"""Analyze an image via LLM vision API.
|
|
|
|
.. note:: **Reserved stub — not implemented.** Will be filled in a future
|
|
round for image analysis. The signature is fixed so callers can
|
|
depend on it.
|
|
"""
|
|
# TODO: Implement in future round for image analysis.
|
|
return LLMResult(
|
|
success=False,
|
|
message="图片分析功能尚未实现。",
|
|
)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal boundary — all network calls go through this single function
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
def _call_chat_completion(
|
|
cfg: LLMConfig,
|
|
*,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int = 1,
|
|
temperature: float | None = None,
|
|
) -> dict:
|
|
"""Call the OpenAI-compatible ``/chat/completions`` endpoint.
|
|
|
|
Returns the parsed JSON response body on success (status 2xx).
|
|
Raises ``httpx.HTTPStatusError`` on non-2xx, or other ``httpx`` exceptions
|
|
on network failures — callers handle these for graceful degradation.
|
|
"""
|
|
url = cfg.base_url.rstrip("/") + "/chat/completions"
|
|
payload: dict[str, Any] = {
|
|
"model": cfg.model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
headers = {
|
|
"Authorization": f"Bearer {cfg.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
with httpx.Client(timeout=_TIMEOUT_SECONDS) as client:
|
|
response = client.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
return response.json()
|