378 lines
13 KiB
Python
378 lines
13 KiB
Python
"""
|
|
RAG processor backed by FAISS and local JSON metadata storage.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Dict, List
|
|
|
|
import faiss
|
|
import numpy as np
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
from langchain_core.documents import Document
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
|
|
class RAGProcessor:
|
|
"""Manage embedding, storage, and retrieval for crawled news."""
|
|
|
|
INDEX_FILENAME = "faiss.index"
|
|
RECORDS_FILENAME = "records.json"
|
|
|
|
def __init__(self, vector_db_dir: str = "./vector_db"):
|
|
self.vector_db_dir = vector_db_dir
|
|
os.makedirs(vector_db_dir, exist_ok=True)
|
|
|
|
self.index_path = os.path.join(self.vector_db_dir, self.INDEX_FILENAME)
|
|
self.records_path = os.path.join(self.vector_db_dir, self.RECORDS_FILENAME)
|
|
|
|
self.ollama_endpoint = os.getenv(
|
|
"OLLAMA_EMBEDDING_ENDPOINT",
|
|
"http://111.198.29.205:11343/",
|
|
)
|
|
self.ollama_model = os.getenv(
|
|
"OLLAMA_EMBEDDING_MODEL",
|
|
"paraphrase-multilingual:latest",
|
|
)
|
|
|
|
print("正在初始化 Ollama embedding...")
|
|
self.embeddings = OllamaEmbeddings(
|
|
base_url=self.ollama_endpoint.rstrip("/"),
|
|
model=self.ollama_model,
|
|
)
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=240,
|
|
chunk_overlap=30,
|
|
length_function=len,
|
|
)
|
|
|
|
self.records: List[Dict] = self._load_records()
|
|
self.dimension = self._determine_dimension()
|
|
self.index = self._load_or_create_index()
|
|
|
|
def _determine_dimension(self) -> int:
|
|
if self.records:
|
|
vector = self.records[0].get("embedding") or []
|
|
if vector:
|
|
return len(vector)
|
|
probe = self.embeddings.embed_query("维度探测")
|
|
return len(probe)
|
|
|
|
def _load_records(self) -> List[Dict]:
|
|
if not os.path.exists(self.records_path):
|
|
return []
|
|
with open(self.records_path, "r", encoding="utf-8") as file:
|
|
return json.load(file)
|
|
|
|
def _save_records(self) -> None:
|
|
os.makedirs(self.vector_db_dir, exist_ok=True)
|
|
with open(self.records_path, "w", encoding="utf-8") as file:
|
|
json.dump(self.records, file, ensure_ascii=False, indent=2)
|
|
|
|
def _load_or_create_index(self):
|
|
if os.path.exists(self.index_path):
|
|
with open(self.index_path, "rb") as file:
|
|
buffer = np.frombuffer(file.read(), dtype=np.uint8)
|
|
return faiss.deserialize_index(buffer)
|
|
return faiss.IndexFlatIP(self.dimension)
|
|
|
|
def _save_index(self) -> None:
|
|
os.makedirs(self.vector_db_dir, exist_ok=True)
|
|
buffer = faiss.serialize_index(self.index)
|
|
with open(self.index_path, "wb") as file:
|
|
file.write(buffer.tobytes())
|
|
|
|
def close(self) -> None:
|
|
self._save_records()
|
|
self._save_index()
|
|
|
|
@staticmethod
|
|
def _build_timestamp_predicate(
|
|
start_ts: int | None = None,
|
|
end_ts: int | None = None,
|
|
):
|
|
def predicate(record: Dict) -> bool:
|
|
std_timestamp = int(record.get("std_timestamp", 0) or 0)
|
|
if start_ts is not None and std_timestamp < int(start_ts):
|
|
return False
|
|
if end_ts is not None and std_timestamp > int(end_ts):
|
|
return False
|
|
return True
|
|
|
|
return predicate
|
|
|
|
@staticmethod
|
|
def _format_timestamp_date(std_timestamp: int | str | None) -> str:
|
|
try:
|
|
if std_timestamp in (None, "", 0, "0"):
|
|
return "N/A"
|
|
return datetime.fromtimestamp(int(std_timestamp)).strftime("%Y-%m-%d")
|
|
except Exception:
|
|
return "N/A"
|
|
|
|
def _split_document_for_retry(self, document: Document) -> List[Document]:
|
|
text = document.page_content or ""
|
|
if len(text) <= 1:
|
|
return [document]
|
|
|
|
half = max(1, len(text) // 2)
|
|
return [
|
|
Document(page_content=text[:half], metadata=document.metadata),
|
|
Document(page_content=text[half:], metadata=document.metadata),
|
|
]
|
|
|
|
def _normalize_vector(self, vector: List[float]) -> np.ndarray:
|
|
array = np.array(vector, dtype=np.float32)
|
|
norm = np.linalg.norm(array)
|
|
if norm > 0:
|
|
array = array / norm
|
|
return array
|
|
|
|
def _embed_text_with_retry(self, text: str) -> np.ndarray:
|
|
embedding = self.embeddings.embed_query(text)
|
|
return self._normalize_vector(embedding)
|
|
|
|
def _rebuild_index_from_records(self) -> None:
|
|
self.index = faiss.IndexFlatIP(self.dimension)
|
|
if not self.records:
|
|
self._save_index()
|
|
return
|
|
|
|
vectors = np.array(
|
|
[self._normalize_vector(record["embedding"]) for record in self.records],
|
|
dtype=np.float32,
|
|
)
|
|
self.index.add(vectors)
|
|
self._save_index()
|
|
|
|
def _delete_documents_by_url(self, url: str) -> None:
|
|
self.records = [record for record in self.records if record.get("url") != url]
|
|
self._rebuild_index_from_records()
|
|
|
|
def _get_existing_urls(self) -> set[str]:
|
|
return {record.get("url", "") for record in self.records if record.get("url")}
|
|
|
|
def _prepare_documents(self, news_list: List[Dict], upsert: bool) -> List[Document]:
|
|
existing_urls = self._get_existing_urls() if upsert else set()
|
|
documents: List[Document] = []
|
|
new_count = 0
|
|
updated_count = 0
|
|
skipped_count = 0
|
|
|
|
for news in news_list:
|
|
content = news.get("content")
|
|
if not content:
|
|
skipped_count += 1
|
|
continue
|
|
|
|
url = str(news.get("url", "") or "")
|
|
metadata = {
|
|
"title": str(news.get("title", "") or ""),
|
|
"url": url,
|
|
"source": str(news.get("source", "") or ""),
|
|
"std_timestamp": int(news.get("std_timestamp", 0) or 0),
|
|
}
|
|
|
|
if upsert and url in existing_urls:
|
|
self._delete_documents_by_url(url)
|
|
updated_count += 1
|
|
else:
|
|
new_count += 1
|
|
|
|
for text in self.text_splitter.split_text(content):
|
|
documents.append(Document(page_content=text, metadata=metadata))
|
|
|
|
print(f"待写入 {len(documents)} 个文档片段")
|
|
print(f" - 新增新闻: {new_count} 条")
|
|
print(f" - 更新新闻: {updated_count} 条")
|
|
if skipped_count > 0:
|
|
print(f" - 跳过新闻: {skipped_count} 条(无正文)")
|
|
return documents
|
|
|
|
def process_news(self, news_list: List[Dict], upsert: bool = True) -> None:
|
|
documents = self._prepare_documents(news_list, upsert=upsert)
|
|
if not documents:
|
|
print("没有可写入的文档")
|
|
return
|
|
|
|
pending = list(documents)
|
|
added_records: List[Dict] = []
|
|
vectors: List[np.ndarray] = []
|
|
chunk_index_by_url: Dict[str, int] = {}
|
|
|
|
for record in self.records:
|
|
url = record.get("url", "")
|
|
chunk_index_by_url[url] = max(chunk_index_by_url.get(url, 0), int(record.get("chunk_index", -1)) + 1)
|
|
|
|
while pending:
|
|
document = pending.pop(0)
|
|
text = document.page_content or ""
|
|
if not text.strip():
|
|
continue
|
|
|
|
try:
|
|
vector = self._embed_text_with_retry(text)
|
|
except ValueError as exc:
|
|
if "input length exceeds the context length" not in str(exc):
|
|
raise
|
|
smaller_docs = self._split_document_for_retry(document)
|
|
if len(smaller_docs) == 1 and smaller_docs[0].page_content == text:
|
|
raise
|
|
pending = smaller_docs + pending
|
|
continue
|
|
|
|
url = document.metadata.get("url", "")
|
|
chunk_index = chunk_index_by_url.get(url, 0)
|
|
chunk_index_by_url[url] = chunk_index + 1
|
|
|
|
added_records.append(
|
|
{
|
|
"id": f"{url}::{chunk_index}",
|
|
"chunk_index": chunk_index,
|
|
"content": text,
|
|
"title": document.metadata.get("title", ""),
|
|
"url": url,
|
|
"source": document.metadata.get("source", ""),
|
|
"std_timestamp": int(document.metadata.get("std_timestamp", 0)),
|
|
"embedding": vector.tolist(),
|
|
}
|
|
)
|
|
vectors.append(vector)
|
|
|
|
if not added_records:
|
|
print("没有成功生成 embedding 的文档")
|
|
return
|
|
|
|
self.records.extend(added_records)
|
|
matrix = np.vstack(vectors).astype(np.float32)
|
|
self.index.add(matrix)
|
|
self._save_records()
|
|
self._save_index()
|
|
print("FAISS 向量数据库写入完成")
|
|
|
|
def _search_record_indices(
|
|
self,
|
|
query: str,
|
|
k: int,
|
|
start_ts: int | None = None,
|
|
end_ts: int | None = None,
|
|
) -> List[int]:
|
|
if not self.records or self.index.ntotal == 0:
|
|
return []
|
|
|
|
predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts)
|
|
allowed_indices = [idx for idx, record in enumerate(self.records) if predicate(record)]
|
|
if not allowed_indices:
|
|
return []
|
|
|
|
query_vector = self._embed_text_with_retry(query).reshape(1, -1)
|
|
search_k = min(max(k * 5, k), len(self.records))
|
|
scores, indices = self.index.search(query_vector, search_k)
|
|
ranked = []
|
|
for idx in indices[0]:
|
|
if idx < 0:
|
|
continue
|
|
if idx in allowed_indices:
|
|
ranked.append(int(idx))
|
|
if len(ranked) >= k:
|
|
break
|
|
return ranked
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
k: int = 5,
|
|
start_ts: int | None = None,
|
|
end_ts: int | None = None,
|
|
) -> List[Document]:
|
|
ranked_indices = self._search_record_indices(
|
|
query,
|
|
k,
|
|
start_ts=start_ts,
|
|
end_ts=end_ts,
|
|
)
|
|
return [
|
|
Document(
|
|
page_content=self.records[idx].get("content", ""),
|
|
metadata={
|
|
"title": self.records[idx].get("title", ""),
|
|
"url": self.records[idx].get("url", ""),
|
|
"source": self.records[idx].get("source", ""),
|
|
"std_timestamp": self.records[idx].get("std_timestamp", 0),
|
|
},
|
|
)
|
|
for idx in ranked_indices
|
|
]
|
|
|
|
def search_recent(self, query: str, k: int = 5, days: int = 7) -> List[Document]:
|
|
end_ts = int(datetime.now().timestamp())
|
|
start_ts = end_ts - days * 24 * 60 * 60
|
|
return self.search(query, k=k, start_ts=start_ts, end_ts=end_ts)
|
|
|
|
def get_all_news_metadata(
|
|
self,
|
|
start_ts: int | None = None,
|
|
end_ts: int | None = None,
|
|
) -> List[Dict]:
|
|
predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts)
|
|
seen_urls = set()
|
|
unique_news = []
|
|
|
|
for record in self.records:
|
|
if not predicate(record):
|
|
continue
|
|
url = record.get("url", "")
|
|
if not url or url in seen_urls:
|
|
continue
|
|
seen_urls.add(url)
|
|
unique_news.append(
|
|
{
|
|
"title": record.get("title", ""),
|
|
"url": url,
|
|
"source": record.get("source", ""),
|
|
"std_timestamp": record.get("std_timestamp", 0),
|
|
"date": self._format_timestamp_date(record.get("std_timestamp")),
|
|
}
|
|
)
|
|
return unique_news
|
|
|
|
def get_recent_news_metadata(self, days: int = 7) -> List[Dict]:
|
|
end_ts = int(datetime.now().timestamp())
|
|
start_ts = end_ts - days * 24 * 60 * 60
|
|
return self.get_all_news_metadata(start_ts=start_ts, end_ts=end_ts)
|
|
|
|
def get_database_stats(self) -> Dict:
|
|
unique_urls = {
|
|
record.get("url", "")
|
|
for record in self.records
|
|
if record.get("url")
|
|
}
|
|
return {
|
|
"total_documents": len(self.records),
|
|
"unique_news": len(unique_urls),
|
|
"database_exists": len(self.records) > 0,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
processor = RAGProcessor()
|
|
try:
|
|
processor.process_news(
|
|
[
|
|
{
|
|
"title": "测试新闻",
|
|
"content": "标题: 测试新闻\n\n这是一条用于验证 FAISS 写入的测试内容。",
|
|
"url": "http://test.example/news",
|
|
"source": "测试源",
|
|
"std_timestamp": int(datetime.now().timestamp()),
|
|
}
|
|
]
|
|
)
|
|
print(processor.get_database_stats())
|
|
print(len(processor.search("测试", k=1)))
|
|
finally:
|
|
processor.close()
|