Abstract
In time-to-event prediction problems, a standard approach to estimating an interpretable model is to use Cox proportional hazards, where features are selected based on lasso regularization or stepwise regression. However, these Cox-based models do not learn how different features relate. As an alternative, we present an interpretable neural network approach to jointly learn a survival model to predict time-to-event outcomes while simultaneously learning how features relate in terms of a topic model. In particular, we model each subject as a distribution over “topics”, which are learned from clinical features as to help predict a time-to-event outcome. From a technical standpoint, we extend existing neural topic modeling approaches to also minimize a survival analysis loss function. We study the effectiveness of this approach on seven healthcare datasets on predicting time until death as well as hospital ICU length of stay, where we find that neural survival-supervised topic models achieves competitive accuracy with existing approaches while yielding interpretable clinical “topics” that explain feature relationships.
Access this chapter
Tax calculation will be finalised at checkout
Purchases are for personal use only
Similar content being viewed by others
Notes
- 1.
We use the UNOS Standard Transplant and Analysis Research data from the Organ Procurement and Transplantation Network as of September 2019, requested at: https://www.unos.org/data/.
- 2.
References
Antolini, L., Boracchi, P., Biganzoli, E.: A time-dependent discrimination index for survival data. Stat. Med. 24(24), 3927–3944 (2005)
Beran, R.: Nonparametric regression with randomly censored survival data. Technical report, University of California, Berkeley (1981)
Blei, D.M., Ng, A.Y., Jordan, M.I.: Latent Dirichlet allocation. J. Mach. Learn. Res. 3, 993–1022 (2013)
Breslow, N.: Discussion of the paper by D. R. Cox cited below. J. R. Stat. Soc. Ser. B 34(2), 216–217 (1972)
Card, D., Tan, C., Smith, N.A.: Neural models for documents with metadata. In: Proceedings of Association for Computational Linguistics (2018)
Cox, D.R.: Regression models and life-tables. J. Roy. Stat. Soc. B 34(2), 187–202 (1972)
Curtis, C., et al.: The genomic and transcriptomic architecture of 2,000 breast tumours reveals novel subgroups. Nature 486(7403), 346 (2012)
Dawson, J.A., Kendziorski, C.: Survival-supervised latent Dirichlet allocation models for genomic analysis of time-to-event outcomes. arXiv preprint arXiv:1202.5999 (2012)
Dieng, A.B., Ruiz, F.J., Blei, D.M.: Topic modeling in embedding spaces. arXiv preprint arXiv:1907.04907 (2019)
Eisenstein, J., Ahmed, A., Xing, E.P.: Sparse additive generative models of text. In: International Conference on Machine Learning, pp, 1041–1048 (2011)
Harrell, F.E.: Regression Modeling Strategies: With Applications to Linear Models, Logistic and Ordinal Regression, and Survival Analysis. Springer, Cham (2015)
Harrell, F.E., Lee, K.L., Califf, R.M., Pryor, D.B., Rosati, R.A.: Regression modelling strategies for improved prognostic prediction. Stat. Med. 3(2), 143–152 (1984)
Ishwaran, H., Kogalur, U.B., Blackstone, E.H., Lauer, M.S.: Random survival forests. Ann. Appl. Stat. 2(3), 841–860 (2008)
Johnson, A.E., et al.: MIMIC-III, a freely accessible critical care database. Sci. Data 3(1), 1–9 (2016)
Kalbfleisch, J.D., Prentice, R.L.: The Statistical Analysis of Failure Time Data, 2nd edn. Wiley, Hoboken (2002)
Katzman, J.L., Shaham, U., Cloninger, A., Bates, J., Jiang, T., Kluger, Y.: DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Med. Res. Methodol. 18(1), 24 (2018)
Kingma, D.P., Ba, J.: Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
Kingma, D.P., Welling, M.: Auto-encoding variational Bayes. In: International Conference on Learning Representations (2014)
Knaus, W.A., et al.: The SUPPORT prognostic model: objective estimates of survival for seriously ill hospitalized adults. Ann. Intern. Med. 122(3), 191–203 (1995)
Lafferty, J.D., Blei, D.M.: Correlated topic models. In: Advances in Neural Information Processing Systems, pp. 147–154 (2006)
Lee, C., Zame, W.R., Yoon, J., van der Schaar, M.: DeepHit: a deep learning approach to survival analysis with competing risks. In: AAAI Conference on Artificial Intelligence (2018)
Lowsky, D.J., et al.: A \(K\)-nearest neighbors survival probability prediction method. Stat. Med. 32(12), 2062–2069 (2013)
McAuliffe, J.D., Blei, D.M.: Supervised topic models. In: Advances in Neural Information Processing Systems, pp. 121–128 (2008)
Rezende, D.J., Mohamed, S., Wierstra, D.: Stochastic backpropagation and approximate inference in deep generative models. In: International Conference on Machine Learning (2014)
Simon, N., Friedman, J., Hastie, T., Tibshirani, R.: Regularization paths for Cox’s proportional hazards model via coordinate descent. J. Stat. Softw. 39(5), 1 (2011)
Acknowledgments
This work was supported in part by Health Resources and Services Administration contract 234-2005-370011C. The content is the responsibility of the authors alone and does not necessarily reflect the views or policies of the Department of Health and Human Services, nor does mention of trade names, commerical products, or organizations imply endorsement by the U.S. Government.
Author information
Authors and Affiliations
Corresponding author
Editor information
Editors and Affiliations
A Interpreting Topic Heatmaps
A Interpreting Topic Heatmaps
In this appendix, we explain how to interpret our topic heatmaps (Fig. 1 and additional plots in our code repository). For many topic models including LDA, a topic is represented as a distribution over d vocabulary words. scholar [5] (and also our survival-supervised version survScholar) reparameterizes these topic distributions; borrowing from SAGE [10], scholar represents a topic as a deviation from a background log-frequency vector. This vector accommodates common words that have similar frequencies across data points. When we visualize a topic, we take this modeling approach into account and only choose to highlight features that have positive log-deviations from the background. Given a topic, having positive log-deviation is analogous to having higher conditional probabilities in the classic topic modeling case but explicitly is relative to background word frequencies (rather than being raw topic word probabilities).
To fill in the details, in step 2(a) of survScholar’s generative process (stated in Sect. 3), each word is drawn from the conditional distribution \(\text {softmax}(\gamma + w^T B)\), where \(\gamma \in \mathbb {R}^d\) is the background log-frequency vector, \(w\in \mathbb {R}^k\) contains a sample’s topic membership weights, and \(B\in \mathbb {R}^{k\times d}\) encodes (per topic) every vocabulary word’s log-deviation from the word’s background. This is a reparameterization of how LDA is encoded, which has each word drawn from the conditional distribution \(\text {softmax}(w^T H)\) for \(H\in \mathbb {R}^{k\times d}\). In particular, note that \(H_g = \gamma + B_g\) for every topic \(g\in \{1,2,\dots ,k\}\). The background log-frequency vector \(\gamma \) is learned during neural net training. Note that SAGE [10] further encourages sparsity in B by adding \(\ell _1\) regularization on B.
We found ranking words within a topic by their raw probabilities (\(A_g\) in Eq. (2.1)) to be less interpretable than ranking words based on their deviations from their background frequencies (\(B_g\)) precisely because commonly occurring background words make interpretation difficult. In fact, when Dawson and Kendziorski [8] introduced survLDA, they used an ad hoc pre-processing step to identify background words to exclude from analysis altogether. We avoid this pre-processing and use log-deviations from background frequencies instead.
In heatmaps such as the one in Fig. 1, each column corresponds to a topic. For the g-th topic, instead of plotting its raw log-deviations (encoded in \(B_g\in \mathbb {R}^d\)), which are harder to interpret, we exponentiated each word’s log-deviation to get the word’s multiplicative ratio from its background frequency (i.e., we compute \(\exp (B_g)\)); the color bar intensity values are precisely these multiplicative ratios of how often a word appears relative to the word’s background frequency.
To highlight features that distinguish topics from one another, we also sort rows in the heatmap by descending differences between the largest and smallest values in a row. Thus, features whose deviations vary greatly across topics tend to show up on the top. A technical detail is that we sorted with respect to the original features, rather than the one-hot encoded or binned features. Therefore, as an example, all bins under mean blood pressure stay together. For features associated with multiple rows in the heatmap, we computed the difference between the largest and smallest values for each row, and used the largest difference (across rows) for sorting.
Rights and permissions
Copyright information
© 2020 Springer Nature Switzerland AG
About this paper
Cite this paper
Li, L., Zuo, R., Coston, A., Weiss, J.C., Chen, G.H. (2020). Neural Topic Models with Survival Supervision: Jointly Predicting Time-to-Event Outcomes and Learning How Clinical Features Relate. In: Michalowski, M., Moskovitch, R. (eds) Artificial Intelligence in Medicine. AIME 2020. Lecture Notes in Computer Science(), vol 12299. Springer, Cham. https://doi.org/10.1007/978-3-030-59137-3_33
Download citation
DOI: https://doi.org/10.1007/978-3-030-59137-3_33
Published:
Publisher Name: Springer, Cham
Print ISBN: 978-3-030-59136-6
Online ISBN: 978-3-030-59137-3
eBook Packages: Computer ScienceComputer Science (R0)