Creating a confusion matrix using scikit-learn

Einblick Content Team - April 13th, 2023

When creating any kind of machine learning model, evaluation methods are critical. In this post, we’ll go over how to create a confusion matrix in sci-kit learn. The first function will create the values for the 4 quadrants in a confusion matrix, and the second function will create a nicely formatted plot.

For this example, we used an Adidas sales dataset from Kaggle. Below our code snippets, we’ve included more information about confusion matrices–what they are, and why they are useful.

Prep: build classification model and get predictions

In order to create the confusion matrix, we'll first need to build a classification model. We'll be using a few variables in the Adidas dataset to "predict" which region the transaction took place in. We've subset for transactions in the Northeast and the West.

Example: DecisionTreeClassifier()

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# Split data into training and testing sets
X = df[["Units Sold", "Operating Profit", "Price per Unit"]]
y = df[["Region"]]
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.25, random_state = 42)

# Build decision tree
clf_dt = DecisionTreeClassifier(max_depth = 5)
fit_dt = clf_dt.fit(X_train, y_train)
y_pred = fit_dt.predict(X_test)

# Print test and predictions
print(y_test)
print(y_pred)

Output:

         Region
3221  Northeast
1117       West
468   Northeast
1129  Northeast
2754  Northeast
...         ...
2981       West
40         West
1379       West
2089       West
2175  Northeast
[3618 rows x 1 columns]
['Northeast' 'Northeast' 'Northeast' ... 'Northeast' 'Northeast'
 'Northeast']

Based on the output, we can see that the testing set and the predictions are just a list of labels. In order to turn these labels into something meaningful, we'll use two functions from scikit-learn: confusion_matrix() and ConfusionMatrixDisplay().

Basic Syntax: confusion_matrix(y_test, y_pred, labels)

To use this function, you just need

  • y_test: a list of the actual labels (the testing set)
  • y_pred: a list of the predicted labels (you can see how we got these in the above code snippet). If you're not using a decision tree classifier, you can find analogous functions for that model.
  • labels: class labels (in this case accessed as an attribute of the classifer, clf_dt)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Get the confusion matrix values
cm = confusion_matrix(y_test, y_pred, labels = clf_dt.classes_)
print(cm)

Output:

[[1211  603]
 [ 487 1317]]

In order to create a more digestible output, we'll use a specific plotting function from scikit-learn.

Create Visualization: ConfusionMatrixDisplay(confusion_matrix, display_labels)

To use the function, we just need two arguments:

  • confusion_matrix: an array of values for the plot, the output from the scikit-learn confusion_matrix() function is sufficient
  • display_labels: class labels (in this case accessed as an attribute of the classifer, clf_dt)
import matplotlib.pyplot as plt

# Create the plot
disp = ConfusionMatrixDisplay(confusion_matrix = cm, display_labels = clf_dt.classes_)
disp.plot()
plt.show()

Output:

sklearn confusion matrix examplesklearn confusion matrix example

Why is a confusion matrix useful?

Fundamentally, a confusion matrix allows you to quickly evaluate the performance of a classifier, by identifying the proportion of misclassifications or false positives and false negatives. Additionally, if you’re comparing multiple models, you can use confusion matrices:

  • To identify which one has the best performance by comparing the rate of true positives, false negatives, false positives, and true negatives
  • With precision-recall curves to select an appropriate threshold in multi-class classification problems.
Confusion matrix exampleConfusion matrix example

See above for a reference image of confusion matrices, created in Lucidchart:

  • True positive (upper left): data points that the model assigned label 1, that are actually categorized under label 1
  • True negative (bottom right): data points that the model assigned label 2, that are actually categorized under label 2
  • False positive (upper right): data points that the model assigned label 2, that are actually categorized under label 1
  • False negative (bottom left): data points that the model assigned label 1, that are actually categorized under label 2

About

Einblick is an agile data science platform that provides data scientists with a collaborative workflow to swiftly explore data, build predictive models, and deploy data apps. Founded in 2020, Einblick was developed based on six years of research at MIT and Brown University. Einblick customers include Cisco, DARPA, Fuji, NetApp and USDA. Einblick is funded by Amplify Partners, Flybridge, Samsung Next, Dell Technologies Capital, and Intel Capital. For more information, please visit www.einblick.ai and follow us on LinkedIn and Twitter.

Start using Einblick

Pull all your data sources together, and build actionable insights on a single unified platform.

  • All connectors
  • Unlimited teammates
  • All operators