Source code for marxs.visualization.utils

# Licensed under GPL version 3 - see LICENSE.rst
'''This module collects helper functions for visualization backends.

The functions here are not intended to be called directly by the
user. Instead, they refractor common tasks that are used in several
visualization backends.
'''
import warnings
import numpy as np

from marxs.math import utils as mutils


[docs] class MARXSVisualizationWarning(Warning): '''Warning class for MARXS objects missing from plotting''' pass
[docs] def get_obj_name(obj): '''Return printable name for objects or functions.''' if hasattr(obj, 'name'): return obj.name elif hasattr(obj, 'func_name'): return obj.func_name else: return str(obj)
[docs] class DisplayDict(dict): '''A dictionary to store how an element is displayed in plotting. A dictionary of this type works just like a normal dictionary, except for an additional look-up step for keys that are not found in the dictionary itself. A ``DisplayDict`` is initialized with a reference to the object it describes and any parameter accessed from ``DisplayDict`` that is not found in the dictionary will be searched for in the object's geometry. This allows us to set any and all display settings in the ``DisplayDict`` to customize plotting in any way without affecting how the ray-trace is run (which uses only the parameters set in the geoemtry), but for those values that are not set, fall back to the settings of the geometry (e.g. the shape of an object is typically taken from the geometry, while the color is not). Parameters ---------- parent : `marxs.base.MarxsElement` Reference to the object that is described by this ``DisplayDict`` args, kwargs: see `dict` ''' def __init__(self, parent, *args, **kwargs): self.parent = parent super().__init__(*args, **kwargs) def __getitem__(self, key): if (key not in self) and hasattr(self.parent, 'geometry'): try: return getattr(self.parent.geometry, key) except AttributeError: raise KeyError(key) else: return super().__getitem__(key)
[docs] def get(self, k, d=None): '''D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.''' try: return self[k] except KeyError: return d
[docs] def plot_object_general(plot_registry, obj, display=None, **kwargs): '''Look up a plotting routine for an object and execute it. This function is not meant to be called directly by the user, instead, it is designed to simplify the implementation of new plotting backends. Parameters ---------- plot_registry : dict Keys are the names of the shape of an object and values in this dictionary are functions that know how to plot this type of shape. The appropriate plotting function is then called with the input `obj`, `display` and any other keyword arguments. If the shape is ``"None"`` (as a string), no plotting function is called. obj : `marxs.base.MarxsElement` The element that should be plotted. display : dict of None Dictionary with display settings. If this is ``None``, ``obj.display`` is used. If that is also ``None`` then the objects is skipped. kwargs : other keyword arguments These arguments are just passed through to the plotting function. Returns ------- out : backend-dependent The output from the plotting function that was executed is passed through. Different plotting backends return different kinds of output. ''' if display is None: if hasattr(obj, 'display') and (obj.display is not None): display = obj.display else: warnings.warn('Skipping {0}: No display dictionary found.'.format( get_obj_name(obj)), MARXSVisualizationWarning) return None try: shape = display['shape'] except KeyError: warnings.warn('Skipping {0}: "shape" not set in display dict.'.format( get_obj_name(obj)), MARXSVisualizationWarning) return None shapes = [s.strip() for s in shape.split(';')] for s in shapes: if s == 'None': return None elif s in plot_registry: # turn into valid color tuple display['color'] = get_color(display) return plot_registry[s](obj, display, **kwargs) else: warnings.warn('Skipping {0}: No function to plot {1}.'.format( get_obj_name(obj), shape), MARXSVisualizationWarning) return None
[docs] def get_color(d): '''Look for color information in dictionary. If missing, return white. This function checks if the `d['color']` is a valid RGB tuple and if not it imports a `matplotlib.colors.ColorConverter` to convert any matplotlib compatible string to an RGB tuple. Parameters ---------- d : dict Color information should be present in ``d['color']``. Returns ------- color : tuple RGB tuple with each element in the range 0..1 ''' if 'color' not in d: return (1., 1., 1.) else: c = d['color'] # check if this is a tuple of three floats n the range 0..1 # or can be converted to one try: cout = tuple(c) if len(cout) != 3: raise TypeError for a in cout: if not isinstance(a, float) or (a < 0.) or (a > 1.): raise TypeError return cout except TypeError: # It's a hex or string. Let matplotlib deal with that. import matplotlib.colors return matplotlib.colors.colorConverter.to_rgb(c)
[docs] def color_tuple_to_hex(color): '''Convert color tuple to hex string. Parameters ---------- color : tuple tuple has three elements (rgb) that are floats between 0 and 1 or ints between 0 and 255. Returns ------- hexstring : string string encoding that number as hex ''' if all([isinstance(a, float) for a in color]): if any(i < 0. for i in color) or any(i > 1. for i in color): raise ValueError('Float values in color tuple must be between 0 and 1.') out = hex(int(color[0] * 256**2 * 255 + color[1] * 256 * 255 + color[2] * 255)) elif all([isinstance(a, int) for a in color]): if any(i < 0 for i in color) or any(i > 255 for i in color): raise ValueError('Int values in color tuple must be between 0 and 255.') out = hex(color[0] * 256**2 + color[1] * 256 + color[2]) else: raise ValueError('Input tuple must be all float or all int.') # Now pad with zeros if required return out[:2] + out[2:].zfill(6)
[docs] def plane_with_hole(outer, inner): '''Triangulation of a plane with an inner hole This function constructs a triangulation for a plane with an inner hole, e.g. a rectangular plane where an inner circle is cut out. Parameters ---------- outer, inner : np.ndarray of shape (n, 3) Coordinates in x,y,z of points that define the inner and outer boundary. ``outer`` and ``inner`` can have a different number of points, but points need to be listed in the same orientation (e.g. clockwise) for both and the starting points need to have a similar angle as seen from the center (e.g. for a plane with z=0, both ``outer`` and ``inner`` could list a point close to the y-axis first. Returns ------- xyz : nd.array stacked ``outer`` and ``inner``. triangles : nd.array List of the indices. Each row has the index of three points in ``xyz``. Examples -------- In this example, we make a square and cut out a smaller square in the middle. >>> import numpy as np >>> from marxs.visualization.utils import plane_with_hole >>> outer = np.array([[-1, -1, 1, 1], [-1, 1, 1, -1], [0,0,0,0]]).T >>> inner = 0.5 * outer >>> xyz, triangles = plane_with_hole(outer, inner) >>> triangles array([[0, 4, 5], [0, 1, 5], [1, 5, 6], [1, 2, 6], [2, 6, 7], [2, 3, 7], [3, 7, 4], [3, 0, 4]]) ''' n_out = outer.shape[0] n_in = inner.shape[0] n = n_out + n_in triangles = np.zeros((n, 3), dtype=int) xyz = np.vstack([outer, inner]) i_in = 0 i_out = 0 for i in range(n_out + n_in): if i/n >= i_in/n_in: triangles[i, :] = [i_out, n_out + i_in, n_out + ((i_in + 1) % n_in)] i_in += 1 else: triangles[i, :] = [i_out, (i_out + 1) % n_out, n_out + (i_in % n_in)] i_out = (i_out + 1) % n_out return xyz, triangles
[docs] def combine_disjoint_triangulations(list_xyz, list_triangles): '''Combine two disjoint triangulations into one set of points This function combines two entirely separate triangulations into one set of point and triangles. Plotting the combined triangulation should have the same effect as plotting each triangulation separately. This function is used for plotting apertures where we have e.g. an open ring. This can be plotted as an inner circle plus an outer shape with a hole in it. Parameters ---------- list_xyz : list of `np.array` Each array holds xyz values for one triangulation list_triangles : list of nd.array Each array holds the list of the indices for one triangulation. Returns ------- xyz : nd.array stacked ``outer`` and ``inner``. triangles : nd.array List of the indices. Each row has the index of three points in ``xyz``. ''' xyz = np.vstack(list_xyz) n_offset = np.cumsum([a.shape[0] for a in list_xyz]) n_offset -= n_offset[0] triangles = np.vstack([list_triangles[i] + n_offset[i] for i in range(len(n_offset))]) return xyz, triangles
[docs] def triangulate_parametricsurface(xyz): xyz = mutils.h2e(xyz) xyz = xyz.reshape(-1, 3) index = np.arange(xyz.shape[0] - 2) indarr = np.vstack([index, index + 1, index + 2]).T return xyz, indarr
[docs] def halfbox_corners(obj, display): corners = np.array([[-1, -1, -1], [-1,+1, -1], [-1, -1, 1], [-1, 1, 1], [ 1, -1, -1], [ 1, 1, -1], [ 1, -1, +1], [ 1, 1, +1]]) if 'box-half' in display: # write in a way that it works with any value for that keyword try: if display['box-half'][0] == '+': factor = +1 elif display['box-half'][0] == '-': factor = -1 else: factor = 0 xyz = {'x': 0, 'y': 1, 'z': 2} if display['box-half'][1] in xyz: j = xyz[display['box-half'][1]] corners[corners[:, j] == factor, j] = 0 except: pass corners = np.einsum('ij,...j->...i', obj.pos4d, mutils.e2h(corners, 1)) corners = mutils.h2e(corners) return corners