172 lines
5.4 KiB
Python
172 lines
5.4 KiB
Python
"""
|
||
Семантический поиск по Яндекс Вики через OpenAI embeddings + pgvector.
|
||
"""
|
||
import re
|
||
import os
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
from openai import OpenAI
|
||
|
||
load_dotenv(Path(__file__).parent / '.env')
|
||
|
||
EMBED_MODEL = 'text-embedding-3-small' # 1536 dims, быстро и дёшево
|
||
|
||
_client = None
|
||
|
||
|
||
def _openai() -> OpenAI:
|
||
global _client
|
||
if _client is None:
|
||
_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
||
return _client
|
||
|
||
|
||
def clean_wiki_markup(text: str) -> str:
|
||
"""Убрать wiki-разметку, оставить чистый текст для embedding."""
|
||
if not text:
|
||
return ''
|
||
# {% ... %} блоки
|
||
text = re.sub(r'\{%[^%]*%\}', '', text)
|
||
# [[ссылки]] → оставить текст после |
|
||
text = re.sub(r'\[\[([^\]|]*)\|([^\]]*)\]\]', r'\2', text)
|
||
text = re.sub(r'\[\[([^\]]*)\]\]', r'\1', text)
|
||
# ((url текст)) → текст
|
||
text = re.sub(r'\(\(https?://\S+\s+([^)]+)\)\)', r'\1', text)
|
||
text = re.sub(r'\(\(https?://\S+\)\)', '', text)
|
||
# Markdown-разметка
|
||
text = re.sub(r'\*{1,3}([^*]+)\*{1,3}', r'\1', text)
|
||
text = re.sub(r'_{1,2}([^_]+)_{1,2}', r'\1', text)
|
||
# Таблицы и спец-символы wiki
|
||
text = re.sub(r'[#|]{2,}', '\n', text)
|
||
text = re.sub(r'^\s*[#>]+\s*', '', text, flags=re.MULTILINE)
|
||
# Лишние пробелы
|
||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||
return text.strip()
|
||
|
||
|
||
def embed_text(text: str) -> list[float]:
|
||
"""Получить embedding для текста через OpenAI API."""
|
||
import openai as _openai_module
|
||
# Начинаем с 15000 символов, при ошибке обрезаем вдвое
|
||
limit = 15000
|
||
while limit >= 1000:
|
||
try:
|
||
resp = _openai().embeddings.create(model=EMBED_MODEL, input=text[:limit])
|
||
return resp.data[0].embedding
|
||
except _openai_module.BadRequestError as e:
|
||
if 'maximum context length' in str(e):
|
||
limit = limit // 2
|
||
continue
|
||
raise
|
||
raise ValueError(f'Не удалось уложиться в лимит токенов даже при 1000 символах')
|
||
|
||
|
||
def embed_page(page: dict) -> list[float]:
|
||
"""Сгенерировать embedding для страницы (title + content)."""
|
||
title = page.get('title', '')
|
||
content = clean_wiki_markup(page.get('content', '') or '')
|
||
combined = f'{title}\n\n{content}'
|
||
return embed_text(combined)
|
||
|
||
|
||
def upsert_embeddings(db, pages: list[dict]) -> dict:
|
||
"""
|
||
Сгенерировать и сохранить embeddings для новых/изменённых страниц.
|
||
|
||
Пропускает страницы, у которых content_hash не изменился.
|
||
|
||
Returns:
|
||
dict: {'embedded': int, 'skipped': int}
|
||
"""
|
||
import json
|
||
import pendulum
|
||
|
||
if not pages:
|
||
return {'embedded': 0, 'skipped': 0}
|
||
|
||
# Получить текущие хеши из wiki_embeddings
|
||
db.cursor.execute('SELECT slug, content_hash FROM wiki_embeddings')
|
||
stored = {row[0]: row[1] for row in db.cursor.fetchall()}
|
||
|
||
embedded = 0
|
||
skipped = 0
|
||
|
||
for page in pages:
|
||
slug = page['slug']
|
||
new_hash = page['content_hash']
|
||
content = page.get('content', '') or ''
|
||
|
||
if not content.strip():
|
||
skipped += 1
|
||
continue
|
||
|
||
if stored.get(slug) == new_hash:
|
||
skipped += 1
|
||
continue
|
||
|
||
print(f' ↑ embedding: {slug}')
|
||
content_text = clean_wiki_markup(content)
|
||
vector = embed_page({'title': page.get('title', ''), 'content': content})
|
||
|
||
db.cursor.execute("""
|
||
INSERT INTO wiki_embeddings
|
||
(pg_load_dttm, slug, title, content_text, content_hash, embedding)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
ON CONFLICT (slug) DO UPDATE SET
|
||
pg_load_dttm = EXCLUDED.pg_load_dttm,
|
||
title = EXCLUDED.title,
|
||
content_text = EXCLUDED.content_text,
|
||
content_hash = EXCLUDED.content_hash,
|
||
embedding = EXCLUDED.embedding
|
||
""", (
|
||
pendulum.now('Europe/Moscow'),
|
||
slug,
|
||
page.get('title', ''),
|
||
content_text,
|
||
new_hash,
|
||
json.dumps(vector),
|
||
))
|
||
db.conn.commit()
|
||
embedded += 1
|
||
|
||
return {'embedded': embedded, 'skipped': skipped}
|
||
|
||
|
||
def search(db, query: str, limit: int = 5) -> list[dict]:
|
||
"""
|
||
Семантический поиск по wiki_embeddings.
|
||
|
||
Args:
|
||
db: подключённый SupabaseManager
|
||
query: текстовый запрос на любом языке
|
||
limit: кол-во результатов
|
||
|
||
Returns:
|
||
Список {'slug', 'title', 'similarity', 'content_text'}
|
||
"""
|
||
import json
|
||
|
||
query_vec = embed_text(query)
|
||
|
||
db.cursor.execute("""
|
||
SELECT
|
||
slug,
|
||
title,
|
||
content_text,
|
||
1 - (embedding <=> %s::vector) AS similarity
|
||
FROM wiki_embeddings
|
||
ORDER BY embedding <=> %s::vector
|
||
LIMIT %s
|
||
""", (json.dumps(query_vec), json.dumps(query_vec), limit))
|
||
|
||
rows = db.cursor.fetchall()
|
||
return [
|
||
{
|
||
'slug': row[0],
|
||
'title': row[1],
|
||
'content_text': row[2],
|
||
'similarity': round(float(row[3]), 4),
|
||
}
|
||
for row in rows
|
||
]
|