介绍 — Bigscity-LibCity 文档 (bigscity-libcity-docs.readthedocs.io)
1 介绍
- 一个统一、全面、可扩展的代码库,为交通预测领域提供了一个可靠的实验工具和便捷的开发框架
- 目前支持
-
交通状态预测
-
交通流量预测
-
交通速度预测
-
交通需求预测
-
起点-终点(OD)矩阵预测
-
交通事故预测
-
-
轨迹下一跳预测
-
到达时间预测
-
路网匹配
-
路网表征学习
-
2 安装
2.1 创建一个conda 环境
conda create --name Libcity
conda activate Libcity
2.2 获取源代码
git clone https://github.com/LibCity/Bigscity-LibCity
cd Bigscity-LibCity
2.3 安装pytorch
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch
2.4 安装依赖包
pip install -r requirements.txt
3 数据集下载
- LibCity中使用的数据集以原子文件的统一数据存储格式存储【后面会有笔记】
- 假设已经下载好了他的数据
- Standard Dataset in LibCity - Google 云端硬盘
- METR_LA 文件夹放在Bigscity-LibCity/raw_data/目录底下
4 运行模型
用于训练和评测一个模型的脚本run_model.py
位于项目根目录中,它提供了一系列命令行参数,使用户可以调整运行参数配置
python run_model.py --task traffic_state_pred --model GRU --dataset METR_LA
4.0 命令行显示内容
首先,记录模块提示模型运行的记录会保存在libcity/log
文件夹下的相应文件中,并开始模型训练的流水线。
然后,执行模块会加载原子文件,创建数据集,划分训练集、验证集、测试集,并将划分好的数据集保存在libcity/cache/dataset_cache
下。下次在相同参数的数据集上运行模型时,不需要重新进行数据预处理操作。
输出模型的结构和参数量、优化器和学习率调整机制
模型训练
结束训练
4.1 遇到的报错&解决方案
我一开始报错如下:
Traceback (most recent call last):File "run_model.py", line 7, in <module>from libcity.pipeline import run_modelFile "/home_nfs/liushuai/Bigscity-LibCity/libcity/pipeline/__init__.py", line 1, in <module>from libcity.pipeline.pipeline import run_model, hyper_parameter, objective_functionFile "/home_nfs/liushuai/Bigscity-LibCity/libcity/pipeline/pipeline.py", line 2, in <module>from ray import tuneFile "/home_nfs/liushuai/anaconda3/envs/Libcity/lib/python3.8/site-packages/ray/__init__.py", line 63, in <module>import ray._raylet # noqa: E402File "python/ray/_raylet.pyx", line 102, in init ray._rayletFile "/home_nfs/liushuai/anaconda3/envs/Libcity/lib/python3.8/site-packages/ray/exceptions.py", line 7, in <module>from ray.core.generated.common_pb2 import RayException, Language, PYTHONFile "/home_nfs/liushuai/anaconda3/envs/Libcity/lib/python3.8/site-packages/ray/core/generated/common_pb2.py", line 33, in <module>_descriptor.EnumValueDescriptor(File "/home_nfs/liushuai/anaconda3/envs/Libcity/lib/python3.8/site-packages/google/protobuf/descriptor.py", line 789, in __new___message.Message._CheckCalledFromGeneratedFile()
TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:1. Downgrade the protobuf package to 3.20.x or lower.2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
- 这个错误是由于
protobuf
库版本与ray
库或代码生成的.pb2.py
文件不兼容造成的 - 解决方法就是降级protobuf库
-
pip install protobuf==3.20.0
-
4.2 脚本支持的命令行参数
task | 要执行的任务名,包括:
默认为 |
model | 要执行的模型名。默认为 |
dataset | 要执行的数据集。默认为 |
config_file | 用户自定义的配置文件名。默认为
|
saved_model | 是否保存训练好的模型。默认为True |
train | 如果模型已经预训练过了,是否要重新训练模型。默认为True |
batch_size | 训练集和验证集的批次大小 |
train_rate | 训练集在整个数据集中所占的比例。(划分的顺序是训练集、验证集、测试集) |
eval_rate | 验证集在整个数据集中所占的比例 |
learning_rate | 学习率。默认值因模型而异 |
max_epoch | 最大的训练轮数。默认值因模型而异 |
gpu | 是否使用GPU。默认为True |
gpu_id | 使用的GPU的ID。默认为0 |
4.3 使用tensorboard
tensorboard --logdir 'libcity/cache'
我是在服务器上执行的,所以如果本地需要看的话,可以:
ssh -N -f -L localhost:6006:localhost:6006 -p 22 liushuai@172.21.32.121
然后本地浏览器访问地址 http://localhost:6006/,就可以看到可视化的结果
参考内容:快速入门 — Bigscity-LibCity 文档 (bigscity-libcity-docs.readthedocs.io)