该存储库提供了TrustGeo框架的原始PyTorch实现。
一、基础用法
1、需求 Requirements
代码使用python 3.8.13、PyTorch 1.12.1、cudatoolkit 11.6.0和cudnn 7.6.5进行了测试。通过Anaconda安装依赖:
# create virtual environment
conda create --name TrustGeo python=3.8.13# activate environment
conda activate TrustGeo# install pytorch & cudatoolkit
conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge
# install other requirements
conda install numpy pandas
pip install scikit-learn
2、运行代码
# Open the "TrustGeo" folder
cd TrustGeo# data preprocess (executing IP clustering).
python preprocess.py --dataset "New_York"
python preprocess.py --dataset "Los_Angeles"
python preprocess.py --dataset "Shanghai"# run the model TrustGeo
python main.py --dataset "New_York" --lr 5e-3 --dim_in 30 --lambda1 7e-3
python main.py --dataset "Los_Angeles" --lr 3e-3 --dim_in 30 --lambda1 7e-3
python main.py --dataset "Shanghai" --lr 0.0015 --dim_in 51 --lambda1 1e-3# load the checkpoint and then test
python test.py --dataset "New_York" --lr 5e-3 --dim_in 30 --lambda1 7e-3 --load_epoch 400
python test.py --dataset "Los_Angeles" --lr 3e-3 --dim_in 30 --lambda1 7e-3 --load_epoch 400
python test.py --dataset "Shanghai" --lr 0.0015 --dim_in 51 --lambda1 1e-3 --load_epoch 200