fastai的TabNet
This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (>=2.0) library. The original paper https://arxiv.org/pdf/1908.07442.pdf.
安装
pip install fast_tabnet
如何使用
model = TabNetModel(emb_szs, n_cont, out_sz, embed_p=0., y_range=None, n_d=8, n_a=8, n_steps=3, gamma=1.5, n_independent=2, n_shared=2, epsilon=1e-15, virtual_batch_size=128, momentum=0.02)
参数emb_szs, n_cont, out_sz, embed_p, y_range
与fastai TabularModel相同。在
- 注意:内景
预测层的维数(通常在4到64之间)
- 内景:内景
注意层的尺寸(通常在4到64之间)
- n_步骤:int
新作业中成功的步骤数(通常在3到10之间)
- gamma:浮动
浮动在1以上,注意力更新的比例因子(通常在1.0到2.0之间)
- 动量:浮动
介于0和1之间的浮点值,将用于所有批次定额中的动量
- n_独立:int
每个GLU块中独立GLU层的数量(默认2)
- n_共享:int
每个GLU块中独立GLU层的数量(默认2)
- epsilon:浮动
避免使用log(0),这个值应该保持很低
示例
下面是fastai库中的一个示例,但使用的模型是TabNet
fromfastai.basicsimport*fromfastai.tabular.allimport*fromfast_tabnet.coreimport*
^{pr2}$
<;样式范围>;
.dataframe tbody tr th:仅类型{
垂直对齐:中间;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
<;/style>;
| age | workclass | fnlwgt | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | hours-per-week | native-country | salary |
---|
0 | 49 | Private | 101320 | Assoc-acdm | 12.0 | Married-civ-spouse | NaN | Wife | White | Female | 0 | 1902 | 40 | United-States | >=50k |
---|
1 | 44 | Private | 236746 | Masters | 14.0 | Divorced | Exec-managerial | Not-in-family | White | Male | 10520 | 0 | 45 | United-States | >=50k |
---|
2 | 38 | Private | 96185 | HS-grad | NaN | Divorced | NaN | Unmarried | Black | Female | 0 | 0 | 32 | United-States | <50k |
---|
3 | 38 | Self-emp-inc | 112847 | Prof-school | 15.0 | Married-civ-spouse | Prof-specialty | Husband | Asian-Pac-Islander | Male | 0 | 0 | 40 | United-States | >=50k |
---|
4 | 42 | Self-emp-not-inc | 82297 | 7th-8th | NaN | Married-civ-spouse | Other-service | Wife | Black | Female | 0 | 0 | 50 | United-States | <50k |
---|
cat_names=['workclass','education','marital-status','occupation','relationship','race','native-country','sex']cont_names=['age','fnlwgt','education-num']procs=[Categorify,FillMissing,Normalize]splits=RandomSplitter()(range_of(df_main))
to=TabularPandas(df_main,procs,cat_names,cont_names,y_names="salary",y_block=CategoryBlock(),splits=splits)
dls=to.dataloaders(bs=32)
dls.valid.show_batch()
^{tb2}$
to_tst=to.new(df_test)to_tst.process()to_tst.all_cols.head()
<;样式范围>;
.dataframe tbody tr th:仅类型{
垂直对齐:中间;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
<;/style>;
| workclass | education | marital-status | occupation | relationship | race | native-country | sex | education-num_na | age | fnlwgt | education-num | salary |
---|
31561 | 5 | 2 | 5 | 9 | 3 | 3 | 40 | 2 | 1 | -1.505833 | -0.559418 | -1.202170 | 0 |
---|
31562 | 2 | 12 | 5 | 2 | 5 | 3 | 40 | 1 | 1 | -1.432653 | 0.421241 | -0.418032 | 0 |
---|
31563 | 5 | 7 | 3 | 4 | 1 | 5 | 40 | 2 | 1 | -0.115406 | 0.132868 | -1.986307 | 0 |
---|
31564 | 8 | 12 | 3 | 9 | 1 | 5 | 40 | 2 | 1 | 1.494561 | 0.749805 | -0.418032 | 0 |
---|
31565 | 1 | 12 | 1 | 1 | 5 | 3 | 40 | 2 | 1 | -0.481308 | 7.529798 | -0.418032 | 0 |
---|
emb_szs=get_emb_sz(to)
这就是模型的用途
model=TabNetModel(emb_szs,len(to.cont_names),dls.c,n_d=8,n_a=8,n_steps=5,mask_type='entmax');
learn=Learner(dls,model,CrossEntropyLossFlat(),opt_func=Adam,lr=3e-2,metrics=[accuracy])
learn.lr_find()
SuggestedLRs(lr_min=0.2754228591918945, lr_steep=1.9054607491852948e-06)
learn.fit_one_cycle(5)
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.446274 | 0.414451 | 0.817015 | 00:30 |
1 | 0.380002 | 0.393030 | 0.818916 | 00:30 |
2 | 0.371149 | 0.359802 | 0.832066 | 00:30 |
3 | 0.349027 | 0.352255 | 0.835868 | 00:30 |
4 | 0.355339 | 0.349360 | 0.836819 | 00:30 |
Tabnet可解释性
# feature importance for 2k rowsdl=learn.dls.test_dl(df.iloc[:2000],bs=256)feature_importances=tabnet_feature_importances(learn.model,dl)
# per sample interpretationdl=learn.dls.test_dl(df.iloc[:20],bs=20)res_explain,res_masks=tabnet_explain(learn.model,dl)
plt.xticks(rotation='vertical')plt.bar(dl.x_names,feature_importances,color='g')plt.show()
defplot_explain(masks,lbls,figsize=(12,12)):"Plots masks with `lbls` (`dls.x_names`)"fig=plt.figure(figsize=figsize)ax=fig.add_axes([0.1,0.1,0.8,0.8])plt.yticks(np.arange(0,len(masks),1.0))plt.xticks(np.arange(0,len(masks[0]),1.0))ax.set_xticklabels(lbls,rotation=90)plt.ylabel('Sample Number')plt.xlabel('Variable')plt.imshow(masks)
plot_explain(res_explain,dl.x_names)
基于贝叶斯优化的超参数搜索
如果你的数据集不是很大,你可以用贝叶斯优化为表格模型调整超参数。如果度量足够敏感,可以使用这种方法直接优化度量(在我们的示例中不是这样,而是使用验证损失)。另外,您应该创建第二个验证集,因为您将使用第一个验证集作为贝叶斯优化的训练集。在
您可能需要安装优化器pip install bayesian-optimization
^{pr21}$
# The function we'll optimize@lru_cache(1000)defget_accuracy(n_d:Int,n_a:Int,n_steps:Int):model=TabNetModel(emb_szs,len(to.cont_names),dls.c,n_d=n_d,n_a=n_a,n_steps=n_steps,gamma=1.5)learn=Learner(dls,model,CrossEntropyLossFlat(),opt_func=opt_func,lr=3e-2,metrics=[accuracy])learn.fit_one_cycle(5)returnfloat(learn.validate(dl=learn.dls.valid)[1])
这种贝叶斯优化的实现不能自然地使用descreet值。这就是为什么我们将包装与lru_cache
一起使用。在
deffit_accuracy(pow_n_d,pow_n_a,pow_n_steps):n_d,n_a,n_steps=map(lambdax:2**int(x),(pow_n_d,pow_n_a,pow_n_steps))returnget_accuracy(n_d,n_a,n_steps)
frombayes_optimportBayesianOptimization# Bounded region of parameter spacepbounds={'pow_n_d':(0,8),'pow_n_a':(0,8),'pow_n_steps':(0,4)}optimizer=BayesianOptimization(f=fit_accuracy,pbounds=pbounds,)
optimizer.maximize(init_points=15,n_iter=100,)
| iter | target | pow_n_a | pow_n_d | pow_n_... |
-------------------------------------------------------------
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.404888 | 0.432834 | 0.793885 | 00:10 |
1 | 0.367979 | 0.384840 | 0.818600 | 00:09 |
2 | 0.366444 | 0.372005 | 0.819708 | 00:09 |
3 | 0.362771 | 0.366949 | 0.823511 | 00:10 |
4 | 0.353682 | 0.367132 | 0.823511 | 00:10 |
| ?[0m 1 ?[0m | ?[0m 0.8235 ?[0m | ?[0m 0.9408 ?[0m | ?[0m 1.898 ?[0m | ?[0m 1.652 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.393301 | 0.449742 | 0.810836 | 00:08 |
1 | 0.379140 | 0.413773 | 0.815589 | 00:07 |
2 | 0.355790 | 0.388907 | 0.822560 | 00:07 |
3 | 0.349984 | 0.362671 | 0.828739 | 00:07 |
4 | 0.348000 | 0.360150 | 0.827313 | 00:07 |
| ?[95m 2 ?[0m | ?[95m 0.8273 ?[0m | ?[95m 4.262 ?[0m | ?[95m 5.604 ?[0m | ?[95m 0.2437 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.451572 | 0.434189 | 0.781210 | 00:12 |
1 | 0.423763 | 0.413420 | 0.805450 | 00:12 |
2 | 0.398922 | 0.408688 | 0.814164 | 00:12 |
3 | 0.390981 | 0.392398 | 0.808935 | 00:12 |
4 | 0.376418 | 0.382250 | 0.817174 | 00:12 |
| ?[0m 3 ?[0m | ?[0m 0.8172 ?[0m | ?[0m 7.233 ?[0m | ?[0m 6.471 ?[0m | ?[0m 2.508 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.403187 | 0.413986 | 0.798162 | 00:07 |
1 | 0.398544 | 0.390102 | 0.820184 | 00:07 |
2 | 0.390569 | 0.389703 | 0.825253 | 00:07 |
3 | 0.375426 | 0.385706 | 0.826996 | 00:07 |
4 | 0.370446 | 0.383366 | 0.831115 | 00:06 |
| ?[95m 4 ?[0m | ?[95m 0.8311 ?[0m | ?[95m 5.935 ?[0m | ?[95m 1.241 ?[0m | ?[95m 0.3809 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.464145 | 0.458641 | 0.751267 | 00:18 |
1 | 0.424691 | 0.436968 | 0.788023 | 00:18 |
2 | 0.431576 | 0.436581 | 0.775824 | 00:18 |
3 | 0.432143 | 0.437062 | 0.759506 | 00:18 |
4 | 0.429915 | 0.438332 | 0.758555 | 00:18 |
^{pr31}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.470359 | 0.475826 | 0.748891 | 00:12 |
1 | 0.411564 | 0.409433 | 0.797053 | 00:12 |
2 | 0.392718 | 0.397363 | 0.809727 | 00:12 |
3 | 0.387564 | 0.380033 | 0.814322 | 00:12 |
4 | 0.374153 | 0.378258 | 0.818916 | 00:12 |
| ?[0m 6 ?[0m | ?[0m 0.8189 ?[0m | ?[0m 4.592 ?[0m | ?[0m 2.138 ?[0m | ?[0m 2.824 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.547042 | 0.588752 | 0.754119 | 00:18 |
1 | 0.491731 | 0.469795 | 0.771863 | 00:18 |
2 | 0.454340 | 0.433961 | 0.775190 | 00:18 |
3 | 0.424386 | 0.432385 | 0.782953 | 00:18 |
4 | 0.397645 | 0.406420 | 0.805767 | 00:19 |
| ?[0m 7 ?[0m | ?[0m 0.8058 ?[0m | ?[0m 6.186 ?[0m | ?[0m 7.016 ?[0m | ?[0m 3.316 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.485245 | 0.487635 | 0.751109 | 00:18 |
1 | 0.450832 | 0.446423 | 0.750317 | 00:18 |
2 | 0.448203 | 0.449419 | 0.755228 | 00:18 |
3 | 0.430258 | 0.443562 | 0.744297 | 00:18 |
4 | 0.429821 | 0.437173 | 0.761565 | 00:18 |
| ?[0m 8 ?[0m | ?[0m 0.7616 ?[0m | ?[0m 2.018 ?[0m | ?[0m 1.316 ?[0m | ?[0m 3.675 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.458425 | 0.455733 | 0.751584 | 00:12 |
1 | 0.439781 | 0.467807 | 0.751109 | 00:12 |
2 | 0.420331 | 0.432216 | 0.775190 | 00:12 |
3 | 0.421012 | 0.421412 | 0.782319 | 00:12 |
4 | 0.401828 | 0.413434 | 0.801014 | 00:12 |
^{pr35}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.546997 | 0.506728 | 0.761407 | 00:18 |
1 | 0.489712 | 0.439324 | 0.799588 | 00:18 |
2 | 0.448558 | 0.448419 | 0.786122 | 00:18 |
3 | 0.436869 | 0.435375 | 0.801648 | 00:18 |
4 | 0.417128 | 0.421093 | 0.798321 | 00:18 |
| ?[0m 10 ?[0m | ?[0m 0.7983 ?[0m | ?[0m 5.203 ?[0m | ?[0m 7.719 ?[0m | ?[0m 3.407 ?[0m |
^{tb15}$
| ?[0m 11 ?[0m | ?[0m 0.8308 ?[0m | ?[0m 6.048 ?[0m | ?[0m 4.376 ?[0m | ?[0m 0.08141 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.430772 | 0.430897 | 0.767744 | 00:12 |
1 | 0.402611 | 0.432137 | 0.764259 | 00:12 |
2 | 0.407579 | 0.409651 | 0.812104 | 00:12 |
3 | 0.374988 | 0.391822 | 0.816698 | 00:12 |
4 | 0.378011 | 0.389278 | 0.816065 | 00:12 |
| ?[0m 12 ?[0m | ?[0m 0.8161 ?[0m | ?[0m 7.083 ?[0m | ?[0m 1.385 ?[0m | ?[0m 2.806 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.402018 | 0.412051 | 0.812262 | 00:09 |
1 | 0.372804 | 0.464937 | 0.811629 | 00:09 |
2 | 0.368274 | 0.384675 | 0.820184 | 00:09 |
3 | 0.364502 | 0.371920 | 0.820659 | 00:09 |
4 | 0.348998 | 0.369445 | 0.823828 | 00:09 |
| ?[0m 13 ?[0m | ?[0m 0.8238 ?[0m | ?[0m 4.812 ?[0m | ?[0m 3.785 ?[0m | ?[0m 1.396 ?[0m |
| ?[0m 14 ?[0m | ?[0m 0.8172 ?[0m | ?[0m 7.672 ?[0m | ?[0m 6.719 ?[0m | ?[0m 2.72 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.476033 | 0.442598 | 0.803549 | 00:12 |
1 | 0.405236 | 0.414015 | 0.788973 | 00:11 |
2 | 0.406291 | 0.451269 | 0.789449 | 00:11 |
3 | 0.391013 | 0.393243 | 0.816065 | 00:12 |
4 | 0.374160 | 0.377635 | 0.821451 | 00:12 |
| ?[0m 15 ?[0m | ?[0m 0.8215 ?[0m | ?[0m 6.464 ?[0m | ?[0m 7.954 ?[0m | ?[0m 2.647 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.390142 | 0.390678 | 0.810995 | 00:06 |
1 | 0.381717 | 0.382202 | 0.813055 | 00:06 |
2 | 0.368564 | 0.378705 | 0.823828 | 00:06 |
3 | 0.358858 | 0.368329 | 0.823511 | 00:07 |
4 | 0.353392 | 0.363913 | 0.825887 | 00:06 |
| ?[0m 16 ?[0m | ?[0m 0.8259 ?[0m | ?[0m 0.1229 ?[0m | ?[0m 7.83 ?[0m | ?[0m 0.3708 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.381215 | 0.422651 | 0.800697 | 00:06 |
1 | 0.377345 | 0.380863 | 0.815906 | 00:06 |
2 | 0.366631 | 0.370579 | 0.822877 | 00:06 |
3 | 0.362745 | 0.366619 | 0.823352 | 00:07 |
4 | 0.356861 | 0.364835 | 0.825887 | 00:07 |
| ?[0m 17 ?[0m | ?[0m 0.8259 ?[0m | ?[0m 0.03098 ?[0m | ?[0m 3.326 ?[0m | ?[0m 0.007025?[0m |
^{tb21}$
| ?[0m 18 ?[0m | ?[0m 0.8294 ?[0m | ?[0m 7.81 ?[0m | ?[0m 7.976 ?[0m | ?[0m 0.0194 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.679292 | 0.677299 | 0.248891 | 00:05 |
1 | 0.675403 | 0.678406 | 0.248891 | 00:05 |
2 | 0.673259 | 0.673374 | 0.248891 | 00:06 |
3 | 0.674996 | 0.673514 | 0.248891 | 00:07 |
4 | 0.668813 | 0.673671 | 0.248891 | 00:07 |
| ?[0m 19 ?[0m | ?[0m 0.2489 ?[0m | ?[0m 0.4499 ?[0m | ?[0m 0.138 ?[0m | ?[0m 0.001101?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.524201 | 0.528132 | 0.729880 | 00:30 |
1 | 0.419377 | 0.403198 | 0.812104 | 00:31 |
2 | 0.399398 | 0.418890 | 0.812421 | 00:31 |
3 | 0.381651 | 0.391744 | 0.819075 | 00:31 |
4 | 0.368742 | 0.377904 | 0.822085 | 00:31 |
| ?[0m 20 ?[0m | ?[0m 0.8221 ?[0m | ?[0m 0.0 ?[0m | ?[0m 6.575 ?[0m | ?[0m 4.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.681083 | 0.682397 | 0.248891 | 00:05 |
1 | 0.672935 | 0.679371 | 0.248891 | 00:06 |
2 | 0.675200 | 0.673466 | 0.248891 | 00:06 |
3 | 0.674251 | 0.673356 | 0.248891 | 00:06 |
4 | 0.668861 | 0.673186 | 0.248891 | 00:06 |
| ?[0m 21 ?[0m | ?[0m 0.2489 ?[0m | ?[0m 8.0 ?[0m | ?[0m 0.0 ?[0m | ?[0m 0.0 ?[0m |
^{tb25}$
| ?[0m 22 ?[0m | ?[0m 0.8251 ?[0m | ?[0m 0.0 ?[0m | ?[0m 4.502 ?[0m | ?[0m 2.193 ?[0m |
^{tb26}$
| ?[0m 23 ?[0m | ?[0m 0.789 ?[0m | ?[0m 8.0 ?[0m | ?[0m 3.702 ?[0m | ?[0m 4.0 ?[0m |
^{tb27}$
| ?[0m 24 ?[0m | ?[0m 0.7549 ?[0m | ?[0m 6.009 ?[0m | ?[0m 0.0 ?[0m | ?[0m 4.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.422837 | 0.403953 | 0.819392 | 00:06 |
1 | 0.380753 | 0.367345 | 0.826838 | 00:06 |
2 | 0.353045 | 0.365174 | 0.830006 | 00:07 |
3 | 0.348628 | 0.364282 | 0.826362 | 00:07 |
4 | 0.343561 | 0.361509 | 0.829214 | 00:07 |
| ?[0m 25 ?[0m | ?[0m 0.8292 ?[0m | ?[0m 3.522 ?[0m | ?[0m 8.0 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.807766 | 1.307279 | 0.481622 | 00:31 |
1 | 0.513308 | 0.499470 | 0.783587 | 00:32 |
2 | 0.445906 | 0.492620 | 0.798004 | 00:31 |
3 | 0.385094 | 0.399986 | 0.807509 | 00:32 |
4 | 0.387228 | 0.384739 | 0.817015 | 00:31 |
^{pr51}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.442076 | 0.491338 | 0.755387 | 00:31 |
1 | 0.441078 | 0.443674 | 0.760773 | 00:31 |
2 | 0.417575 | 0.418758 | 0.792142 | 00:31 |
3 | 0.410825 | 0.417581 | 0.788498 | 00:34 |
4 | 0.403407 | 0.410941 | 0.798321 | 00:46 |
| ?[0m 27 ?[0m | ?[0m 0.7983 ?[0m | ?[0m 0.0 ?[0m | ?[0m 0.0 ?[0m | ?[0m 4.0 ?[0m |
^{tb31}$
^{pr53}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.430604 | 0.469592 | 0.781210 | 00:45 |
1 | 0.423074 | 0.429704 | 0.797529 | 00:45 |
2 | 0.400120 | 0.393398 | 0.810995 | 00:45 |
3 | 0.382361 | 0.390651 | 0.816065 | 00:46 |
4 | 0.389520 | 0.401878 | 0.807193 | 00:46 |
^{pr54}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.396348 | 0.397454 | 0.806717 | 00:08 |
1 | 0.383342 | 0.386023 | 0.819550 | 00:07 |
2 | 0.369493 | 0.374401 | 0.820025 | 00:07 |
3 | 0.356015 | 0.366535 | 0.826204 | 00:08 |
4 | 0.341073 | 0.365241 | 0.826204 | 00:08 |
^{pr55}$
^{tb34}$
| ?[0m 31 ?[0m | ?[0m 0.8016 ?[0m | ?[0m 8.0 ?[0m | ?[0m 8.0 ?[0m | ?[0m 4.0 ?[0m |
^{tb35}$
| ?[0m 32 ?[0m | ?[0m 0.8294 ?[0m | ?[0m 5.864 ?[0m | ?[0m 8.0 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.498376 | 0.436696 | 0.794043 | 00:16 |
1 | 0.411699 | 0.435537 | 0.801331 | 00:16 |
2 | 0.385327 | 0.396916 | 0.820184 | 00:16 |
3 | 0.382020 | 0.389856 | 0.813371 | 00:16 |
4 | 0.373869 | 0.377804 | 0.820817 | 00:15 |
| ?[0m 33 ?[0m | ?[0m 0.8208 ?[0m | ?[0m 1.776 ?[0m | ?[0m 8.0 ?[0m | ?[0m 2.212 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.404653 | 0.440106 | 0.772180 | 00:11 |
1 | 0.377931 | 0.393715 | 0.817332 | 00:11 |
2 | 0.373221 | 0.379273 | 0.826838 | 00:11 |
3 | 0.359682 | 0.362844 | 0.828422 | 00:11 |
4 | 0.340384 | 0.363072 | 0.828897 | 00:11 |
| ?[0m 34 ?[0m | ?[0m 0.8289 ?[0m | ?[0m 5.777 ?[0m | ?[0m 2.2 ?[0m | ?[0m 1.31 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.520308 | 0.503207 | 0.749208 | 00:45 |
1 | 0.472501 | 0.451469 | 0.780418 | 00:45 |
2 | 0.454686 | 0.429175 | 0.784854 | 00:45 |
3 | 0.400800 | 0.413727 | 0.795469 | 00:44 |
4 | 0.405604 | 0.409770 | 0.801648 | 00:45 |
| ?[0m 35 ?[0m | ?[0m 0.8016 ?[0m | ?[0m 2.748 ?[0m | ?[0m 5.915 ?[0m | ?[0m 4.0 ?[0m |
^{39}$
^{pr61}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.420782 | 0.420721 | 0.791350 | 00:10 |
1 | 0.403576 | 0.408376 | 0.800222 | 00:10 |
2 | 0.390236 | 0.393624 | 0.820342 | 00:11 |
3 | 0.377777 | 0.389657 | 0.821610 | 00:11 |
4 | 0.382809 | 0.386011 | 0.820976 | 00:11 |
| ?[0m 37 ?[0m | ?[0m 0.821 ?[0m | ?[0m 5.093 ?[0m | ?[0m 0.172 ?[0m | ?[0m 1.64 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.393575 | 0.397811 | 0.812262 | 00:08 |
1 | 0.378272 | 0.381915 | 0.815748 | 00:08 |
2 | 0.364799 | 0.369214 | 0.824620 | 00:08 |
3 | 0.355757 | 0.364554 | 0.826996 | 00:08 |
4 | 0.342090 | 0.362723 | 0.824303 | 00:08 |
| ?[0m 38 ?[0m | ?[0m 0.8243 ?[0m | ?[0m 8.0 ?[0m | ?[0m 5.799 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.393693 | 0.396980 | 0.822085 | 00:11 |
1 | 0.361231 | 0.393146 | 0.813847 | 00:11 |
2 | 0.345645 | 0.379510 | 0.823986 | 00:11 |
3 | 0.349778 | 0.367077 | 0.826679 | 00:11 |
4 | 0.342390 | 0.362027 | 0.827788 | 00:11 |
| ?[0m 39 ?[0m | ?[0m 0.8278 ?[0m | ?[0m 1.62 ?[0m | ?[0m 3.832 ?[0m | ?[0m 1.151 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.832737 | 0.491002 | 0.771546 | 00:43 |
1 | 0.627948 | 0.553552 | 0.764734 | 00:43 |
2 | 0.498901 | 0.467162 | 0.791984 | 00:46 |
3 | 0.431196 | 0.444576 | 0.785646 | 00:43 |
4 | 0.399745 | 0.427060 | 0.796578 | 00:42 |
| ?[0m 40 ?[0m | ?[0m 0.7966 ?[0m | ?[0m 2.198 ?[0m | ?[0m 8.0 ?[0m | ?[0m 4.0 ?[0m |
^{tb44}$
| ?[0m 41 ?[0m | ?[0m 0.7641 ?[0m | ?[0m 8.0 ?[0m | ?[0m 1.03 ?[0m | ?[0m 4.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.408504 | 0.413275 | 0.797212 | 00:15 |
1 | 0.392707 | 0.399085 | 0.805767 | 00:15 |
2 | 0.379938 | 0.395550 | 0.817807 | 00:15 |
3 | 0.375288 | 0.383186 | 0.820817 | 00:15 |
4 | 0.360417 | 0.375098 | 0.823194 | 00:16 |
| ?[0m 42 ?[0m | ?[0m 0.8232 ?[0m | ?[0m 0.0 ?[0m | ?[0m 2.504 ?[0m | ?[0m 2.135 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.399371 | 0.415196 | 0.801014 | 00:07 |
1 | 0.367804 | 0.392020 | 0.810995 | 00:06 |
2 | 0.362288 | 0.385124 | 0.820659 | 00:07 |
3 | 0.344728 | 0.371339 | 0.823669 | 00:07 |
4 | 0.345769 | 0.362059 | 0.829373 | 00:07 |
| ?[0m 43 ?[0m | ?[0m 0.8294 ?[0m | ?[0m 0.0 ?[0m | ?[0m 5.441 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.397157 | 0.431003 | 0.803866 | 00:06 |
1 | 0.394964 | 0.396448 | 0.810361 | 00:06 |
2 | 0.378584 | 0.387943 | 0.820659 | 00:07 |
3 | 0.371601 | 0.386186 | 0.818283 | 00:07 |
4 | 0.369759 | 0.384339 | 0.827630 | 00:07 |
| ?[0m 44 ?[0m | ?[0m 0.8276 ?[0m | ?[0m 4.636 ?[0m | ?[0m 1.476 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.408654 | 0.426806 | 0.791191 | 00:12 |
1 | 0.394184 | 0.406586 | 0.786439 | 00:12 |
2 | 0.369625 | 0.372680 | 0.822560 | 00:12 |
3 | 0.349444 | 0.368142 | 0.823828 | 00:12 |
4 | 0.351684 | 0.363406 | 0.826204 | 00:12 |
| ?[0m 45 ?[0m | ?[0m 0.8262 ?[0m | ?[0m 0.0 ?[0m | ?[0m 7.071 ?[0m | ?[0m 2.071 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.400293 | 0.416098 | 0.811629 | 00:08 |
1 | 0.377387 | 0.433395 | 0.807034 | 00:08 |
2 | 0.368131 | 0.395448 | 0.796420 | 00:08 |
3 | 0.367750 | 0.376879 | 0.817174 | 00:08 |
4 | 0.362124 | 0.371432 | 0.821134 | 00:08 |
^{pr71}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.404579 | 0.437443 | 0.814797 | 00:07 |
1 | 0.375342 | 0.380416 | 0.824937 | 00:07 |
2 | 0.365835 | 0.377617 | 0.812738 | 00:07 |
3 | 0.354619 | 0.364503 | 0.827471 | 00:07 |
4 | 0.340603 | 0.363488 | 0.827947 | 00:07 |
| ?[0m 47 ?[0m | ?[0m 0.8279 ?[0m | ?[0m 6.579 ?[0m | ?[0m 6.485 ?[0m | ?[0m 0.0 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.384890 | 0.440342 | 0.812579 | 00:08 |
1 | 0.371483 | 0.387200 | 0.813847 | 00:09 |
2 | 0.365951 | 0.378071 | 0.818283 | 00:09 |
3 | 0.362965 | 0.369994 | 0.821610 | 00:09 |
4 | 0.356483 | 0.365151 | 0.826521 | 00:09 |
| ?[0m 48 ?[0m | ?[0m 0.8265 ?[0m | ?[0m 8.0 ?[0m | ?[0m 4.293 ?[0m | ?[0m 1.74 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.386308 | 0.389250 | 0.815431 | 00:08 |
1 | 0.368402 | 0.389338 | 0.814956 | 00:09 |
2 | 0.362211 | 0.377196 | 0.824778 | 00:09 |
3 | 0.356135 | 0.362951 | 0.829531 | 00:09 |
4 | 0.341577 | 0.362476 | 0.830799 | 00:09 |
| ?[0m 49 ?[0m | ?[0m 0.8308 ?[0m | ?[0m 7.909 ?[0m | ?[0m 7.827 ?[0m | ?[0m 1.323 ?[0m |
^{tb53}$
| ?[0m 50 ?[0m | ?[0m 0.8303 ?[0m | ?[0m 4.946 ?[0m | ?[0m 1.246 ?[0m | ?[0m 1.589 ?[0m |
^{tb54}$
| ?[95m 51 ?[0m | ?[95m 0.8314 ?[0m | ?[95m 5.664 ?[0m | ?[95m 2.626 ?[0m | ?[95m 0.003048?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.389195 | 0.390032 | 0.817332 | 00:06 |
1 | 0.369993 | 0.382199 | 0.819708 | 00:07 |
2 | 0.362801 | 0.373282 | 0.826521 | 00:06 |
3 | 0.359760 | 0.363597 | 0.824303 | 00:06 |
4 | 0.344525 | 0.362097 | 0.828897 | 00:07 |
^{pr77}$
^{tb56}$
^{pr78}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.500322 | 0.519261 | 0.757129 | 00:10 |
1 | 0.413270 | 0.423630 | 0.801965 | 00:11 |
2 | 0.380234 | 0.395588 | 0.813371 | 00:12 |
3 | 0.361677 | 0.378123 | 0.817174 | 00:12 |
4 | 0.374629 | 0.373772 | 0.820025 | 00:12 |
| ?[0m 54 ?[0m | ?[0m 0.82 ?[0m | ?[0m 4.579 ?[0m | ?[0m 5.017 ?[0m | ?[0m 2.928 ?[0m |
| ?[0m 55 ?[0m | ?[0m 0.8259 ?[0m | ?[0m 0.02565 ?[0m | ?[0m 3.699 ?[0m | ?[0m 0.9808 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.452787 | 0.443697 | 0.768695 | 00:11 |
1 | 0.428332 | 0.415454 | 0.800697 | 00:11 |
2 | 0.396522 | 0.402850 | 0.807668 | 00:12 |
3 | 0.424802 | 0.414648 | 0.783587 | 00:12 |
4 | 0.385055 | 0.392359 | 0.801489 | 00:12 |
| ?[0m 56 ?[0m | ?[0m 0.8015 ?[0m | ?[0m 1.927 ?[0m | ?[0m 5.92 ?[0m | ?[0m 2.53 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.435597 | 0.438222 | 0.810836 | 00:19 |
1 | 0.399920 | 0.531189 | 0.770754 | 00:19 |
2 | 0.403408 | 0.409382 | 0.804816 | 00:18 |
3 | 0.363519 | 0.383823 | 0.815906 | 00:19 |
4 | 0.360030 | 0.377621 | 0.819708 | 00:19 |
| ?[0m 57 ?[0m | ?[0m 0.8197 ?[0m | ?[0m 0.7796 ?[0m | ?[0m 4.576 ?[0m | ?[0m 3.952 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.388445 | 0.420243 | 0.800539 | 00:07 |
1 | 0.372912 | 0.369659 | 0.827630 | 00:07 |
2 | 0.354443 | 0.366757 | 0.828105 | 00:07 |
3 | 0.352468 | 0.366038 | 0.822560 | 00:07 |
4 | 0.347822 | 0.362001 | 0.829690 | 00:07 |
| ?[0m 58 ?[0m | ?[0m 0.8297 ?[0m | ?[0m 3.525 ?[0m | ?[0m 4.198 ?[0m | ?[0m 0.02314 ?[0m |
^{tb61}$
| ?[0m 59 ?[0m | ?[0m 0.8194 ?[0m | ?[0m 6.711 ?[0m | ?[0m 3.848 ?[0m | ?[0m 2.395 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.396831 | 0.389045 | 0.809569 | 00:07 |
1 | 0.371171 | 0.375065 | 0.818600 | 00:07 |
2 | 0.350309 | 0.371795 | 0.824620 | 00:07 |
3 | 0.359700 | 0.363041 | 0.828739 | 00:07 |
4 | 0.345735 | 0.361556 | 0.830799 | 00:07 |
^{pr84}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.422853 | 0.412691 | 0.804341 | 00:09 |
1 | 0.375209 | 0.394692 | 0.817174 | 00:09 |
2 | 0.365574 | 0.380376 | 0.820184 | 00:08 |
3 | 0.359143 | 0.363607 | 0.831115 | 00:08 |
4 | 0.347991 | 0.361650 | 0.827947 | 00:08 |
| ?[0m 61 ?[0m | ?[0m 0.8279 ?[0m | ?[0m 7.962 ?[0m | ?[0m 6.151 ?[0m | ?[0m 1.119 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.388147 | 0.405885 | 0.810678 | 00:07 |
1 | 0.367743 | 0.391867 | 0.807826 | 00:07 |
2 | 0.366964 | 0.362980 | 0.828739 | 00:07 |
3 | 0.363402 | 0.363396 | 0.829531 | 00:07 |
4 | 0.351094 | 0.362245 | 0.829214 | 00:07 |
^{pr86}$
^{tb65}$
^{pr87}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.418990 | 0.420463 | 0.808618 | 00:07 |
1 | 0.389830 | 0.398110 | 0.816223 | 00:08 |
2 | 0.382975 | 0.387620 | 0.814956 | 00:06 |
3 | 0.384093 | 0.379607 | 0.819392 | 00:06 |
4 | 0.358019 | 0.371140 | 0.823828 | 00:08 |
^{pr88}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.385147 | 0.399680 | 0.806242 | 00:09 |
1 | 0.376032 | 0.381131 | 0.822560 | 00:09 |
2 | 0.363870 | 0.378227 | 0.822402 | 00:09 |
3 | 0.351089 | 0.368790 | 0.826838 | 00:09 |
4 | 0.340404 | 0.361807 | 0.829214 | 00:09 |
^{pr89}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.526628 | 0.507236 | 0.746673 | 00:17 |
1 | 0.460229 | 0.455675 | 0.765684 | 00:19 |
2 | 0.417427 | 0.421368 | 0.785963 | 00:19 |
3 | 0.462800 | 0.458844 | 0.773923 | 00:19 |
4 | 0.449479 | 0.456627 | 0.783587 | 00:19 |
| ?[0m 66 ?[0m | ?[0m 0.7836 ?[0m | ?[0m 3.68 ?[0m | ?[0m 3.977 ?[0m | ?[0m 3.919 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.621678 | 0.559080 | 0.749049 | 00:10 |
1 | 0.457104 | 0.473610 | 0.758397 | 00:11 |
2 | 0.416287 | 0.416622 | 0.764575 | 00:13 |
3 | 0.388107 | 0.403844 | 0.811945 | 00:13 |
4 | 0.384231 | 0.396397 | 0.813055 | 00:13 |
| ?[0m 67 ?[0m | ?[0m 0.8131 ?[0m | ?[0m 5.907 ?[0m | ?[0m 0.9452 ?[0m | ?[0m 2.168 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.410105 | 0.416171 | 0.808618 | 00:11 |
1 | 0.381669 | 0.400109 | 0.809094 | 00:13 |
2 | 0.377539 | 0.403879 | 0.803074 | 00:13 |
3 | 0.374653 | 0.389122 | 0.808618 | 00:13 |
4 | 0.366356 | 0.380526 | 0.814005 | 00:13 |
^{pr92}$
^{tb71}$
^{pr93}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.445924 | 0.488581 | 0.752693 | 00:17 |
1 | 0.410709 | 0.400962 | 0.813688 | 00:18 |
2 | 0.373518 | 0.393235 | 0.820184 | 00:18 |
3 | 0.364160 | 0.378920 | 0.820817 | 00:17 |
4 | 0.357551 | 0.371629 | 0.825412 | 00:17 |
| ?[0m 70 ?[0m | ?[0m 0.8254 ?[0m | ?[0m 0.009375?[0m | ?[0m 5.081 ?[0m | ?[0m 3.79 ?[0m |
^{tb73}$
^{pr95}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.391509 | 0.399966 | 0.806400 | 00:06 |
1 | 0.366694 | 0.405719 | 0.823828 | 00:07 |
2 | 0.359751 | 0.375496 | 0.822877 | 00:07 |
3 | 0.347678 | 0.361711 | 0.830799 | 00:07 |
4 | 0.336896 | 0.361922 | 0.828580 | 00:07 |
^{pr96}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.435485 | 0.414883 | 0.808143 | 00:08 |
1 | 0.373591 | 0.417138 | 0.814005 | 00:09 |
2 | 0.369590 | 0.375724 | 0.820184 | 00:09 |
3 | 0.370829 | 0.368655 | 0.829531 | 00:09 |
4 | 0.346463 | 0.366307 | 0.825412 | 00:09 |
^{pr97}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.455741 | 0.530008 | 0.794994 | 00:11 |
1 | 0.421961 | 0.423317 | 0.805292 | 00:11 |
2 | 0.405799 | 0.405729 | 0.807351 | 00:12 |
3 | 0.383895 | 0.395092 | 0.816857 | 00:12 |
4 | 0.378882 | 0.386044 | 0.818758 | 00:12 |
^{pr98}$
^{tb77}$
^{pr99}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.412856 | 0.410016 | 0.799113 | 00:08 |
1 | 0.410852 | 0.416405 | 0.788498 | 00:08 |
2 | 0.373897 | 0.384385 | 0.824303 | 00:09 |
3 | 0.353164 | 0.366129 | 0.822719 | 00:09 |
4 | 0.353253 | 0.362269 | 0.826362 | 00:09 |
| ?[0m 79 ?[0m | ?[0m 0.8264 ?[0m | ?[0m 3.438 ?[0m | ?[0m 7.982 ?[0m | ?[0m 1.829 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.419316 | 0.408936 | 0.798162 | 00:08 |
1 | 0.393826 | 0.390526 | 0.820184 | 00:08 |
2 | 0.372879 | 0.374823 | 0.822719 | 00:08 |
3 | 0.358019 | 0.370913 | 0.820342 | 00:08 |
4 | 0.346020 | 0.362252 | 0.829690 | 00:08 |
^{pr101}$
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.433481 | 0.437320 | 0.790082 | 00:18 |
1 | 0.415280 | 0.402946 | 0.814164 | 00:18 |
2 | 0.365575 | 0.376285 | 0.822877 | 00:18 |
3 | 0.363206 | 0.371865 | 0.820501 | 00:18 |
4 | 0.356401 | 0.370252 | 0.823828 | 00:18 |
| ?[0m 81 ?[0m | ?[0m 0.8238 ?[0m | ?[0m 0.03221 ?[0m | ?[0m 1.306 ?[0m | ?[0m 3.909 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.393150 | 0.420964 | 0.783745 | 00:07 |
1 | 0.375065 | 0.371380 | 0.823986 | 00:07 |
2 | 0.362952 | 0.387037 | 0.813688 | 00:07 |
3 | 0.347245 | 0.370225 | 0.824937 | 00:07 |
4 | 0.348406 | 0.361420 | 0.830640 | 00:07 |
| ?[0m 82 ?[0m | ?[0m 0.8306 ?[0m | ?[0m 1.575 ?[0m | ?[0m 2.689 ?[0m | ?[0m 0.8684 ?[0m |
epoch | train_loss | valid_loss | accuracy | time |
---|
0 | 0.395530 | 0.397430 | 0.818600 | 00:06 |
1 | 0.358679 | 0.396773 | 0.818283 | 00:07 |
2 | 0.349305 | 0.372877 | 0.823828 | 00:07 |
3 | 0.347346 | 0.363006 | 0.828422 | 00:07 |
4 | 0.335652 | 0.362567 | 0.830957 | 00:07 |
| ?[0m 83 ?[0m | ?[0m 0.831 ?[0m | ?[0m 2.765 ?[0m | ?[0m 5.439 ?[0m | ?[0m 0.04047 ?[0m |
^{83磅}$
| ?[0m 84 ?[0m | ?[0m 0.8253 ?[0m | ?[0m 0.1961 ?[0m | ?[0m 4.123 ?[0m | ?[0m 0.02039 ?[0m |
^{tb84}$
| ?[0m 85 ?[0m | ?[0m 0.8213 ?[0m | ?[0m 7.937 ?[0m | ?[0m 7.939 ?[0m | ?[0m 2.895 ?[0m |
^{tb85}$
| ?[0m 86 ?[0m | ?[0m 0.8235 ?[0m | ?[0m 0.06921 ?[0m | ?[0m 5.7 ?[0m | ?[0m 2.778 ?[0m |
^{tb86}$
| ?[0m 87 ?[0m | ?[0m 0.8184 ?[0m | ?[0m 7.965 ?[0m | ?[0m 5.261 ?[0m | ?[0m 2.661 ?[0m |
^{tb87}$
^{pr109}$
^{tb 8八$
{pr 110}$
{tb89}
{pr 111}$
{tb 90}$
{pr 112}$
{t91}$
{pr 113}$
{t92}$
{pr 114}$
{tb 93}$
{pr 115}$
{t94}
{pr 116}$
{tb95}$
{pr 117}$
{tb 96}$
{pr 118}$
{tb 97}$
{pr 119}$
{tb 98}$
{pr 120}$
{t99}$
{pr121}$
{tb 100}$
{pr 122}$
{tb 101}$
{123}$
{tb 102}$
{pr 124}$
{tb 103}$
{pr 125}$
{tb 104}$
{pr 126}$
{tb 105}$
{pr127}$
{tb 106}$
{pr 128}$
{tb 107}$
{pr 129}$
{tb 108}$
{pr 130}$
{tb 109}$
{pr 131}$
{pr 132}$
{pr 133}$
{pr 134}$
{pr 135}$
内存不足的数据集
如果你的数据集太大以至于无法放入内存中,你可以在每个历元中加载一个数据块。在
{pr 136}$
{pr137}$
{pr 138}$
{pr 139}$
{pr 140}$
{pr 141}$
{pr 142}$
{pr 143}$
{pr 144}$
{pr 145}$
{pr 146}$
{pr 147}$
{tb 110}$
欢迎加入QQ群-->: 979659372
推荐PyPI第三方库