**1. Introduction**

Supervised contrastive learning has emerged as a promising method for training deep models, with strong empirical results over traditional supervised learning [1]. Recent theoretical work has shown that under certain assumptions, *class collapse*—when the representation of every point from a class collapses to the same embedding on the hypersphere, as in Figure 1—minimizes the supervised contrastive loss *LSC* [2]. Furthermore, modern deep networks, which can memorize arbitrary labels [3], are powerful enough to produce class collapse.

Although class collapse minimizes *LSC* and produces accurate models, it loses information that is not explicitly encoded in the class labels. For example, consider images with the label "cat." As shown in Figure 1, some cats may be sleeping, some may be jumping, and some may be swatting at a bug. We call each of these semantically-unique categories of data—some of which are rarer than others, and none of which are explicitly labeled—a *stratum*. Distinguishing strata is important; it empirically can improve model performance [4] and fine-grained robustness [5]. It is also critical in high-stakes applications such as medical imaging [6]. However, *LSC* maps the sleeping, jumping, and swatting cats all to a single "cat" embedding, losing strata information. As a result, these embeddings are less useful

**Citation:** Fu, D.Y.; Chen, M.F.; Zhang, M.; Fatahalian, K.; Ré, C. The Details Matter: Preventing Class Collapse in Supervised Contrastive Learning. *CSFM* **2022**, *3*, 4. https://doi.org/10.3390/ cmsf2022003004

Academic Editors: Kuan-Chuan Peng and Ziyan Wu

Published: 15 April 2022

**Publisher's Note:** MDPI stays neutral with regard to jurisdictional claims in published maps and institutional affiliations.

**Copyright:** © 2022 by the authors. Licensee MDPI, Basel, Switzerland. This article is an open access article distributed under the terms and conditions of the Creative Commons Attribution (CC BY) license (https:// creativecommons.org/licenses/by/ 4.0/).

for common downstream applications in the modern machine learning landscape, such as transfer learning.

In this paper, we explore a simple modification to *LSC* that prevents class collapse. We study how this modification affects embedding quality by considering how strata are represented in embedding space. We evaluate our loss both in terms of embedding quality, which we evaluate through three downstream applications, and end model quality.

**Figure 1.** Classes contain critical information that is not explicitly encoded in the class labels. Supervised contrastive learning (left) loses this information, since it maps unlabeled strata such as sleeping cats, jumping cats, and swatting cat to a single embedding. We introduce a new loss function *Lspread* that prevents class collapse and maintains strata distinctions. *Lspread* produces higher-quality embeddings, which we evaluate with three downstream applications.

In Section 3, we present our modification to *LSC*, which prevents class collapse by changing how embeddings are pushed and pulled apart. *LSC* pushes together embeddings of points from the same class and pulls apart embeddings of points from different classes. In contrast, our modified loss *Lspread* includes an additional class-conditional InfoNCE loss term that uniformly pulls apart individual points from within the same class. This term on its own encourages points from the same class to be maximally spread apart in embedding space, which discourages class collapse (see Figure 1 middle). Even though *Lspread* does not use strata labels, we observe that it still produces embeddings that qualitatively appear to retain more strata information than those produced by *LSC* (see Figure 2).

In Section 4, motivated by these empirical observations, we study how well *Lspread* preserves distinctions between strata in the representation space. Previous theoretical tools that study the optimal embedding distribution fail to characterize the geometry of strata. Instead, we propose a simple thought experiment considering the embeddings that the supervised contrastive loss generates when it is trained on a partial sample of the dataset. This setup enables us to distinguish strata based on their sizes by considering how likely it is for them to be represented in the sample (larger strata are more likely to appear in a small sample). In particular, we find that points from rarer and more distinct strata are clustered less tightly than points from common strata, and we show that this clustering property can improve embedding quality and generalization error.

In Section 5, we empirically validate several downstream implications of these insights. First, we demonstrate that *Lspread* produces embeddings that retain more information about strata, resulting in lift on three downstream applications that require strata recovery:


points from common strata. Our coresets outperform prior work by 1.0 points when coreset size is 30% of the training set.

Next, we find that *Lspread* produces higher-quality models, outperforming *LSC* by up to 4.0 points across 9 tasks. Finally, we discuss related work in Section 6 and conclude in Section 7.

**Figure 2.** *Lspread* produces embeddings that are qualitatively better than those produced by *LSC*. We show t-SNE visualizations of embeddings for the CIFAR10 test set and report cosine similarity metrics (average intracluster cosine similarities, and similarities between individual points and the class cluster). *Lspread* produces lower intraclass cosine similarity and embeds images from rare strata further out over the hypersphere than *LSC*.

## **2. Background**

We present our generative model for strata (Section 2.1). Then, we discuss supervised contrastive learning—in particular the SupCon loss *LSC* from [1] and its optimal embedding distribution [2]—and the end model for classification (Section 2.2).

### *2.1. Data Setup*

We have a labeled input dataset D = {(*xi*, *yi*)}*Ni*=1, where (*<sup>x</sup>*, *y*) ∼ P for *x* ∈ X and *y* ∈ Y = {1, ... , *<sup>K</sup>*}. For a particular data point *x*, we denote its label as *h*(*x*) ∈ Y with distribution *p*(*y*|*x*). We assume that data is class-balanced such that *p*(*y* = *i*) = 1*K* for all *i* ∈ Y. The goal is to learn a model *p*<sup>ˆ</sup>(*y*|*x*) on D to classify points.

Data points also belong to categories beyond their labels, called *strata*. Following [5], we denote a stratum as a latent variable *z*, which can take on values in Z = {1, ... , *<sup>C</sup>*}. Z can be partitioned into disjoint subsets *S*1, ... , *SK* such that if *z* ∈ *Sk*, then its corresponding *y* label is equal to *k*. Let *<sup>S</sup>*(*c*) denote the deterministic label corresponding to stratum *c*. We model the data generating process as follows. First, the latent stratum is sampled from distribution *p*(*z*). Then, the data point *x* is sampled from the distribution P*z* = *p*(·|*z*), and its corresponding label is *y* = *<sup>S</sup>*(*z*) (see Figure 2 of [5]). We assume that each class has *m* strata, and that there exist at least two strata, *z*1, *z*2, where *<sup>S</sup>*(*<sup>z</sup>*1) = *<sup>S</sup>*(*<sup>z</sup>*2) and supp(*<sup>z</sup>*1) ∩ supp(*<sup>z</sup>*2) = ∅.

### *2.2. Supervised Contrastive Loss*

Supervised contrastive loss pushes together pairs of points from the same class (called positives) and pulls apart pairs of points from different classes (called negatives) to train an encoder *f* : X → R*d*. Following previous works, we make three assumptions on the encoder: (1) we restrict the encoder output space to be <sup>S</sup>*d*−1, the unit hypersphere; (2) we assume *K* ≤ *d* + 1, which allows Graf et al. [2] to recover optimal embedding geometry; and (3) we assume the encoder *f* is "infinitely powerful", meaning that any distribution on S*d*−<sup>1</sup> is realizable by *f*(*x*).

### 2.2.1. SupCon and Collapsed Embeddings

We focus on the SupCon loss *LSC* from [1]. Denote *<sup>σ</sup>*(*<sup>x</sup>*, *x*) = *f*(*x*) *f*(*x*)/*<sup>τ</sup>*, where *τ* is a temperature hyperparameter. Let B be the set of batches of labeled data on D and *<sup>P</sup>*(*<sup>i</sup>*, *B*) = {*p* ∈ *B*\*i* : *h*(*p*) = *h*(*i*)} be the points in *B* with the same label as *xi*. For an anchor *xi*, the SupCon loss is *L*ˆ *SC*(*f* , *xi*, *B*) = −1 |*P*(*<sup>i</sup>*,*<sup>B</sup>*)| ∑*p*∈*<sup>P</sup>*(*<sup>i</sup>*,*<sup>B</sup>*) log exp(*σ*(*xi*,*xp*)) ∑*a*∈*<sup>B</sup>*\*<sup>i</sup>* exp(*σ*(*xi*,*xa* )), where *<sup>P</sup>*(*<sup>i</sup>*, *B*) forms positive pairs and *B*\*i* forms negative pairs.

The optimal embedding distribution that minimizes *LSC* has one embedding per class, with the per-class embeddings collectively forming a regular simplex inscribed in the hypersphere Graf et al. [2]. Formally, if *h*(*x*) = *i*, then *f*(*x*) = *vi* for all *x* ∈ B. {*vi*}*<sup>K</sup> i*=1 makes up the regular simplex, defined by: a) ∑*<sup>K</sup> i*=1 *vi* = 0; b) *vi*2 = 1; and c) ∃*cK* ∈ R s.t. *v i vj* = *cK* for *i* = *j*. We describe this property as *class collapse* and define the distribution of *f*(*x*) that satisfies these conditions as *collapsed embeddings*.

### 2.2.2. End Model

After the supervised contrastive loss is used to train an encoder, a linear classifier *W* ∈ R*K*×*<sup>d</sup>* is trained on top of the representations *f*(*x*) by minimizing cross-entropy loss over softmax scores. We assume that *Wy*2 ≤ 1 for each *y* ∈ Y. The end model's empirical loss can be defined as <sup>L</sup><sup>ˆ</sup>(*<sup>W</sup>*, D) = ∑*xi*∈D − log exp(*f*(*xi*) *Wh*(*xi*)) ∑*<sup>K</sup> j*=1 exp(*f*(*xi*) *Wj*). The model uses softmax scores constructed with *f*(*x*) and *W* to generate predictions *p*<sup>ˆ</sup>(*y*|*x*), which we also write as *p*<sup>ˆ</sup>(*y*| *f*(*x*)). Finally, the generalization error of the model on P is the expected cross-entropy between *p*<sup>ˆ</sup>(*y*|*x*) and *p*(*y*|*x*), namely L(*<sup>x</sup>*, *y*, *f*) = <sup>E</sup>*<sup>x</sup>*,*y*∼P[<sup>−</sup> log *p*<sup>ˆ</sup>(*y*| *f*(*x*))].
