From 4a9433aeb032385b49879dcde144324d2275f541 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 2 Jul 2018 21:17:52 +0100 Subject: [PATCH] Added new Panelled canvas (#15) Added new Panelled canvas. --- docs/source/api/canvases.rst | 1 + docs/source/api/canvases/panelled.rst | 9 ++ docs/source/getting_started.rst | 6 +- mATLASplotlib/canvases/__init__.py | 3 +- mATLASplotlib/canvases/base_canvas.py | 119 ++++++++++------- mATLASplotlib/canvases/panelled.py | 161 +++++++++++++++++++++++ mATLASplotlib/canvases/ratio.py | 54 +++----- mATLASplotlib/canvases/simple.py | 21 ++- mATLASplotlib/decorations/legend.py | 10 +- mATLASplotlib/formatters/__init__.py | 4 + mATLASplotlib/formatters/label.py | 34 +++++ setup.py | 2 +- tests/canvases/test_canvases_base.py | 15 ++- tests/canvases/test_canvases_panelled.py | 114 ++++++++++++++++ tests/canvases/test_canvases_ratio.py | 26 ++-- tests/canvases/test_canvases_simple.py | 13 ++ 16 files changed, 484 insertions(+), 108 deletions(-) create mode 100644 docs/source/api/canvases/panelled.rst create mode 100644 mATLASplotlib/canvases/panelled.py create mode 100644 mATLASplotlib/formatters/__init__.py create mode 100644 mATLASplotlib/formatters/label.py create mode 100644 tests/canvases/test_canvases_panelled.py diff --git a/docs/source/api/canvases.rst b/docs/source/api/canvases.rst index e99a452..3bfc206 100644 --- a/docs/source/api/canvases.rst +++ b/docs/source/api/canvases.rst @@ -5,5 +5,6 @@ canvases :maxdepth: 2 canvases/base_canvas + canvases/panelled canvases/ratio canvases/simple \ No newline at end of file diff --git a/docs/source/api/canvases/panelled.rst b/docs/source/api/canvases/panelled.rst new file mode 100644 index 0000000..8dc4b07 --- /dev/null +++ b/docs/source/api/canvases/panelled.rst @@ -0,0 +1,9 @@ +panelled +======== + +.. automodule:: mATLASplotlib.canvases.panelled + :members: + :special-members: __init__ + :inherited-members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index b1c1249..ced1817 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -28,7 +28,9 @@ this should have drawn 10000 samples from a normal distribution and added them t 3. Setting up a canvas ---------------------- We use a context manager to open the canvas, which ensures that necessary cleanup is done when the canvas is no longer needed. -Currently the supported canvases are the :py:class:`.Simple` canvas which contains one set of ``matplotlib`` axes and the :py:class:`.Ratio` canvas, which contains a main plot and a ratio plot underneath. +Currently the supported canvases are the :py:class:`.Simple` canvas which contains one set of ``matplotlib`` axes, +the :py:class:`.Ratio` canvas, which contains a main plot and a ratio plot underneath, +and the the :py:class:`.Panelled` canvas which contains a top panel and an arbitrary number of lower panels beneath it. .. code:: python @@ -36,7 +38,7 @@ Currently the supported canvases are the :py:class:`.Simple` canvas which contai with mATLASplotlib.canvases.Simple(shape="square") as canvas: canvas.plot_dataset(hist, style="scatter", label="Generated data", colour="black") -The two shapes preferred in the ATLAS style guide are "square" (600 x 600 pixels) and "landscape" (600 x 800 pixels). +The three shapes preferred by the ATLAS style guide are "square" (600 x 600 pixels), "landscape" (600 x 800 pixels) and "portrait" (800 x 600 pixels). Here we have chosen to use "square". After setting up the canvas, we can plot the dataset we constructed earlier using the :py:meth:`plot_dataset <.BaseCanvas.plot_dataset>` method. diff --git a/mATLASplotlib/canvases/__init__.py b/mATLASplotlib/canvases/__init__.py index 87ea418..9b966b2 100644 --- a/mATLASplotlib/canvases/__init__.py +++ b/mATLASplotlib/canvases/__init__.py @@ -1,5 +1,6 @@ """This subpackage contains the various canvas types""" +from panelled import Panelled from ratio import Ratio from simple import Simple -__all__ = ["Ratio", "Simple"] +__all__ = ["Panelled", "Ratio", "Simple"] diff --git a/mATLASplotlib/canvases/base_canvas.py b/mATLASplotlib/canvases/base_canvas.py index 47f73f5..4a182f8 100644 --- a/mATLASplotlib/canvases/base_canvas.py +++ b/mATLASplotlib/canvases/base_canvas.py @@ -2,8 +2,10 @@ import logging import math import matplotlib +import numpy as np from .. import style from ..converters import Dataset +from ..formatters import force_extra_ticks from ..plotters import get_plotter from ..decorations import draw_ATLAS_text, draw_text, Legend @@ -13,6 +15,7 @@ class BaseCanvas(object): """Base class for canvas properties.""" + #: Map of locations to matplotlib coordinates location_map = {"upper right": ["right", "top"], "upper left": ["left", "top"], "centre left": ["left", "center"], @@ -20,10 +23,16 @@ class BaseCanvas(object): "lower right": ["right", "bottom"], "lower left": ["left", "bottom"]} + #: List of sensible tick intervals + auto_tick_intervals = [0.001, 0.002, 0.0025, 0.004, 0.005, + 0.01, 0.02, 0.025, 0.04, 0.05, + 0.1, 0.2, 0.25, 0.4, 0.5, + 1.0, 2.0, 2.5, 4.0, 5.0] + def __init__(self, shape="square", **kwargs): """Set up universal canvas properties. - :param shape: use either the 'square' or 'landscape' ATLAS proportions + :param shape: use either the 'square', 'landscape' or 'portrait' ATLAS proportions :type shape: str :Keyword Arguments: @@ -37,7 +46,7 @@ def __init__(self, shape="square", **kwargs): # Set ATLAS style style.set_atlas() # Set up figure - n_pixels = {"square": (600, 600), "landscape": (800, 600)}[shape] + n_pixels = {"square": (600, 600), "landscape": (800, 600), "portrait": (600, 800)}[shape] self.figure = matplotlib.pyplot.figure(figsize=(n_pixels[0] / 100.0, n_pixels[1] / 100.0), dpi=100, facecolor="white") self.main_subplot = None # Set properties from arguments @@ -48,6 +57,7 @@ def __init__(self, shape="square", **kwargs): # Set up value holders self.legend = Legend() self.axis_ranges = {} + self.axis_tick_ndps = {} self.subplots = {} self.internal_header_fraction = None @@ -79,14 +89,14 @@ def plot_dataset(self, *args, **kwargs): * **label**: (*str*) -- label to use in automatic legend generation * **sort_as**: (*str*) -- override """ - axes = kwargs.pop("axes", self.main_subplot) + subplot_name = kwargs.pop("axes", self.main_subplot) plot_style = kwargs.pop("style", None) remove_zeros = kwargs.pop("remove_zeros", False) dataset = Dataset(*args, remove_zeros=remove_zeros, **kwargs) plotter = get_plotter(plot_style) if "label" in kwargs: self.legend.add_dataset(label=kwargs["label"], is_stack=("stack" in plot_style), sort_as=kwargs.pop("sort_as", None)) - plotter.add_to_axes(dataset=dataset, axes=self.subplots[axes], **kwargs) + plotter.add_to_axes(dataset=dataset, axes=self.subplots[subplot_name], **kwargs) def add_legend(self, x, y, anchor_to="lower left", fontsize=None, axes=None): """Add a legend to the canvas at (x, y). @@ -102,9 +112,8 @@ def add_legend(self, x, y, anchor_to="lower left", fontsize=None, axes=None): :param axes: which of the different axes in this canvas to use. :type axes: str """ - if axes is None: - axes = self.main_subplot - self.legend.plot(x, y, self.subplots[axes], anchor_to, fontsize) + subplot_name = self.main_subplot if axes is None else axes + self.legend.plot(x, y, self.subplots[subplot_name], anchor_to, fontsize) def add_ATLAS_label(self, x, y, plot_type=None, anchor_to="lower left", fontsize=None, axes=None): """Add an ATLAS label to the canvas at (x, y). @@ -122,11 +131,8 @@ def add_ATLAS_label(self, x, y, plot_type=None, anchor_to="lower left", fontsize :param axes: which of the different axes in this canvas to use. :type axes: str """ - if axes is None: - axes = self.main_subplot - # ha, va = self.location_map[anchor_to] - # draw_ATLAS_text(x, y, self.subplots[axes], ha=ha, va=va, plot_type=plot_type, fontsize=fontsize) - draw_ATLAS_text(self.subplots[axes], (x, y), self.location_map[anchor_to], plot_type=plot_type, fontsize=fontsize) + subplot_name = self.main_subplot if axes is None else axes + draw_ATLAS_text(self.subplots[subplot_name], (x, y), self.location_map[anchor_to], plot_type=plot_type, fontsize=fontsize) def add_luminosity_label(self, x, y, sqrts_TeV, luminosity, units="fb-1", anchor_to="lower left", fontsize=14, axes=None): """Add a luminosity label to the canvas at (x, y). @@ -148,13 +154,12 @@ def add_luminosity_label(self, x, y, sqrts_TeV, luminosity, units="fb-1", anchor :param axes: which of the different axes in this canvas to use. :type axes: str """ - if axes is None: - axes = self.main_subplot + subplot_name = self.main_subplot if axes is None else axes text_sqrts = r"$\sqrt{\mathsf{s}} = " +\ str([sqrts_TeV, int(1000 * sqrts_TeV)][sqrts_TeV < 1.0]) +\ r"\,\mathsf{" + ["TeV", "GeV"][sqrts_TeV < 1.0] + "}" text_lumi = "$" if luminosity is None else ", $" + str(luminosity) + " " + units.replace("-1", "$^{-1}$") - draw_text(text_sqrts + text_lumi, self.subplots[axes], (x, y), self.location_map[anchor_to], fontsize=fontsize) + draw_text(text_sqrts + text_lumi, self.subplots[subplot_name], (x, y), self.location_map[anchor_to], fontsize=fontsize) def add_text(self, x, y, text, **kwargs): """Add text to the canvas at (x, y). @@ -166,9 +171,9 @@ def add_text(self, x, y, text, **kwargs): :param text: text to add. :type text: str """ - axes = kwargs.pop("axes", self.main_subplot) + subplot_name = kwargs.pop("axes", self.main_subplot) anchor_to = kwargs.pop("anchor_to", "lower left") - draw_text(text, self.subplots[axes], (x, y), self.location_map[anchor_to], **kwargs) + draw_text(text, self.subplots[subplot_name], (x, y), self.location_map[anchor_to], **kwargs) def save(self, output_name, extension="pdf"): """Save the current state of the canvas to a file. @@ -242,6 +247,16 @@ def set_axis_ticks(self, axis_name, ticks): """ raise NotImplementedError("set_axis_ticks not defined by {0}".format(type(self))) + def set_axis_tick_ndp(self, axis_name, ndp): + """Set number of decimal places to show. + + :param axis_name: which axis to apply this to. + :type axis_name: str + :param ndp: how many decimal places to show. + :type ndp: int + """ + self.axis_tick_ndps[axis_name] = ndp + def set_axis_log(self, axis_names): """Set the specified axis to be on a log-scale. @@ -280,39 +295,40 @@ def y_tick_label_size(self): def __finalise_plot_formatting(self): """Finalise plot by applying previously requested formatting.""" - for _, axes in self.subplots.items(): + for _, subplot in self.subplots.items(): # Apply axis limits self._apply_axis_limits() # Draw x ticks if self.x_tick_labels is not None: - x_interval = (max(axes.get_xlim()) - min(axes.get_xlim())) / (len(self.x_tick_labels)) - axes.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(x_interval)) + x_interval = (max(subplot.get_xlim()) - min(subplot.get_xlim())) / (len(self.x_tick_labels)) + subplot.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(x_interval)) tmp_kwargs = {"fontsize": self.x_tick_label_size} if self.x_tick_label_size is not None else {} - axes.set_xticklabels([""] + self.x_tick_labels, **tmp_kwargs) # the first and last ticks are off the scale so add a dummy label + subplot.set_xticklabels([""] + self.x_tick_labels, **tmp_kwargs) # the first and last ticks are off the scale so add a dummy label # Draw y ticks if self.y_tick_labels is not None: - y_interval = (max(axes.get_ylim()) - min(axes.get_ylim())) / (len(self.y_tick_labels)) - axes.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(y_interval)) + y_interval = (max(subplot.get_ylim()) - min(subplot.get_ylim())) / (len(self.y_tick_labels)) + subplot.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(y_interval)) tmp_kwargs = {"fontsize": self.y_tick_label_size} if self.y_tick_label_size is not None else {} - axes.set_yticklabels([""] + self.y_tick_labels, **tmp_kwargs) # the first and last ticks are off the scale so add a dummy label + subplot.set_yticklabels([""] + self.y_tick_labels, **tmp_kwargs) # the first and last ticks are off the scale so add a dummy label + # Set x-axis locators if "x" in self.log_type: - xlocator = axes.xaxis.get_major_locator() - axes.set_xscale("log", subsx=[2, 3, 4, 5, 6, 7, 8, 9]) - axes.yaxis.set_major_locator(xlocator) - axes.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) - axes.xaxis.set_minor_formatter(matplotlib.ticker.FuncFormatter(self.__force_extra_x_ticks)) # only show certain minor labels + xlocator = subplot.xaxis.get_major_locator() + subplot.set_xscale("log", subsx=[2, 3, 4, 5, 6, 7, 8, 9]) + subplot.yaxis.set_major_locator(xlocator) + subplot.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) + subplot.xaxis.set_minor_formatter(matplotlib.ticker.FuncFormatter(force_extra_ticks(self.x_ticks_extra))) # only show certain minor labels else: - axes.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) + subplot.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) # Set y-axis locators if "y" in self.log_type: - locator = axes.yaxis.get_major_locator() - axes.set_yscale("log") - axes.yaxis.set_major_locator(locator) + locator = subplot.yaxis.get_major_locator() + subplot.set_yscale("log") + subplot.yaxis.set_major_locator(locator) fixed_minor_points = [10**x * val for x in range(-100, 100) for val in [2, 3, 4, 5, 6, 7, 8, 9]] - axes.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(fixed_minor_points)) + subplot.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(fixed_minor_points)) else: - axes.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) + subplot.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator()) # Finish by adding internal header if self.internal_header_fraction is not None: @@ -333,21 +349,6 @@ def _apply_final_formatting(self): """Apply any necessary final formatting.""" pass - def __force_extra_x_ticks(self, x, pos): - """Implement user-defined tick positions. - - :param x: tick value. - :type x: float - :param pos: position. - :type pos: float - :return: formatted tick position string - :rtype: str - """ - del pos # this function signature is required by FuncFormatter - if any(int(x) == elem for elem in self.x_ticks_extra): - return "{0:.0f}".format(x) - return "" - def get_axis_label(self, axis_name): """Get the label for the chosen axis @@ -370,3 +371,21 @@ def get_axis_range(self, axis_name): return self.axis_ranges[axis_name] else: raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def _get_auto_axis_ticks(self, axis_name, n_approximate=4): + """Choose axis ticks to be sensibly spaced and always include 1.0. + + :param axis_name: name of axis to work on + :type axis_name: str + :param n_approximate: approximate number of ticks to use. + :type n_approximate: int + :return: list of tick positions + :rtype: list + """ + # Underestimate the interval size since we might be removing the highest tick + interval = 0.99 * abs(self.axis_ranges[axis_name][1] - self.axis_ranges[axis_name][0]) + tick_size = min(self.auto_tick_intervals, key=lambda x: abs((interval / x) - n_approximate)) + tick_list = np.arange(1.0 - 10 * tick_size, 1.0 + 10 * tick_size, tick_size) + # Remove topmost tick if it would be at the top of the axis + tick_list = [t for t in tick_list if not np.allclose(t, self.axis_ranges[axis_name][1])] + return tick_list diff --git a/mATLASplotlib/canvases/panelled.py b/mATLASplotlib/canvases/panelled.py new file mode 100644 index 0000000..a7b84c9 --- /dev/null +++ b/mATLASplotlib/canvases/panelled.py @@ -0,0 +1,161 @@ +""" This module provides the ``Panelled`` canvas.""" +from matplotlib.ticker import FixedLocator, FuncFormatter, NullLocator +from base_canvas import BaseCanvas +from ..formatters import force_ndp + + +class Panelled(BaseCanvas): + """Panelled canvas with standard ATLAS setup.""" + + def __init__(self, shape="portrait", n_panels=3, top_panel_fraction=0.16, **kwargs): + """Set up Panelled canvas properties. + + The canvas consists of a single top panel and ``n_panels`` equally sized additional panels underneath. + These additional panels are called ``plot0``, ``plot1``, etc. with the numbering starting from the top. + + :param shape: use either the 'square', 'landscape' or 'portrait' ATLAS proportions + :type shape: str + :param n_panels: how many panels to include + :type n_panels: int + :param top_panel_fraction: fraction of vertical space that the top panel should use up + :type top_panel_fraction: float + + :Keyword Arguments: as for :py:class:`.BaseCanvas` + """ + super(Panelled, self).__init__(shape=shape, **kwargs) + _margin_top, _margin_bottom = 0.02, 0.08 + self.n_panels = n_panels + subplot_height = (1.0 - _margin_top - _margin_bottom - top_panel_fraction) / self.n_panels + self.subplots["top"] = self.figure.add_axes([0.15, 1.0 - _margin_top - top_panel_fraction, 0.8, top_panel_fraction]) + for idx in range(n_panels): + _panel_limits = [0.15, (1.0 - _margin_top - top_panel_fraction - (idx + 1) * subplot_height), 0.8, subplot_height] + self.subplots["plot{0}".format(idx)] = self.figure.add_axes(_panel_limits) + self.axis_ranges["y_plot{0}".format(idx)] = [0.5, 1.5] + self.use_auto_ratio_ticks = dict((name, True) for name in self.subplots if name != "top") + self.main_subplot = "plot0" + + def plot_dataset(self, *args, **kwargs): + subplot_name = kwargs.get("axes", self.main_subplot) + super(Panelled, self).plot_dataset(*args, **kwargs) + if "x" not in self.axis_ranges: + self.set_axis_range("x", self.subplots[subplot_name].get_xlim()) + if "plot" in subplot_name: + y_axis_name = "y_{0}".format(subplot_name) + self.set_axis_range(y_axis_name, self.subplots[subplot_name].get_ylim()) + + def add_legend(self, x, y, anchor_to="lower left", fontsize=None, axes=None): + """Add a legend to the canvas at (x, y). + + If added to the ``top`` panel then all elements from the lower panels will be included in it. + + :Arguments: as for :py:meth:`.BaseCanvas.add_legend` + """ + subplot_name = self.main_subplot if axes is None else axes + if subplot_name == "top": + subplots = [self.subplots["plot{0}".format(idx)] for idx in range(self.n_panels)] + self.legend.plot(x, y, self.subplots[subplot_name], anchor_to, fontsize, use_axes=subplots) + else: + self.legend.plot(x, y, self.subplots[subplot_name], anchor_to, fontsize) + + def get_axis_label(self, axis_name): + if axis_name == "x": + return self.subplots[self.bottom_panel].get_xlabel() + elif axis_name == "y": + return self.subplots["top"].get_ylabel() + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_axis_label(self, axis_name, axis_label, fontsize=16): + if axis_name == "x": + self.subplots[self.bottom_panel].set_xlabel(axis_label, position=(1.0, 0.0), + fontsize=fontsize, va="top", ha="right") + elif axis_name == "y": + self.subplots["top"].set_ylabel(axis_label, fontsize=fontsize) + self.subplots["top"].yaxis.get_label().set_ha("right") + self.subplots["top"].yaxis.set_label_coords(-0.13, 1.0) + else: + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_axis_max(self, axis_name, maximum): + if axis_name in self.axis_ranges: + self.axis_ranges[axis_name] = (self.axis_ranges[axis_name][0], maximum) + if axis_name == "x": + for subplot in self.subplots.values(): + subplot.set_xlim(right=maximum) + elif axis_name[0] == "y": + self.subplots[axis_name.replace("y_", "")].set_ylim(top=maximum) + else: + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_axis_min(self, axis_name, minimum): + if axis_name in self.axis_ranges: + self.axis_ranges[axis_name] = (minimum, self.axis_ranges[axis_name][1]) + if axis_name == "x": + for subplot in self.subplots.values(): + subplot.set_xlim(left=minimum) + elif "y_plot" in axis_name: + self.subplots[axis_name.replace("y_", "")].set_ylim(bottom=minimum) + else: + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_axis_range(self, axis_name, axis_range): + if axis_name == "x": + self.axis_ranges["x"] = axis_range + elif "y_plot" in axis_name: + self.axis_ranges[axis_name] = axis_range + else: + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_axis_ticks(self, axis_name, ticks): + if axis_name == "x": + for subplot in [p for n, p in self.subplots.items() if n != "top"]: + subplot.xaxis.set_major_locator(FixedLocator(ticks)) + elif "y_plot" in axis_name: + subplot_name = axis_name.replace("y_", "") + self.subplots[subplot_name].yaxis.set_major_locator(FixedLocator(ticks)) + self.use_auto_ratio_ticks[subplot_name] = False + else: + raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) + + def set_title(self, title): + self.subplots["top"].set_title(title) + + def _apply_axis_limits(self): + if "x" in self.axis_ranges: + for subplot in self.subplots.values(): + subplot.set_xlim(self.axis_ranges["x"]) + for axis_name in self.axis_ranges: + if "y_plot" in axis_name: + self.subplots[axis_name.replace("y_", "")].set_ylim(self.axis_ranges[axis_name]) + + def _apply_final_formatting(self): + """Apply final formatting. Remove unnecessary ticks and labels.""" + # Set axis decimal places + for axis_name, ndp in self.axis_tick_ndps.items(): + if axis_name == "x": + self.subplots[self.bottom_panel].xaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + elif "y_plot" in axis_name: + self.subplots[axis_name.replace("y_", "")].yaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + + # Remove x-axis labels from plots + for subplot in [a for n, a in self.subplots.items() if n != self.bottom_panel]: + subplot.set_xticklabels([]) + + # Set the ratio ticks appropriately + for subplot_name in [n for n in self.subplots if n != "top"]: + if self.use_auto_ratio_ticks[subplot_name]: + axis_name = "y_{0}".format(subplot_name) + self.set_axis_ticks(axis_name, self._get_auto_axis_ticks(axis_name=axis_name)) + + # Remove all tick marks from top plot + self.subplots["top"].xaxis.set_major_locator(NullLocator()) + self.subplots["top"].yaxis.set_major_locator(NullLocator()) + self.subplots["top"].xaxis.set_minor_locator(NullLocator()) + self.subplots["top"].yaxis.set_minor_locator(NullLocator()) + + # Shift y-axis label downwards + self.subplots["top"].yaxis.set_label_coords(-0.12, 0.3) + + @property + def bottom_panel(self): + """Name of the bottom-most panel.""" + return "plot{0}".format(self.n_panels - 1) diff --git a/mATLASplotlib/canvases/ratio.py b/mATLASplotlib/canvases/ratio.py index 5319ff2..caba92b 100644 --- a/mATLASplotlib/canvases/ratio.py +++ b/mATLASplotlib/canvases/ratio.py @@ -1,8 +1,9 @@ """ This module provides the ``Ratio`` canvas.""" from matplotlib.lines import Line2D -from matplotlib.ticker import FixedLocator +from matplotlib.ticker import FixedLocator, FuncFormatter import numpy as np from base_canvas import BaseCanvas +from ..formatters import force_ndp class Ratio(BaseCanvas): @@ -11,7 +12,7 @@ class Ratio(BaseCanvas): def __init__(self, shape="square", line_ypos=1.0, **kwargs): """Set up Ratio canvas properties. - :param shape: use either the 'square' or 'rectangular' ATLAS proportions + :param shape: use either the 'square', 'landscape' or 'portrait' ATLAS proportions :type shape: str :param line_ypos: where to draw the reference line in the ratio plot :type line_ypos: float @@ -24,25 +25,20 @@ def __init__(self, shape="square", line_ypos=1.0, **kwargs): self.line_ypos = line_ypos self.main_subplot = "top" self.axis_ranges["y_ratio"] = [0.5, 1.5] - self.auto_ratio_tick_intervals = [0.001, 0.002, 0.0025, 0.004, 0.005, - 0.01, 0.02, 0.025, 0.04, 0.05, - 0.1, 0.2, 0.25, 0.4, 0.5, - 1.0, 2.0, 2.5, 4.0, 5.0] self.use_auto_ratio_ticks = True def plot_dataset(self, *args, **kwargs): - axes = kwargs.get("axes", self.main_subplot) + subplot_name = kwargs.get("axes", self.main_subplot) super(Ratio, self).plot_dataset(*args, **kwargs) if "x" not in self.axis_ranges: - self.set_axis_range("x", self.subplots[axes].get_xlim()) - if axes == "top": + self.set_axis_range("x", self.subplots[subplot_name].get_xlim()) + if subplot_name == "top": if "y" not in self.axis_ranges: - self.set_axis_range("y", self.subplots[axes].get_ylim()) - elif axes == "bottom": + self.set_axis_range("y", self.subplots[subplot_name].get_ylim()) + elif subplot_name == "bottom": if np.array_equal(self.axis_ranges["y_ratio"], [0.5, 1.5]): - self.set_axis_range("y_ratio", self.subplots[axes].get_ylim()) + self.set_axis_range("y_ratio", self.subplots[subplot_name].get_ylim()) - # Axis labels def get_axis_label(self, axis_name): if axis_name == "x": return self.subplots["bottom"].get_xlabel() @@ -92,8 +88,6 @@ def set_axis_min(self, axis_name, minimum): self.subplots["bottom"].set_ylim(bottom=minimum) else: raise ValueError("axis {0} not recognised by {1}".format(axis_name, type(self))) - if axis_name in self.axis_ranges: - self.axis_ranges[axis_name] = (minimum, self.axis_ranges[axis_name][1]) def set_axis_range(self, axis_name, axis_range): if axis_name == "x": @@ -107,7 +101,8 @@ def set_axis_range(self, axis_name, axis_range): def set_axis_ticks(self, axis_name, ticks): if axis_name == "x": - self.subplots["top"].xaxis.set_major_locator(FixedLocator(ticks)) + for subplot in self.subplots.values(): + subplot.xaxis.set_major_locator(FixedLocator(ticks)) elif axis_name == "y": self.subplots["top"].yaxis.set_major_locator(FixedLocator(ticks)) elif axis_name == "y_ratio": @@ -135,6 +130,15 @@ def _apply_final_formatting(self): transform=self.subplots["bottom"].transData, linewidth=1, linestyle="--", color="black")) + # Set axis decimal places + for axis_name, ndp in self.axis_tick_ndps.items(): + if axis_name == "x": + self.subplots["bottom"].xaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + elif axis_name == "y": + self.subplots["top"].yaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + elif axis_name == "y_ratio": + self.subplots["bottom"].yaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + # Set ratio plot to linear scale if self.log_type.find("y") != -1: self.subplots["bottom"].set_yscale("linear") @@ -145,20 +149,4 @@ def _apply_final_formatting(self): # Set the ratio ticks appropriately if self.use_auto_ratio_ticks: - self.set_axis_ticks("y_ratio", self.__get_auto_ratio_ticks()) - - def __get_auto_ratio_ticks(self, n_approximate=4): - """Choose ratio ticks to be sensibly spaced and always include 1.0. - - :param n_approximate: approximate number of ticks to use. - :type n_approximate: int - :return: list of tick positions - :rtype: list - """ - # Underestimate the interval size since we might be removing the highest tick - interval = 0.99 * abs(self.axis_ranges["y_ratio"][1] - self.axis_ranges["y_ratio"][0]) - tick_size = min(self.auto_ratio_tick_intervals, key=lambda x: abs((interval / x) - n_approximate)) - tick_list = np.arange(1.0 - 10 * tick_size, 1.0 + 10 * tick_size, tick_size) - # Remove topmost tick if it would be at the top of the axis - tick_list = [t for t in tick_list if not np.allclose(t, self.axis_ranges["y_ratio"][1])] - return tick_list + self.set_axis_ticks("y_ratio", self._get_auto_axis_ticks(axis_name="y_ratio")) diff --git a/mATLASplotlib/canvases/simple.py b/mATLASplotlib/canvases/simple.py index 1cd4f1b..b29c496 100644 --- a/mATLASplotlib/canvases/simple.py +++ b/mATLASplotlib/canvases/simple.py @@ -1,6 +1,7 @@ """ This module provides the ``Simple`` canvas.""" -from matplotlib.ticker import FixedLocator +from matplotlib.ticker import FixedLocator, FuncFormatter from base_canvas import BaseCanvas +from ..formatters import force_ndp class Simple(BaseCanvas): @@ -8,19 +9,20 @@ class Simple(BaseCanvas): def __init__(self, shape="square", **kwargs): shape_dict = {"square": {"dimensions": (0.15, 0.1, 0.8, 0.85), "y_label_offset": -0.13}, - "landscape": {"dimensions": (0.12, 0.1, 0.84, 0.85), "y_label_offset": -0.0975}} + "landscape": {"dimensions": (0.12, 0.1, 0.84, 0.85), "y_label_offset": -0.0975}, + "portrait": {"dimensions": (0.12, 0.1, 0.84, 0.85), "y_label_offset": -0.13}} self.shape_dict = shape_dict[shape] super(Simple, self).__init__(shape=shape, **kwargs) self.subplots["main"] = self.figure.add_axes(self.shape_dict["dimensions"]) self.main_subplot = "main" def plot_dataset(self, *args, **kwargs): - axes = kwargs.get("axes", self.main_subplot) + subplot_name = kwargs.get("axes", self.main_subplot) super(Simple, self).plot_dataset(*args, **kwargs) if "x" not in self.axis_ranges: - self.set_axis_range("x", self.subplots[axes].get_xlim()) + self.set_axis_range("x", self.subplots[subplot_name].get_xlim()) if "y" not in self.axis_ranges: - self.set_axis_range("y", self.subplots[axes].get_ylim()) + self.set_axis_range("y", self.subplots[subplot_name].get_ylim()) def _apply_axis_limits(self): if "x" in self.axis_ranges: @@ -28,6 +30,15 @@ def _apply_axis_limits(self): if "y" in self.axis_ranges: self.subplots["main"].set_ylim(self.axis_ranges["y"]) + def _apply_final_formatting(self): + """Apply final formatting.""" + # Set axis decimal places + for axis_name, ndp in self.axis_tick_ndps.items(): + if axis_name == "x": + self.subplots["main"].xaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + elif axis_name == "y": + self.subplots["main"].yaxis.set_major_formatter(FuncFormatter(force_ndp(nplaces=ndp))) + def get_axis_label(self, axis_name): if axis_name == "x": return self.subplots["main"].get_xlabel() diff --git a/mATLASplotlib/decorations/legend.py b/mATLASplotlib/decorations/legend.py index 1a92309..b4e5f40 100644 --- a/mATLASplotlib/decorations/legend.py +++ b/mATLASplotlib/decorations/legend.py @@ -27,7 +27,7 @@ def add_dataset(self, label, is_stack=False, sort_as=None): if sort_as is not None: self.sort_overrides[sort_as] = legend_text - def plot(self, x, y, axes, anchor_to, fontsize): + def plot(self, x, y, axes, anchor_to, fontsize, use_axes=False): """Plot the legend at (x, y) on the chosen axes. :param x: x-position of legend @@ -40,9 +40,15 @@ def plot(self, x, y, axes, anchor_to, fontsize): :type anchor_to: str :param fontsize: fontsize of legend contents :type fontsize: float + :param use_axes: get handles and labels from all axes in list + :type use_axes: list[matplotlib.axes.Axes] """ transform = axes.transAxes - handles, labels = self.__get_legend_handles_labels(axes) + if use_axes: + handles, labels = zip(*[self.__get_legend_handles_labels(subplot) for subplot in use_axes]) + handles, labels = sum(handles, []), sum(labels, []) + else: + handles, labels = self.__get_legend_handles_labels(axes) _legend = axes.legend(handles, labels, numpoints=1, loc=anchor_to, bbox_to_anchor=(x, y), bbox_transform=transform, borderpad=0, borderaxespad=0, columnspacing=0) _legend.get_frame().set_linewidth(0) diff --git a/mATLASplotlib/formatters/__init__.py b/mATLASplotlib/formatters/__init__.py new file mode 100644 index 0000000..078f914 --- /dev/null +++ b/mATLASplotlib/formatters/__init__.py @@ -0,0 +1,4 @@ +"""This subpackage contains utility label formatters""" +from label import force_extra_ticks, force_ndp + +__all__ = ["force_extra_ticks", "force_ndp"] diff --git a/mATLASplotlib/formatters/label.py b/mATLASplotlib/formatters/label.py new file mode 100644 index 0000000..39f0a7d --- /dev/null +++ b/mATLASplotlib/formatters/label.py @@ -0,0 +1,34 @@ +"""This module provides the ``get_plotter()`` convenience function.""" +from functools import partial + + +def force_extra_ticks(x_ticks_extra): + """Implement user-defined tick positions. + + :param x: tick value. + :type x: float + :param pos: position. + :type pos: float + :return: formatted tick position string + :rtype: str + """ + def inner(x, pos, x_ticks_extra): + del pos # this function signature is required by FuncFormatter + if any(int(x) == elem for elem in x_ticks_extra): + return "{0:.0f}".format(x) + return "" + return partial(inner, x_ticks_extra=x_ticks_extra) + + +def force_ndp(nplaces): + """Force rounding for all labels. + + :param nplaces: how many decimal places to use. + :type nplaces: int + :return: tick formatter + :rtype: function + """ + def inner(x, pos, nplaces): + del pos # this function signature is required by FuncFormatter + return "{0:.{1}f}".format(x, nplaces) + return partial(inner, nplaces=nplaces) diff --git a/setup.py b/setup.py index aaf9a42..fdf0e73 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name = "mATLASplotlib", - version = "1.0", + version = "1.1.0", description = "Wrappers around matplotlib functionality to produce plots compatible with the style guidelines for the ATLAS experiment at the LHC", long_description = long_description, long_description_content_type = "text/markdown", diff --git a/tests/canvases/test_canvases_base.py b/tests/canvases/test_canvases_base.py index 1871d3e..1150c20 100644 --- a/tests/canvases/test_canvases_base.py +++ b/tests/canvases/test_canvases_base.py @@ -1,3 +1,4 @@ +import os import matplotlib import pytest import mATLASplotlib @@ -54,19 +55,27 @@ def test_base_set_title(): canvas.set_title("title") -def test_simple_apply_axis_limits(): +def test_base_apply_axis_limits(): with pytest.raises(NotImplementedError): with mATLASplotlib.canvases.base_canvas.BaseCanvas() as canvas: canvas._apply_axis_limits() -def test_simple_get_axis_label(): +def test_base_apply_final_formatting(): + with mATLASplotlib.canvases.base_canvas.BaseCanvas() as canvas: + canvas._apply_final_formatting() + canvas.save("blank_test_output") + assert os.path.isfile("blank_test_output.pdf") + os.remove("blank_test_output.pdf") + + +def test_base_get_axis_label(): with pytest.raises(NotImplementedError): with mATLASplotlib.canvases.base_canvas.BaseCanvas() as canvas: canvas.get_axis_label("x") -def test_simple_get_axis_range(): +def test_base_get_axis_range(): with pytest.raises(ValueError): with mATLASplotlib.canvases.base_canvas.BaseCanvas() as canvas: canvas.get_axis_range("x") diff --git a/tests/canvases/test_canvases_panelled.py b/tests/canvases/test_canvases_panelled.py new file mode 100644 index 0000000..2d68a89 --- /dev/null +++ b/tests/canvases/test_canvases_panelled.py @@ -0,0 +1,114 @@ +import os +import matplotlib +import numpy as np +import pytest +import mATLASplotlib + +def test_panelled_constructor(): + with mATLASplotlib.canvases.Panelled() as canvas: + assert canvas.figure.get_figheight() == 8.0 + assert canvas.figure.get_figwidth() == 6.0 + + +def test_panelled_constructor_n_panels(): + with mATLASplotlib.canvases.Panelled(n_panels=5) as canvas: + assert canvas.n_panels == 5 + assert len([c for c in canvas.figure.get_children() if isinstance(c, matplotlib.axes._axes.Axes)]) == 6 + + +def test_panelled_axis_labels(): + with mATLASplotlib.canvases.Panelled() as canvas: + for axis, label in zip(["x", "y"], ["xlabel", "ylabel"]): + canvas.set_axis_label(axis, label) + assert canvas.get_axis_label(axis) == label + + +def test_panelled_axis_labels_unknown(): + with pytest.raises(ValueError): + with mATLASplotlib.canvases.Panelled() as canvas: + canvas.set_axis_label("imaginary", "test") + with pytest.raises(ValueError): + with mATLASplotlib.canvases.Panelled() as canvas: + canvas.get_axis_label("imaginary") + + +def test_panelled_axis_ranges(): + with mATLASplotlib.canvases.Panelled() as canvas: + for axis, ax_range in zip(["x", "y_plot1", "y_plot0"], [(5, 10), [0, 100], [0, 2]]): + canvas.set_axis_range(axis, ax_range) + assert np.array_equal(canvas.get_axis_range(axis), ax_range) + canvas.set_axis_min(axis, 3) + assert np.array_equal(canvas.get_axis_range(axis), (3, ax_range[1])) + canvas.set_axis_max(axis, 7) + assert np.array_equal(canvas.get_axis_range(axis), [3, 7]) + canvas.save("blank_test_output") + assert os.path.isfile("blank_test_output.pdf") + os.remove("blank_test_output.pdf") + + +def test_panelled_axis_ranges_unknown(): + with mATLASplotlib.canvases.Panelled() as canvas: + with pytest.raises(ValueError): + canvas.set_axis_range("imaginary", (0, 5)) + with pytest.raises(ValueError): + canvas.set_axis_min("imaginary", 0) + with pytest.raises(ValueError): + canvas.set_axis_max("imaginary", 5) + + +def test_panelled_tick_ndp(): + with mATLASplotlib.canvases.Panelled() as canvas: + for axis, ax_range in zip(["x", "y_plot0"], [(5, 10), [0, 100]]): + canvas.set_axis_range(axis, ax_range) + canvas.set_axis_tick_ndp(axis, 2) + assert axis in canvas.axis_tick_ndps + assert canvas.axis_tick_ndps[axis] == 2 + canvas.save("blank_test_output") + assert os.path.isfile("blank_test_output.pdf") + os.remove("blank_test_output.pdf") + + +def test_panelled_plot_dataset(): + with mATLASplotlib.canvases.Panelled() as canvas: + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [5, 10, 12], None, style="line") + assert "x" in canvas.axis_ranges.keys() + + +def test_panelled_plot_dataset_in_panel(): + with mATLASplotlib.canvases.Panelled() as canvas: + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [5, 10, 12], None, style="line", axes="plot1") + assert "x" in canvas.axis_ranges.keys() + + +def test_panelled_set_axis_ticks(): + with mATLASplotlib.canvases.Panelled(n_panels=3) as canvas: + canvas.set_axis_ticks("x", [1, 2, 3]) + assert np.array_equal(canvas.subplots["plot0"].xaxis.get_major_locator()(), [1, 2, 3]) + assert np.array_equal(canvas.subplots["plot1"].xaxis.get_major_locator()(), [1, 2, 3]) + assert np.array_equal(canvas.subplots["plot2"].xaxis.get_major_locator()(), [1, 2, 3]) + canvas.set_axis_ticks("y_plot0", [4, 5, 6]) + canvas.set_axis_range("y_plot0", (0, 10)) + assert np.array_equal(canvas.subplots["plot0"].yaxis.get_major_locator()(), [4, 5, 6]) + with pytest.raises(ValueError): + canvas.set_axis_ticks("imaginary", [0.8, 1.0, 1.2]) + + +def test_panelled_title(): + with mATLASplotlib.canvases.Panelled() as canvas: + canvas.set_title("title") + title_text = [c for c in canvas.subplots["top"].get_children() if isinstance(c, matplotlib.text.Text)][0] + assert title_text.get_text() == "title" + + +def test_panelled_plot_datasets_legends(): + with mATLASplotlib.canvases.Panelled(n_panels=4) as canvas: + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [5, 10, 12], None, style="line", colour="red", axes="plot0", label="red") + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [12, 10, 5], None, style="line", colour="blue", axes="plot1", label="blue") + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [3, 5, 7], None, style="line", colour="green", axes="plot2", label="green") + canvas.plot_dataset([0, 1, 2], [0.5, 0.5, 0.5], [9, 8, 6], None, style="line", colour="orange", axes="plot3", label="orange") + canvas.add_legend(0.1, 0.1, axes="top") + canvas.add_legend(0.1, 0.1, axes="plot0") + legend_element = [c for c in canvas.subplots["top"].get_children() if isinstance(c, matplotlib.legend.Legend)][0] + assert len(legend_element.get_texts()) == 4 + texts = [t.get_text() for t in legend_element.get_texts()] + assert np.array_equal(texts, ["red", "blue", "green", "orange"]) diff --git a/tests/canvases/test_canvases_ratio.py b/tests/canvases/test_canvases_ratio.py index b79a24e..04c3564 100644 --- a/tests/canvases/test_canvases_ratio.py +++ b/tests/canvases/test_canvases_ratio.py @@ -7,16 +7,13 @@ def test_ratio_constructor(): with mATLASplotlib.canvases.Ratio() as canvas: - # Default shape should be square assert canvas.figure.get_figheight() == 6.0 assert canvas.figure.get_figwidth() == 6.0 - # Subplot should be main assert canvas.main_subplot == "top" def test_ratio_constructor_shape(): with mATLASplotlib.canvases.Ratio(shape="landscape") as canvas: - # Default shape should be square assert canvas.figure.get_figheight() == 6.0 assert canvas.figure.get_figwidth() == 8.0 @@ -49,6 +46,7 @@ def test_simple_set_axis_ticks(): with mATLASplotlib.canvases.Ratio() as canvas: canvas.set_axis_ticks("x", [1, 2, 3]) assert np.array_equal(canvas.subplots["top"].xaxis.get_major_locator()(), [1, 2, 3]) + assert np.array_equal(canvas.subplots["bottom"].xaxis.get_major_locator()(), [1, 2, 3]) canvas.set_axis_ticks("y", [4, 5, 6]) assert np.array_equal(canvas.subplots["top"].yaxis.get_major_locator()(), [4, 5, 6]) canvas.set_axis_ticks("y_ratio", [0.8, 1.0, 1.2]) @@ -93,17 +91,25 @@ def test_ratio_axis_labels_unknown(): canvas.get_axis_label("imaginary") +def test_ratio_tick_ndp(): + with mATLASplotlib.canvases.Ratio() as canvas: + for axis, ax_range in zip(["x", "y", "y_ratio"], [(5, 10), [0, 100], [0, 2]]): + canvas.set_axis_range(axis, ax_range) + canvas.set_axis_tick_ndp(axis, 2) + assert axis in canvas.axis_tick_ndps + assert canvas.axis_tick_ndps[axis] == 2 + canvas.save("blank_test_output") + assert os.path.isfile("blank_test_output.pdf") + os.remove("blank_test_output.pdf") + + def test_ratio_axis_ranges(): with mATLASplotlib.canvases.Ratio() as canvas: for axis, ax_range in zip(["x", "y", "y_ratio"], [(5, 10), [0, 100], [0, 2]]): - # Test set_axis_range() canvas.set_axis_range(axis, ax_range) assert np.array_equal(canvas.get_axis_range(axis), ax_range) - # Test set_axis_min() with a tuple canvas.set_axis_min(axis, 3) - assert np.array_equal( - canvas.get_axis_range(axis), (3, ax_range[1])) - # Test set_axis_max() with a list + assert np.array_equal(canvas.get_axis_range(axis), (3, ax_range[1])) canvas.set_axis_max(axis, 7) assert np.array_equal(canvas.get_axis_range(axis), [3, 7]) canvas.save("blank_test_output") @@ -129,17 +135,14 @@ def test_ratio_title(): def test_ratio_save(): - # Test pdf output with mATLASplotlib.canvases.Ratio() as canvas: canvas.save("blank_test_output") assert os.path.isfile("blank_test_output.pdf") os.remove("blank_test_output.pdf") - # Test png output with mATLASplotlib.canvases.Ratio() as canvas: canvas.save("blank_test_output", extension="png") assert os.path.isfile("blank_test_output.png") os.remove("blank_test_output.png") - # Test eps output with mATLASplotlib.canvases.Ratio() as canvas: canvas.save("blank_test_output", extension="eps") assert os.path.isfile("blank_test_output.eps") @@ -154,6 +157,7 @@ def test_ratio_plot_dataset(): assert "y" in canvas.axis_ranges.keys() assert "y_ratio" in canvas.axis_ranges.keys() + def test_ratio_plot_dataset_bottom(): with mATLASplotlib.canvases.Ratio() as canvas: assert canvas.axis_ranges.keys() == ["y_ratio"] diff --git a/tests/canvases/test_canvases_simple.py b/tests/canvases/test_canvases_simple.py index c0c1065..c1e5d13 100644 --- a/tests/canvases/test_canvases_simple.py +++ b/tests/canvases/test_canvases_simple.py @@ -122,6 +122,18 @@ def test_simple_axis_ranges_unknown(): canvas.set_axis_max("imaginary", 5) +def test_simple_tick_ndp(): + with mATLASplotlib.canvases.Simple() as canvas: + for axis, ax_range in zip(["x", "y"], [(5, 10), [0, 100]]): + canvas.set_axis_range(axis, ax_range) + canvas.set_axis_tick_ndp(axis, 2) + assert axis in canvas.axis_tick_ndps + assert canvas.axis_tick_ndps[axis] == 2 + canvas.save("blank_test_output") + assert os.path.isfile("blank_test_output.pdf") + os.remove("blank_test_output.pdf") + + def test_simple_save(): # Test pdf output with mATLASplotlib.canvases.Simple() as canvas: @@ -175,6 +187,7 @@ def test_simple_save_internal_header_fraction(): assert os.path.isfile("blank_test_output.pdf") os.remove("blank_test_output.pdf") + def test_simple_save_internal_header_fraction_log(): with mATLASplotlib.canvases.Simple(x_ticks_extra=[20, 60]) as canvas: canvas.plot_dataset([10, 50, 100], [5, 5, 5], [5, 10, 12], None, style="line", marker="o", markersize=10)