os_labs/lab12/thinkplot.py

505 lines
12 KiB
Python

"""This file contains code for use with "Think Stats",
by Allen B. Downey, available from greenteapress.com
Copyright 2010 Allen B. Downey
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
"""
import math
import matplotlib
import matplotlib.pyplot as pyplot
import numpy as np
# customize some matplotlib attributes
#matplotlib.rc('figure', figsize=(4, 3))
#matplotlib.rc('font', size=14.0)
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
#matplotlib.rc('legend', fontsize=20.0)
#matplotlib.rc('xtick.major', size=6.0)
#matplotlib.rc('xtick.minor', size=3.0)
#matplotlib.rc('ytick.major', size=6.0)
#matplotlib.rc('ytick.minor', size=3.0)
class Brewer(object):
"""Encapsulates a nice sequence of colors.
Shades of blue that look good in color and can be distinguished
in grayscale (up to a point).
Borrowed from http://colorbrewer2.org/
"""
color_iter = None
colors = ['#081D58',
'#253494',
'#225EA8',
'#1D91C0',
'#41B6C4',
'#7FCDBB',
'#C7E9B4',
'#EDF8B1',
'#FFFFD9']
# lists that indicate which colors to use depending on how many are used
which_colors = [[],
[1],
[1, 3],
[0, 2, 4],
[0, 2, 4, 6],
[0, 2, 3, 5, 6],
[0, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
]
@classmethod
def Colors(cls):
"""Returns the list of colors.
"""
return cls.colors
@classmethod
def ColorGenerator(cls, n):
"""Returns an iterator of color strings.
n: how many colors will be used
"""
for i in cls.which_colors[n]:
yield cls.colors[i]
raise StopIteration('Ran out of colors in Brewer.ColorGenerator')
@classmethod
def InitializeIter(cls, num):
"""Initializes the color iterator with the given number of colors."""
cls.color_iter = cls.ColorGenerator(num)
@classmethod
def ClearIter(cls):
"""Sets the color iterator to None."""
cls.color_iter = None
@classmethod
def GetIter(cls):
"""Gets the color iterator."""
return cls.color_iter
def PrePlot(num=None, rows=1, cols=1):
"""Takes hints about what's coming.
num: number of lines that will be plotted
"""
if num:
Brewer.InitializeIter(num)
# TODO: get sharey and sharex working. probably means switching
# to subplots instead of subplot.
# also, get rid of the gray background.
if rows > 1 or cols > 1:
pyplot.subplots(rows, cols, sharey=True)
global SUBPLOT_ROWS, SUBPLOT_COLS
SUBPLOT_ROWS = rows
SUBPLOT_COLS = cols
def SubPlot(rows, cols, plot_number):
"""Configures the number of subplots and changes the current plot.
rows: int
cols: int
plot_number: int
"""
pyplot.subplot(rows, cols, plot_number)
class InfiniteList(list):
"""A list that returns the same value for all indices."""
def __init__(self, val):
"""Initializes the list.
val: value to be stored
"""
list.__init__(self)
self.val = val
def __getitem__(self, index):
"""Gets the item with the given index.
index: int
returns: the stored value
"""
return self.val
def Underride(d, **options):
"""Add key-value pairs to d only if key is not in d.
If d is None, create a new dictionary.
d: dictionary
options: keyword args to add to d
"""
if d is None:
d = {}
for key, val in options.iteritems():
d.setdefault(key, val)
return d
def Clf():
"""Clears the figure and any hints that have been set."""
Brewer.ClearIter()
pyplot.clf()
def Figure(**options):
"""Sets options for the current figure."""
Underride(options, figsize=(6, 8))
pyplot.figure(**options)
def Plot(xs, ys, style='', **options):
"""Plots a line.
Args:
xs: sequence of x values
ys: sequence of y values
style: style string passed along to pyplot.plot
options: keyword args passed to pyplot.plot
"""
color_iter = Brewer.GetIter()
if color_iter:
try:
options = Underride(options, color=color_iter.next())
except StopIteration:
print 'Warning: Brewer ran out of colors.'
Brewer.ClearIter()
options = Underride(options, linewidth=3, alpha=0.8)
pyplot.plot(xs, ys, style, **options)
def Scatter(xs, ys, **options):
"""Makes a scatter plot.
xs: x values
ys: y values
options: options passed to pyplot.scatter
"""
options = Underride(options, color='blue', alpha=0.2,
s=30, edgecolors='none')
pyplot.scatter(xs, ys, **options)
def Pmf(pmf, **options):
"""Plots a Pmf or Hist as a line.
Args:
pmf: Hist or Pmf object
options: keyword args passed to pyplot.plot
"""
xs, ps = pmf.Render()
if pmf.name:
options = Underride(options, label=pmf.name)
Plot(xs, ps, **options)
def Pmfs(pmfs, **options):
"""Plots a sequence of PMFs.
Options are passed along for all PMFs. If you want different
options for each pmf, make multiple calls to Pmf.
Args:
pmfs: sequence of PMF objects
options: keyword args passed to pyplot.plot
"""
for pmf in pmfs:
Pmf(pmf, **options)
def Hist(hist, **options):
"""Plots a Pmf or Hist with a bar plot.
The default width of the bars is based on the minimum difference
between values in the Hist. If that's too small, you can override
it by providing a width keyword argument, in the same units
as the values.
Args:
hist: Hist or Pmf object
options: keyword args passed to pyplot.bar
"""
# find the minimum distance between adjacent values
xs, fs = hist.Render()
width = min(Diff(xs))
if hist.name:
options = Underride(options, label=hist.name)
options = Underride(options,
align='center',
linewidth=0,
width=width)
pyplot.bar(xs, fs, **options)
def Hists(hists, **options):
"""Plots two histograms as interleaved bar plots.
Options are passed along for all PMFs. If you want different
options for each pmf, make multiple calls to Pmf.
Args:
hists: list of two Hist or Pmf objects
options: keyword args passed to pyplot.plot
"""
for hist in hists:
Hist(hist, **options)
def Diff(t):
"""Compute the differences between adjacent elements in a sequence.
Args:
t: sequence of number
Returns:
sequence of differences (length one less than t)
"""
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
return diffs
def Cdf(cdf, complement=False, transform=None, **options):
"""Plots a CDF as a line.
Args:
cdf: Cdf object
complement: boolean, whether to plot the complementary CDF
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
options: keyword args passed to pyplot.plot
Returns:
dictionary with the scale options that should be passed to
Config, Show or Save.
"""
xs, ps = cdf.Render()
scale = dict(xscale='linear', yscale='linear')
for s in ['xscale', 'yscale']:
if s in options:
scale[s] = options.pop(s)
if transform == 'exponential':
complement = True
scale['yscale'] = 'log'
if transform == 'pareto':
complement = True
scale['yscale'] = 'log'
scale['xscale'] = 'log'
if complement:
ps = [1.0-p for p in ps]
if transform == 'weibull':
xs.pop()
ps.pop()
ps = [-math.log(1.0-p) for p in ps]
scale['xscale'] = 'log'
scale['yscale'] = 'log'
if transform == 'gumbel':
xs.pop(0)
ps.pop(0)
ps = [-math.log(p) for p in ps]
scale['yscale'] = 'log'
if cdf.name:
options = Underride(options, label=cdf.name)
Plot(xs, ps, **options)
return scale
def Cdfs(cdfs, complement=False, transform=None, **options):
"""Plots a sequence of CDFs.
cdfs: sequence of CDF objects
complement: boolean, whether to plot the complementary CDF
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
options: keyword args passed to pyplot.plot
"""
for cdf in cdfs:
Cdf(cdf, complement, transform, **options)
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
"""Makes a contour plot.
d: map from (x, y) to z, or object that provides GetDict
pcolor: boolean, whether to make a pseudocolor plot
contour: boolean, whether to make a contour plot
imshow: boolean, whether to use pyplot.imshow
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
"""
try:
d = obj.GetDict()
except AttributeError:
d = obj
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
xs, ys = zip(*d.iterkeys())
xs = sorted(set(xs))
ys = sorted(set(ys))
X, Y = np.meshgrid(xs, ys)
func = lambda x, y: d.get((x, y), 0)
func = np.vectorize(func)
Z = func(X, Y)
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
axes = pyplot.gca()
axes.xaxis.set_major_formatter(x_formatter)
if pcolor:
pyplot.pcolormesh(X, Y, Z, **options)
if contour:
cs = pyplot.contour(X, Y, Z, **options)
pyplot.clabel(cs, inline=1, fontsize=10)
if imshow:
extent = xs[0], xs[-1], ys[0], ys[-1]
pyplot.imshow(Z, extent=extent, **options)
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
"""Makes a pseudocolor plot.
xs:
ys:
zs:
pcolor: boolean, whether to make a pseudocolor plot
contour: boolean, whether to make a contour plot
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
"""
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
X, Y = np.meshgrid(xs, ys)
Z = zs
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
axes = pyplot.gca()
axes.xaxis.set_major_formatter(x_formatter)
if pcolor:
pyplot.pcolormesh(X, Y, Z, **options)
if contour:
cs = pyplot.contour(X, Y, Z, **options)
pyplot.clabel(cs, inline=1, fontsize=10)
def Config(**options):
"""Configures the plot.
Pulls options out of the option dictionary and passes them to
the corresponding pyplot functions.
"""
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
'xticks', 'yticks', 'axis']
for name in names:
if name in options:
getattr(pyplot, name)(options[name])
loc = options.get('loc', 0)
legend = options.get('legend', True)
if legend:
pyplot.legend(loc=loc)
def Show(**options):
"""Shows the plot.
For options, see Config.
options: keyword args used to invoke various pyplot functions
"""
# TODO: figure out how to show more than one plot
Config(**options)
pyplot.show()
def Save(root=None, formats=None, **options):
"""Saves the plot in the given formats.
For options, see Config.
Args:
root: string filename root
formats: list of string formats
options: keyword args used to invoke various pyplot functions
"""
Config(**options)
if formats is None:
formats = ['pdf', 'eps']
if root:
for fmt in formats:
SaveFormat(root, fmt)
Clf()
def SaveFormat(root, fmt='eps'):
"""Writes the current figure to a file in the given format.
Args:
root: string filename root
fmt: string format
"""
filename = '%s.%s' % (root, fmt)
print 'Writing', filename
pyplot.savefig(filename, format=fmt, dpi=300)
# provide aliases for calling functons with lower-case names
preplot = PrePlot
subplot = SubPlot
clf = Clf
figure = Figure
plot = Plot
scatter = Scatter
pmf = Pmf
pmfs = Pmfs
hist = Hist
hists = Hists
diff = Diff
cdf = Cdf
cdfs = Cdfs
contour = Contour
pcolor = Pcolor
config = Config
show = Show
save = Save
def main():
color_iter = Brewer.ColorGenerator(7)
for color in color_iter:
print color
if __name__ == '__main__':
main()