Skip to content

Echo-Ji/ST-SSL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ST-SSL: Spatio-Temporal Self-Supervised Learning for Traffic Prediction

PWC

PWC

PWC

PWC

This is a Pytorch implementation of ST-SSL in the following paper:

framework

new 27/10/2023: This paper is picked up by leading WeChat official accounts in the field of data mining and transportation. 当交通遇上机器学习 | 时空实验室 | AI蜗牛车

new 22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.

new 09/02/2023: The video replay of academic presentation at AAAI 2023.

new 04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.

Requirement

We build this project by Python 3.8 with the following packages:

numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1

Datasets

The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.

Each dataset is composed of 4 files, namely train.npz, val.npz, test.npz, and adj_mx.npz.

|----NYCBike1\
|    |----train.npz    # training data
|    |----adj_mx.npz   # predefined graph structure
|    |----test.npz     # test data
|    |----val.npz      # validation data

The train/val/test data is composed of 4 numpy.ndarray objects:

  • X: input data. It is a 4D tensor of shape (#samples, #lookback_window, #nodes, #flow_types), where # denotes the number sign.
  • Y: data to be predicted. It is a 4D tensor of shape (#samples, #predict_horizon, #nodes, #flow_types). Note that X and Y are paired in the sample dimension. For instance, (X_i, Y_i) is the i-the data sample with i indexing the sample dimension.
  • X_offset: a list indicating offsets of X's lookback window relative to the current time with offset 0.
  • Y_offset: a list indicating offsets of Y's prediction horizon relative to the current time with offset 0.

For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to forecast flows for the next time step.

adj_mx.npz is the graph adjacency matrix that indicates the spatial relation of every two regions/nodes in the studied area.

⚠️ Note that all datasets are processed as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN.

Model training and Evaluation

If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}.

>> cd ST-SSL
>> ./runme 0 NYCBike1   # 0 specifies the GPU id, NYCBike1 gives the dataset

Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.

Cite

If you find the paper useful, please cite the following:

@article{ji2023spatio, 
  title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction}, 
  author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu}, 
  journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 
  volume={37},
  number={4},
  pages={4356-4364},
  year={2023}
}