Source code for genesis.recorders.plotters

import io
import itertools
import logging
import sys
import threading
import time
from collections import defaultdict
from collections.abc import Sequence
from functools import partial, cached_property
from typing import Any, Callable, TypeVar

import numpy as np
import torch
from PIL import Image

import genesis as gs
import genesis.utils.geom as gu
from genesis.options.recorders import (
    BasePlotterOptions,
    LinePlotterMixinOptions,
    PyQtLinePlot as PyQtLinePlotterOptions,
    MPLLinePlot as MPLLinePlotterOptions,
    MPLImagePlot as MPLImagePlotterOptions,
    MPLVectorFieldPlot as MPLVectorFieldPlotterOptions,
)
from genesis.utils import has_display, tensor_to_array

from .base_recorder import Recorder
from .recorder_manager import RecorderManager, register_recording

IS_PYQTGRAPH_AVAILABLE = False
try:
    import pyqtgraph as pg

    IS_PYQTGRAPH_AVAILABLE = True
except ImportError:
    pass

IS_MATPLOTLIB_AVAILABLE = False
try:
    import matplotlib as mpl

    IS_MATPLOTLIB_AVAILABLE = tuple(map(int, mpl.__version__.replace("+", ".").split(".")[:3])) >= (3, 7, 0)
except ImportError:
    pass


LOGGER = logging.getLogger(__name__)


MPL_PLOTTER_RESCALE_MIN_X = 0.5
MPL_PLOTTER_RESCALE_RATIO_X = 0.15
MPL_PLOTTER_RESCALE_RATIO_Y = 0.15

COLORS = itertools.cycle(("r", "g", "b", "c", "m", "y"))


T = TypeVar("T")


def _data_to_array(data: Sequence) -> np.ndarray:
    if isinstance(data, torch.Tensor):
        data = tensor_to_array(data)
    return np.atleast_1d(data)


[docs]class BasePlotter(Recorder): def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]): if options.show_window is None: options.show_window = has_display() super().__init__(manager, options, data_func) self._frames_buffer: list[np.ndarray] = []
[docs] def build(self): super().build() self.video_writer = None if self._options.save_to_filename: def _get_video_frame_buffer(plotter): # Wait for the plotter to produce a frame. When the plotter runs in a background thread, # it may have already dequeued data but not yet appended the rendered frame to the buffer. # When not threaded, frames are produced synchronously before this call, so an empty # buffer means something went wrong — the None check handles that case too. while not plotter._frames_buffer: if plotter._processor_thread is None or not plotter._processor_thread.is_alive(): gs.raise_exception( f"[{type(plotter).__name__}] No frame available and plotter thread is not running." ) time.sleep(0.01) return plotter._frames_buffer.pop(0) self.video_writer = self._manager.add_recorder( data_func=partial(_get_video_frame_buffer, self), rec_options=gs.recorders.VideoFile( filename=self._options.save_to_filename, hz=self._options.hz, ), )
[docs] def process(self, data, cur_time): # Update plot self._update_plot() # Render frame if necessary if self._options.save_to_filename: self._frames_buffer.append(self.get_image_array())
[docs] def cleanup(self): if self.video_writer is not None: self.video_writer.stop() self._frames_buffer.clear() self.video_writer = None
def _update_plot(self): """ Update plot. """ raise NotImplementedError(f"[{type(self).__name__}] _update_plot() is not implemented.")
[docs] def get_image_array(self): """ Capture the plot image as a video frame. Returns ------- image_array : np.ndarray The RGB image as a numpy array. """ raise NotImplementedError(f"[{type(self).__name__}] get_image_array() is not implemented.")
[docs]class LinePlotHelper: """ Helper class that manages line plot data. Use composition pattern. """ def __init__(self, options: LinePlotterMixinOptions, data: dict[str, Sequence] | Sequence): self.x_data: list[float] = [] self.y_data: defaultdict[str, defaultdict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) self._history_length = options.history_length # Note that these attributes will be set during first data processing or initialization self._is_dict_data: bool | None = None self._subplot_structure: dict[str, tuple[str, ...]] = {} if isinstance(data, dict): self._is_dict_data = True if options.labels is not None: assert isinstance(options.labels, dict), ( f"[{type(self).__name__}] Labels must be a dict when data is a dict" ) assert set(options.labels.keys()) == set(data.keys()), ( f"[{type(self).__name__}] Label keys must match data keys" ) for key in data.keys(): data_values = _data_to_array(data[key]) label_values = options.labels[key] assert len(label_values) == len(data_values), ( f"[{type(self).__name__}] Label count must match data count for key '{key}'" ) self._subplot_structure[key] = tuple(label_values) else: self._subplot_structure = {} for key, values in data.items(): values = _data_to_array(values) self._subplot_structure[key] = tuple(f"{key}_{i}" for i in range(len(values))) else: self._is_dict_data = False data = _data_to_array(data) if options.labels is not None: labels = options.labels if isinstance(options.labels, Sequence) else (options.labels,) assert len(labels) == len(data), f"[{type(self).__name__}] Label count must match data count" plot_labels = tuple(labels) else: plot_labels = tuple(f"data_{i}" for i in range(len(data))) self._subplot_structure = {"main": plot_labels}
[docs] def clear_data(self): self.x_data.clear() self.y_data.clear()
[docs] def process(self, data, cur_time): """Process new data point and update plot.""" if self._is_dict_data: processed_data = {} for key, values in data.items(): if key not in self._subplot_structure: continue # skip keys not included in subplot structure values = _data_to_array(values) processed_data[key] = values else: data = _data_to_array(data) processed_data = {"main": data} # Update time data self.x_data.append(cur_time) # Update y data for each subplot for subplot_key, subplot_data in processed_data.items(): channel_labels = self._subplot_structure[subplot_key] if len(subplot_data) != len(channel_labels): gs.logger.warning( f"[{type(self).__name__}] Data length ({len(subplot_data)}) doesn't match " f"expected number of channels ({len(channel_labels)}) for subplot '{subplot_key}', skipping..." ) continue for i, channel_label in enumerate(channel_labels): if i < len(subplot_data): self.y_data[subplot_key][channel_label].append(float(subplot_data[i])) # Maintain rolling history window if len(self.x_data) > self._history_length: self.x_data.pop(0) for subplot_key in self.y_data: for channel_label in self.y_data[subplot_key]: try: self.y_data[subplot_key][channel_label].pop(0) except IndexError: break # empty, nothing to do.
@property def history_length(self): return self._history_length @property def is_dict_data(self): return self._is_dict_data @property def subplot_structure(self): return self._subplot_structure
[docs]class BasePyQtPlotter(BasePlotter): """ Base class for PyQt based plotters. """ def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]): super().__init__(manager, options, data_func) if threading.current_thread() is not threading.main_thread(): gs.raise_exception("Impossible to run PyQtPlotter in background thread.")
[docs] def build(self): if not IS_PYQTGRAPH_AVAILABLE: gs.raise_exception( f"{type(self).__name__} pyqtgraph is not installed. Please install it with `pip install pyqtgraph`." ) super().build() self.app: pg.QtWidgets.QApplication | None = None self.widget: pg.GraphicsLayoutWidget | None = None self.plot_widgets: list[pg.PlotWidget] = [] if not pg.QtWidgets.QApplication.instance(): self.app = pg.QtWidgets.QApplication([]) else: self.app = pg.QtWidgets.QApplication.instance() self.widget = pg.GraphicsLayoutWidget(show=self._options.show_window, title=self._options.title) if self._options.show_window: gs.logger.info(f"[{type(self).__name__}] created PyQtGraph window") self.widget.resize(*self._options.window_size)
[docs] def cleanup(self): super().cleanup() if self.widget: try: self.widget.close() (gs.logger or LOGGER).debug(f"[{type(self).__name__}] closed PyQtGraph window") except Exception as e: (gs.logger or LOGGER).warning(f"[{type(self).__name__}] Error closing window: {e}") finally: self.plot_widgets.clear() self.widget = None
@property def run_in_thread(self) -> bool: return False
[docs] def get_image_array(self): """ Capture the plot image as a video frame. Returns ------- image_array : np.ndarray The image as a numpy array in (b,g,r,a) format. """ pixmap = self.widget.grab() qimage = pixmap.toImage() # pyqtgraph provides imageToArray but it always outputs (b,g,r,a) format # https://pyqtgraph.readthedocs.io/en/latest/api_reference/functions.html#pyqtgraph.functions.imageToArray return pg.imageToArray(qimage, copy=True, transpose=True)
[docs]@register_recording(PyQtLinePlotterOptions) class PyQtLinePlotter(BasePyQtPlotter):
[docs] def build(self): super().build() self.line_plot = LinePlotHelper(options=self._options, data=self._data_func()) self.curves: dict[str, list[pg.PlotCurveItem]] = {} # create plots for each subplot for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()): # add new row if not the first plot if subplot_idx > 0: self.widget.nextRow() plot_widget = self.widget.addPlot(title=subplot_key if self.line_plot.is_dict_data else self._options.title) plot_widget.setLabel("bottom", self._options.x_label) plot_widget.setLabel("left", self._options.y_label) plot_widget.showGrid(x=True, y=True, alpha=0.3) plot_widget.addLegend() # create lines for this subplot subplot_curves = [] for color, channel_label in zip(COLORS, channel_labels): curve = plot_widget.plot(pen=pg.mkPen(color=color, width=2), name=channel_label) subplot_curves.append(curve) self.plot_widgets.append(plot_widget) if self._options.show_window: plot_widget.show() self.curves[subplot_key] = subplot_curves
[docs] def process(self, data, cur_time): self.line_plot.process(data, cur_time) super().process(data, cur_time)
def _update_plot(self): # update all curves for subplot_key, curves in self.curves.items(): channel_labels = self.line_plot.subplot_structure[subplot_key] for curve, channel_label in zip(curves, channel_labels): curve.setData(x=self.line_plot.x_data, y=self.line_plot.y_data[subplot_key][channel_label]) if self.app: self.app.processEvents()
[docs] def cleanup(self): super().cleanup() self.line_plot.clear_data() self.curves.clear()
[docs]class BaseMPLPlotter(BasePlotter): """ Base class for matplotlib based plotters. """ def __init__(self, manager: "RecorderManager", options: BasePlotterOptions, data_func: Callable[[], T]): super().__init__(manager, options, data_func) if threading.current_thread() is not threading.main_thread(): gs.raise_exception("Impossible to run MPLPlotter in background thread.")
[docs] def build(self): if not IS_MATPLOTLIB_AVAILABLE: gs.raise_exception( f"{type(self).__name__} matplotlib is not installed. Please install it with `pip install matplotlib>=3.7.0`." ) super().build() import matplotlib.pyplot as plt self.fig: plt.Figure | None = None self._lock = threading.Lock() # matplotlib figsize uses inches dpi = mpl.rcParams.get("figure.dpi", 100) self.figsize = (self._options.window_size[0] / dpi, self._options.window_size[1] / dpi)
def _show_fig(self): if self._options.show_window: self.fig.show() gs.logger.info(f"[{type(self).__name__}] created matplotlib window")
[docs] def cleanup(self): """Clean up matplotlib resources.""" super().cleanup() # Logger may not be available anymore logger_exists = hasattr(gs, "logger") if self.fig is not None: try: import matplotlib.pyplot as plt plt.close(self.fig) if logger_exists: (gs.logger or LOGGER).debug(f"[{type(self).__name__}] Closed matplotlib window") except Exception as e: if logger_exists: (gs.logger or LOGGER).warning(f"[{type(self).__name__}] Error closing window: {e}") finally: self.fig = None
[docs] def get_image_array(self): """ Capture the plot image as a video frame. Returns ------- image_array : np.ndarray The RGB image as a numpy array. """ from matplotlib.backends.backend_agg import FigureCanvasAgg self._lock.acquire() if isinstance(self.fig.canvas, FigureCanvasAgg): # Read internal buffer width, height = self.fig.canvas.get_width_height(physical=True) rgba_array_flat = np.frombuffer(self.fig.canvas.buffer_rgba(), dtype=np.uint8) rgb_array = rgba_array_flat.reshape((height, width, 4))[..., :3] # Rescale image if necessary if (width, height) != tuple(self._options.window_size): img = Image.fromarray(rgb_array) img = img.resize(self._options.window_size, resample=Image.BILINEAR) rgb_array = np.asarray(img) else: rgb_array = rgb_array.copy() else: # Slower but more generic fallback only if necessary buffer = io.BytesIO() self.fig.canvas.print_figure(buffer, format="png", dpi="figure") buffer.seek(0) img = Image.open(buffer) rgb_array = np.asarray(img.convert("RGB")) self._lock.release() return rgb_array
@cached_property def run_in_thread(self) -> bool: from matplotlib.backends.backend_agg import FigureCanvasAgg if sys.platform == "darwin": return False if self._is_built: assert self.fig is not None # All Agg-based backends derives from the surfaceless Agg backend, so 'isinstance' cannot be used to # discriminate the latter from others. return type(self.fig.canvas) is FigureCanvasAgg return not self._options.show_window
[docs]@register_recording(MPLLinePlotterOptions) class MPLLinePlotter(BaseMPLPlotter):
[docs] def build(self): super().build() self.line_plot = LinePlotHelper(options=self._options, data=self._data_func()) import matplotlib.pyplot as plt self.axes: list[plt.Axes] = [] self.lines: dict[str, list[plt.Line2D]] = {} self.caches_bbox: list[Any] = [] self.cache_xmax: float = -1 # Create figure and subplots n_subplots = len(self.line_plot.subplot_structure) if n_subplots == 1: self.fig, ax = plt.subplots(figsize=self.figsize) self.axes = [ax] else: self.fig, axes = plt.subplots(n_subplots, 1, figsize=self.figsize, sharex=True, constrained_layout=True) self.axes = axes if isinstance(axes, (list, tuple, np.ndarray)) else [axes] self.fig.suptitle(self._options.title) # Create lines for each subplot for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()): ax = self.axes[subplot_idx] ax.set_xlabel(self._options.x_label) ax.set_ylabel(self._options.y_label) ax.grid(True, alpha=0.3) if self.line_plot.is_dict_data and n_subplots > 1: ax.set_title(subplot_key) subplot_lines = [] for color, channel_label in zip(COLORS, channel_labels): (line,) = ax.plot([], [], color=color, label=channel_label, linewidth=2) subplot_lines.append(line) self.lines[subplot_key] = subplot_lines # Legend must be outside, otherwise it will not play well with blitting self.fig.legend(ncol=sum(map(len, self.lines.values())), loc="outside lower center") self.fig.canvas.draw() for ax in self.axes: self.caches_bbox.append(self.fig.canvas.copy_from_bbox(ax.bbox)) self._show_fig()
[docs] def process(self, data, cur_time): self.line_plot.process(data, cur_time) super().process(data, cur_time)
def _update_plot(self): self._lock.acquire() # Update limits for each subplot if necessary limits_changed = False if len(self.line_plot.x_data) > 1: # First, check if the limits on y-axis must be extended to display all the available data subplots_ylim_data = [] must_update_limit_y = False for ax, subplot_key in zip(self.axes, self.lines.keys()): subplot_y_data = self.line_plot.y_data[subplot_key] subplot_ylim_data = None if subplot_y_data: all_y_values = list(itertools.chain.from_iterable(subplot_y_data.values())) subplot_ylim_data = y_min_data, y_max_data = min(all_y_values), max(all_y_values) y_min_plot, y_max_plot = ax.get_ylim() if y_min_data < y_min_plot or y_max_plot < y_max_data: must_update_limit_y = True subplots_ylim_data.append(subplot_ylim_data) # Next, adjust the limits on x-axis if they must be extended or adjusting y-axis is already planned x_limits_changed = False x_min_plot, x_max_plot = ax.get_xlim() x_min_data, x_max_data = self.line_plot.x_data[0], self.line_plot.x_data[-1] if must_update_limit_y or x_min_plot < 0.0 or x_max_plot < x_max_data: x_min_plot = max(0.0, x_min_data) x_max_plot = x_max_data + max( MPL_PLOTTER_RESCALE_RATIO_X * (x_max_data - x_min_data), MPL_PLOTTER_RESCALE_MIN_X ) ax.set_xlim((x_min_plot - gs.EPS, x_max_plot + gs.EPS)) x_limits_changed = True # Finally, adjust the limits on y-axis if either x- or y-axis must be extended if x_limits_changed or must_update_limit_y: for ax, subplot_ylim_data in zip(self.axes, subplots_ylim_data): if subplot_ylim_data is not None: y_min_data, y_max_data = subplot_ylim_data y_min_plot = y_min_data - MPL_PLOTTER_RESCALE_RATIO_Y * (y_max_data - y_min_data) y_max_plot = y_max_data + MPL_PLOTTER_RESCALE_RATIO_Y * (y_max_data - y_min_data) ax.set_ylim((y_min_plot - gs.EPS, y_max_plot + gs.EPS)) limits_changed = True # Must redraw the entire figure if the limits have changed if limits_changed: self.fig.canvas.draw() # Update background if the entire figure has been updated, or the buffer size has been exceeded if limits_changed or (len(self.line_plot.x_data) > 1 and self.cache_xmax < self.line_plot.x_data[0] + gs.EPS): self.caches_bbox = [self.fig.canvas.copy_from_bbox(ax.bbox) for ax in self.axes] self.cache_xmax = self.line_plot.x_data[-2] # Update lines for each subplot for ax, cache_bbox, (subplot_key, subplot_lines) in zip(self.axes, self.caches_bbox, self.lines.items()): # Restore background and update line data for this subplot self.fig.canvas.restore_region(cache_bbox) # Update lines channel_labels = self.line_plot.subplot_structure[subplot_key] for line, channel_label in zip(subplot_lines, channel_labels): y_data = self.line_plot.y_data[subplot_key][channel_label] line.set_data(self.line_plot.x_data, y_data) ax.draw_artist(line) # Blit the updated subplot self.fig.canvas.blit(ax.bbox) self.fig.canvas.flush_events() self._lock.release()
[docs] def cleanup(self): super().cleanup() self.line_plot.clear_data() self.lines.clear() self.caches_bbox.clear() self.cache_xmax = -1
[docs]@register_recording(MPLImagePlotterOptions) class MPLImagePlotter(BaseMPLPlotter): """ Live image viewer using matplotlib. The image data should be an array-like object with shape (H, W), (H, W, 1), (H, W, 3), or (H, W, 4). """
[docs] def build(self): super().build() import matplotlib.pyplot as plt self.image_plot = None self.background = None self.fig, self.ax = plt.subplots(figsize=self.figsize) self.fig.tight_layout(pad=0) self.ax.set_axis_off() self.fig.subplots_adjust(left=0, right=1, top=1, bottom=0) self.image_plot = self.ax.imshow(np.zeros((1, 1)), cmap="plasma", origin="upper", aspect="auto") self._show_fig()
[docs] def process(self, data, cur_time): """Process new image data and update display.""" if isinstance(data, torch.Tensor): img_data = tensor_to_array(data) else: img_data = np.asarray(data) vmin, vmax = np.min(img_data), np.max(img_data) current_vmin, current_vmax = self.image_plot.get_clim() if vmin != current_vmin or vmax != current_vmax: self.image_plot.set_clim(vmin, vmax) self.fig.canvas.draw() self.background = self.fig.canvas.copy_from_bbox(self.ax.bbox) self.fig.canvas.restore_region(self.background) self.image_plot.set_data(img_data) self.ax.draw_artist(self.image_plot) self.fig.canvas.blit(self.ax.bbox) self.fig.canvas.flush_events()
[docs] def cleanup(self): super().cleanup() self.ax = None self.image_plot = None self.background = None
def _project_to_plane(normal: np.ndarray, *arrays: np.ndarray) -> tuple[np.ndarray, ...]: """Project 3D arrays onto the plane perpendicular to normal.""" uv = np.stack(gu.orthogonals(normal / np.linalg.norm(normal)), axis=1) return tuple(data @ uv for data in arrays)
[docs]@register_recording(MPLVectorFieldPlotterOptions) class MPLVectorFieldPlotter(BaseMPLPlotter): """ Live 3D vector field viewer: projects positions and vectors onto a 2D plane and plots arrows colored by magnitude. The data_func should return an array of shape (N, 3) with the 3D vector at each position given in options. """
[docs] def build(self): super().build() import matplotlib.pyplot as plt opts = self._options positions = np.array(opts.positions, dtype=float) if positions.ndim != 2 or positions.shape[1] != 3: gs.raise_exception(f"[{type(self).__name__}] positions must have shape (N, 3), got {positions.shape}.") normal = np.array(opts.normal, dtype=float) n_norm = np.linalg.norm(normal) if normal.size != 3 or n_norm < gs.EPS: gs.raise_exception(f"[{type(self).__name__}] normal must be a non-zero 3D vector.") normal = normal / n_norm (xy,) = _project_to_plane(normal, positions) (x_min, y_min), (x_max, y_max) = xy.min(axis=0), xy.max(axis=0) margin = 0.1 * max(np.max(np.ptp(xy, axis=0)), gs.EPS) self.fig, self.ax = plt.subplots(figsize=self.figsize) self.fig.suptitle(opts.title) self.ax.set_xlim(x_min - margin, x_max + margin) self.ax.set_ylim(y_min - margin, y_max + margin) self.ax.set_aspect("equal") self.ax.set_axis_off() self._positions = positions self._normal = normal self._scale_factor = opts.scale_factor self._max_magnitude = opts.max_magnitude n = len(xy) self._scatter = self.ax.scatter( xy[:, 0], xy[:, 1], s=8, c=np.zeros(n), cmap="plasma", vmin=0, vmax=self._max_magnitude, zorder=0, ) self._quiver = self.ax.quiver( xy[:, 0], xy[:, 1], np.zeros_like(xy[:, 0]), np.zeros_like(xy[:, 1]), np.zeros(len(xy)), cmap="plasma", clim=(0, self._max_magnitude), zorder=1, scale_units="xy", scale=1, ) self.fig.colorbar(self._quiver, ax=self.ax, label="Magnitude") self.fig.canvas.draw() self._background = self.fig.canvas.copy_from_bbox(self.ax.bbox) self._show_fig() self.fig.canvas.mpl_connect("resize_event", self.on_resize)
[docs] def on_resize(self, event): self._lock.acquire() try: if self.fig is not None and self.ax is not None: self.fig.canvas.draw() self._background = self.fig.canvas.copy_from_bbox(self.ax.bbox) finally: self._lock.release()
[docs] def process(self, data, cur_time): """Process new vector data and update the quiver plot.""" if isinstance(data, torch.Tensor): vectors = tensor_to_array(data) else: vectors = np.asarray(data, dtype=float) if vectors.ndim != 2 or vectors.shape[1] != 3: return if vectors.shape[0] != len(self._positions): return magnitudes = np.linalg.norm(vectors, axis=-1) xy, uv = _project_to_plane(self._normal, self._positions, vectors) if self._background is not None: self._lock.acquire() self._scatter.set_offsets(xy) self._scatter.set_array(magnitudes) self._quiver.set_offsets(xy) self._quiver.set_UVC(*(uv * self._scale_factor).T) self._quiver.set_array(magnitudes) self.fig.canvas.restore_region(self._background) self.ax.draw_artist(self._scatter) self.ax.draw_artist(self._quiver) self.fig.canvas.blit(self.ax.bbox) self.fig.canvas.flush_events() self._lock.release()
[docs] def cleanup(self): super().cleanup() self._scatter = None self._quiver = None self._positions = None self._normal = None self._scale_factor = None self._max_magnitude = None self._background = None