When creating any kind of machine learning model, evaluation methods are critical. In this post, we’ll go over how to create a confusion matrix in
sci-kit learn. The first function will create the values for the 4 quadrants in a confusion matrix, and the second function will create a nicely formatted plot.
For this example, we used an Adidas sales dataset from Kaggle. Below our code snippets, we’ve included more information about confusion matrices–what they are, and why they are useful.
Prep: build classification model and get predictions
In order to create the confusion matrix, we'll first need to build a classification model. We'll be using a few variables in the Adidas dataset to "predict" which region the transaction took place in. We've subset for transactions in the Northeast and the West.
import pandas as pd from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier # Split data into training and testing sets X = df[["Units Sold", "Operating Profit", "Price per Unit"]] y = df[["Region"]] X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.25, random_state = 42) # Build decision tree clf_dt = DecisionTreeClassifier(max_depth = 5) fit_dt = clf_dt.fit(X_train, y_train) y_pred = fit_dt.predict(X_test) # Print test and predictions print(y_test) print(y_pred)
Region 3221 Northeast 1117 West 468 Northeast 1129 Northeast 2754 Northeast ... ... 2981 West 40 West 1379 West 2089 West 2175 Northeast [3618 rows x 1 columns] ['Northeast' 'Northeast' 'Northeast' ... 'Northeast' 'Northeast' 'Northeast']
Based on the output, we can see that the testing set and the predictions are just a list of labels. In order to turn these labels into something meaningful, we'll use two functions from
confusion_matrix(y_test, y_pred, labels)
To use this function, you just need
y_test: a list of the actual labels (the testing set)
y_pred: a list of the predicted labels (you can see how we got these in the above code snippet). If you're not using a decision tree classifier, you can find analogous functions for that model.
labels: class labels (in this case accessed as an attribute of the classifer,
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # Get the confusion matrix values cm = confusion_matrix(y_test, y_pred, labels = clf_dt.classes_) print(cm)
[[1211 603] [ 487 1317]]
In order to create a more digestible output, we'll use a specific plotting function from
To use the function, we just need two arguments:
confusion_matrix: an array of values for the plot, the output from the
confusion_matrix()function is sufficient
display_labels: class labels (in this case accessed as an attribute of the classifer,
import matplotlib.pyplot as plt # Create the plot disp = ConfusionMatrixDisplay(confusion_matrix = cm, display_labels = clf_dt.classes_) disp.plot() plt.show()
Why is a confusion matrix useful?
Fundamentally, a confusion matrix allows you to quickly evaluate the performance of a classifier, by identifying the proportion of misclassifications or false positives and false negatives. Additionally, if you’re comparing multiple models, you can use confusion matrices:
- To identify which one has the best performance by comparing the rate of true positives, false negatives, false positives, and true negatives
- With precision-recall curves to select an appropriate threshold in multi-class classification problems.
See above for a reference image of confusion matrices, created in Lucidchart:
- True positive (upper left): data points that the model assigned label 1, that are actually categorized under label 1
- True negative (bottom right): data points that the model assigned label 2, that are actually categorized under label 2
- False positive (upper right): data points that the model assigned label 2, that are actually categorized under label 1
- False negative (bottom left): data points that the model assigned label 1, that are actually categorized under label 2
Einblick is an agile data science platform that provides data scientists with a collaborative workflow to swiftly explore data, build predictive models, and deploy data apps. Founded in 2020, Einblick was developed based on six years of research at MIT and Brown University. Einblick customers include Cisco, DARPA, Fuji, NetApp and USDA. Einblick is funded by Amplify Partners, Flybridge, Samsung Next, Dell Technologies Capital, and Intel Capital. For more information, please visit www.einblick.ai and follow us on LinkedIn and Twitter.