newsreport_agent_for_traffic/main.py

455 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
交通新闻自动报表系统 - 主程序
"""
from __future__ import annotations
import argparse
import os
from collections import Counter
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from dotenv import load_dotenv
# 必须在导入其他模块之前加载环境变量
load_dotenv()
from crawler import BaiduMapCrawler, CCGPCrawler, TrafficNewsCrawler, WeChatCrawler
from rag import RAGProcessor
from report import ReportGenerator
REPORT_LOOKBACK_DAYS = 15
MIN_NEWS_PER_SOURCE = 5
def parse_selected_sources(raw_sources: str) -> List[str]:
"""Parse sources argument."""
if raw_sources == "all":
return ["traffic", "wechat", "baidu", "ccgp"]
return [s.strip() for s in raw_sources.split(",") if s.strip()]
def init_crawlers(selected_sources: List[str]) -> Tuple[Dict[str, object], List[str]]:
"""Initialize crawlers and filter unavailable sources."""
target_url = os.getenv("TARGET_URL", "https://www.7its.com/")
wechat_mp_cookie = os.getenv("WECHAT_MP_COOKIE", "").strip()
crawlers: Dict[str, object] = {
"traffic": None,
"wechat": None,
"baidu": None,
"ccgp": None,
}
if "traffic" in selected_sources:
crawlers["traffic"] = TrafficNewsCrawler(base_url=target_url)
if "wechat" in selected_sources:
if wechat_mp_cookie:
crawlers["wechat"] = WeChatCrawler()
else:
print("\n[警告] 微信公众号(高德)配置缺失,跳过高德地图爬虫")
print("请在.env文件中配置WECHAT_MP_COOKIE")
selected_sources = [s for s in selected_sources if s != "wechat"]
if "baidu" in selected_sources:
if wechat_mp_cookie:
crawlers["baidu"] = BaiduMapCrawler()
else:
print("\n[警告] 微信公众号(百度地图)配置缺失,跳过百度地图爬虫")
print("请在.env文件中配置WECHAT_MP_COOKIE")
selected_sources = [s for s in selected_sources if s != "baidu"]
if "ccgp" in selected_sources:
crawlers["ccgp"] = CCGPCrawler()
return crawlers, selected_sources
def crawl_selected_sources(
selected_sources: List[str],
crawlers: Dict[str, object],
data_dir: str,
args: argparse.Namespace,
compact: bool = False,
) -> List[Dict]:
"""Crawl all selected sources and return unified raw items."""
all_news: List[Dict] = []
steps = [
("traffic", "赛文交通网"),
("wechat", "微信公众号(高德地图)"),
("baidu", "微信公众号(百度地图)"),
("ccgp", "政府采购网"),
]
for idx, (source_key, source_title) in enumerate(steps, start=1):
if source_key not in selected_sources:
continue
prefix = f" [{idx}]" if compact else f"[{idx}/4]"
print(f"\n{prefix} 爬取{source_title}...")
if source_key == "traffic" and crawlers["traffic"]:
news_list = crawlers["traffic"].crawl_and_save(
output_dir=data_dir,
max_news=args.max_news,
)
print(f"[成功] 赛文交通网爬取了 {len(news_list)} 条新闻")
all_news.extend(news_list)
elif source_key == "wechat" and crawlers["wechat"]:
articles = crawlers["wechat"].crawl_articles(
max_count=args.wechat_count,
keyword=args.wechat_keyword,
)
if articles:
crawlers["wechat"].save_articles(articles, output_dir=data_dir)
print(f"[成功] 微信公众号(高德)爬取了 {len(articles)} 篇文章")
all_news.extend(articles)
elif source_key == "baidu" and crawlers["baidu"]:
articles = crawlers["baidu"].crawl_articles(
max_count=args.baidu_count,
keyword=args.baidu_keyword,
)
if articles:
crawlers["baidu"].save_articles(articles, output_dir=data_dir)
print(f"[成功] 微信公众号(百度地图)爬取了 {len(articles)} 篇文章")
all_news.extend(articles)
elif source_key == "ccgp" and crawlers["ccgp"]:
ccgp_keywords = [k.strip() for k in args.ccgp_keywords.split(",") if k.strip()]
results = crawlers["ccgp"].crawl_by_keywords(
keywords=ccgp_keywords,
max_per_keyword=args.ccgp_count,
)
if results:
crawlers["ccgp"].save_results(results, output_dir=data_dir)
print(f"[成功] 政府采购网爬取了 {len(results)} 条信息")
all_news.extend(results)
return all_news
def build_rag_news_items(raw_items: List[Dict]) -> List[Dict]:
"""Convert crawled raw items into RAG processor format."""
news_list: List[Dict] = []
for item in raw_items:
content = item.get("content")
if not content:
continue
news_list.append(
{
"title": item.get("title", ""),
"content": f"标题: {item.get('title', '')}\n\n{content}",
"url": item.get("url", ""),
"source": item.get("source", "未知来源"),
"std_timestamp": item.get("std_timestamp", 0),
}
)
return news_list
def print_db_stats(rag_processor: RAGProcessor) -> None:
"""Print vector database stats."""
db_stats = rag_processor.get_database_stats()
if db_stats["database_exists"]:
print(
f"\n[数据库状态] 当前包含 {db_stats['unique_news']} 条新闻, "
f"{db_stats['total_documents']} 个文档片段"
)
else:
print("\n[数据库状态] 向量数据库为空")
def build_report_time_window(days: int = REPORT_LOOKBACK_DAYS) -> Dict:
"""Build the exact std_timestamp interval used for report generation."""
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=days)
return {
"days": days,
"start_dt": start_dt,
"end_dt": end_dt,
"start_ts": int(start_dt.timestamp()),
"end_ts": int(end_dt.timestamp()),
"start_str": start_dt.strftime("%Y-%m-%d %H:%M:%S"),
"end_str": end_dt.strftime("%Y-%m-%d %H:%M:%S"),
}
def print_report_time_window(window: Dict) -> None:
"""Print the concrete time interval used for report generation."""
print(f"\n[时间窗口] 报告生成实时截止时间: {window['end_str']}")
print(
f"[时间窗口] 新闻筛选区间: {window['start_str']} "
f"{window['end_str']}"
)
print(
f"[时间窗口] std_timestamp 区间: "
f"{window['start_ts']} - {window['end_ts']}"
)
print(f"[时间窗口] 回溯时长: {window['days']}")
def select_balanced_news_by_source(
news_list: List[Dict],
min_per_source: int = MIN_NEWS_PER_SOURCE,
) -> List[Dict]:
"""Select at least ``min_per_source`` items for each source when available."""
grouped: Dict[str, List[Dict]] = {}
for news in sorted(
news_list,
key=lambda item: int(item.get("std_timestamp", 0) or 0),
reverse=True,
):
source = str(news.get("source", "未知来源") or "未知来源")
grouped.setdefault(source, []).append(news)
selected: List[Dict] = []
seen_urls = set()
for source in sorted(grouped.keys()):
picked = 0
for news in grouped[source]:
url = news.get("url", "")
if url and url in seen_urls:
continue
selected.append(news)
if url:
seen_urls.add(url)
picked += 1
if picked >= min_per_source:
break
for news in sorted(
news_list,
key=lambda item: int(item.get("std_timestamp", 0) or 0),
reverse=True,
):
url = news.get("url", "")
if url and url in seen_urls:
continue
selected.append(news)
if url:
seen_urls.add(url)
return selected
def print_selected_news_distribution(news_list: List[Dict]) -> None:
"""Print selected source distribution for report input."""
source_counts = Counter(news.get("source", "未知来源") for news in news_list)
print("[信息] 综合报告新闻取样分布:")
for source, count in source_counts.items():
print(f" - {source}: {count}")
def generate_summary_report_from_db(
rag_processor: RAGProcessor,
report_generator: ReportGenerator,
window: Dict,
) -> str:
"""Generate the summary report from vector DB using the aligned selection logic."""
news_metadata = rag_processor.get_all_news_metadata(
start_ts=window["start_ts"],
end_ts=window["end_ts"],
)
if not news_metadata:
raise RuntimeError("向量数据库中最近15天没有数据请先运行爬取模式")
selected_news = select_balanced_news_by_source(news_metadata)
print(f"从向量数据库加载了最近15天的 {len(news_metadata)} 条新闻")
print_selected_news_distribution(selected_news)
relevant_docs = rag_processor.search(
"交通信号控制 信控 绿波 智能交通 导航 路况",
k=10,
start_ts=window["start_ts"],
end_ts=window["end_ts"],
)
return report_generator.generate_summary_report(selected_news, relevant_docs)
def main() -> None:
"""主函数"""
parser = argparse.ArgumentParser(description="交通新闻自动报表系统")
parser.add_argument(
"--mode",
type=str,
default="full",
choices=["crawl", "report", "full", "topic"],
help="运行模式: crawl(仅爬取), report(仅生成报表), full(完整流程), topic(主题分析)",
)
parser.add_argument("--max-news", type=int, default=20, help="最大爬取新闻数量(赛文交通网)")
parser.add_argument("--topic", type=str, default="", help="主题关键词用于topic模式")
parser.add_argument(
"--sources",
type=str,
default="all",
help="数据源选择: all(全部), traffic(赛文交通网), wechat(微信-高德), baidu(微信-百度地图), ccgp(政府采购网),多个用逗号分隔",
)
parser.add_argument("--wechat-count", type=int, default=30, help="微信公众号(高德)爬取数量")
parser.add_argument("--wechat-keyword", type=str, default="交通", help="微信公众号(高德)关键词过滤")
parser.add_argument("--baidu-count", type=int, default=30, help="微信公众号(百度地图)爬取数量")
parser.add_argument("--baidu-keyword", type=str, default="交通", help="微信公众号(百度地图)关键词过滤")
parser.add_argument("--ccgp-keywords", type=str, default="信控,绿波", help="政府采购网关键词(逗号分隔)")
parser.add_argument("--ccgp-count", type=int, default=30, help="政府采购网每个关键词爬取数量")
args = parser.parse_args()
qwen_api_key = os.getenv("QWEN_API_KEY")
qwen_model = os.getenv("QWEN_MODEL", "qwen-max")
data_dir = os.getenv("DATA_DIR", "./data")
vector_db_dir = os.getenv("VECTOR_DB_DIR", "./vector_db")
print("=" * 60)
print("交通新闻自动报表系统 - 多源数据采集与分析")
print("=" * 60)
selected_sources = parse_selected_sources(args.sources)
crawlers, selected_sources = init_crawlers(selected_sources)
report_window = build_report_time_window(days=REPORT_LOOKBACK_DAYS)
rag_processor = RAGProcessor(vector_db_dir=vector_db_dir)
print_db_stats(rag_processor)
report_generator = None
if args.mode in {"report", "topic", "full"}:
report_generator = ReportGenerator(api_key=qwen_api_key, model_name=qwen_model)
try:
if args.mode == "crawl":
print(f"\n[模式] 仅爬取新闻 - 数据源: {', '.join(selected_sources)}")
all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=False)
if all_news:
print(f"\n[RAG] 将 {len(all_news)} 条数据存入向量数据库...")
rag_items = build_rag_news_items(all_news)
rag_processor.process_news(rag_items, upsert=True)
db_stats = rag_processor.get_database_stats()
print(
f"[成功] 向量数据库现有 {db_stats['unique_news']} 条新闻, "
f"{db_stats['total_documents']} 个文档片段"
)
print(f"\n[完成] 总共爬取了 {len(all_news)} 条数据")
elif args.mode == "report":
print("\n[模式] 生成报表使用已有数据最近15天")
print_report_time_window(report_window)
try:
report = generate_summary_report_from_db(
rag_processor,
report_generator,
report_window,
)
except RuntimeError as exc:
print(f"错误: {exc}")
return
print("\n正在生成报表...")
report_path = report_generator.save_report(
report,
report_type="summary",
output_dir=data_dir,
)
print(f"\n[成功] 报表已生成: {report_path}")
print("\n" + "=" * 60)
print("报表内容预览:")
print("=" * 60)
print(report)
elif args.mode == "topic":
if not args.topic:
print("错误: topic模式需要指定 --topic 参数")
return
print(f"\n[模式] 主题分析最近15天: {args.topic}")
print_report_time_window(report_window)
print(f"正在检索最近15天关于'{args.topic}'的相关内容...")
relevant_docs = rag_processor.search(
args.topic,
k=10,
start_ts=report_window["start_ts"],
end_ts=report_window["end_ts"],
)
if not relevant_docs:
print(f"未找到关于'{args.topic}'的相关内容")
return
print(f"找到 {len(relevant_docs)} 条相关内容")
print("\n正在生成主题分析报表...")
report = report_generator.generate_topic_report(args.topic, relevant_docs)
report_path = report_generator.save_report(
report,
report_type=f"topic_{args.topic}",
output_dir=data_dir,
)
print(f"\n[成功] 主题报表已生成: {report_path}")
print("\n" + "=" * 60)
print("报表内容预览:")
print("=" * 60)
print(report)
else: # full
print("\n[模式] 完整流程(爬取 -> RAG处理 -> 生成报表)")
print(f"数据源: {', '.join(selected_sources)}")
print("\n[步骤 1/3] 爬取新闻...")
all_news = crawl_selected_sources(selected_sources, crawlers, data_dir, args, compact=True)
if not all_news:
print("错误: 未能爬取到任何数据")
return
print(f"\n[汇总] 共爬取 {len(all_news)} 条数据")
print("\n[步骤 2/3] RAG数据处理...")
rag_items = build_rag_news_items(all_news)
rag_processor.process_news(rag_items, upsert=True)
print("[成功] 数据已存入向量数据库")
db_stats = rag_processor.get_database_stats()
print(
f"[数据库状态] 当前共有 {db_stats['unique_news']} 条新闻, "
f"{db_stats['total_documents']} 个文档片段"
)
print("\n[步骤 3/3] 生成最近15天分析报表...")
print_report_time_window(report_window)
try:
report = generate_summary_report_from_db(
rag_processor,
report_generator,
report_window,
)
except RuntimeError as exc:
print(f"错误: {exc}")
return
report_path = report_generator.save_report(
report,
report_type="summary",
output_dir=data_dir,
)
print(f"[成功] 报表已生成: {report_path}")
print("\n" + "=" * 60)
print("报表内容预览:")
print("=" * 60)
print(report)
print("\n" + "=" * 60)
print("任务完成!")
print("=" * 60)
finally:
rag_processor.close()
if __name__ == "__main__":
main()