SCVI — Inference-based optional integration
The batch-integrating SCVI model is a conditional variational autoencoder (cVAE, Sohn et al. 2015), where the generative model can be written as
I usually focus on this generative part of the model. It is what we use to interpret the model. Variational autoencoders (Kingma & Welling, 2013) are divided into two parts: the generative model written above, and an inference model that is used to learn posterior distributions of z from the data. In my mind, I think of it as the part the matters the most.
Technically, you don’t need an inference model; you could consider just the generative model and find posterior distributions for all the latent z variables, using the observed data and the likelihood with Bayesian posterior sampling methods. But this is computationally impossible.
Stepping back a bit, the posterior distributions of the z latent variables could, slightly more efficiently, be estimated by variational distributions: simplified parametric distributions with the goal of being as similar to the true posterior distributions as possible:
Variational inference is the strategy of approximating posterior distributions, which is an integration problem, by converting it to an optimization problem, where you are optimizing the parameters of the variational distributions.
If you have single-cell RNA-seq dataset with 100,000 cells, and you want to learn 10-dimensional representations of your cells, you would need to optimize 2 * 10 * 100,000 = 2,000,000 variational parameters, which ends up being a very difficult optimization problem.
One of the revolutionary insights in machine learning is that this optimization problem can be solved with neural networks (Kingma & Welling, 2013). Instead of optimizing the variational parameters, you optimize the weights of neural networks that outputs the variational parameters given the observed data
This is known as ‘autoencoding variational Bayes’ or ‘amortized inference’, and is the other half of the variational autoencoder. We are very good at training neural networks, so turning a problem into a neural network training task ends up being a good solution. These neural networks represent the inference model in the variational autoencoder model.
In the case of the conditional variational autoencoder, there are two valid options for implementing the inference models:
In option (1), the inference model is not aware of the batches to integrate. Only the generative model will be aware of the batches. The effect, though, is that the inference networks g will learn to map observed data to appropriate representations in an unsupervised manner. A strength with this choice is that you can infer z representations for new data without knowing if any particular batch category is equivalent.
In option (2), the inference model is explicitly informed of which batch the data comes from. The inference networks will use the interactions between the values in the observed data and the batches to learn potentially more efficiently to map observations to the appropriate representations. A strength with this choice is that you can obtain counterfactual representations. A weakness is that the model will not know how to embed data from a new batch without performing architecture surgery.
The default in the SCVI models from scvi-tools with batch integration is option (1), but option (2) can be enabled with the encode_covariates = True
option.
For a demonstration, we can replicate the previous post, and compare representational embeddings between an SCVI model that does not integrate out the batches, and an SCVI model that integrates out batches, but in this case using the encode_covariates = True
option.
This can be illustrated as in the previous post using the AIDA data (Tian et al. 2024).
The batch-integrating version of the SCVI model with batch-aware encoding leaves out donor-to-donor variation in the representations. The effect is the same when encoding batches in integration as with the default integration explored in the previous post.
Optional integration using a conditional inference model
In the previous post, we discussed how the choice of learning representations with or without contribution from known factors depends on the questions you aim to answer. In particular, we discussed how the MrVI model lets us work in both frameworks with a single model.
We can think about this problem when considering the SCVI model with encoded batches:
Imagine having the option to ‘switch off’ the s_n contribution to the encoders g and decoder f. Then this batch-integrated model would turn into the unintegrated model. Of course, the model will not know what to do if you simply remove the s_n input to the models.
We can give the model this ability by augmenting the training data. The original data can be expanded, so that each Y observation is used to train the model twice in an epoch: once with the observed batch category, and once with a new dummy batch category we can call 'unintegrated'
.
Now, for the added data points where the batch category is 'unintegrated' the model and loss will be equivalent to an unintegrated model,
After training the model with this expanded training data, we can create two sets of cell embeddings by running the observed data through the encoder g_mu with different settings for the batch categories:
We can try this approach using the AIDA data, and see how these two different versions of the cell embeddings do or do not contain variation due to donor IDs.
These results were surprising to me! My expectation was that encoding batches -case would integrate out differences between donors, while using the fixed 'unintegrated'
label would retain variation between donors. Instead we are seeing the opposite. When batches are provided to the encoder, it introduces variation between batches into the representations.
On the positive side, there is a substantial difference between the two versions of cell embeddings. I don’t understand why the model ends up having this behavior.
I have always viewed the inference models in VAE-based models as a clever solution to the optimization problem, focusing on the generative model. Through this experiment I have gotten an appreciation for how the encoders can be potentially be leveraged to solve problems. In particular I think it is interesting how we can expand the capabilities of a model by expanding and augmenting the training data.
Notebooks with analysis code are available on Github: https://github.com/vals/Blog/tree/master/250327-optional-integration
References
Kingma, Diederik P., and Max Welling. 2013. “Auto-Encoding Variational Bayes.” arXiv [Stat.ML]. arXiv. http://arxiv.org/abs/1312.6114v10.
Sohn, Kihyuk, Honglak Lee, and Xinchen Yan. 2015. “Learning Structured Output Representation Using Deep Conditional Generative Models.” Neural Information Processing Systems, December, 3483–91.
Tian, Chi, Yuntian Zhang, Yihan Tong, Kian Hong Kock, Donald Yuhui Sim, Fei Liu, Jiaqi Dong, et al. 2024. “Single-Cell RNA Sequencing of Peripheral Blood Links Cell-Type-Specific Regulation of Splicing to Autoimmune and Inflammatory Diseases.” Nature Genetics 56 (12): 2739–52.