""" RAG processor backed by FAISS and local JSON metadata storage. """ from __future__ import annotations import json import os from datetime import datetime from typing import Dict, List import faiss import numpy as np from langchain_community.embeddings import OllamaEmbeddings from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter class RAGProcessor: """Manage embedding, storage, and retrieval for crawled news.""" INDEX_FILENAME = "faiss.index" RECORDS_FILENAME = "records.json" def __init__(self, vector_db_dir: str = "./vector_db"): self.vector_db_dir = vector_db_dir os.makedirs(vector_db_dir, exist_ok=True) self.index_path = os.path.join(self.vector_db_dir, self.INDEX_FILENAME) self.records_path = os.path.join(self.vector_db_dir, self.RECORDS_FILENAME) self.ollama_endpoint = os.getenv( "OLLAMA_EMBEDDING_ENDPOINT", "http://111.198.29.205:11343/", ) self.ollama_model = os.getenv( "OLLAMA_EMBEDDING_MODEL", "paraphrase-multilingual:latest", ) print("正在初始化 Ollama embedding...") self.embeddings = OllamaEmbeddings( base_url=self.ollama_endpoint.rstrip("/"), model=self.ollama_model, ) self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=240, chunk_overlap=30, length_function=len, ) self.records: List[Dict] = self._load_records() self.dimension = self._determine_dimension() self.index = self._load_or_create_index() def _determine_dimension(self) -> int: if self.records: vector = self.records[0].get("embedding") or [] if vector: return len(vector) probe = self.embeddings.embed_query("维度探测") return len(probe) def _load_records(self) -> List[Dict]: if not os.path.exists(self.records_path): return [] with open(self.records_path, "r", encoding="utf-8") as file: return json.load(file) def _save_records(self) -> None: os.makedirs(self.vector_db_dir, exist_ok=True) with open(self.records_path, "w", encoding="utf-8") as file: json.dump(self.records, file, ensure_ascii=False, indent=2) def _load_or_create_index(self): if os.path.exists(self.index_path): with open(self.index_path, "rb") as file: buffer = np.frombuffer(file.read(), dtype=np.uint8) return faiss.deserialize_index(buffer) return faiss.IndexFlatIP(self.dimension) def _save_index(self) -> None: os.makedirs(self.vector_db_dir, exist_ok=True) buffer = faiss.serialize_index(self.index) with open(self.index_path, "wb") as file: file.write(buffer.tobytes()) def close(self) -> None: self._save_records() self._save_index() @staticmethod def _build_timestamp_predicate( start_ts: int | None = None, end_ts: int | None = None, ): def predicate(record: Dict) -> bool: std_timestamp = int(record.get("std_timestamp", 0) or 0) if start_ts is not None and std_timestamp < int(start_ts): return False if end_ts is not None and std_timestamp > int(end_ts): return False return True return predicate @staticmethod def _format_timestamp_date(std_timestamp: int | str | None) -> str: try: if std_timestamp in (None, "", 0, "0"): return "N/A" return datetime.fromtimestamp(int(std_timestamp)).strftime("%Y-%m-%d") except Exception: return "N/A" def _split_document_for_retry(self, document: Document) -> List[Document]: text = document.page_content or "" if len(text) <= 1: return [document] half = max(1, len(text) // 2) return [ Document(page_content=text[:half], metadata=document.metadata), Document(page_content=text[half:], metadata=document.metadata), ] def _normalize_vector(self, vector: List[float]) -> np.ndarray: array = np.array(vector, dtype=np.float32) norm = np.linalg.norm(array) if norm > 0: array = array / norm return array def _embed_text_with_retry(self, text: str) -> np.ndarray: embedding = self.embeddings.embed_query(text) return self._normalize_vector(embedding) def _rebuild_index_from_records(self) -> None: self.index = faiss.IndexFlatIP(self.dimension) if not self.records: self._save_index() return vectors = np.array( [self._normalize_vector(record["embedding"]) for record in self.records], dtype=np.float32, ) self.index.add(vectors) self._save_index() def _delete_documents_by_url(self, url: str) -> None: self.records = [record for record in self.records if record.get("url") != url] self._rebuild_index_from_records() def _get_existing_urls(self) -> set[str]: return {record.get("url", "") for record in self.records if record.get("url")} def _prepare_documents(self, news_list: List[Dict], upsert: bool) -> List[Document]: existing_urls = self._get_existing_urls() if upsert else set() documents: List[Document] = [] new_count = 0 updated_count = 0 skipped_count = 0 for news in news_list: content = news.get("content") if not content: skipped_count += 1 continue url = str(news.get("url", "") or "") metadata = { "title": str(news.get("title", "") or ""), "url": url, "source": str(news.get("source", "") or ""), "std_timestamp": int(news.get("std_timestamp", 0) or 0), } if upsert and url in existing_urls: self._delete_documents_by_url(url) updated_count += 1 else: new_count += 1 for text in self.text_splitter.split_text(content): documents.append(Document(page_content=text, metadata=metadata)) print(f"待写入 {len(documents)} 个文档片段") print(f" - 新增新闻: {new_count} 条") print(f" - 更新新闻: {updated_count} 条") if skipped_count > 0: print(f" - 跳过新闻: {skipped_count} 条(无正文)") return documents def process_news(self, news_list: List[Dict], upsert: bool = True) -> None: documents = self._prepare_documents(news_list, upsert=upsert) if not documents: print("没有可写入的文档") return pending = list(documents) added_records: List[Dict] = [] vectors: List[np.ndarray] = [] chunk_index_by_url: Dict[str, int] = {} for record in self.records: url = record.get("url", "") chunk_index_by_url[url] = max(chunk_index_by_url.get(url, 0), int(record.get("chunk_index", -1)) + 1) while pending: document = pending.pop(0) text = document.page_content or "" if not text.strip(): continue try: vector = self._embed_text_with_retry(text) except ValueError as exc: if "input length exceeds the context length" not in str(exc): raise smaller_docs = self._split_document_for_retry(document) if len(smaller_docs) == 1 and smaller_docs[0].page_content == text: raise pending = smaller_docs + pending continue url = document.metadata.get("url", "") chunk_index = chunk_index_by_url.get(url, 0) chunk_index_by_url[url] = chunk_index + 1 added_records.append( { "id": f"{url}::{chunk_index}", "chunk_index": chunk_index, "content": text, "title": document.metadata.get("title", ""), "url": url, "source": document.metadata.get("source", ""), "std_timestamp": int(document.metadata.get("std_timestamp", 0)), "embedding": vector.tolist(), } ) vectors.append(vector) if not added_records: print("没有成功生成 embedding 的文档") return self.records.extend(added_records) matrix = np.vstack(vectors).astype(np.float32) self.index.add(matrix) self._save_records() self._save_index() print("FAISS 向量数据库写入完成") def _search_record_indices( self, query: str, k: int, start_ts: int | None = None, end_ts: int | None = None, ) -> List[int]: if not self.records or self.index.ntotal == 0: return [] predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts) allowed_indices = [idx for idx, record in enumerate(self.records) if predicate(record)] if not allowed_indices: return [] query_vector = self._embed_text_with_retry(query).reshape(1, -1) search_k = min(max(k * 5, k), len(self.records)) scores, indices = self.index.search(query_vector, search_k) ranked = [] for idx in indices[0]: if idx < 0: continue if idx in allowed_indices: ranked.append(int(idx)) if len(ranked) >= k: break return ranked def search( self, query: str, k: int = 5, start_ts: int | None = None, end_ts: int | None = None, ) -> List[Document]: ranked_indices = self._search_record_indices( query, k, start_ts=start_ts, end_ts=end_ts, ) return [ Document( page_content=self.records[idx].get("content", ""), metadata={ "title": self.records[idx].get("title", ""), "url": self.records[idx].get("url", ""), "source": self.records[idx].get("source", ""), "std_timestamp": self.records[idx].get("std_timestamp", 0), }, ) for idx in ranked_indices ] def search_recent(self, query: str, k: int = 5, days: int = 7) -> List[Document]: end_ts = int(datetime.now().timestamp()) start_ts = end_ts - days * 24 * 60 * 60 return self.search(query, k=k, start_ts=start_ts, end_ts=end_ts) def get_all_news_metadata( self, start_ts: int | None = None, end_ts: int | None = None, ) -> List[Dict]: predicate = self._build_timestamp_predicate(start_ts=start_ts, end_ts=end_ts) seen_urls = set() unique_news = [] for record in self.records: if not predicate(record): continue url = record.get("url", "") if not url or url in seen_urls: continue seen_urls.add(url) unique_news.append( { "title": record.get("title", ""), "url": url, "source": record.get("source", ""), "std_timestamp": record.get("std_timestamp", 0), "date": self._format_timestamp_date(record.get("std_timestamp")), } ) return unique_news def get_recent_news_metadata(self, days: int = 7) -> List[Dict]: end_ts = int(datetime.now().timestamp()) start_ts = end_ts - days * 24 * 60 * 60 return self.get_all_news_metadata(start_ts=start_ts, end_ts=end_ts) def get_database_stats(self) -> Dict: unique_urls = { record.get("url", "") for record in self.records if record.get("url") } return { "total_documents": len(self.records), "unique_news": len(unique_urls), "database_exists": len(self.records) > 0, } if __name__ == "__main__": processor = RAGProcessor() try: processor.process_news( [ { "title": "测试新闻", "content": "标题: 测试新闻\n\n这是一条用于验证 FAISS 写入的测试内容。", "url": "http://test.example/news", "source": "测试源", "std_timestamp": int(datetime.now().timestamp()), } ] ) print(processor.get_database_stats()) print(len(processor.search("测试", k=1))) finally: processor.close()