Deep Learning in Survival Analysis: Implementation of Breslow Approximation for Tied Event Times from Scratch
In my previous post, Deep Learning in Survival Analysis: loss function, we explored the loss function in survival analysis. Following an insightful query regarding the calculation of the Breslow estimation for tied event times, I will dive deeper into this topic.
Tied event times are a frequent occurrence in clinical data. To effectively handle these instances, constructing a partial likelihood that accommodates ties is crucial. In this update, I'll demonstrate implementing the Breslow method from scratch using Python. Additionally, I'll offer a comparison using R for a comprehensive understanding.
1. Algorithm
The Breslow estimation is:
% LaTeX equation
L(\beta)\approx \prod_{j=1}^{D}\frac{exp(\beta\sum_{l \in D_{j}}z_{l})}{[\sum_{l\in{R_{j}}}exp(z_{l\beta})]^{d_{j}}}
where d_{j} is the survival times at jth distinct survival time; R_{j} is the risk set; D_{j} is the event set; and z is the covariate.
领英推荐
2. Implementation from scratch using Python and PyTorch
import torch
# data simulation
time = [2,2,5,8,8,9]
event = [1, 1, 0, 1, 1, 1]
X = [1, 2, 3, 4, 5, 6]
beta1 = 0.1
# convert to torch data sets
time1 = torch.tensor(time,dtype=float)
event1 = torch.tensor(event,dtype=float)
X1 = torch.tensor(X,dtype=float)
beta = torch.tensor(beta1,dtype=float)
# calcuate the risk score
risk_score = torch.multiply(X1,beta)
risk_score = risk_score.view(-1)
time = time1.view(-1)
event = event1.view(-1)
# sort the event time
time, idx = torch.sort(time,descending=True)
event = event[idx]
event_idx = event.nonzero().flatten()
risk_score = risk_score[idx]
# calculate the Breslow estimation
risk_logcumsumexp = torch.logcumsumexp(risk_score,dim=0)[event_idx]
time_uni, inverse_idx, count = torch.unique_consecutive(time[event_idx], return_counts=True, return_inverse=True)
risk_logcumsumexp_tied = risk_logcumsumexp[count.cumsum(axis=0) - 1][inverse_idx]
# the loss function
loss = risk_logcumsumexp_tied.mean() - risk_score[event_idx].mean()
loss
In this simple example, we got loss as follows:
tensor(1.2633, dtype=torch.float64)
3. Comparison with R
library(survival)
times = c(2, 2, 5, 8, 8, 9)
status = c(1, 1, 0, 1, 1, 1)
X = c(1,2,3,4,5,6)
b1 = 0.1
beta1 <- coxph(Surv(times, status) ~ X,
init = b1,
ties = 'breslow',
control=coxph.control(iter.max=1))
likelihood = beta1$loglik/sum(status==1)
likelihood
> likelihood
[1] -1.2632964 -0.7689787
In this post, we successfully demonstrated that our custom implementation from scratch yields results consistent with those obtained using the R package 'survival'. Stay tuned as we navigate this interesting aspect of survival analysis together!