455 lines
16 KiB
Python
455 lines
16 KiB
Python
"""
|
||
交通新闻自动报表系统 - 主程序
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
from collections import Counter
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Tuple
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
# 必须在导入其他模块之前加载环境变量
|
||
load_dotenv()
|
||
|
||
from crawler import BaiduMapCrawler, CCGPCrawler, TrafficNewsCrawler, WeChatCrawler
|
||
from rag import RAGProcessor
|
||
from report import ReportGenerator
|
||
|
||
|
||
REPORT_LOOKBACK_DAYS = 15
|
||
MIN_NEWS_PER_SOURCE = 5
|
||
|
||
|
||
def parse_selected_sources(raw_sources: str) -> List[str]:
|
||
"""Parse sources argument."""
|
||
if raw_sources == "all":
|
||
return ["traffic", "wechat", "baidu", "ccgp"]
|
||
return [s.strip() for s in raw_sources.split(",") if s.strip()]
|
||
|
||
|
||
def init_crawlers(selected_sources: List[str]) -> Tuple[Dict[str, object], List[str]]:
|
||
"""Initialize crawlers and filter unavailable sources."""
|
||
target_url = os.getenv("TARGET_URL", "https://www.7its.com/")
|
||
wechat_mp_cookie = os.getenv("WECHAT_MP_COOKIE", "").strip()
|
||
|
||
crawlers: Dict[str, object] = {
|
||
"traffic": None,
|
||
"wechat": None,
|
||
"baidu": None,
|
||
"ccgp": None,
|
||
}
|
||
|
||
if "traffic" in selected_sources:
|
||
crawlers["traffic"] = TrafficNewsCrawler(base_url=target_url)
|
||
|
||
if "wechat" in selected_sources:
|
||
if wechat_mp_cookie:
|
||
crawlers["wechat"] = WeChatCrawler()
|
||
else:
|
||
print("\n[警告] 微信公众号(高德)配置缺失,跳过高德地图爬虫")
|
||
print("请在.env文件中配置WECHAT_MP_COOKIE")
|
||
selected_sources = [s for s in selected_sources if s != "wechat"]
|
||
|
||
if "baidu" in selected_sources:
|
||
if wechat_mp_cookie:
|
||
crawlers["baidu"] = BaiduMapCrawler()
|
||
else:
|
||
print("\n[警告] 微信公众号(百度地图)配置缺失,跳过百度地图爬虫")
|
||
print("请在.env文件中配置WECHAT_MP_COOKIE")
|
||
selected_sources = [s for s in selected_sources if s != "baidu"]
|
||
|
||
if "ccgp" in selected_sources:
|
||
crawlers["ccgp"] = CCGPCrawler()
|
||
|
||
return crawlers, selected_sources
|
||
|
||
|
||
def crawl_selected_sources(
|
||
selected_sources: List[str],
|
||
crawlers: Dict[str, object],
|
||
data_dir: str,
|
||
args: argparse.Namespace,
|
||
compact: bool = False,
|
||
) -> List[Dict]:
|
||
"""Crawl all selected sources and return unified raw items."""
|
||
all_news: List[Dict] = []
|
||
|
||
steps = [
|
||
("traffic", "赛文交通网"),
|
||
("wechat", "微信公众号(高德地图)"),
|
||
("baidu", "微信公众号(百度地图)"),
|
||
("ccgp", "政府采购网"),
|
||
]
|
||
|
||
for idx, (source_key, source_title) in enumerate(steps, start=1):
|
||
if source_key not in selected_sources:
|
||
continue
|
||
|
||
prefix = f" [{idx}]" if compact else f"[{idx}/4]"
|
||
print(f"\n{prefix} 爬取{source_title}...")
|
||
|
||
if source_key == "traffic" and crawlers["traffic"]:
|
||
news_list = crawlers["traffic"].crawl_and_save(
|
||
output_dir=data_dir,
|
||
max_news=args.max_news,
|
||
)
|
||
print(f"[成功] 赛文交通网爬取了 {len(news_list)} 条新闻")
|
||
all_news.extend(news_list)
|
||
|
||
elif source_key == "wechat" and crawlers["wechat"]:
|
||
articles = crawlers["wechat"].crawl_articles(
|
||
max_count=args.wechat_count,
|
||
keyword=args.wechat_keyword,
|
||
)
|
||
if articles:
|
||
crawlers["wechat"].save_articles(articles, output_dir=data_dir)
|
||
print(f"[成功] 微信公众号(高德)爬取了 {len(articles)} 篇文章")
|
||
all_news.extend(articles)
|
||
|
||
elif source_key == "baidu" and crawlers["baidu"]:
|
||
articles = crawlers["baidu"].crawl_articles(
|
||
max_count=args.baidu_count,
|
||
keyword=args.baidu_keyword,
|
||
)
|
||
if articles:
|
||
crawlers["baidu"].save_articles(articles, output_dir=data_dir)
|
||
print(f"[成功] 微信公众号(百度地图)爬取了 {len(articles)} 篇文章")
|
||
all_news.extend(articles)
|
||
|
||
elif source_key == "ccgp" and crawlers["ccgp"]:
|
||
ccgp_keywords = [k.strip() for k in args.ccgp_keywords.split(",") if k.strip()]
|
||
results = crawlers["ccgp"].crawl_by_keywords(
|
||
keywords=ccgp_keywords,
|
||
max_per_keyword=args.ccgp_count,
|
||
)
|
||
if results:
|
||
crawlers["ccgp"].save_results(results, output_dir=data_dir)
|
||
print(f"[成功] 政府采购网爬取了 {len(results)} 条信息")
|
||
all_news.extend(results)
|
||
|
||
return all_news
|
||
|
||
|
||
def build_rag_news_items(raw_items: List[Dict]) -> List[Dict]:
|
||
"""Convert crawled raw items into RAG processor format."""
|
||
news_list: List[Dict] = []
|
||
for item in raw_items:
|
||
content = item.get("content")
|
||
if not content:
|
||
continue
|
||
news_list.append(
|
||
{
|
||
"title": item.get("title", ""),
|
||
"content": f"标题: {item.get('title', '')}\n\n{content}",
|
||
"url": item.get("url", ""),
|
||
"source": item.get("source", "未知来源"),
|
||
"std_timestamp": item.get("std_timestamp", 0),
|
||
}
|
||
)
|
||
return news_list
|
||
|
||
|
||
def print_db_stats(rag_processor: RAGProcessor) -> None:
|
||
"""Print vector database stats."""
|
||
db_stats = rag_processor.get_database_stats()
|
||
if db_stats["database_exists"]:
|
||
print(
|
||
f"\n[数据库状态] 当前包含 {db_stats['unique_news']} 条新闻, "
|
||
f"{db_stats['total_documents']} 个文档片段"
|
||
)
|
||
else:
|
||
print("\n[数据库状态] 向量数据库为空")
|
||
|
||
|
||
def build_report_time_window(days: int = REPORT_LOOKBACK_DAYS) -> Dict:
|
||
"""Build the exact std_timestamp interval used for report generation."""
|
||
end_dt = datetime.now()
|
||
start_dt = end_dt - timedelta(days=days)
|
||
return {
|
||
"days": days,
|
||
"start_dt": start_dt,
|
||
"end_dt": end_dt,
|
||
"start_ts": int(start_dt.timestamp()),
|
||
"end_ts": int(end_dt.timestamp()),
|
||
"start_str": start_dt.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"end_str": end_dt.strftime("%Y-%m-%d %H:%M:%S"),
|
||
}
|
||
|
||
|
||
def print_report_time_window(window: Dict) -> None:
|
||
"""Print the concrete time interval used for report generation."""
|
||
print(f"\n[时间窗口] 报告生成实时截止时间: {window['end_str']}")
|
||
print(
|
||
f"[时间窗口] 新闻筛选区间: {window['start_str']} "
|
||
f"至 {window['end_str']}"
|
||
)
|
||
print(
|
||
f"[时间窗口] std_timestamp 区间: "
|
||
f"{window['start_ts']} - {window['end_ts']}"
|
||
)
|
||
print(f"[时间窗口] 回溯时长: {window['days']} 天")
|
||
|
||
|
||
def select_balanced_news_by_source(
|
||
news_list: List[Dict],
|
||
min_per_source: int = MIN_NEWS_PER_SOURCE,
|
||
) -> List[Dict]:
|
||
"""Select at least ``min_per_source`` items for each source when available."""
|
||
grouped: Dict[str, List[Dict]] = {}
|
||
for news in sorted(
|
||
news_list,
|
||
key=lambda item: int(item.get("std_timestamp", 0) or 0),
|
||
reverse=True,
|
||
):
|
||
source = str(news.get("source", "未知来源") or "未知来源")
|
||
grouped.setdefault(source, []).append(news)
|
||
|
||
selected: List[Dict] = []
|
||
seen_urls = set()
|
||
|
||
for source in sorted(grouped.keys()):
|
||
picked = 0
|
||
for news in grouped[source]:
|
||
url = news.get("url", "")
|
||
if url and url in seen_urls:
|
||
continue
|
||
selected.append(news)
|
||
if url:
|
||
seen_urls.add(url)
|
||
picked += 1
|
||
if picked >= min_per_source:
|
||
break
|
||
|
||
for news in sorted(
|
||
news_list,
|
||
key=lambda item: int(item.get("std_timestamp", 0) or 0),
|
||
reverse=True,
|
||
):
|
||
url = news.get("url", "")
|
||
if url and url in seen_urls:
|
||
continue
|
||
selected.append(news)
|
||
if url:
|
||
seen_urls.add(url)
|
||
|
||
return selected
|
||
|
||
|
||
def print_selected_news_distribution(news_list: List[Dict]) -> None:
|
||
"""Print selected source distribution for report input."""
|
||
source_counts = Counter(news.get("source", "未知来源") for news in news_list)
|
||
print("[信息] 综合报告新闻取样分布:")
|
||
for source, count in source_counts.items():
|
||
print(f" - {source}: {count} 条")
|
||
|
||
|
||
def generate_summary_report_from_db(
|
||
rag_processor: RAGProcessor,
|
||
report_generator: ReportGenerator,
|
||
window: Dict,
|
||
) -> str:
|
||
"""Generate the summary report from vector DB using the aligned selection logic."""
|
||
news_metadata = rag_processor.get_all_news_metadata(
|
||
start_ts=window["start_ts"],
|
||
end_ts=window["end_ts"],
|
||
)
|
||
if not news_metadata:
|
||
raise RuntimeError("向量数据库中最近15天没有数据,请先运行爬取模式")
|
||
|
||
selected_news = select_balanced_news_by_source(news_metadata)
|
||
print(f"从向量数据库加载了最近15天的 {len(news_metadata)} 条新闻")
|
||
print_selected_news_distribution(selected_news)
|
||
|
||
relevant_docs = rag_processor.search(
|
||
"交通信号控制 信控 绿波 智能交通 导航 路况",
|
||
k=10,
|
||
start_ts=window["start_ts"],
|
||
end_ts=window["end_ts"],
|
||
)
|
||
return report_generator.generate_summary_report(selected_news, relevant_docs)
|
||
|
||
|
||
def main() -> None:
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="交通新闻自动报表系统")
|
||
parser.add_argument(
|
||
"--mode",
|
||
type=str,
|
||
default="full",
|
||
choices=["crawl", "report", "full", "topic"],
|
||
help="运行模式: crawl(仅爬取), report(仅生成报表), full(完整流程), topic(主题分析)",
|
||
)
|
||
parser.add_argument("--max-news", type=int, default=20, help="最大爬取新闻数量(赛文交通网)")
|
||
parser.add_argument("--topic", type=str, default="", help="主题关键词(用于topic模式)")
|
||
parser.add_argument(
|
||
"--sources",
|
||
type=str,
|
||
default="all",
|
||
help="数据源选择: all(全部), traffic(赛文交通网), wechat(微信-高德), baidu(微信-百度地图), ccgp(政府采购网),多个用逗号分隔",
|
||
)
|
||
parser.add_argument("--wechat-count", type=int, default=30, help="微信公众号(高德)爬取数量")
|
||
parser.add_argument("--wechat-keyword", type=str, default="交通", help="微信公众号(高德)关键词过滤")
|
||
parser.add_argument("--baidu-count", type=int, default=30, help="微信公众号(百度地图)爬取数量")
|
||
parser.add_argument("--baidu-keyword", type=str, default="交通", help="微信公众号(百度地图)关键词过滤")
|
||
parser.add_argument("--ccgp-keywords", type=str, default="信控,绿波", help="政府采购网关键词(逗号分隔)")
|
||
parser.add_argument("--ccgp-count", type=int, default=30, help="政府采购网每个关键词爬取数量")
|
||
args = parser.parse_args()
|
||
|
||
qwen_api_key = os.getenv("QWEN_API_KEY")
|
||
qwen_model = os.getenv("QWEN_MODEL", "qwen-max")
|
||
data_dir = os.getenv("DATA_DIR", "./data")
|
||
vector_db_dir = os.getenv("VECTOR_DB_DIR", "./vector_db")
|
||
|
||
print("=" * 60)
|
||
print("交通新闻自动报表系统 - 多源数据采集与分析")
|
||
print("=" * 60)
|
||
|
||
selected_sources = parse_selected_sources(args.sources)
|
||
crawlers, selected_sources = init_crawlers(selected_sources)
|
||
report_window = build_report_time_window(days=REPORT_LOOKBACK_DAYS)
|
||
|
||
rag_processor = RAGProcessor(vector_db_dir=vector_db_dir)
|
||
print_db_stats(rag_processor)
|
||
|
||
report_generator = None
|
||
if args.mode in {"report", "topic", "full"}:
|
||
report_generator = ReportGenerator(api_key=qwen_api_key, model_name=qwen_model)
|
||
|
||
try:
|
||
if args.mode == "crawl":
|
||
print(f"\n[模式] 仅爬取新闻 - 数据源: {', '.join(selected_sources)}")
|
||
all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=False)
|
||
|
||
if all_news:
|
||
print(f"\n[RAG] 将 {len(all_news)} 条数据存入向量数据库...")
|
||
rag_items = build_rag_news_items(all_news)
|
||
rag_processor.process_news(rag_items, upsert=True)
|
||
db_stats = rag_processor.get_database_stats()
|
||
print(
|
||
f"[成功] 向量数据库现有 {db_stats['unique_news']} 条新闻, "
|
||
f"{db_stats['total_documents']} 个文档片段"
|
||
)
|
||
|
||
print(f"\n[完成] 总共爬取了 {len(all_news)} 条数据")
|
||
|
||
elif args.mode == "report":
|
||
print("\n[模式] 生成报表(使用已有数据,最近15天)")
|
||
print_report_time_window(report_window)
|
||
try:
|
||
report = generate_summary_report_from_db(
|
||
rag_processor,
|
||
report_generator,
|
||
report_window,
|
||
)
|
||
except RuntimeError as exc:
|
||
print(f"错误: {exc}")
|
||
return
|
||
|
||
print("\n正在生成报表...")
|
||
report_path = report_generator.save_report(
|
||
report,
|
||
report_type="summary",
|
||
output_dir=data_dir,
|
||
)
|
||
|
||
print(f"\n[成功] 报表已生成: {report_path}")
|
||
print("\n" + "=" * 60)
|
||
print("报表内容预览:")
|
||
print("=" * 60)
|
||
print(report)
|
||
|
||
elif args.mode == "topic":
|
||
if not args.topic:
|
||
print("错误: topic模式需要指定 --topic 参数")
|
||
return
|
||
|
||
print(f"\n[模式] 主题分析(最近15天): {args.topic}")
|
||
print_report_time_window(report_window)
|
||
print(f"正在检索最近15天关于'{args.topic}'的相关内容...")
|
||
relevant_docs = rag_processor.search(
|
||
args.topic,
|
||
k=10,
|
||
start_ts=report_window["start_ts"],
|
||
end_ts=report_window["end_ts"],
|
||
)
|
||
|
||
if not relevant_docs:
|
||
print(f"未找到关于'{args.topic}'的相关内容")
|
||
return
|
||
|
||
print(f"找到 {len(relevant_docs)} 条相关内容")
|
||
print("\n正在生成主题分析报表...")
|
||
report = report_generator.generate_topic_report(args.topic, relevant_docs)
|
||
report_path = report_generator.save_report(
|
||
report,
|
||
report_type=f"topic_{args.topic}",
|
||
output_dir=data_dir,
|
||
)
|
||
|
||
print(f"\n[成功] 主题报表已生成: {report_path}")
|
||
print("\n" + "=" * 60)
|
||
print("报表内容预览:")
|
||
print("=" * 60)
|
||
print(report)
|
||
|
||
else: # full
|
||
print("\n[模式] 完整流程(爬取 -> RAG处理 -> 生成报表)")
|
||
print(f"数据源: {', '.join(selected_sources)}")
|
||
|
||
print("\n[步骤 1/3] 爬取新闻...")
|
||
all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=True)
|
||
if not all_news:
|
||
print("错误: 未能爬取到任何数据")
|
||
return
|
||
|
||
print(f"\n[汇总] 共爬取 {len(all_news)} 条数据")
|
||
print("\n[步骤 2/3] RAG数据处理...")
|
||
|
||
rag_items = build_rag_news_items(all_news)
|
||
rag_processor.process_news(rag_items, upsert=True)
|
||
print("[成功] 数据已存入向量数据库")
|
||
|
||
db_stats = rag_processor.get_database_stats()
|
||
print(
|
||
f"[数据库状态] 当前共有 {db_stats['unique_news']} 条新闻, "
|
||
f"{db_stats['total_documents']} 个文档片段"
|
||
)
|
||
|
||
print("\n[步骤 3/3] 生成最近15天分析报表...")
|
||
print_report_time_window(report_window)
|
||
try:
|
||
report = generate_summary_report_from_db(
|
||
rag_processor,
|
||
report_generator,
|
||
report_window,
|
||
)
|
||
except RuntimeError as exc:
|
||
print(f"错误: {exc}")
|
||
return
|
||
|
||
report_path = report_generator.save_report(
|
||
report,
|
||
report_type="summary",
|
||
output_dir=data_dir,
|
||
)
|
||
|
||
print(f"[成功] 报表已生成: {report_path}")
|
||
print("\n" + "=" * 60)
|
||
print("报表内容预览:")
|
||
print("=" * 60)
|
||
print(report)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("任务完成!")
|
||
print("=" * 60)
|
||
finally:
|
||
rag_processor.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|