""" 交通新闻自动报表系统 - 主程序 """ from __future__ import annotations import argparse import os from collections import Counter from datetime import datetime, timedelta from typing import Dict, List, Tuple from dotenv import load_dotenv # 必须在导入其他模块之前加载环境变量 load_dotenv() from crawler import BaiduMapCrawler, CCGPCrawler, TrafficNewsCrawler, WeChatCrawler from rag import RAGProcessor from report import ReportGenerator REPORT_LOOKBACK_DAYS = 15 MIN_NEWS_PER_SOURCE = 5 def parse_selected_sources(raw_sources: str) -> List[str]: """Parse sources argument.""" if raw_sources == "all": return ["traffic", "wechat", "baidu", "ccgp"] return [s.strip() for s in raw_sources.split(",") if s.strip()] def init_crawlers(selected_sources: List[str]) -> Tuple[Dict[str, object], List[str]]: """Initialize crawlers and filter unavailable sources.""" target_url = os.getenv("TARGET_URL", "https://www.7its.com/") wechat_mp_cookie = os.getenv("WECHAT_MP_COOKIE", "").strip() crawlers: Dict[str, object] = { "traffic": None, "wechat": None, "baidu": None, "ccgp": None, } if "traffic" in selected_sources: crawlers["traffic"] = TrafficNewsCrawler(base_url=target_url) if "wechat" in selected_sources: if wechat_mp_cookie: crawlers["wechat"] = WeChatCrawler() else: print("\n[警告] 微信公众号(高德)配置缺失,跳过高德地图爬虫") print("请在.env文件中配置WECHAT_MP_COOKIE") selected_sources = [s for s in selected_sources if s != "wechat"] if "baidu" in selected_sources: if wechat_mp_cookie: crawlers["baidu"] = BaiduMapCrawler() else: print("\n[警告] 微信公众号(百度地图)配置缺失,跳过百度地图爬虫") print("请在.env文件中配置WECHAT_MP_COOKIE") selected_sources = [s for s in selected_sources if s != "baidu"] if "ccgp" in selected_sources: crawlers["ccgp"] = CCGPCrawler() return crawlers, selected_sources def crawl_selected_sources( selected_sources: List[str], crawlers: Dict[str, object], data_dir: str, args: argparse.Namespace, compact: bool = False, ) -> List[Dict]: """Crawl all selected sources and return unified raw items.""" all_news: List[Dict] = [] steps = [ ("traffic", "赛文交通网"), ("wechat", "微信公众号(高德地图)"), ("baidu", "微信公众号(百度地图)"), ("ccgp", "政府采购网"), ] for idx, (source_key, source_title) in enumerate(steps, start=1): if source_key not in selected_sources: continue prefix = f" [{idx}]" if compact else f"[{idx}/4]" print(f"\n{prefix} 爬取{source_title}...") if source_key == "traffic" and crawlers["traffic"]: news_list = crawlers["traffic"].crawl_and_save( output_dir=data_dir, max_news=args.max_news, ) print(f"[成功] 赛文交通网爬取了 {len(news_list)} 条新闻") all_news.extend(news_list) elif source_key == "wechat" and crawlers["wechat"]: articles = crawlers["wechat"].crawl_articles( max_count=args.wechat_count, keyword=args.wechat_keyword, ) if articles: crawlers["wechat"].save_articles(articles, output_dir=data_dir) print(f"[成功] 微信公众号(高德)爬取了 {len(articles)} 篇文章") all_news.extend(articles) elif source_key == "baidu" and crawlers["baidu"]: articles = crawlers["baidu"].crawl_articles( max_count=args.baidu_count, keyword=args.baidu_keyword, ) if articles: crawlers["baidu"].save_articles(articles, output_dir=data_dir) print(f"[成功] 微信公众号(百度地图)爬取了 {len(articles)} 篇文章") all_news.extend(articles) elif source_key == "ccgp" and crawlers["ccgp"]: ccgp_keywords = [k.strip() for k in args.ccgp_keywords.split(",") if k.strip()] results = crawlers["ccgp"].crawl_by_keywords( keywords=ccgp_keywords, max_per_keyword=args.ccgp_count, ) if results: crawlers["ccgp"].save_results(results, output_dir=data_dir) print(f"[成功] 政府采购网爬取了 {len(results)} 条信息") all_news.extend(results) return all_news def build_rag_news_items(raw_items: List[Dict]) -> List[Dict]: """Convert crawled raw items into RAG processor format.""" news_list: List[Dict] = [] for item in raw_items: content = item.get("content") if not content: continue news_list.append( { "title": item.get("title", ""), "content": f"标题: {item.get('title', '')}\n\n{content}", "url": item.get("url", ""), "source": item.get("source", "未知来源"), "std_timestamp": item.get("std_timestamp", 0), } ) return news_list def print_db_stats(rag_processor: RAGProcessor) -> None: """Print vector database stats.""" db_stats = rag_processor.get_database_stats() if db_stats["database_exists"]: print( f"\n[数据库状态] 当前包含 {db_stats['unique_news']} 条新闻, " f"{db_stats['total_documents']} 个文档片段" ) else: print("\n[数据库状态] 向量数据库为空") def build_report_time_window(days: int = REPORT_LOOKBACK_DAYS) -> Dict: """Build the exact std_timestamp interval used for report generation.""" end_dt = datetime.now() start_dt = end_dt - timedelta(days=days) return { "days": days, "start_dt": start_dt, "end_dt": end_dt, "start_ts": int(start_dt.timestamp()), "end_ts": int(end_dt.timestamp()), "start_str": start_dt.strftime("%Y-%m-%d %H:%M:%S"), "end_str": end_dt.strftime("%Y-%m-%d %H:%M:%S"), } def print_report_time_window(window: Dict) -> None: """Print the concrete time interval used for report generation.""" print(f"\n[时间窗口] 报告生成实时截止时间: {window['end_str']}") print( f"[时间窗口] 新闻筛选区间: {window['start_str']} " f"至 {window['end_str']}" ) print( f"[时间窗口] std_timestamp 区间: " f"{window['start_ts']} - {window['end_ts']}" ) print(f"[时间窗口] 回溯时长: {window['days']} 天") def select_balanced_news_by_source( news_list: List[Dict], min_per_source: int = MIN_NEWS_PER_SOURCE, ) -> List[Dict]: """Select at least ``min_per_source`` items for each source when available.""" grouped: Dict[str, List[Dict]] = {} for news in sorted( news_list, key=lambda item: int(item.get("std_timestamp", 0) or 0), reverse=True, ): source = str(news.get("source", "未知来源") or "未知来源") grouped.setdefault(source, []).append(news) selected: List[Dict] = [] seen_urls = set() for source in sorted(grouped.keys()): picked = 0 for news in grouped[source]: url = news.get("url", "") if url and url in seen_urls: continue selected.append(news) if url: seen_urls.add(url) picked += 1 if picked >= min_per_source: break for news in sorted( news_list, key=lambda item: int(item.get("std_timestamp", 0) or 0), reverse=True, ): url = news.get("url", "") if url and url in seen_urls: continue selected.append(news) if url: seen_urls.add(url) return selected def print_selected_news_distribution(news_list: List[Dict]) -> None: """Print selected source distribution for report input.""" source_counts = Counter(news.get("source", "未知来源") for news in news_list) print("[信息] 综合报告新闻取样分布:") for source, count in source_counts.items(): print(f" - {source}: {count} 条") def generate_summary_report_from_db( rag_processor: RAGProcessor, report_generator: ReportGenerator, window: Dict, ) -> str: """Generate the summary report from vector DB using the aligned selection logic.""" news_metadata = rag_processor.get_all_news_metadata( start_ts=window["start_ts"], end_ts=window["end_ts"], ) if not news_metadata: raise RuntimeError("向量数据库中最近15天没有数据,请先运行爬取模式") selected_news = select_balanced_news_by_source(news_metadata) print(f"从向量数据库加载了最近15天的 {len(news_metadata)} 条新闻") print_selected_news_distribution(selected_news) relevant_docs = rag_processor.search( "交通信号控制 信控 绿波 智能交通 导航 路况", k=10, start_ts=window["start_ts"], end_ts=window["end_ts"], ) return report_generator.generate_summary_report(selected_news, relevant_docs) def main() -> None: """主函数""" parser = argparse.ArgumentParser(description="交通新闻自动报表系统") parser.add_argument( "--mode", type=str, default="full", choices=["crawl", "report", "full", "topic"], help="运行模式: crawl(仅爬取), report(仅生成报表), full(完整流程), topic(主题分析)", ) parser.add_argument("--max-news", type=int, default=20, help="最大爬取新闻数量(赛文交通网)") parser.add_argument("--topic", type=str, default="", help="主题关键词(用于topic模式)") parser.add_argument( "--sources", type=str, default="all", help="数据源选择: all(全部), traffic(赛文交通网), wechat(微信-高德), baidu(微信-百度地图), ccgp(政府采购网),多个用逗号分隔", ) parser.add_argument("--wechat-count", type=int, default=30, help="微信公众号(高德)爬取数量") parser.add_argument("--wechat-keyword", type=str, default="交通", help="微信公众号(高德)关键词过滤") parser.add_argument("--baidu-count", type=int, default=30, help="微信公众号(百度地图)爬取数量") parser.add_argument("--baidu-keyword", type=str, default="交通", help="微信公众号(百度地图)关键词过滤") parser.add_argument("--ccgp-keywords", type=str, default="信控,绿波", help="政府采购网关键词(逗号分隔)") parser.add_argument("--ccgp-count", type=int, default=30, help="政府采购网每个关键词爬取数量") args = parser.parse_args() qwen_api_key = os.getenv("QWEN_API_KEY") qwen_model = os.getenv("QWEN_MODEL", "qwen-max") data_dir = os.getenv("DATA_DIR", "./data") vector_db_dir = os.getenv("VECTOR_DB_DIR", "./vector_db") print("=" * 60) print("交通新闻自动报表系统 - 多源数据采集与分析") print("=" * 60) selected_sources = parse_selected_sources(args.sources) crawlers, selected_sources = init_crawlers(selected_sources) report_window = build_report_time_window(days=REPORT_LOOKBACK_DAYS) rag_processor = RAGProcessor(vector_db_dir=vector_db_dir) print_db_stats(rag_processor) report_generator = None if args.mode in {"report", "topic", "full"}: report_generator = ReportGenerator(api_key=qwen_api_key, model_name=qwen_model) try: if args.mode == "crawl": print(f"\n[模式] 仅爬取新闻 - 数据源: {', '.join(selected_sources)}") all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=False) if all_news: print(f"\n[RAG] 将 {len(all_news)} 条数据存入向量数据库...") rag_items = build_rag_news_items(all_news) rag_processor.process_news(rag_items, upsert=True) db_stats = rag_processor.get_database_stats() print( f"[成功] 向量数据库现有 {db_stats['unique_news']} 条新闻, " f"{db_stats['total_documents']} 个文档片段" ) print(f"\n[完成] 总共爬取了 {len(all_news)} 条数据") elif args.mode == "report": print("\n[模式] 生成报表(使用已有数据,最近15天)") print_report_time_window(report_window) try: report = generate_summary_report_from_db( rag_processor, report_generator, report_window, ) except RuntimeError as exc: print(f"错误: {exc}") return print("\n正在生成报表...") report_path = report_generator.save_report( report, report_type="summary", output_dir=data_dir, ) print(f"\n[成功] 报表已生成: {report_path}") print("\n" + "=" * 60) print("报表内容预览:") print("=" * 60) print(report) elif args.mode == "topic": if not args.topic: print("错误: topic模式需要指定 --topic 参数") return print(f"\n[模式] 主题分析(最近15天): {args.topic}") print_report_time_window(report_window) print(f"正在检索最近15天关于'{args.topic}'的相关内容...") relevant_docs = rag_processor.search( args.topic, k=10, start_ts=report_window["start_ts"], end_ts=report_window["end_ts"], ) if not relevant_docs: print(f"未找到关于'{args.topic}'的相关内容") return print(f"找到 {len(relevant_docs)} 条相关内容") print("\n正在生成主题分析报表...") report = report_generator.generate_topic_report(args.topic, relevant_docs) report_path = report_generator.save_report( report, report_type=f"topic_{args.topic}", output_dir=data_dir, ) print(f"\n[成功] 主题报表已生成: {report_path}") print("\n" + "=" * 60) print("报表内容预览:") print("=" * 60) print(report) else: # full print("\n[模式] 完整流程(爬取 -> RAG处理 -> 生成报表)") print(f"数据源: {', '.join(selected_sources)}") print("\n[步骤 1/3] 爬取新闻...") all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=True) if not all_news: print("错误: 未能爬取到任何数据") return print(f"\n[汇总] 共爬取 {len(all_news)} 条数据") print("\n[步骤 2/3] RAG数据处理...") rag_items = build_rag_news_items(all_news) rag_processor.process_news(rag_items, upsert=True) print("[成功] 数据已存入向量数据库") db_stats = rag_processor.get_database_stats() print( f"[数据库状态] 当前共有 {db_stats['unique_news']} 条新闻, " f"{db_stats['total_documents']} 个文档片段" ) print("\n[步骤 3/3] 生成最近15天分析报表...") print_report_time_window(report_window) try: report = generate_summary_report_from_db( rag_processor, report_generator, report_window, ) except RuntimeError as exc: print(f"错误: {exc}") return report_path = report_generator.save_report( report, report_type="summary", output_dir=data_dir, ) print(f"[成功] 报表已生成: {report_path}") print("\n" + "=" * 60) print("报表内容预览:") print("=" * 60) print(report) print("\n" + "=" * 60) print("任务完成!") print("=" * 60) finally: rag_processor.close() if __name__ == "__main__": main()