from kaggle_secrets import UserSecretsClient #kaggle 可忽略
import wandb#####user_secrets = UserSecretsClient() #### kaggle
secret_value_0 = user_secrets.get_secret("wandb_key") ### kaggle,此次为wandb_api
wandb.login(key=secret_value_0) #####初始化from wandb.keras import WandbCallback, WandbMetricsLogger
run = wandb.init(project = 'open_problems', #项目名称,自动创建save_code = True,name='tabtransformer')#### 中间插入代码 ####tabTransformer = TabTransformer(categories = nu, # number of unique elements in each categorical featurenum_continuous = 5, # number of numerical featuresdim = 16, # embedding/transformer dimensiondim_out = 35, # dimension of the model outputdepth = 6, # number of transformer layers in the stackheads = 8, # number of attention headsattn_dropout = 0.1, # attention layer dropout in transformersff_dropout = 0.1, # feed-forward layer dropout in transformersmlp_hidden = [(32, 'relu'), (16, 'relu')] # mlp layer dimensions and activations
)
tabTransformer.compile(Adam(0.001),'mae',metrics=['mae'])
tabTransformer.fit(X_train,y_train,validation_data=(X_val,y_val),batch_size=32,epochs=30,callbacks=[WandbMetricsLogger()]) ################run.finish() #运行结束
参考🚀🚀[Keras]TabTransformer+W&B | Kaggle