Understanding Correlation Matrix Plot
I’m loving this top down approach. Initially, I started with completing a ML project without clear understanding of how it is working as expected. Once I got the understanding at 10,000 ft level, I’m diving one step deeper to get into the understanding of “how” and “why”.
Today I got into the deeper understanding of Correlation Matrix Plot. I have used this plot in my projects before. However I did not know why I’m executing certain code. In this post, I’ll explain what each line of code is doing.
For this exercise, I took a new data set “Breast Cancer Wisconsin (Diagnostic) Data Set”. This data set has 32 attributes, hence a classic match to understand correlation matrix plot.
Before we get into the plot, I need to load the data set. As usual I’ll import necessary libraries and then load the data. Since I have explained these before, I’m not going to repeat and I’ll simply show the code.
import numpy as np
import pandas as pd
from pandas import read_csv
# Load dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data"
names=['id','dia','rad_mean','tex_mean','per_mean','are_mean','smo_mean','com_mean','con_mean','conp_mean','sym_mean','fra_mean','rad_se','tex_se','per_se','are_se','smo_se','com_se','con_se','conp_se','sym_se','fra_se','rad_worst','tex_worst','per_worst','are_worst','smo_worst','com_worst','con_worst','conp_worst','sym_worst','fra_worst']
dataset = read_csv(url, names=names)
To plot, we need to import this library
import matplotlib.pyplot as plt
As my objective is to plot a correlation matrix, I’ll create that object
corr = dataset.corr()
If I display the variable corr, this is what I see
As you can observe, it is basically a correlation matrix with “correlation value”. If you recall, +1 signifies “highly correlated”. We are going to display this matrix in a more colorful and appealing manner.
First, I need to create a figure object where I can draw a plot.
fig = plt.figure(figsize = (16,16))
Let’s understand figsize paremeter here. If I don’t provide any parameter like this
fig = plt.figure()
The output looks something like this:
If you observe labels at the top, you can hardly make out. Since we have 32 attributes, we need a larger plot so that we can see each attribute clearly.
Now let’s get back to the previous dimension
fig = plt.figure(figsize = (16,16))
The output for 16x16 looks like this:
In fact this is better than the default one, but still many labels are overlapped. So I’ll change the size to 32x32.
fig = plt.figure(figsize = (32,32))
The output looks like this:
If you open the above image in a new tab, you can see each attribute very clearly.
So basically based on number of attributes in a matrix you can change the size to get a clearer picture.
Word of caution: Do note that the above image is obtained with additional set of code. Since I wanted to explain about the size part I have shown these images. So far we haven’t added the required code to get the final output.
Let’s continue.
I need to add a sub plot to the figure object.
ax = fig.add_subplot(111)
Since we can add any number of sub plots, we can define the layout of subplots through the parameter 111. The first two numbers signify “the grid” and the last number signify the position of the plot. So here 111 means “1x1 grid, first subplot”.
Since we need to show the matrix as a figure, we need to call this method
cax = ax.matshow(corr, vmin=-1, vmax=1)
Observe here we are passing array object (corr) and also specifying the range from -1 to +1.
To display the legend, we need to add colorbar. The legend is the scale which is shown to the right in the above images.
fig.colorbar(cax)
Now if I check the plot, this is how it looks
One important observation here is the scale shown at the top and in the left does not make any sense. So let’s add labels.
ax.set_xticklabels(names)
ax.set_yticklabels(names)
Now the plot looks like this:
Even though the labels are shown, did you observe it is not showing all the attributes? Interestingly this needs to be explicitly mentioned.
ticks = np.arange(0,32,1)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
This is called “ticks”. We are saying we have 32 attributes and we need “ticks” for each of these attributes.
With that when I say
plt.show()
This is what the plot looks like
Observe the “ticks” on either side and compare with previous images.
Final observation: If we closely observe the final plot, we can make out the “correlation” between attributes and see which attributes are “highly correlated”. For example, the output “diagnosis” is highly correlated to the attributes “tex_mean”, “per_mean”, “fra_se”, “tex_worst”.
References:
Comments