一个适应的深度复发生存分析模型的Pythorch实现。
drsa的Python项目详细描述
PyTorch深部复发生存分析
Documentation
This project features a PyTorch implementation of the Deep Recurrent Survival Analysis model that is intended for use on uncensored sequential data in which the event is known to occur at the last time step for each observation More specifically, this library is made up of two small modules.
- 在 在
- 在 在
安装
$ pip install drsa
使用
^{pr2}$# generating random databatch_size,seq_len,n_features=(64,25,10)defdata_gen(batch_size,seq_len,n_features):samples=[]for_inrange(batch_size):sample=torch.cat([torch.normal(mean=torch.arange(1.,float(seq_len)+1)).unsqueeze(-1)for_inrange(n_features)],dim=-1)samples.append(sample.unsqueeze(0))returntorch.cat(samples,dim=0)data=data_gen(batch_size,seq_len,n_features)# generating random embedding for each sequencen_embeddings=10embedding_idx=torch.mul(torch.ones(batch_size,seq_len,1),torch.randint(low=0,high=n_embeddings,size=(batch_size,1,1)),)# concatenating embeddings and featuresX=torch.cat([embedding_idx,data],dim=-1)
# instantiating embedding parametersembedding_size=5embeddings=torch.nn.Embedding(n_embeddings,embedding_size)
# instantiating modelmodel=DRSA(n_features=n_features+1,# +1 for the embeddingshidden_dim=2,n_layers=1,embeddings=[embeddings],)
# defining training loopdeftraining_loop(X,optimizer,alpha,epochs):forepochinrange(epochs):optimizer.zero_grad()preds=model(X)# weighted average of survival analysis lossesevt_loss=event_time_loss(preds)evr_loss=event_rate_loss(preds)loss=(alpha*evt_loss)+((1-alpha)*evr_loss)# updating parametersloss.backward()optimizer.step()ifepoch%100==0:print(f"epoch: {epoch} - loss: {round(loss.item(),4)}")
# running training loopoptimizer=optim.Adam(model.parameters())training_loop(X,optimizer,alpha=0.5,epochs=1001)
epoch: 0 - loss: 12.485
epoch: 100 - loss: 10.0184
epoch: 200 - loss: 6.5471
epoch: 300 - loss: 4.6741
epoch: 400 - loss: 3.9786
epoch: 500 - loss: 3.5133
epoch: 600 - loss: 3.1826
epoch: 700 - loss: 2.9421
epoch: 800 - loss: 2.7656
epoch: 900 - loss: 2.6355
epoch: 1000 - loss: 2.5397
- 项目
标签: