10. Advanced Plotting with matplotlib#

Figures and main concepts from: realpython.com.

Learning Objectives

Be able to:

  • Read a publically available file directly from google drive

  • Use matplotlib’s advanced layout features for multiple plots

  • Understand the difference between a “figure” and “axes” in matplotlib

Date files needed:

Note

You can load data files directly from my google drive by using:
raw=pd.read_csv('https://drive.google.com/uc?id=FILE_ID')
where “FILE_ID” is the google drive file identifier given previously. Also, use either read_csv or read_excel as appropriate. For the california housing data, we would write:
raw=pd.read_csv('https://drive.google.com/uc?id=1lcMbOh19eUUhCdvSloNs5mOYLBCMkIgz')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

10.1. “Figure” versus “Axes” according to matplotlib#

In matplotlib, a “figure” is the outer container that might contain multiple plots. The individual plots are referred to as “axes”. That is, we create a set of axes for every plot that we want. The plot itself can contain multiple curves by using multiple plt.plot() commands. In all of our previous lessons, we have only defined a single set of axes using for example fig, ax = plt.subplots(nrows=1, ncols=1)

fig vs axes

Examples
fig, ax = plt.subplots(nrows=1, ncols=1) if we only have one axes (plot)
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2) if we define 2 axes in one figure with the axes named “ax1” and “ax2”.

Note

The names “fig”, “ax1”, “ax2” etc are up to you. There is nothing special about these particular names other than they are descriptive (easy to read) and short.

10.2. Creating rows and columns of axes#

The command plt.subplots() creates a figure with the number of axes specified by the call. For example, plt.subplots(nrows=2, ncols=3) creates a figure with 6 sets of axes (2 rows by 3 columns) as shown below.

plt.subplots(nrows=2, ncols=3)
plt.show()
../_images/8976ca1bafa7164707cd0b216e870ff36d19135033b9c11270ea30eb55d7469d.png

To make a plot within one of the axes, we need to access the axes by its name. We can do this by naming each axes. We usually call these ax1, ax2, etc. and specify them as a matrix. For example:
fig, ((ax1, ax2, ax3),(ax4,ax5,ax6))=plt.subplots(nrows=2, ncols=3)
fig, (ax1, ax2, ax3)=plt.subplots(nrows=1, ncols=3)
fig, (ax1, ax2, ax3)=plt.subplots(nrows=3, ncols=1)

fig, ((ax1, ax2, ax3),(ax4,ax5,ax6))=plt.subplots(nrows=2, ncols=3)
ax1.annotate('This is "ax1"',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
ax5.annotate('This is "ax5"',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
fig.tight_layout() # this removes the overlap that we had above
plt.show()
../_images/fcc39289f97a2409a5e9ffee1db672bed8505fb9c070882a9f5d3a3190a5e548.png

And for 2 axes in a row:

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(8, 4))
ax1.annotate('This is "ax1"',(0.5,0.5), xycoords='axes fraction', va='center', ha='center', bbox=dict(facecolor='pink', alpha=0.6),fontsize=10)
ax2.annotate('This is "ax2"',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
plt.show()
../_images/180c337b114184b67cf165c595aedee6114faaf17fb58010756fd7f2435befea.png

or for 2 axes in a column:

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1,figsize=(8, 4))
ax1.annotate('This is axes 1',(0.5,0.5), xycoords='axes fraction', va='center', ha='center', bbox=dict(facecolor='seagreen', alpha=0.6),fontsize=10)
ax2.annotate('This is axes 2',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
plt.show()
../_images/19ead127fa4d36585e1c9806c233c990ab4c2d005787909a3935358eb412c5f0.png

10.3. Adding data to our axes#

Now let’s add some data to our axes and perhaps different plot types. First, let’s generate some data to plot

x = np.random.randint(low=1, high=11, size=50) #generate 50 random integers between 1 and 10
y = x + np.random.randint(1, 5, size=x.size)  #generate 50 random integers between 1 and 4 and add to x
data = np.column_stack((x, y))
data[0:5] #with numpy arrays we use different notation to slice elements as compared to pandas
array([[ 1,  2],
       [ 6,  7],
       [ 5,  7],
       [ 8, 11],
       [10, 13]])

Now we will add a few options and make some nice plots.

# create a figure with two axes in a row and adjust the overall figure size
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(8, 4))
# make a scatter plot in the first axes (ax1)  
# marker = circle, color = red, edgecolor = blue, size=75 points
ax1.scatter(x=x, y=y, marker='o', c='r', edgecolor='b', s=75)
ax1.set_title('Scatter: x versus y') # add title to ax1
ax1.set_xlabel('x') # add xlabel to ax1
ax1.set_ylabel('y') # add ylabel to ax1

# now let's work on the second axes (ax2)
ax2.hist(data, bins=np.arange(data.min(), data.max()),label=('x', 'y')) #define a histogram
ax2.legend(loc=(0.65, 0.8)) #by using "label" in hist() we can now show a legend, location at 65% in horizontal and 80% in vertical direction
ax2.set_title('Frequencies of x and y') # add title to ax2
ax2.yaxis.tick_right() # place the y-axis ticks on the right side rather than left

fig.tight_layout() # keep axes tight together but avoid overlap
../_images/8f4ef6b53d731d41c7822f9e8f8ebf6f77012364319655fe65fdbd17f6fa4bc3.png

10.4. More advanced layouts (span rows and/or columns)#

If you need more advanced layout beyond simple grids of graphs, then Matplotlib’s gridspec module allows for more subplot customization. Pyplot’s subplot2grid() interacts with this module.

Say we want this setup:
fig vs axes

Set it up like a 3 row x 2 column grid (as shown below) then allow ax1 to span 2 columns and 2 rows:
Notice below the numbering of the grid. The upper left plot is at location (0,0) and the lower right would be (2,1).
fig vs axes

gridsize = (3, 2) # 3 rows x 2 columns
fig = plt.figure(figsize=(12, 8)) #overall size of figure contains total grid 3x2
ax1 = plt.subplot2grid(gridsize, (0, 0), colspan=2, rowspan=2) # start at axes (0,0) and span 2 col and 2 rows
ax2 = plt.subplot2grid(gridsize, (2, 0)) #3rd row, 1st col
ax3 = plt.subplot2grid(gridsize, (2, 1)) #3rd row, 2nd col
ax1.annotate('This is axes 1',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
ax2.annotate('This is axes 2',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
ax3.annotate('This is axes 3',(0.5,0.5), xycoords='axes fraction', va='center', ha='center')
plt.show()
../_images/c1d487fac6a296fc08018b7a58c4194489832cf9afcfe7e2301e6778e0092443.png

You could make many creative layouts using the simple concepts shown in the previous output.

10.5. Example 1 California housing data#

Let’s take a look at some California housing data whose description can be found at https://developers.google.com/machine-learning/crash-course/california-housing-data-description

raw2=pd.read_csv('https://drive.google.com/uc?id=1lcMbOh19eUUhCdvSloNs5mOYLBCMkIgz')

housing=raw2.copy()
housing
longitude latitude housing median age total rooms total bedrooms population households median income (10k) median house value
0 -122.23 37.88 41 880 129 322 126 8.3252 452600
1 -122.22 37.86 21 7099 1106 2401 1138 8.3014 358500
2 -122.24 37.85 52 1467 190 496 177 7.2574 352100
3 -122.25 37.85 52 1274 235 558 219 5.6431 341300
4 -122.25 37.85 52 1627 280 565 259 3.8462 342200
... ... ... ... ... ... ... ... ... ...
20635 -121.09 39.48 25 1665 374 845 330 1.5603 78100
20636 -121.21 39.49 18 697 150 356 114 2.5568 77100
20637 -121.22 39.43 17 2254 485 1007 433 1.7000 92300
20638 -121.32 39.43 18 1860 409 741 349 1.8672 84700
20639 -121.24 39.37 16 2785 616 1387 530 2.3886 89400

20640 rows × 9 columns

y = housing.iloc[:, -1] #home price
pop=housing.loc[:, 'population'] #population
age=housing.loc[:, 'housing median age'] #average age of home
#create layout
gridsize = (3, 2)
fig = plt.figure(figsize=(8, 7)) #overall size of figure
ax1 = plt.subplot2grid(gridsize, (0, 0), colspan=2, rowspan=2)
ax2 = plt.subplot2grid(gridsize, (2, 0))
ax3 = plt.subplot2grid(gridsize, (2, 1))

#fill layout with plots
ax1.set_title('Home value as a function of home age & area population',fontsize=14)
sctr = ax1.scatter(x=age, y=pop, c=y, cmap='RdYlGn')
plt.colorbar(sctr, ax=ax1, format='$%d')
ax1.set_yscale('log')
ax2.hist(age, bins='auto')
ax3.hist(pop, bins='auto', log=True)

ax2.annotate('Histogram: home age',(0.5,0.9), xycoords='axes fraction', va='center', ha='center', bbox=dict(facecolor='white', alpha=0.6),fontsize=10) 

ax3.annotate('Histogram: area population (log scl.)',(0.5,0.9), xycoords='axes fraction', va='center', ha='center', bbox=dict(facecolor='white', alpha=0.6),fontsize=10) 


plt.show()
../_images/3032a124cc8a3faf8b53a55ab3ebe521090c61534611d2fa2214c6eac58f2236.png

10.6. Example 2 County income data#

raw=pd.read_excel('https://drive.google.com/uc?id=1rmIPRkMAqqfY3TNpWZ13aV4PfuKihguM')
data=raw.copy()
data.columns=['county', 'state', 'households',
       '<10k',
       '10k_15k',
       '15k_25k',
       '25k_35k',
       '35k_50k',
       '50k_75k',
       '75k_100k']
data
county state households <10k 10k_15k 15k_25k 25k_35k 35k_50k 50k_75k 75k_100k
0 Autauga Alabama 21397 1417 1172 2219 1684 2721 3817 2931
1 Baldwin Alabama 80930 4068 3685 7092 8403 11245 14928 10374
2 Barbour Alabama 9345 1391 796 1641 1092 1331 1239 769
3 Bibb Alabama 6891 570 554 897 610 1021 1100 1119
4 Blount Alabama 20847 1706 1145 2199 2600 2921 4050 2497
... ... ... ... ... ... ... ... ... ... ...
4925 Sweetwater Wyoming 15523 805 645 1202 868 1653 2605 2506
4926 Teton Wyoming 9019 182 198 667 383 838 1552 1448
4927 Uinta Wyoming 7597 382 232 559 791 934 1451 1241
4928 Washakie Wyoming 3365 144 174 286 379 541 710 519
4929 Weston Wyoming 2891 181 143 338 246 398 433 433

4930 rows × 10 columns

The above data shows the total number of households in each county of each state and a breakdown of the number of households that have an income in one of the indicated ranges.

For example, 2931 households have an income between 75,000 and 100,000 in Autauga County, Alabama. This represents 2931/21397*100 = 13.7% of all households in Autauga County.

10.7. Exercises#

10.7.1. Problem 1#

Complete the following steps on the county income data:

First, you need to fix the data set to remove duplicates using drop_duplicates(). Instead of 4930 rows you should get 3142 records.

Determine the following:

  1. The number of households that have an income greater than 100,000 for each county across the USA. Add this data to the dataframe.

  2. The percentage of households for each income range for each county, i.e. divide the data for each row by “households” and multiply by 100.

10.7.2. Problem 2#

Using the cleaned data from Problem 1, plot the following in the layout described below.

  1. The percentage of households with an income less than 50,000 and the percentage of households with an income greater than 100,000 for every state in the US.

  2. The same data but now for every county in Virginia.

  3. A bar chart for the same data but now for Montgomery County, Virginia.

The three plots should be laid out similar to the figure below where each plot is the same height but plot 1 has a width of 3, plot 2 has a width of 2 and plot 3 has a width of 1.
fig vs axes