Add AI search query expansion
test / pytest (push) Successful in 1m20s
docker-image / build-and-push (push) Successful in 5m6s

This commit is contained in:
2026-06-01 21:28:29 +02:00
parent d36b940981
commit 70b0cf08ee
10 changed files with 1064 additions and 123 deletions
+122 -25
View File
@@ -8,6 +8,7 @@ Public API:
- ``is_configured(cfg)`` — returns True when the client can make calls.
- ``test_connection(cfg)`` — minimal request to verify credentials.
- ``expand_query(cfg, query)`` — query-term expansion (step 3 consumer).
Returns ``ExpansionResult`` with ``terms`` and optional ``error``.
- ``analyze_image(...)`` — **reserved stub, not implemented**.
All calls go through ``_call_chat_completion()`` so tests can mock a single
@@ -16,6 +17,8 @@ boundary.
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any
@@ -26,6 +29,18 @@ from app.settings_store import LLMConfig
# Sensible defaults
_TIMEOUT_SECONDS = 30
# ── Prompt for query expansion (Step 3) ──────────────────────────────────
_EXPAND_QUERY_SYSTEM_PROMPT = (
"你是搬家物品搜索助手。用户在搜索自己打包的箱子与物品(家居/搬家场景)。"
"给定一个搜索词,列出用户可能用来命名同一类物品的相关词:"
"近义词、常见别称、上位类别、具体品类。"
"规则:用与查询相同的语言;"
"只给与该物品紧密相关、有助于在清单里找到它的词;"
"不要解释、不要造无关词;最多 8 个;"
"只输出一个 JSON 字符串数组,例如 "
'`["炒锅","平底锅","汤锅","厨具"]`。'
)
@dataclass
class LLMResult:
@@ -36,6 +51,20 @@ class LLMResult:
data: Any = None
@dataclass
class ExpansionResult:
"""Structured result from ``expand_query``.
``terms`` is always a list (may be empty).
``error`` is ``None`` on success (including legitimate empty results);
on failure (timeout, network error, HTTP error) it contains a
human-friendly error message.
"""
terms: list[str]
error: str | None = None
def is_configured(cfg: LLMConfig) -> bool:
"""Return True only when the LLM is enabled AND has required fields."""
return bool(cfg.enabled and cfg.model and cfg.api_key)
@@ -87,44 +116,109 @@ def test_connection(cfg: LLMConfig) -> LLMResult:
)
def expand_query(cfg: LLMConfig, query: str) -> list[str]:
def expand_query(
cfg: LLMConfig,
query: str,
extra_hints: str = "",
) -> ExpansionResult:
"""Expand a search query into multiple synonymous terms via LLM.
**Step 3 will consume this.** Returns a list including the original query.
If the LLM call fails or is not configured, returns ``[query]`` as a
fallback (graceful degradation).
Returns an ``ExpansionResult``. On success ``terms`` contains the expanded
terms (possibly empty) and ``error`` is ``None``. On failure (network
error, timeout, HTTP error) ``terms`` is ``[]`` and ``error`` contains a
human-friendly message.
"""
if not is_configured(cfg):
return [query]
return ExpansionResult(terms=[])
system_prompt = _EXPAND_QUERY_SYSTEM_PROMPT
if extra_hints and extra_hints.strip():
system_prompt += "\n" + extra_hints.strip()
try:
response = _call_chat_completion(
cfg,
messages=[
{
"role": "system",
"content": (
"你是一个搜索词扩展助手。用户给你一个搜索词,"
"你返回 3-5 个同义词或相关词,每行一个。"
"不要编号、不要解释、不要标点。"
),
},
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
max_tokens=100,
max_tokens=200,
temperature=0,
)
except httpx.TimeoutException:
return ExpansionResult(
terms=[],
error="AI 搜索请求超时,请稍后再试。",
)
except httpx.ConnectError:
return ExpansionResult(
terms=[],
error="无法连接到 AI 服务,请检查网络或设置。",
)
except httpx.HTTPStatusError:
return ExpansionResult(
terms=[],
error="AI 服务返回错误,请检查配置。",
)
choices = response.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "")
expanded = [
line.strip() for line in content.strip().splitlines() if line.strip()
]
if expanded:
# Always include the original query
return [query] + [t for t in expanded if t != query]
return [query]
except Exception: # noqa: BLE001 — graceful degradation
return [query]
return ExpansionResult(
terms=[],
error="AI 搜索暂时不可用,请稍后再试。",
)
choices = response.get("choices", [])
if not choices:
return ExpansionResult(terms=[])
content = choices[0].get("message", {}).get("content", "")
return ExpansionResult(terms=_parse_json_string_array(content))
# ── Constants for output contract enforcement ────────────────────────────
_MAX_EXPANSION_TERMS = 8
_MAX_TERM_LENGTH = 30
def _parse_json_string_array(content: str) -> list[str]:
"""Parse LLM output into a list of strings.
Strict contract enforcement:
1. Strip markdown code fences;
2. Try ``json.loads`` — only accept a JSON **array of strings**;
3. Anything else (prose, JSON objects, bad JSON) → return ``[]``.
This ensures the output contract is enforced by code: no matter what
the model returns or what ``ai_search_extra_hints`` contains, only a
valid JSON string array is accepted.
"""
text = content.strip()
if not text:
return []
# Strip markdown code fences
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
text = text.strip()
# Attempt JSON parse — strictly require a list
try:
parsed = json.loads(text)
except (json.JSONDecodeError, ValueError):
return []
if not isinstance(parsed, list):
return []
# Validate every element is a string; reject non-string items
terms: list[str] = []
for item in parsed:
if not isinstance(item, str):
return []
cleaned = item.strip()
if cleaned and len(cleaned) <= _MAX_TERM_LENGTH:
terms.append(cleaned)
# Cap total count
return terms[:_MAX_EXPANSION_TERMS]
def analyze_image(cfg: LLMConfig, image_data: bytes, prompt: str) -> LLMResult:
@@ -151,6 +245,7 @@ def _call_chat_completion(
*,
messages: list[dict[str, str]],
max_tokens: int = 1,
temperature: float | None = None,
) -> dict:
"""Call the OpenAI-compatible ``/chat/completions`` endpoint.
@@ -164,6 +259,8 @@ def _call_chat_completion(
"messages": messages,
"max_tokens": max_tokens,
}
if temperature is not None:
payload["temperature"] = temperature
headers = {
"Authorization": f"Bearer {cfg.api_key}",
"Content-Type": "application/json",