22import logging
33import math
44import matplotlib
5+ import numpy as np
56from .. import style
67from ..converters import Dataset
8+ from ..formatters import force_extra_ticks
79from ..plotters import get_plotter
810from ..decorations import draw_ATLAS_text , draw_text , Legend
911
1315class BaseCanvas (object ):
1416 """Base class for canvas properties."""
1517
18+ #: Map of locations to matplotlib coordinates
1619 location_map = {"upper right" : ["right" , "top" ],
1720 "upper left" : ["left" , "top" ],
1821 "centre left" : ["left" , "center" ],
1922 "centre right" : ["right" , "center" ],
2023 "lower right" : ["right" , "bottom" ],
2124 "lower left" : ["left" , "bottom" ]}
2225
26+ #: List of sensible tick intervals
27+ auto_tick_intervals = [0.001 , 0.002 , 0.0025 , 0.004 , 0.005 ,
28+ 0.01 , 0.02 , 0.025 , 0.04 , 0.05 ,
29+ 0.1 , 0.2 , 0.25 , 0.4 , 0.5 ,
30+ 1.0 , 2.0 , 2.5 , 4.0 , 5.0 ]
31+
2332 def __init__ (self , shape = "square" , ** kwargs ):
2433 """Set up universal canvas properties.
2534
26- :param shape: use either the 'square' or 'landscape ' ATLAS proportions
35+ :param shape: use either the 'square', 'landscape' or 'portrait ' ATLAS proportions
2736 :type shape: str
2837
2938 :Keyword Arguments:
@@ -37,7 +46,7 @@ def __init__(self, shape="square", **kwargs):
3746 # Set ATLAS style
3847 style .set_atlas ()
3948 # Set up figure
40- n_pixels = {"square" : (600 , 600 ), "landscape" : (800 , 600 )}[shape ]
49+ n_pixels = {"square" : (600 , 600 ), "landscape" : (800 , 600 ), "portrait" : ( 600 , 800 ) }[shape ]
4150 self .figure = matplotlib .pyplot .figure (figsize = (n_pixels [0 ] / 100.0 , n_pixels [1 ] / 100.0 ), dpi = 100 , facecolor = "white" )
4251 self .main_subplot = None
4352 # Set properties from arguments
@@ -48,6 +57,7 @@ def __init__(self, shape="square", **kwargs):
4857 # Set up value holders
4958 self .legend = Legend ()
5059 self .axis_ranges = {}
60+ self .axis_tick_ndps = {}
5161 self .subplots = {}
5262 self .internal_header_fraction = None
5363
@@ -79,14 +89,14 @@ def plot_dataset(self, *args, **kwargs):
7989 * **label**: (*str*) -- label to use in automatic legend generation
8090 * **sort_as**: (*str*) -- override
8191 """
82- axes = kwargs .pop ("axes" , self .main_subplot )
92+ subplot_name = kwargs .pop ("axes" , self .main_subplot )
8393 plot_style = kwargs .pop ("style" , None )
8494 remove_zeros = kwargs .pop ("remove_zeros" , False )
8595 dataset = Dataset (* args , remove_zeros = remove_zeros , ** kwargs )
8696 plotter = get_plotter (plot_style )
8797 if "label" in kwargs :
8898 self .legend .add_dataset (label = kwargs ["label" ], is_stack = ("stack" in plot_style ), sort_as = kwargs .pop ("sort_as" , None ))
89- plotter .add_to_axes (dataset = dataset , axes = self .subplots [axes ], ** kwargs )
99+ plotter .add_to_axes (dataset = dataset , axes = self .subplots [subplot_name ], ** kwargs )
90100
91101 def add_legend (self , x , y , anchor_to = "lower left" , fontsize = None , axes = None ):
92102 """Add a legend to the canvas at (x, y).
@@ -102,9 +112,8 @@ def add_legend(self, x, y, anchor_to="lower left", fontsize=None, axes=None):
102112 :param axes: which of the different axes in this canvas to use.
103113 :type axes: str
104114 """
105- if axes is None :
106- axes = self .main_subplot
107- self .legend .plot (x , y , self .subplots [axes ], anchor_to , fontsize )
115+ subplot_name = self .main_subplot if axes is None else axes
116+ self .legend .plot (x , y , self .subplots [subplot_name ], anchor_to , fontsize )
108117
109118 def add_ATLAS_label (self , x , y , plot_type = None , anchor_to = "lower left" , fontsize = None , axes = None ):
110119 """Add an ATLAS label to the canvas at (x, y).
@@ -122,11 +131,8 @@ def add_ATLAS_label(self, x, y, plot_type=None, anchor_to="lower left", fontsize
122131 :param axes: which of the different axes in this canvas to use.
123132 :type axes: str
124133 """
125- if axes is None :
126- axes = self .main_subplot
127- # ha, va = self.location_map[anchor_to]
128- # draw_ATLAS_text(x, y, self.subplots[axes], ha=ha, va=va, plot_type=plot_type, fontsize=fontsize)
129- draw_ATLAS_text (self .subplots [axes ], (x , y ), self .location_map [anchor_to ], plot_type = plot_type , fontsize = fontsize )
134+ subplot_name = self .main_subplot if axes is None else axes
135+ draw_ATLAS_text (self .subplots [subplot_name ], (x , y ), self .location_map [anchor_to ], plot_type = plot_type , fontsize = fontsize )
130136
131137 def add_luminosity_label (self , x , y , sqrts_TeV , luminosity , units = "fb-1" , anchor_to = "lower left" , fontsize = 14 , axes = None ):
132138 """Add a luminosity label to the canvas at (x, y).
@@ -148,13 +154,12 @@ def add_luminosity_label(self, x, y, sqrts_TeV, luminosity, units="fb-1", anchor
148154 :param axes: which of the different axes in this canvas to use.
149155 :type axes: str
150156 """
151- if axes is None :
152- axes = self .main_subplot
157+ subplot_name = self .main_subplot if axes is None else axes
153158 text_sqrts = r"$\sqrt{\mathsf{s}} = " + \
154159 str ([sqrts_TeV , int (1000 * sqrts_TeV )][sqrts_TeV < 1.0 ]) + \
155160 r"\,\mathsf{" + ["TeV" , "GeV" ][sqrts_TeV < 1.0 ] + "}"
156161 text_lumi = "$" if luminosity is None else ", $" + str (luminosity ) + " " + units .replace ("-1" , "$^{-1}$" )
157- draw_text (text_sqrts + text_lumi , self .subplots [axes ], (x , y ), self .location_map [anchor_to ], fontsize = fontsize )
162+ draw_text (text_sqrts + text_lumi , self .subplots [subplot_name ], (x , y ), self .location_map [anchor_to ], fontsize = fontsize )
158163
159164 def add_text (self , x , y , text , ** kwargs ):
160165 """Add text to the canvas at (x, y).
@@ -166,9 +171,9 @@ def add_text(self, x, y, text, **kwargs):
166171 :param text: text to add.
167172 :type text: str
168173 """
169- axes = kwargs .pop ("axes" , self .main_subplot )
174+ subplot_name = kwargs .pop ("axes" , self .main_subplot )
170175 anchor_to = kwargs .pop ("anchor_to" , "lower left" )
171- draw_text (text , self .subplots [axes ], (x , y ), self .location_map [anchor_to ], ** kwargs )
176+ draw_text (text , self .subplots [subplot_name ], (x , y ), self .location_map [anchor_to ], ** kwargs )
172177
173178 def save (self , output_name , extension = "pdf" ):
174179 """Save the current state of the canvas to a file.
@@ -242,6 +247,16 @@ def set_axis_ticks(self, axis_name, ticks):
242247 """
243248 raise NotImplementedError ("set_axis_ticks not defined by {0}" .format (type (self )))
244249
250+ def set_axis_tick_ndp (self , axis_name , ndp ):
251+ """Set number of decimal places to show.
252+
253+ :param axis_name: which axis to apply this to.
254+ :type axis_name: str
255+ :param ndp: how many decimal places to show.
256+ :type ndp: int
257+ """
258+ self .axis_tick_ndps [axis_name ] = ndp
259+
245260 def set_axis_log (self , axis_names ):
246261 """Set the specified axis to be on a log-scale.
247262
@@ -280,39 +295,40 @@ def y_tick_label_size(self):
280295
281296 def __finalise_plot_formatting (self ):
282297 """Finalise plot by applying previously requested formatting."""
283- for _ , axes in self .subplots .items ():
298+ for _ , subplot in self .subplots .items ():
284299 # Apply axis limits
285300 self ._apply_axis_limits ()
286301 # Draw x ticks
287302 if self .x_tick_labels is not None :
288- x_interval = (max (axes .get_xlim ()) - min (axes .get_xlim ())) / (len (self .x_tick_labels ))
289- axes .xaxis .set_major_locator (matplotlib .ticker .MultipleLocator (x_interval ))
303+ x_interval = (max (subplot .get_xlim ()) - min (subplot .get_xlim ())) / (len (self .x_tick_labels ))
304+ subplot .xaxis .set_major_locator (matplotlib .ticker .MultipleLocator (x_interval ))
290305 tmp_kwargs = {"fontsize" : self .x_tick_label_size } if self .x_tick_label_size is not None else {}
291- axes .set_xticklabels (["" ] + self .x_tick_labels , ** tmp_kwargs ) # the first and last ticks are off the scale so add a dummy label
306+ subplot .set_xticklabels (["" ] + self .x_tick_labels , ** tmp_kwargs ) # the first and last ticks are off the scale so add a dummy label
292307 # Draw y ticks
293308 if self .y_tick_labels is not None :
294- y_interval = (max (axes .get_ylim ()) - min (axes .get_ylim ())) / (len (self .y_tick_labels ))
295- axes .yaxis .set_major_locator (matplotlib .ticker .MultipleLocator (y_interval ))
309+ y_interval = (max (subplot .get_ylim ()) - min (subplot .get_ylim ())) / (len (self .y_tick_labels ))
310+ subplot .yaxis .set_major_locator (matplotlib .ticker .MultipleLocator (y_interval ))
296311 tmp_kwargs = {"fontsize" : self .y_tick_label_size } if self .y_tick_label_size is not None else {}
297- axes .set_yticklabels (["" ] + self .y_tick_labels , ** tmp_kwargs ) # the first and last ticks are off the scale so add a dummy label
312+ subplot .set_yticklabels (["" ] + self .y_tick_labels , ** tmp_kwargs ) # the first and last ticks are off the scale so add a dummy label
313+
298314 # Set x-axis locators
299315 if "x" in self .log_type :
300- xlocator = axes .xaxis .get_major_locator ()
301- axes .set_xscale ("log" , subsx = [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ])
302- axes .yaxis .set_major_locator (xlocator )
303- axes .xaxis .set_major_formatter (matplotlib .ticker .ScalarFormatter ())
304- axes .xaxis .set_minor_formatter (matplotlib .ticker .FuncFormatter (self .__force_extra_x_ticks )) # only show certain minor labels
316+ xlocator = subplot .xaxis .get_major_locator ()
317+ subplot .set_xscale ("log" , subsx = [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ])
318+ subplot .yaxis .set_major_locator (xlocator )
319+ subplot .xaxis .set_major_formatter (matplotlib .ticker .ScalarFormatter ())
320+ subplot .xaxis .set_minor_formatter (matplotlib .ticker .FuncFormatter (force_extra_ticks ( self .x_ticks_extra ) )) # only show certain minor labels
305321 else :
306- axes .xaxis .set_minor_locator (matplotlib .ticker .AutoMinorLocator ())
322+ subplot .xaxis .set_minor_locator (matplotlib .ticker .AutoMinorLocator ())
307323 # Set y-axis locators
308324 if "y" in self .log_type :
309- locator = axes .yaxis .get_major_locator ()
310- axes .set_yscale ("log" )
311- axes .yaxis .set_major_locator (locator )
325+ locator = subplot .yaxis .get_major_locator ()
326+ subplot .set_yscale ("log" )
327+ subplot .yaxis .set_major_locator (locator )
312328 fixed_minor_points = [10 ** x * val for x in range (- 100 , 100 ) for val in [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]]
313- axes .yaxis .set_minor_locator (matplotlib .ticker .FixedLocator (fixed_minor_points ))
329+ subplot .yaxis .set_minor_locator (matplotlib .ticker .FixedLocator (fixed_minor_points ))
314330 else :
315- axes .yaxis .set_minor_locator (matplotlib .ticker .AutoMinorLocator ())
331+ subplot .yaxis .set_minor_locator (matplotlib .ticker .AutoMinorLocator ())
316332
317333 # Finish by adding internal header
318334 if self .internal_header_fraction is not None :
@@ -333,21 +349,6 @@ def _apply_final_formatting(self):
333349 """Apply any necessary final formatting."""
334350 pass
335351
336- def __force_extra_x_ticks (self , x , pos ):
337- """Implement user-defined tick positions.
338-
339- :param x: tick value.
340- :type x: float
341- :param pos: position.
342- :type pos: float
343- :return: formatted tick position string
344- :rtype: str
345- """
346- del pos # this function signature is required by FuncFormatter
347- if any (int (x ) == elem for elem in self .x_ticks_extra ):
348- return "{0:.0f}" .format (x )
349- return ""
350-
351352 def get_axis_label (self , axis_name ):
352353 """Get the label for the chosen axis
353354
@@ -370,3 +371,21 @@ def get_axis_range(self, axis_name):
370371 return self .axis_ranges [axis_name ]
371372 else :
372373 raise ValueError ("axis {0} not recognised by {1}" .format (axis_name , type (self )))
374+
375+ def _get_auto_axis_ticks (self , axis_name , n_approximate = 4 ):
376+ """Choose axis ticks to be sensibly spaced and always include 1.0.
377+
378+ :param axis_name: name of axis to work on
379+ :type axis_name: str
380+ :param n_approximate: approximate number of ticks to use.
381+ :type n_approximate: int
382+ :return: list of tick positions
383+ :rtype: list
384+ """
385+ # Underestimate the interval size since we might be removing the highest tick
386+ interval = 0.99 * abs (self .axis_ranges [axis_name ][1 ] - self .axis_ranges [axis_name ][0 ])
387+ tick_size = min (self .auto_tick_intervals , key = lambda x : abs ((interval / x ) - n_approximate ))
388+ tick_list = np .arange (1.0 - 10 * tick_size , 1.0 + 10 * tick_size , tick_size )
389+ # Remove topmost tick if it would be at the top of the axis
390+ tick_list = [t for t in tick_list if not np .allclose (t , self .axis_ranges [axis_name ][1 ])]
391+ return tick_list
0 commit comments