newsreport_agent_for_traffic/rag/rag_processor.py

378 lines
13 KiB
Python

"""
RAG processor backed by FAISS and local JSON metadata storage.
"""
from __future__ import annotations
import json
import os
from datetime import datetime
from typing import Dict, List
import faiss
import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
class RAGProcessor:
"""Manage embedding, storage, and retrieval for crawled news."""
INDEX_FILENAME = "faiss.index"
RECORDS_FILENAME = "records.json"
def __init__(self, vector_db_dir: str = "./vector_db"):
self.vector_db_dir = vector_db_dir
os.makedirs(vector_db_dir, exist_ok=True)
self.index_path = os.path.join(self.vector_db_dir, self.INDEX_FILENAME)
self.records_path = os.path.join(self.vector_db_dir, self.RECORDS_FILENAME)
self.ollama_endpoint = os.getenv(
"OLLAMA_EMBEDDING_ENDPOINT",
"http://111.198.29.205:11343/",
)
self.ollama_model = os.getenv(
"OLLAMA_EMBEDDING_MODEL",
"paraphrase-multilingual:latest",
)
print("正在初始化 Ollama embedding...")
self.embeddings = OllamaEmbeddings(
base_url=self.ollama_endpoint.rstrip("/"),
model=self.ollama_model,
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=240,
chunk_overlap=30,
length_function=len,
)
self.records: List[Dict] = self._load_records()
self.dimension = self._determine_dimension()
self.index = self._load_or_create_index()
def _determine_dimension(self) -> int:
if self.records:
vector = self.records[0].get("embedding") or []
if vector:
return len(vector)
probe = self.embeddings.embed_query("维度探测")
return len(probe)
def _load_records(self) -> List[Dict]:
if not os.path.exists(self.records_path):
return []
with open(self.records_path, "r", encoding="utf-8") as file:
return json.load(file)
def _save_records(self) -> None:
os.makedirs(self.vector_db_dir, exist_ok=True)
with open(self.records_path, "w", encoding="utf-8") as file:
json.dump(self.records, file, ensure_ascii=False, indent=2)
def _load_or_create_index(self):
if os.path.exists(self.index_path):
with open(self.index_path, "rb") as file:
buffer = np.frombuffer(file.read(), dtype=np.uint8)
return faiss.deserialize_index(buffer)
return faiss.IndexFlatIP(self.dimension)
def _save_index(self) -> None:
os.makedirs(self.vector_db_dir, exist_ok=True)
buffer = faiss.serialize_index(self.index)
with open(self.index_path, "wb") as file:
file.write(buffer.tobytes())
def close(self) -> None:
self._save_records()
self._save_index()
@staticmethod
def _build_timestamp_predicate(
start_ts: int | None = None,
end_ts: int | None = None,
):
def predicate(record: Dict) -> bool:
std_timestamp = int(record.get("std_timestamp", 0) or 0)
if start_ts is not None and std_timestamp < int(start_ts):
return False
if end_ts is not None and std_timestamp > int(end_ts):
return False
return True
return predicate
@staticmethod
def _format_timestamp_date(std_timestamp: int | str | None) -> str:
try:
if std_timestamp in (None, "", 0, "0"):
return "N/A"
return datetime.fromtimestamp(int(std_timestamp)).strftime("%Y-%m-%d")
except Exception:
return "N/A"
def _split_document_for_retry(self, document: Document) -> List[Document]:
text = document.page_content or ""
if len(text) <= 1:
return [document]
half = max(1, len(text) // 2)
return [
Document(page_content=text[:half], metadata=document.metadata),
Document(page_content=text[half:], metadata=document.metadata),
]
def _normalize_vector(self, vector: List[float]) -> np.ndarray:
array = np.array(vector, dtype=np.float32)
norm = np.linalg.norm(array)
if norm > 0:
array = array / norm
return array
def _embed_text_with_retry(self, text: str) -> np.ndarray:
embedding = self.embeddings.embed_query(text)
return self._normalize_vector(embedding)
def _rebuild_index_from_records(self) -> None:
self.index = faiss.IndexFlatIP(self.dimension)
if not self.records:
self._save_index()
return
vectors = np.array(
[self._normalize_vector(record["embedding"]) for record in self.records],
dtype=np.float32,
)
self.index.add(vectors)
self._save_index()
def _delete_documents_by_url(self, url: str) -> None:
self.records = [record for record in self.records if record.get("url") != url]
self._rebuild_index_from_records()
def _get_existing_urls(self) -> set[str]:
return {record.get("url", "") for record in self.records if record.get("url")}
def _prepare_documents(self, news_list: List[Dict], upsert: bool) -> List[Document]:
existing_urls = self._get_existing_urls() if upsert else set()
documents: List[Document] = []
new_count = 0
updated_count = 0
skipped_count = 0
for news in news_list:
content = news.get("content")
if not content:
skipped_count += 1
continue
url = str(news.get("url", "") or "")
metadata = {
"title": str(news.get("title", "") or ""),
"url": url,
"source": str(news.get("source", "") or ""),
"std_timestamp": int(news.get("std_timestamp", 0) or 0),
}
if upsert and url in existing_urls:
self._delete_documents_by_url(url)
updated_count += 1
else:
new_count += 1
for text in self.text_splitter.split_text(content):
documents.append(Document(page_content=text, metadata=metadata))
print(f"待写入 {len(documents)} 个文档片段")
print(f" - 新增新闻: {new_count}")
print(f" - 更新新闻: {updated_count}")
if skipped_count > 0:
print(f" - 跳过新闻: {skipped_count} 条(无正文)")
return documents
def process_news(self, news_list: List[Dict], upsert: bool = True) -> None:
documents = self._prepare_documents(news_list, upsert=upsert)
if not documents:
print("没有可写入的文档")
return
pending = list(documents)
added_records: List[Dict] = []
vectors: List[np.ndarray] = []
chunk_index_by_url: Dict[str, int] = {}
for record in self.records:
url = record.get("url", "")
chunk_index_by_url[url] = max(chunk_index_by_url.get(url, 0), int(record.get("chunk_index", -1)) + 1)
while pending:
document = pending.pop(0)
text = document.page_content or ""
if not text.strip():
continue
try:
vector = self._embed_text_with_retry(text)
except ValueError as exc:
if "input length exceeds the context length" not in str(exc):
raise
smaller_docs = self._split_document_for_retry(document)
if len(smaller_docs) == 1 and smaller_docs[0].page_content == text:
raise
pending = smaller_docs + pending
continue
url = document.metadata.get("url", "")
chunk_index = chunk_index_by_url.get(url, 0)
chunk_index_by_url[url] = chunk_index + 1
added_records.append(
{
"id": f"{url}::{chunk_index}",
"chunk_index": chunk_index,
"content": text,
"title": document.metadata.get("title", ""),
"url": url,
"source": document.metadata.get("source", ""),
"std_timestamp": int(document.metadata.get("std_timestamp", 0)),
"embedding": vector.tolist(),
}
)
vectors.append(vector)
if not added_records:
print("没有成功生成 embedding 的文档")
return
self.records.extend(added_records)
matrix = np.vstack(vectors).astype(np.float32)
self.index.add(matrix)
self._save_records()
self._save_index()
print("FAISS 向量数据库写入完成")
def _search_record_indices(
self,
query: str,
k: int,
start_ts: int | None = None,
end_ts: int | None = None,
) -> List[int]:
if not self.records or self.index.ntotal == 0:
return []
predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts)
allowed_indices = [idx for idx, record in enumerate(self.records) if predicate(record)]
if not allowed_indices:
return []
query_vector = self._embed_text_with_retry(query).reshape(1, -1)
search_k = min(max(k * 5, k), len(self.records))
scores, indices = self.index.search(query_vector, search_k)
ranked = []
for idx in indices[0]:
if idx < 0:
continue
if idx in allowed_indices:
ranked.append(int(idx))
if len(ranked) >= k:
break
return ranked
def search(
self,
query: str,
k: int = 5,
start_ts: int | None = None,
end_ts: int | None = None,
) -> List[Document]:
ranked_indices = self._search_record_indices(
query,
k,
start_ts=start_ts,
end_ts=end_ts,
)
return [
Document(
page_content=self.records[idx].get("content", ""),
metadata={
"title": self.records[idx].get("title", ""),
"url": self.records[idx].get("url", ""),
"source": self.records[idx].get("source", ""),
"std_timestamp": self.records[idx].get("std_timestamp", 0),
},
)
for idx in ranked_indices
]
def search_recent(self, query: str, k: int = 5, days: int = 7) -> List[Document]:
end_ts = int(datetime.now().timestamp())
start_ts = end_ts - days * 24 * 60 * 60
return self.search(query, k=k, start_ts=start_ts, end_ts=end_ts)
def get_all_news_metadata(
self,
start_ts: int | None = None,
end_ts: int | None = None,
) -> List[Dict]:
predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts)
seen_urls = set()
unique_news = []
for record in self.records:
if not predicate(record):
continue
url = record.get("url", "")
if not url or url in seen_urls:
continue
seen_urls.add(url)
unique_news.append(
{
"title": record.get("title", ""),
"url": url,
"source": record.get("source", ""),
"std_timestamp": record.get("std_timestamp", 0),
"date": self._format_timestamp_date(record.get("std_timestamp")),
}
)
return unique_news
def get_recent_news_metadata(self, days: int = 7) -> List[Dict]:
end_ts = int(datetime.now().timestamp())
start_ts = end_ts - days * 24 * 60 * 60
return self.get_all_news_metadata(start_ts=start_ts, end_ts=end_ts)
def get_database_stats(self) -> Dict:
unique_urls = {
record.get("url", "")
for record in self.records
if record.get("url")
}
return {
"total_documents": len(self.records),
"unique_news": len(unique_urls),
"database_exists": len(self.records) > 0,
}
if __name__ == "__main__":
processor = RAGProcessor()
try:
processor.process_news(
[
{
"title": "测试新闻",
"content": "标题: 测试新闻\n\n这是一条用于验证 FAISS 写入的测试内容。",
"url": "http://test.example/news",
"source": "测试源",
"std_timestamp": int(datetime.now().timestamp()),
}
]
)
print(processor.get_database_stats())
print(len(processor.search("测试", k=1)))
finally:
processor.close()