Deep Learning in Survival Analysis: Implementation of Breslow Approximation for Tied Event Times from Scratch

Deep Learning in Survival Analysis: Implementation of Breslow Approximation for Tied Event Times from Scratch

In my previous post, Deep Learning in Survival Analysis: loss function, we explored the loss function in survival analysis. Following an insightful query regarding the calculation of the Breslow estimation for tied event times, I will dive deeper into this topic.

Tied event times are a frequent occurrence in clinical data. To effectively handle these instances, constructing a partial likelihood that accommodates ties is crucial. In this update, I'll demonstrate implementing the Breslow method from scratch using Python. Additionally, I'll offer a comparison using R for a comprehensive understanding.

1. Algorithm

The Breslow estimation is:

% LaTeX equation

L(\beta)\approx \prod_{j=1}^{D}\frac{exp(\beta\sum_{l \in D_{j}}z_{l})}{[\sum_{l\in{R_{j}}}exp(z_{l\beta})]^{d_{j}}}        


The Breslow's approximation.


where d_{j} is the survival times at jth distinct survival time; R_{j} is the risk set; D_{j} is the event set; and z is the covariate.

2. Implementation from scratch using Python and PyTorch

import torch

# data simulation
time = [2,2,5,8,8,9]
event = [1, 1, 0, 1, 1, 1]
X = [1, 2, 3, 4, 5, 6]
beta1 = 0.1

# convert to torch data sets
time1 = torch.tensor(time,dtype=float)
event1 = torch.tensor(event,dtype=float)
X1 = torch.tensor(X,dtype=float)
beta = torch.tensor(beta1,dtype=float)

# calcuate the risk score
risk_score = torch.multiply(X1,beta)
risk_score = risk_score.view(-1)
time = time1.view(-1)
event = event1.view(-1)

# sort the event time 
time, idx = torch.sort(time,descending=True)
event = event[idx]
event_idx = event.nonzero().flatten()
risk_score = risk_score[idx]

# calculate the Breslow estimation
risk_logcumsumexp = torch.logcumsumexp(risk_score,dim=0)[event_idx]
time_uni, inverse_idx, count = torch.unique_consecutive(time[event_idx], return_counts=True, return_inverse=True)
risk_logcumsumexp_tied = risk_logcumsumexp[count.cumsum(axis=0) - 1][inverse_idx]

# the loss function
loss =  risk_logcumsumexp_tied.mean() - risk_score[event_idx].mean()
loss        

In this simple example, we got loss as follows:

tensor(1.2633, dtype=torch.float64)        

3. Comparison with R

library(survival)

times = c(2, 2, 5, 8, 8, 9)
status = c(1, 1, 0, 1, 1, 1)
X = c(1,2,3,4,5,6)
b1 = 0.1

beta1 <- coxph(Surv(times, status) ~ X, 
               init = b1, 
               ties = 'breslow',
               control=coxph.control(iter.max=1))

likelihood = beta1$loglik/sum(status==1)
likelihood         
> likelihood
[1] -1.2632964 -0.7689787        

In this post, we successfully demonstrated that our custom implementation from scratch yields results consistent with those obtained using the R package 'survival'. Stay tuned as we navigate this interesting aspect of survival analysis together!




要查看或添加评论,请登录

社区洞察

其他会员也浏览了