newsreport_agent_for_traffic/rebuild_vector_db.py

238 lines
8.0 KiB
Python
Raw Permalink Normal View History

2026-05-09 10:46:52 +08:00
"""
Rebuild the local vector database from JSON files under data/.
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
from dotenv import load_dotenv
from rag import RAGProcessor
SKIP_JSON_FILES = {
"managed_recipients.json",
"scheduled_send_config.json",
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Rebuild FAISS data from local JSON files")
parser.add_argument("--data-dir", default=os.getenv("DATA_DIR", "./data"))
parser.add_argument("--vector-db-dir", default=os.getenv("VECTOR_DB_DIR", "./vector_db"))
parser.add_argument("--batch-size", type=int, default=10)
return parser.parse_args()
def safe_int(value: object) -> int:
try:
return int(value or 0)
except Exception:
return 0
def infer_source(file_path: Path, payload: Dict) -> str:
filename = file_path.name.lower()
if filename.startswith("gaode_backend_probe"):
return "高德地图公众号"
if filename.startswith("baidu_backend_probe"):
return "百度地图公众号"
if filename.startswith("ccgp_"):
return "中国政府采购网"
biz = payload.get("biz") if isinstance(payload, dict) else None
if isinstance(biz, dict) and biz.get("nickname"):
return str(biz["nickname"])
if isinstance(payload, dict) and payload.get("source"):
return str(payload["source"])
return "未知来源"
def normalize_record(item: Dict, source: str) -> Dict | None:
title = str(item.get("title", "") or "").strip()
url = str(item.get("url") or item.get("link") or "").strip()
content_body = str(item.get("content") or item.get("content_text") or "").strip()
std_timestamp = safe_int(item.get("std_timestamp"))
record_source = str(item.get("source") or source or "未知来源").strip()
if not title or not url or not content_body or not std_timestamp:
return None
return {
"title": title,
"content": f"标题: {title}\n\n{content_body}",
"url": url,
"source": record_source,
"std_timestamp": std_timestamp,
}
def extract_records(file_path: Path) -> Tuple[List[Dict], str]:
payload = json.loads(file_path.read_text(encoding="utf-8"))
extracted: List[Dict] = []
if isinstance(payload, list):
source = "未知来源"
for item in payload:
if not isinstance(item, dict):
continue
source = str(item.get("source") or source)
record = normalize_record(item, source=source)
if record:
extracted.append(record)
return extracted, "list"
if not isinstance(payload, dict):
return extracted, "unsupported"
if isinstance(payload.get("articles"), list):
source = infer_source(file_path, payload)
for item in payload["articles"]:
if not isinstance(item, dict):
continue
record = normalize_record(item, source=source)
if record:
extracted.append(record)
return extracted, "articles"
if isinstance(payload.get("items"), list):
source = infer_source(file_path, payload)
for item in payload["items"]:
if not isinstance(item, dict):
continue
record = normalize_record(item, source=source)
if record:
extracted.append(record)
return extracted, "items"
return extracted, "skipped"
def dedupe_records(records: Iterable[Dict]) -> Tuple[List[Dict], int]:
by_url: Dict[str, Dict] = {}
duplicate_count = 0
for record in records:
existing = by_url.get(record["url"])
if existing is None:
by_url[record["url"]] = record
continue
duplicate_count += 1
existing_score = (existing["std_timestamp"], len(existing["content"]))
current_score = (record["std_timestamp"], len(record["content"]))
if current_score >= existing_score:
by_url[record["url"]] = record
return list(by_url.values()), duplicate_count
def ensure_safe_delete(target_dir: Path, workspace_root: Path) -> None:
resolved_target = target_dir.resolve()
resolved_workspace = workspace_root.resolve()
if resolved_target == resolved_workspace:
raise ValueError(f"Refusing to delete workspace root: {resolved_target}")
if resolved_target == Path(resolved_target.anchor):
raise ValueError(f"Refusing to delete filesystem root: {resolved_target}")
if resolved_workspace not in resolved_target.parents:
raise ValueError(f"Refusing to delete path outside workspace: {resolved_target}")
def main() -> None:
load_dotenv()
args = parse_args()
workspace_root = Path.cwd()
data_dir = Path(args.data_dir).resolve()
vector_db_dir = Path(args.vector_db_dir).resolve()
if not data_dir.exists():
raise FileNotFoundError(f"Data directory does not exist: {data_dir}")
all_records: List[Dict] = []
file_stats: List[Tuple[str, str, int]] = []
skipped_files: List[str] = []
for file_path in sorted(data_dir.glob("*.json")):
if file_path.name in SKIP_JSON_FILES:
skipped_files.append(file_path.name)
continue
records, kind = extract_records(file_path)
if records:
file_stats.append((file_path.name, kind, len(records)))
all_records.extend(records)
else:
skipped_files.append(file_path.name)
deduped_records, duplicate_count = dedupe_records(all_records)
source_counter = Counter(record["source"] for record in deduped_records)
ensure_safe_delete(vector_db_dir, workspace_root)
if vector_db_dir.exists():
shutil.rmtree(vector_db_dir)
print(f"[info] parsed_records={len(all_records)}", flush=True)
print(f"[info] deduped_records={len(deduped_records)}", flush=True)
print(f"[info] duplicate_urls_removed={duplicate_count}", flush=True)
print("[info] initializing RAGProcessor...", flush=True)
rag_processor = RAGProcessor(vector_db_dir=str(vector_db_dir))
print("[info] RAGProcessor ready", flush=True)
total_batches = max((len(deduped_records) + args.batch_size - 1) // args.batch_size, 1)
for batch_index in range(total_batches):
start = batch_index * args.batch_size
end = start + args.batch_size
batch = deduped_records[start:end]
print(
f"[info] ingesting batch {batch_index + 1}/{total_batches} "
f"with {len(batch)} records",
flush=True,
)
rag_processor.process_news(batch, upsert=False)
batch_stats = rag_processor.get_database_stats()
print(
f"[info] batch {batch_index + 1}/{total_batches} complete: "
f"unique_news={batch_stats['unique_news']}, "
f"total_documents={batch_stats['total_documents']}",
flush=True,
)
db_stats = rag_processor.get_database_stats()
recent_count = len(rag_processor.get_recent_news_metadata(days=7))
rag_processor.close()
print("=" * 60)
print("VECTOR DB REBUILD SUMMARY")
print("=" * 60)
print(f"data_dir={data_dir}")
print(f"vector_db_dir={vector_db_dir}")
print(f"parsed_records={len(all_records)}")
print(f"duplicate_urls_removed={duplicate_count}")
print(f"ingested_unique_records={len(deduped_records)}")
print(f"db_unique_news={db_stats['unique_news']}")
print(f"db_total_documents={db_stats['total_documents']}")
print(f"recent_7d_news={recent_count}")
print("source_distribution:")
for source, count in source_counter.most_common():
print(f" - {source}: {count}")
print("ingested_files:")
for file_name, kind, count in file_stats:
print(f" - {file_name} [{kind}] -> {count}")
print("skipped_files:")
for file_name in skipped_files:
print(f" - {file_name}")
if __name__ == "__main__":
main()