Source code for bumps.dream.corrplot

# This program is public domain
# Author Paul Kienzle
"""
2-D correlation histograms

Generate 2-D correlation histograms and display them in a figure.

Uses false color plots of density.
"""

__all__ = ["Corr2d"]

import numpy as np
from numpy import inf

from matplotlib import cm, colors, image, artist
from matplotlib.font_manager import FontProperties
from matplotlib.ticker import MaxNLocator

try:
    COLORMAP = colors.LinearSegmentedColormap.from_list("density", ("w", "y", "g", "b", "r"))
except Exception:
    COLORMAP = cm.gist_earth_r


[docs] class Corr2d(object): """ Generate and manage 2D correlation histograms. """ def __init__(self, data, labels=None, **kw): if labels is None: labels = ["P" + str(i + 1) for i, _ in enumerate(data)] self.N = len(data) self.labels = labels self.data = data self.hists = _hists(data, **kw) # for k, v in self.hists.items(): # print k, (v[1][0], v[1][-1]), (v[2][0], v[2][-1]) self.ax = None # will be set on plot
[docs] def R(self): return np.corrcoef(self.data)
def __getitem__(self, key): """ Retrieve correlation histogram for data[i] X data[j]. Returns bin i edges, bin j edges, and histogram """ i, j = key return self.hists[i, j]
[docs] def plot(self, title=None, fig=None): """ Plot the correlation histograms on the specified figure """ import pylab if fig is None: pylab.clf() fig = pylab.gcf() if title is not None: fig.text(0.5, 0.95, title, horizontalalignment="center", fontproperties=FontProperties(size=16)) self.ax = _plot(fig, self.hists, self.labels, self.N)
def _hists(data, ranges=None, **kw): """ Generate pair-wise correlation histograms """ n = len(data) if ranges is None: low, high = np.min(data, axis=1), np.max(data, axis=1) ranges = [(l, h) for l, h in zip(low, high)] return dict( ((i, j), np.histogram2d(data[i], data[j], range=[ranges[i], ranges[j]], **kw)) for i in range(0, n) for j in range(i + 1, n) ) def _plot(fig, hists, labels, n, show_ticks=None): """ Plot pair-wise correlation histograms """ if n <= 1: fig.text(0.5, 0.5, "No correlation plots when only one variable", ha="center", va="center") return vmin, vmax = inf, -inf for data, _, _ in hists.values(): positive = data[data > 0] if len(positive) > 0: vmin = min(vmin, np.amin(positive)) vmax = max(vmax, np.amax(positive)) norm = colors.LogNorm(vmin=vmin, vmax=vmax, clip=False) # norm = colors.Normalize(vmin=vmin, vmax=vmax) mapper = image.FigureImage(fig) mapper.set_array(np.zeros((1, 1))) mapper.set_cmap(cmap=COLORMAP) mapper.set_norm(norm) if show_ticks is None: show_ticks = n < 3 ax = {} Nr = Nc = n - 1 for i in range(0, n - 1): for j in range(i + 1, n): sharex = ax.get((0, j), None) sharey = ax.get((i, i + 1), None) a = fig.add_subplot(Nr, Nc, (Nr - i - 1) * Nc + j, sharex=sharex, sharey=sharey) ax[(i, j)] = a a.xaxis.set_major_locator(MaxNLocator(4, steps=[1, 2, 4, 5, 10])) a.yaxis.set_major_locator(MaxNLocator(4, steps=[1, 2, 4, 5, 10])) data, x, y = hists[(i, j)] data = np.clip(data, vmin, vmax) a.pcolorfast(y, x, data, cmap=COLORMAP, norm=norm) # Show labels or hide ticks if i != 0: artist.setp(a.get_xticklabels(), visible=False) if i == n - 2 and j == n - 1: a.set_xlabel(labels[j]) # a.xaxis.set_label_position("top") # a.xaxis.set_offset_position("top") if not show_ticks: a.xaxis.set_ticks([]) if j == i + 1: a.set_ylabel(labels[i]) else: artist.setp(a.get_yticklabels(), visible=False) if not show_ticks: a.yaxis.set_ticks([]) a.zoomable = True # Adjust subplots and add the colorbar fig.subplots_adjust(left=0.07, bottom=0.07, top=0.9, right=0.85, wspace=0.0, hspace=0.0) cax = fig.add_axes([0.88, 0.2, 0.04, 0.6]) fig.colorbar(mapper, cax=cax, orientation="vertical") return ax def zoom(event, step): ax = event.inaxes if not hasattr(ax, "zoomable"): return # TODO: test logscale step *= 3 if ax.zoomable is not True and "mapper" in ax.zoomable: mapper = ax.zoomable["mapper"] if event.ydata is not None: lo, hi = mapper.get_clim() pt = event.ydata * (hi - lo) + lo lo, hi = _rescale(lo, hi, pt, step) mapper.set_clim((lo, hi)) if ax.zoomable is True and event.xdata is not None: lo, hi = ax.get_xlim() lo, hi = _rescale(lo, hi, event.xdata, step) ax.set_xlim((lo, hi)) if ax.zoomable is True and event.ydata is not None: lo, hi = ax.get_ylim() lo, hi = _rescale(lo, hi, event.ydata, step) ax.set_ylim((lo, hi)) ax.figure.canvas.draw_idle() def _rescale(lo, hi, pt, step): scale = float(hi - lo) * step / (100 if step > 0 else 100 - step) bal = float(pt - lo) / (hi - lo) new_lo = lo - bal * scale new_hi = hi + (1 - bal) * scale return new_lo, new_hi