t-SNE: Evolving Symmetry of High-Dimensional Data in Low Dimension space
Himanshu S.
Engineering Leader, ML & LLM / Gen AI enthusiast, Senior Engineering Manager @ Cohesity | Ex-Veritas | Ex-DDN
For this article, I am going to create a solid use case of something that has interested me for years. It’s cancer research. I read a lot about the chemical pathways that one can use to attack it, including Glycolysis, Glutaminolysis, and Gluconeogenesis pathways. However, this time, I thought to focus on the tumor cells analysis part of it to understand what they are and find patterns in it for dominant traits.
Imagine you’re a researcher studying cancer progression. You’ve collected data from thousands of individual cells, each holding clues about their behavior in the form of gene expressions (the activity levels of genes). Each cell has DNA, which contains several genes. These genes are responsible for protein synthesis in the body. Protein synthesis depends on the mRNA activity expressed by a gene. The number of mRNA molecules expressed by the genes correlates with the level of activity within a gene. Higher mRNA expression means more activity and increased protein synthesis.
If all this terminology feels overwhelming, consider the example of the COVID-19 mRNA vaccine by Moderna. It contained synthetic mRNA, which entered cells and directed them to produce the spike protein, mimicking the one found in the COVID-19 virus. This spike protein helped the body recognize the virus and develop defenses by producing antibodies.
For simplicity, what we have learned is that each cell has DNA, DNA contains genes, genes express mRNAs, and mRNAs act to produce proteins. You can imagine a dataset with cell ID, cell type, genes, and mRNA count per gene. Let’s imagine there are thousands of datasets containing similar and different types of cells (from tumor sites). For each cell, there could be thousands of gene expressions, meaning there are potentially thousands of dimensions of data. How would you figure out which genes are more active together and which ones aren’t?
This is where t-SNE comes in handy, it’s a great visualization tool to help visualize thousands of dimensions compressed into two to three dimensions. Does the data clusters together? Is it segregated? Meaning, figuring out the kind of relationships that exist in the data. When compressed the data could look something like this.
Now imagine you are in a very high-dimensional space where each gene out of thousands of genes in a cell is having its own dimension with data points in dimension being the number of mRNA sequences expressed by the gene. It's possible that some data points are nearby, and some are far. There could be intricate relationships between all points in the data space.
So how do you evaluate the relationship? You say, well, for each point, I am going to figure out its relationship with every other point in space, but that could be computationally expensive. In that case, you might consider doing pairwise analysis for each point with a finite number of relevant points in space. This approach is more meaningful as it provides a more impactful idea of relationships. These finite points against which you evaluate are determined by perplexity (a hyperparameter in t-SNE). The same process would then need to be repeated for all points in the space.
Question is, how do you calculate the relative distance, if you read my earlier articles then the way to do it is to calculate the Euclidean distances using following.
You can imagine, each point with a perplexity number lying on a canvas with a set of points (usually perplexity = 30). You calculate pair wise Euclidean distance between each point wrt this point. Just like you created a canvas for this point, you can create a canvas for each point in space with its own perplexity.
Given the distance, Idea here is to convert the relative distances into some sort of probability number (0, 1). It's easy for interpretation and it's normalized within a scale. Look at the formula below. What is it calculating?
In the numerator, it's calculating probability distribution of distance between point I and j within the perplexity. Here d_ij is the Euclidean distance squared to avoid negative sign and 2σ2 is a controlling factor which in a way decides whether the canvas has to be bigger or smaller depending on perplexity. Square term in σ is to keep the term normalized for square in Euclidean distance between two points in space. For larger σ the canvas is going to increase and for smaller values (for larger negative values or smaller σ values the expression evaluates lower values) it becomes small.
Term in the denominator is the evaluation of summation of distance divided by controlling factor 2σ2 between all the points in the perplexity. This is just a way to create a complex sort of probability for a single pair wise relationship considering distances between points in a somewhat stretching and shrinking canvas framework.
Imagine many such canvases all over the place for any point in space and each one is stretching and shrinking the way it did in earlier example. That would mean if I and j both have their own cavass then distance between I and j could be different if you are going from I to j vs j to I. As it's possible that one side could have stretched or shrunk after the other calculation. How do you normalize it then?
So averaging makes sense but why divide by N? We want to create a probability distribution for N number of points in space (sum to 1) hence we divide the individual probabilities by N number of data points in the space. This is called the similarity score.
This entire exercise is done in the high dimensional space only because points could be well placed and not as dense as in low dimension space.
At this stage we have the similarity score for all points in high dimension space. Now let's say you transformed an N-dimensional space to 2-dimensional space. That means all the data points in the 2-dimensional space now have additional context of N-2 discarded dimensions. That means the coordinates of these points in 2 dimensions would shift from their original place. What that means is that all the points that were closer to these points would also shift along. However, the points that were a little away in the higher-dimensional space would possibly shift away too. So, while we got clusters of points together and got good context similarity between points, we missed some context of relatively closer high-dimensional points going further away in low dimensions.
You need a formula that keeps closer points together and distant points somewhat near (if not closer), so you don't lose the context of points that are far away. This is where the heavy tail of the Student’s t-distribution plays a role. Unlike the high dimension probability function, this function ensures distant points have small but non-zero influence, helping avoid overcrowding and maintaining global structure and the points that are clustered together are still there. This function maintains both local and global structure.
Numerator term has the Euclidean distances between two points inversed. It is then divided by sum of all of the pair wise distances inverse. Inverse on both numerator and denominator ensures that points close enough are having high probability and those at distance still have non zero probability.
Now we have point wise probability for pair wise distances in high and low dimensions but how do we know that all the points are actually mapped correctly from high dimension to low dimension? For that we compare the probabilities of both dimensions to know how far the low dimension probabilities are. To do this, we use K L Divergence.
Since the ideal value of KL divergence is 0 and the maximum value can be very high on both the negative and positive sides, that’s why the logarithm is used here. If Q is high, it means Q is better represented than P, and that needs adjustment to bring it closer to P. If P is higher than Q, it means Q is not well-clustered yet, so that also needs adjustment. In a perfect situation, if P and Q are equal, then KL divergence would be zero (log (1) =0), meaning everything is perfectly aligned. In simpler terms, the logarithm penalizes significantly high or low values.
That brings another question: why is this term scaled by P? Well, P comes from the original dimension, and by scaling it by that factor, we give KL divergence the context of the high dimension. Essentially, we are saying: in the context of the high dimension, how much is the divergence in the lower dimensions? In layman’s language one can say, look, the high dimension was good. Comparing this to the low dimension, here’s how far off you are. Now, what can you do to fix it?
In comes the gradient descent from neural networks!
if you differentiate KL divergence with respect to any point in space, you get this formula for gradient updates. This process is similar to how we use gradient descent in neural networks by differentiating the loss function.
Using the gradient formula, we iteratively adjust each point in the low-dimensional space.
Here, η is the learning rate, a hyperparameter controlling the step size of these adjustments. Each point is shifted from its current position (t) to a new position (t+1), refining the low-dimensional map.
After every update, we recalculate the KL divergence to measure how closely the low-dimensional probabilities (Q) align with the high-dimensional probabilities (P). This process continues until KL divergence becomes negligible, meaning the mapping is optimal.
Eventually the diagram would look like this.
This looks like a great technique, but it has some issues.
1. It can map the noise of the data as well.
2. It can be computationally expensive.
3. It doesn’t scale well for the large amount of data.
Some of this problem can go away if we reduce the data dimension using PCA before applying t-SNE, but other concerns still remain. In the next article, we will explore UMAP, which addresses some of these concerns and yet provide equally powerful insights.
As we end this article, I will borrow a quote from Richard Feynman: 'You can know the name of a bird in all the languages of the world, but when you’re finished, you’ll know absolutely nothing about the bird.' In that spirit, now that we’ve explored the nature of the bird, we can finally talk about its name: T-distributed Stochastic Neighbor Embedding.
Senior Engineering Manager at Cohesity
1 个月This is so well-researched and thorough. I'm impressed.
iocl at Indian Oil Corporation Limited
1 个月Very informative