TimesFM(Time Series Foundation Model)安装简介
readme
TimesFM(Time Series Foundation Model)安装简介(1)-CSDN博客https://blog.csdn.net/chenchihwen/article/details/144359861?spm=1001.2014.3001.5501
在Window 上安装并执行出现报错,
{"name": "TypeError","message": "TimesFmBase.__init__() got an unexpected keyword argument 'context_len'","stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 83 from jax._src import config4 config.update(5 \"jax_platforms\", {\"cpu\": \"cpu\", \"gpu\": \"cuda\", \"tpu\": \"\"}[timesfm_backend]6 )
----> 8 model = timesfm.TimesFm(9 context_len=512,10 horizon_len=128,11 input_patch_len=32,12 output_patch_len=128,13 num_layers=20,14 model_dims=1280,15 backend=timesfm_backend,16 )17 model.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")TypeError: TimesFmBase.__init__() got an unexpected keyword argument 'context_len'"
}
决定在 小红帽ubuntu UBUNTU
安装 timesFM
在 ide.cloud.tencent.com 的环境上进行安装 环境
We recommend at least 16GB RAM to load TimesFM dependencies.
慎选环境,确保>16G
安装Conda 3.10 python
重要步骤 安装 pyenv and poetry
## Installation
### Local installation using poetry
We will be using `pyenv` and `poetry`. In order to set these things up please follow the instructions [here](https://substack.com/home/post/p-148747960?r=28a5lx&utm_campaign=post&utm_medium=web). Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:
确认已经完成安装
(base) root@VM-0-170-ubuntu:/workspace/timesfm# pyenv --version
pyenv 2.4.22
(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry --version
Poetry (version 1.8.5)
这里安装完需要设置环境变量,如果不能看到 version 版本时
Add `export PATH="/root/.local/bin:$PATH"` to your shell configuration file
要修改环境变量
nano ~/.bash_profile
nano ~/.bash_profile 执行后添加环境变量,在 bash_profile里应该是这个参数
/root/.bash_profile
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
export PATH="/root/.local/bin:$PATH"
修改后要执行 ctrl-O : write out
再执行 ctrl-X :Exit
刷新环境变量,这时后才会生效,执行以下代码
source ~/.bash_profile
克隆timesFM
git clone https://github.com/google-research/timesfm/gitgit clone https://github.com/google-research/timesfm.git
Cloning into 'timesfm'...
remote: Enumerating objects: 667, done.
remote: Counting objects: 100% (665/665), done.
remote: Compressing objects: 100% (316/316), done.
remote: Total 667 (delta 353), reused 568 (delta 306), pack-reused 2 (from 1)
Receiving objects: 100% (667/667), 1.94 MiB | 3.76 MiB/s, done.
Resolving deltas: 100% (353/353), done.
在 timesfm 下找到 pyproject.toml 在最后面添加 aliyun 的source,不然安装不起来
[[tool.poetry.source]]
name = "aliyun"
url = "https://mirrors.aliyun.com/pypi/simple/"
priority = "primary"
Note that the PAX (or JAX) version needs to run on python 3.10.x and the PyTorch version can run on >=3.11.x. Therefore make sure you have two versions of python installed:
pyenv install 3.10
pyenv install 3.11
pyenv versions # to list the versions available (lets assume the versions are 3.10.15 and 3.11.10)
For PAX version installation do the following.
在 timesfm git 目录下 执行
pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E pax
在 timesfm git 目录下 执行
➜ timesfm git:(master) pyenv local 3.10.15
poetry env use 3.10.15
poetry lock
poetry install -E pax
pyenv: version `3.10.15' not installedCould not find the python executable python3.10
Creating virtualenv timesfm-jW3uZHTw-py3.11 in /root/.cache/pypoetry/virtualenvs
Updating dependencies
Resolving dependencies... (378.2s)
过程中还要安装numpy
根据 readme
Please look into the README files in the respective benchmark directories within `experiments/` for instructions for running TimesFM on the respective benchmarks.
## Running TimesFM on the benchmark
We need to add the following packages for running these benchmarks. Follow the installation instructions till before `poetry lock`. Then,
```
poetry add git+https://github.com/awslabs/gluon-ts.git
poetry lock
poetry install --only <pax or pytorch>
```To run the timesfm on the benchmark do:
```
poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m(-pytorch) --backend="gpu"
```
执行 poetry shell
(base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry shell
Spawning shell within /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10
. /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(base) root@VM-0-170-ubuntu:/workspace/timesfm# . /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/bin/activate
bash: __vsc_prompt_cmd_original: command not found
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm#
执行命令 poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m
在成功执行之前,还需安装 jax,
run_timesfm.py 的代码如下
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for timesfm."""import os
import sys
import timefrom absl import flags
import numpy as np
import pandas as pd
import timesfmfrom .utils import ExperimentHandlerdataset_names = ["m1_monthly","m1_quarterly","m1_yearly","m3_monthly","m3_other","m3_quarterly","m3_yearly","m4_quarterly","m4_yearly","tourism_monthly","tourism_quarterly","tourism_yearly","nn5_daily_without_missing","m5","nn5_weekly","traffic","weather","australian_electricity_demand","car_parts_without_missing","cif_2016","covid_deaths","ercot","ett_small_15min","ett_small_1h","exchange_rate","fred_md","hospital",
]context_dict = {"cif_2016": 32,"tourism_yearly": 64,"covid_deaths": 64,"tourism_quarterly": 64,"tourism_monthly": 64,"m1_monthly": 64,"m1_quarterly": 64,"m1_yearly": 64,"m3_monthly": 64,"m3_other": 64,"m3_quarterly": 64,"m3_yearly": 64,"m4_quarterly": 64,"m4_yearly": 64,
}_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m","Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
_BACKEND = flags.DEFINE_string("backend", "gpu", "Backend")
_NUM_JOBS = flags.DEFINE_integer("num_jobs", 1, "Number of jobs")
_SAVE_DIR = flags.DEFINE_string("save_dir", "./results", "Save directory")QUANTILES = list(np.arange(1, 10) / 10.0)def main():results_list = []tfm = timesfm.TimesFm(hparams=timesfm.TimesFmHparams(backend=_BACKEND.value,per_core_batch_size=_BATCH_SIZE.value,horizon_len=_HORIZON.value,),checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=_MODEL_PATH.value),)run_id = np.random.randint(100000)model_name = "timesfm"for dataset in dataset_names:print(f"Evaluating model {model_name} on dataset {dataset}", flush=True)exp = ExperimentHandler(dataset, quantiles=QUANTILES)if dataset in context_dict:context_len = context_dict[dataset]else:context_len = 512train_df = exp.train_dffreq = exp.freqinit_time = time.time()fcsts_df = tfm.forecast_on_df(inputs=train_df,freq=freq,value_name="y",model_name=model_name,forecast_context_len=context_len,num_jobs=_NUM_JOBS.value,)total_time = time.time() - init_timetime_df = pd.DataFrame({"time": [total_time], "model": model_name})results = exp.evaluate_from_predictions(models=[model_name],fcsts_df=fcsts_df,times_df=time_df)print(results, flush=True)results_list.append(results)results_full = pd.concat(results_list)save_path = os.path.join(_SAVE_DIR.value, str(run_id))print(f"Saving results to {save_path}", flush=True)os.makedirs(save_path, exist_ok=True)results_full.to_csv(f"{save_path}/results.csv")if __name__ == "__main__":FLAGS = flags.FLAGSFLAGS(sys.argv)main()
执行结果:
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=google/timesfm-1.0-200m
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.warnings.warn(
Fetching 5 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 63937.56it/s]
Multiprocessing context has already been set.
Constructing model weights.
Constructed model weights in 2.83 seconds.
Restoring checkpoint from /root/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/checkpoints.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide`train_state_unpadded_shape_dtype_struct` during checkpointsaving/restoring, to avoid potential silent bugs when loadingcheckpoints to incompatible unpadded shapes of TrainState.
Restored checkpoint in 1.50 seconds.
Jitting decoding.
Jitted decoding in 20.78 seconds.
Evaluating model timesfm on dataset m1_monthly
/root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10/lib/python3.10/site-packages/gluonts/time_feature/seasonality.py:47: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead.offset = pd.tseries.frequencies.to_offset(freq)
Multiprocessing context has already been set.
/root/miniforge3/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.self.pid = os.fork()
Processing dataframe with single process.
Finished preprocessing dataframe.
Finished forecasting.
预测结果
dataset metric model value
0 tourism_monthly mae timesfm 1970.148438
1 tourism_monthly mase timesfm 1.541883
2 tourism_monthly scaled_crps timesfm 0.121862
3 tourism_monthly smape timesfm 0.101539
4 tourism_monthly time timesfm 0.762044---------------dataset metric model value
0 tourism_quarterly mae timesfm 7439.246094
1 tourism_quarterly mase timesfm 1.731996
2 tourism_quarterly scaled_crps timesfm 0.087743
3 tourism_quarterly smape timesfm 0.083795
4 tourism_quarterly time timesfm 0.839042-----------------dataset metric model value
0 tourism_yearly mae timesfm 82434.085938
1 tourism_yearly mase timesfm 3.233205
2 tourism_yearly scaled_crps timesfm 0.129402
3 tourism_yearly smape timesfm 0.181012
4 tourism_yearly time timesfm 1.023866