Apr 2, 2022 · 25 min read
An Introduction to Slice Discovery with Domino
Sabri Eyuboglu, Maya Varma, Khaled Saab, Jared Dunnmon, James Zou and Chris Ré.
In this post, we introduce Domino, a new approach for discovering the systematic errors made by machine learning models. We also discuss a framework for quantitatively evaluating methods like Domino.
Links: 📄 Paper (ICLR 2022) 🌍 SAIL Blog Post 💻 GitHub 📘 Docs 📒 Google Colab
Machine learning models that achieve high overall accuracy often make systematic errors on coherent slices of validation data. For example, an image classification model trained to detect cars may perform poorly on the data slice consisting of vintage vehicles.
The search for underperforming slices is a critical (but often overlooked) part of model evaluation. Slice-based evaluations can enable practitioners to make informed decisions around model deployment. This is particularly important in safety-critical settings like medicine. It can also help practitioners debug and improve models: after an underperforming slice is identified, we can improve model robustness by either updating the training dataset or using robust optimization techniques (e.g. Sohoni et al., 2020; Sagawa et al., 2020).
However, identifying underperforming slices is difficult in practice. Many data slices are “hidden”, meaning they are not annotated in metadata and are difficult to extract programmatically from unstructured inputs (e.g. images, video, time-series data). Returning to our example from earlier, it is likely that your image dataset will not include fine-grained labels identifying vintage cars, making it difficult to evaluate performance on the slice. This raises the following question:
How can we automatically identify data slices on which our model underperforms?
In this blog post, we discuss our recent exploration of this question. We introduce Domino, a novel method for identifying and describing underperforming slices. We also discuss an evaluation framework for rigorously evaluating our method across diverse slice types, tasks, and datasets.
Blog Post outline. We’ll begin by grabbing a popular machine learning model off-the-shelf and trying to evaluate it using standard techniques: point estimates of metrics, precision-recall curves, and manual error analysis. This exercise highlights the limitations of this standard toolbox when it comes to identifying “hidden” data slices with poor performance.
Next, we’ll discuss a class of slice discovery methods (SDMs) designed to discover these hidden slices. We’ll introduce our own SDM, Domino (a nod to the pizza chain of the same name, known for its reliable slice deliveries) and apply it to the same off-the-shelf model.
Finally, we’ll discuss a framework for evaluating slice discovery methods, highlighting some of their limitations and exciting directions for future work.
Evaluating a model with the standard toolbox
Let’s investigate a model that you’ve likely used before: a ResNet18 pretrained on ImageNet.
💻 Loading the model and dataset
import meerkat as mk
from torchvision.models import resnet18
dp = mk.datasets.get(
"imagenet",
dataset_dir="/home/common/datasets/imagenet",
download=False
)
dp = dp.lz[dp["split"] == "valid"]
model = resnet18(pretrained=True)
Suppose we’d like to use this pretrained ResNet to scrape images of cars from the internet. Before doing so, it would be prudent to measure how good the model actually is at distinguishing cars from other objects.
To do so, we can take the full ImageNet validation set (a collection of 50,000 Flickr images), compute the model’s predictions, and calculate some metrics. Upon doing so, we find that the model is quite discriminative (99.77% AUROC), but still not perfect: tracing along the precision recall curve to the right, we see that there is no decision threshold for which the model achieves both precision and recall above 85%.
Ideally, we would like to better understand the model’s failure modes before deciding to use it and choosing a decision threshold. For example, if the model struggles to recognize certain types of cars, then using the model with a high threshold could mean our collection of scraped images will omit whole subclasses of cars.
💻 Running inference
import torchvision.transforms as transforms
# The im
# source: https://github.com/pytorch/vision/issues/39
transform = transforms.Compose(
[
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
dp["input"] = dp["image"].to_lambda(transform)
import torch
model.to(0)
@torch.autograd.no_grad()
def forward(batch):
return model(batch.data.to(0)).cpu().numpy()
valid_dp["preds"] = valid_dp["input"].map(forward, is_batched_fn=True, batch_size=128, pbar=True)
💻 Computing metrics
from sklearn.metrics import roc_auc_score
# There isn't actually a general "car" class in ImageNet.
# There are only more specific subclasses like "sports car",
# "jeep", "convertible". Luckily, because imagenet labels
# are organized into an ontology, we can rather easily get
# labels for a general "car" class. This requires a bit of
# code that is out of the scope of this blogpost, so we've
# abstracted it away below.
# See the notebook for more details.
dp["target"] = get_target_labels(valid_dp, target_synset="car.n.01")
dp["prob"] = torch.softmax(torch.tensor(dp["preds"].data), dim=-1)[
:, np.unique(dp.lz[dp["target"]]["class_idx"])
].sum(axis=1)
from sklearn.metrics import roc_auc_score
roc_auc_score(dp["target"], dp["prob"].cpu().numpy())
💻 Plotting precision-recall curve
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
import seaborn as sns
precision, recall, thresholds = precision_recall_curve(
dp["target"],
dp["prob"]
)
sns.lineplot(recall, precision)
plt.xlabel("recall")
plpt.ylabel("precision")
How can we better understand a model’s failure modes? The simplest (and perhaps most commonly employed) error analysis strategy involves manually inspecting the set of images where the model made a mistake and looking for any worrying patterns. Below, we show the 155 incorrectly classified images in the validation set (when using a decision threshold of 0.5). Try scrolling through the images and see if you can find any trends in the errors.
💻 Displaying errors
dp.lz[
(dp["target"].data != (dp["prob"] > 0.5)).astype(bool)
][['image_id', "image", "target", "prob"]]
We see from this example that manually identifying trends across incorrect predictions is difficult and time-consuming, especially on large datasets with diverse examples. Slice discovery methods, like Domino, offer an alternative path forward.
Evaluating a model with slice discovery
This brings us to slice discovery: the task of mining unstructured input data for semantically meaningful subgroups on which a model performs poorly. We refer to automated techniques that mine input data for semantically meaningful slices as slice discovery methods (SDMs).
In order to be broadly useful across diverse settings, an ideal SDM should satisfy the following desiderata:
- Identified slices should contain examples on which the model underperforms, or has a high error rate.
- Identified slices should contain examples that are coherent, or align closely with a human-understandable concept.
Below, we apply our new slice discovery method, Domino, to the same model we began auditing above. Domino was designed to identify coherent, underperforming data slices (i.e. groups of similar validation datapoints on which the model makes errors). It leverages a powerful class of recently-developed cross-modal representation learning approaches, which yield semantically-meaningful representations by embedding images and text in the same latent space. We demonstrate that using cross-modal representations both improves slice coherence and enables Domino to generate natural language descriptions for identified slices!
Domino follows a three-step procedure:
-
Embed: Domino encodes the validation images alongside text in a cross-modal embedding space using a model like CLIP.
💻 Embedding images
from domino import embed dp = embed( data=dp, input_col="image", encoder="clip", device=0 )
-
Slice: Using an error-aware mixture model, Domino identifies regions in the embedding space with a high concentration of errors.
💻 Running Domino mixture model
from domino import DominoSlicer domino = DominoSlicer( y_log_likelihood_weight=40, y_hat_log_likelihood_weight=40, n_mixture_components=60, n_slices=10, max_iter=10, n_pca_components=128, init_params="confusion", confusion_noise=3e-3 ) domino.fit(data=dp, embeddings="emb", targets="target", pred_probs="prob") dp["domino_slices"] = domino.predict_proba( data=dp, embeddings="emb", targets="target", pred_probs="prob" )
-
Describe: Finally, Domino uses the text embeddings computed in Step 1 to generate natural language descriptions for discovered slices. By describing underperforming slices with natural language, Domino can enable practitioners to address dataset biases and correct models.
💻 Embedding candidate descriptions
from domino import generate_candidate_descriptions phrase_templates = [ "a photo of [MASK].", "a photo of {} [MASK].", "a photo of [MASK] {}.", "a photo of [MASK] {} [MASK].", ] text_dp = generate_phrases( templates=phrase_templates ) text_dp = embed( text_dp, input_col="output_phrase", encoder="clip", device=0 )
💻 Describing slices
from domino import describe description_dp = describe( data=dp, embeddings="clip(image)", pred_probs="prob", targets="target", slices="domino_slices", slice_idx=0, text=text_dp, text_embeddings="clip(output_phrase)", )
Let’s return to our car example from earlier. We first embed the images using CLIP representations, and we then slice in the embedding space. Let’s take a look at one of the slices Domino discovers.
First, we’ll compare the distribution of model predictions in the slice to the distribution of model predictions in the rest of the dataset. The plot below shows these distributions for one of the slices discovered by Domino. Clearly, the prediction distribution in the slice differs from the distribution in the rest of dataset. This tells us that the slice consists mostly of false negatives and low-confidence true-positives. Because they reflect the model's confidence (not just its prediction), plots like these help us better understand how the models behavior in the slice differs from its behavior out of the slice.
With an understanding of which error types are represented in the slice (in this case, mostly false negatives), we now look to describe the slice with natural language. Let’s inspect the natural language descriptions produced by Domino.
Domino’s natural language descriptions of the discovered slice. The output_phrase
column contains the descriptions and the score
column contains the inner-product scores Domino assigned to each phrase. The top five descriptions ranked by the inner-product score are shown.
These descriptions suggest that the model fails to identify photos of cars taken from the interior. Finally, we can scroll through the images to confirm this and look for any other visual patterns.
Let’s look at some of the other slices discovered by Domino.
🍕 Slice 1: Photos of car interiors.
🍕 Slice 2: Photos of cars offroading.
🍕 Slice 3: Photos of vintage trucks.
🍕 Slice 4: Photos of racecars.
🍕 Slice 5: Photos of police vehicles
Domino is the first SDM that can describe identified slices with natural language. This has the potential to shorten error analysis iteration cycles, since practitioners will not be required to manually analyze long lists of prediction errors. Our hope is that users will be able to leverage the descriptions to quickly identify sources of systematic errors, allowing them to take the appropriate actions to patch the model.
Of course, Domino isn’t perfect. Sometimes the identified slices are incoherent (i.e. the examples in it seem to have nothing in common).
🍕 Slice 6: An incoherent slice.
Domino may also fail to identify a coherent slice on which the model underperforms. Given the potential for these failures, how can we measure how well Domino does its job? How does it compare to other slice discovery methods?
Evaluating slice discovery methods
SDMs like Domino have traditionally been evaluated qualitatively, due to a lack of a simple quantitative approach. Typically, in these evaluations, the SDM is applied to a few different models and identified slices are visualized, just as we did above. Practitioners can then inspect the slices and judge whether the slices make sense. However, these qualitative evaluations are subjective and do not scale beyond more than a few models. Moreover, they cannot tell us if the SDM has missed an important, coherent slice.
Ideally, we’d like to estimate the failure rate of an SDM: how often it fails to identify a coherent slice on which the model underperforms. Estimating this failure rate is very challenging because we don’t typically know the set of slices on which a model underperforms. How could we possibly know if the SDM is missing any?
To solve this problem, we trained thousands of models that were specifically constrained to underperform on pre-defined slices. Our approach involved (1) obtaining a dataset with some annotated slices, and (2) manipulating the dataset such that, with high probability, a model trained on it would exhibit poor performance on one or more of the annotated slices.
An illustrative example. To better understand this, let’s step through the process of training one of these models. CelebA is one such dataset where slices are annotated. It consists of thousands of celebrity photos annotated with attributes like “black hair” and “mustache”. See a sample of the data below:
💻 Loading the CelebA dataset
import meerkat as mk
dp = mk.datasets.get("celeba")
dp.head(20)
Leveraging these annotations, we can train a model which we know will underperform on a coherent slice of data. First, we choose a target task: detect whether the celebrity is wearing a necktie. Next, we subsample the dataset such that there is a correlation between the target attribute necktie and some other attribute, say, eyeglasses. We call this process inducing a spurious correlation.
💻 Inducing a spurious correlation
from domino.eval.utils import induce_correlation
idxs = induce_correlation(
dp=dp,
corr=0.8,
n=30_000,
attr_a="eyeglasses",
attr_b="wearing_necktie"
)
corr_dp = dp.lz[idxs]
corr_dp[["image", "eyeglasses", "wearing_necktie"]].head(20)
In this example, we induced a rather strong association: the Pearson correlation coefficient is 0.8. Due to this artificial association between neckties and eyeglasses, a model trained to detect neckties on this subsampled dataset will likely underperform on the coherent slice celebrities wearing eyeglasses but not wearing neckties. To confirm this we’ll train a model on the new dataset:
💻 Training a model
from domino.eval.train import train
model = train(
dp=corr_dp,
input_column="image",
target_column="wearing_necktie",
id_column="file",
batch_size=128,
max_epochs=1
)
With the model trained, we can confirm that it underperforms on the slice. In this case, the model achieves a specificity (i.e. true-negative rate) of 99.1% overall, but only 51.9% among people wearing eyeglasses.
💻 Checking model performance on slice
corr_dp["input"] = corr_dp["image"].to_lambda(model.config["transform"])
model.to(0)
@torch.autograd.no_grad()
def forward(batch):
return torch.softmax(model(batch.data.to(0)), dim=1).cpu().numpy()
corr_dp["probs"] = corr_dp["input"].map(
forward,
is_batched_fn=True,
batch_size=128,
pbar=True,
num_workers=6
)
valid_dp = corr_dp.lz[corr_dp["split"] == "valid"]
pos_dp = valid_dp.lz[valid_dp["wearing_necktie"] != 1]
print(f'Overall Specificity:{(pos_dp["probs"][:, 0] > 0.5).mean()}')
print(f'In-slice Specificity:{(pos_dp.lz[pos_dp["eyeglasses"] == 1]["probs"][:, 1] > 0.5).mean()}')
We now know that there is coherent slice on which the model significantly underperforms. We can then run a slice discovery method (e.g. Domino) and check whether it is able to identify that slice without access to the ground-truth slice annotations.
💻 Running the SDM
from domino import DominoSDM
phrase_templates = [
"a photo of a person {} [MASK].",
"a photo of a person [MASK] {}.",
"a photo of a person [MASK] {} [MASK].",
"a photo of a person [MASK] [MASK] {}.",
"a photo of [MASK] {} person",
"[MASK] {} photo of a person",
]
from domino import embed, generate_candidate_descriptions
text_dp = generate_candidate_descriptions(
templates=phrase_templates
)
text_dp =embed(
text_dp,
input_col="output_phrase",
encoder="clip",
device=0
)
valid_dp = corr_dp.lz[corr_dp["split"] == "valid"]
valid_dp = embed(data=valid_dp, input_col="image", encoder="clip", device=0)
domino = DominoSDM(
y_log_likelihood_weight=20,
y_hat_log_likelihood_weight=20,
n_mixture_components=30,
n_slices=10,
max_iter=10,
n_pca_components=128,
init_params="confusion",
confusion_noise=3e-3
)
domino.fit(data=valid_dp, embeddings="emb", targets="target", pred_probs="probs")
valid_dp["domino_slices"] = domino.transform(
data=valid_dp, embeddings="emb", targets="target", pred_probs="probs"
)
The chart at the top indicates that the slice consists mostly of false positives, the textual descriptions exactly describe the induced slice, and the images are almost all of folks wearing glasses and no tie!
Because we have access to the ground-truth slice annotations, we can also compute quantitative metrics like the precision-at-10: the fraction of the top ten examples in Domino’s discovered slice (ranked by the score that the mixture model assigns to each example) that are actually in the ground-truth slice.
💻 Computing precision-at-10
slice_dp = valid_dp.lz[(-valid_dp["domino_slices"].data[:, 5]).argsort()]
slice_dp["target_slice"] = ~slice_dp["wearing_necktie"] & slice_dp["eyeglasses"]
slice_dp.lz[:10]["target_slice"].mean()
In this example, the precision-at-10 is 100%. Of the top ten examples in Domino’s predicted slice, all of them show a person wearing eyeglasses but not wearing a necktie. So, for this one model, we can say Domino successfully identified the coherent slice that we induced.
Scaling up. In order to estimate an SDM’s failure rate (i.e. how often it fails to identify a coherent slice on which the model underperforms), we repeated this process thousands of times across a variety of different datasets, tasks and slice types. Our choice of slice types and datasets was informed by a literature survey of underperforming slices occuring in the wild (see Section A.1 of our paper for a summary).
Using this evaluation framework, we were able to compare SDMs and run ablation studies evaluating specific SDM design choices. Two key findings emerged from these experiments:
- Cross-modal embeddings improve SDM performance. We found that the choice of representation matters – a lot! Slice discovery methods based on cross-modal embeddings outperform those based on a single modality by at least 9 percentage-points. When compared with using the activations of the trained model, the gap grows to 15 percentage points. This finding is of particular interest given that classifier activations are a popular embedding choice in existing SDMs.
- Modeling both the prediction and class label enables accurate slice discovery. Good embeddings alone do not suffice – the choice of algorithm for actually extracting the underperforming slices from the embedding space, is significant as well. We find that a simple mixture model that jointly models the embeddings, labels and predictions enables a 105% improvement over the next best slicing algorithm. We hypothesize that this is because this algorithm is unique in modeling the class labels and the model’s predictions as separate variables, which leads to slices that are “pure” in their error type (false positive vs. false negative).
However, there’s still a long way to go: slice discovery is a challenging task, and Domino, the best performing method in our experiments, still fails to recover over 60% of coherent slices. We see a number of exciting avenues for future work that could begin to close this gap.
- We suspect that improvements in the embeddings that power slice discovery will be driven by large cross modal datasets, so work in dataset curation and management could help push the needle.
- In our work, we compared slicing algorithms empirically. Applying theoretical tools to slice discovery may help us understand the tradeoffs between different slicing algorithms.
- In our experiments, we performed slice discovery on validation sets with tens of thousands of examples. In practice, validation datasets are sometimes much larger, so it will be important to address the systems challenges that arise when applying these algorithms on very large validation datasets with hundreds of thousands or millions of examples.
- In this blog post, we described slice discovery as a fully automated process, while, in the future, we expect that effective slice discovery systems will be highly interactive: practitioners will be able to quickly explore slices and provide feedback. Forager, a system for rapid data exploration, is an exciting step in this direction.
We are really excited to continue working on this important problem and to collaborate with others as we seek to develop more reliable slice discovery methods. To facilitate this process, we are releasing 984 models and their associated slices as part of DCBench, a suite of data centric benchmarks. This will allow others to reproduce our results and also develop new slice discovery methods. Additionally, we are also releasing a Domino Python package which includes implementations of popular slice discovery methods.
pip install domino
If you’ve developed a new slice discovery method and would like us to add it to the library please reach out!
We'd like to thank Arjun Desai, Megan Leszczynski, Michael Zhang, and Simran Arora for providing valuable feedback on this blog post.