import numpy
import matplotlib.colors
import matplotlib.cm
import mpl_toolkits.mplot3d
# Adapted from http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html
# which in turn is based on an example from http://matplotlib.org/users/event_handling.html
[docs]class DraggableColorbar(object):
def __init__(self, cbar, mappable):
self.cbar = cbar
self.mappable = mappable
self.press = None
self.cycle = sorted([i for i in dir(matplotlib.cm) if hasattr(getattr(matplotlib.cm, i), 'N')])
self.index = self.cycle.index(cbar.get_cmap().name)
self.canvas = self.cbar.patch.figure.canvas
[docs] def connect(self):
self.cidpress = self.canvas.mpl_connect('button_press_event', self.on_press)
self.cidrelease = self.canvas.mpl_connect('button_release_event', self.on_release)
self.cidmotion = self.canvas.mpl_connect('motion_notify_event', self.on_motion)
self.cidkeypress = self.canvas.mpl_connect('key_press_event', self.key_press)
[docs] def disconnect(self):
self.canvas.mpl_disconnect(self.cidpress)
self.canvas.mpl_disconnect(self.cidrelease)
self.canvas.mpl_disconnect(self.cidmotion)
self.canvas.mpl_disconnect(self.cidkeypress)
[docs] def on_press(self, event):
if event.inaxes == self.cbar.ax:
self.press = event.x, event.y
[docs] def key_press(self, event):
if event.key == 'down':
self.index += 1
elif event.key == 'up':
self.index -= 1
if self.index < 0:
self.index = len(self.cycle)
elif self.index >= len(self.cycle):
self.index = 0
cmap = self.cycle[self.index]
self.mappable.set_cmap(cmap)
self.cbar.patch.figure.canvas.draw()
[docs] def on_motion(self, event):
if self.press is None or event.inaxes != self.cbar.ax:
return
xprev, yprev = self.press
dx = event.x - xprev
dy = event.y - yprev
self.press = event.x, event.y
if isinstance(self.cbar.norm, matplotlib.colors.LogNorm):
scale = 0.999 * numpy.log10(self.cbar.norm.vmax / self.cbar.norm.vmin)
if event.button == 1:
self.cbar.norm.vmin *= scale**numpy.sign(dy)
self.cbar.norm.vmax *= scale**numpy.sign(dy)
elif event.button == 3:
self.cbar.norm.vmin *= scale**numpy.sign(dy)
self.cbar.norm.vmax /= scale**numpy.sign(dy)
else:
scale = 0.03 * (self.cbar.norm.vmax - self.cbar.norm.vmin)
if event.button == 1:
self.cbar.norm.vmin -= scale*numpy.sign(dy)
self.cbar.norm.vmax -= scale*numpy.sign(dy)
elif event.button == 3:
self.cbar.norm.vmin -= scale*numpy.sign(dy)
self.cbar.norm.vmax += scale*numpy.sign(dy)
self.mappable.set_norm(self.cbar.norm)
self.canvas.draw()
[docs] def on_release(self, event):
# force redraw on mouse release
self.press = None
self.mappable.set_norm(self.cbar.norm)
self.canvas.draw()
[docs]def get_clipped_norm(data, clipping=0.0, log=True):
if hasattr(data, 'compressed'):
data = data.compressed()
else:
data = data.flatten()
if log:
data = data[data > 0]
if numpy.alen(data) == 0:
return matplotlib.colors.LogNorm(1, 10)
if clipping:
chop = int(round(data.size * clipping))
clip = sorted(data)[chop:-(1+chop)]
vmin, vmax = clip[0], clip[-1]
else:
vmin, vmax = data.min(), data.max()
if log:
return matplotlib.colors.LogNorm(vmin, vmax)
else:
return matplotlib.colors.Normalize(vmin, vmax)
[docs]def plot(space, fig, ax, log=True, loglog=False, clipping=0.0, fit=None, norm=None, colorbar=True, labels=True, interpolation='nearest', **plotopts):
if space.dimension == 1:
data = space.get_masked()
xrange = numpy.ma.array(space.axes[0][:], mask=data.mask)
if fit is not None:
if log:
p1 = ax.semilogy(xrange, data, 'wo', **plotopts)
p2 = ax.semilogy(xrange, fit, 'r', linewidth=2, **plotopts)
elif loglog:
p1 = ax.loglog(xrange, data, 'wo', **plotopts)
p2 = ax.loglog(xrange, fit, 'r', linewidth=2, **plotopts)
else:
p1 = ax.plot(xrange, data, 'wo', **plotopts)
p2 = ax.plot(xrange, fit, 'r', linewidth=2, **plotopts)
else:
if log:
p1 = ax.semilogy(xrange, data, **plotopts)
elif loglog:
p1 = ax.loglog(xrange, data, **plotopts)
else:
p1 = ax.plot(xrange, data, **plotopts)
p2 = []
if labels:
ax.set_xlabel(space.axes[0].label)
ax.set_ylabel('Intensity (a.u.)')
return p1 + p2
elif space.dimension == 2:
data = space.get_masked()
# 2D IMSHOW PLOT
xmin = space.axes[0].min
xmax = space.axes[0].max
ymin = space.axes[1].min
ymax = space.axes[1].max
if not norm:
norm = get_clipped_norm(data, clipping, log)
if fit is not None:
im = ax.imshow(fit.transpose(), origin='lower', extent=(xmin, xmax, ymin, ymax), aspect='auto', norm=norm, interpolation=interpolation, **plotopts)
else:
im = ax.imshow(data.transpose(), origin='lower', extent=(xmin, xmax, ymin, ymax), aspect='auto', norm=norm, interpolation=interpolation, **plotopts)
if labels:
ax.set_xlabel(space.axes[0].label)
ax.set_ylabel(space.axes[1].label)
if colorbar:
cbarwidget = fig.colorbar(im)
fig._draggablecbar = DraggableColorbar(cbarwidget, im) # we need to store this instance somewhere
fig._draggablecbar.connect()
return im
elif space.dimension == 3:
if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D):
raise ValueError("For 3D plots, the 'ax' parameter must be an Axes3D instance (use for example gca(projection='3d') to get one)")
cmap = getattr(matplotlib.cm, plotopts.pop('cmap', 'jet'))
if norm is None:
norm = get_clipped_norm(space.get_masked(), clipping, log)
data = space.get()
mask = numpy.bitwise_or(~numpy.isfinite(data), data == 0)
gridx, gridy, gridz = tuple(grid[~mask] for grid in space.get_grid())
im = ax.scatter(gridx, gridy, gridz, c=cmap(norm(data[~mask])), marker=',' , alpha=0.7, linewidths=0)
#p1 = ax.plot_surface(gridx[0,:,:], gridy[0,:,:], gridz[0,:,:], facecolors=cmap(norm(space.project(0).get_masked())), shade=False, cstride=1, rstride=1)
#p2 = ax.plot_surface(gridx[:,-1,:], gridy[:,-1,:], gridz[:,-1,:], facecolors=cmap(norm(space.project(1).get_masked())), shade=False, cstride=1, rstride=1)
#p3 = ax.plot_surface(gridx[:,:,0], gridy[:,:,0], gridz[:,:,0], facecolors=cmap(norm(space.project(2).get_masked())), shade=False, cstride=1, rstride=1)
if labels:
ax.set_xlabel(space.axes[0].label)
ax.set_ylabel(space.axes[1].label)
ax.set_zlabel(space.axes[2].label)
if fig._draggablecbar:
fig._draggablecbar.disconnect()
return im
elif space.dimension > 3:
raise ValueError("Cannot plot 4 or higher dimensional spaces, use projections or slices to decrease dimensionality.")