Source code for genesis.vis.viewer

import importlib
import os
import sys
import threading
from traceback import TracebackException
from typing import TYPE_CHECKING

import numpy as np
import OpenGL.error
import OpenGL.platform

import genesis as gs
import genesis.utils.geom as gu
from genesis.ext import pyrender
from genesis.ext.pyrender.overlay import ImGuiOverlayPlugin
from genesis.repr_base import RBC
from genesis.utils.misc import redirect_libc_stderr, tensor_to_array
from genesis.utils.tools import Rate
from genesis.vis.keybindings import Key, KeyAction, Keybind, KeyMod
from genesis.vis.viewer_plugins import DefaultControlsPlugin

if TYPE_CHECKING:
    from genesis.options.vis import ViewerOptions
    from genesis.vis.viewer_plugins import ViewerPlugin


class ViewerLock:
    def __init__(self, pyrender_viewer):
        self._pyrender_viewer = pyrender_viewer

    def __enter__(self):
        self._pyrender_viewer.render_lock.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self._pyrender_viewer.render_lock.release()


[docs]class Viewer(RBC): def __init__(self, options: "ViewerOptions", context): self._is_built = False self._res = options.res self._run_in_thread = options.run_in_thread self._refresh_rate = options.refresh_rate self._max_FPS = options.max_FPS self._camera_init_pos = np.asarray(options.camera_pos, dtype=gs.np_float) self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov self._enable_help_text = options.enable_help_text self._viewer_plugins: list["ViewerPlugin"] = [] if options.enable_default_keybinds: self._viewer_plugins.append(DefaultControlsPlugin()) if options.enable_gui: self._viewer_plugins.append(ImGuiOverlayPlugin()) # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") self._pyrender_viewer = None self.context = context self._followed_entity = None self._follow_fixed_axis = None self._follow_smoothing = None self._follow_fix_orientation = None self._follow_lookat = None if self._max_FPS is not None: self.rate = Rate(self._max_FPS)
[docs] def build(self, scene): self.scene = scene # set viewer camera self.setup_camera() # Reuse an existing window across an InteractiveScene rebuild instead of opening a new one (which # would close and reopen the OS window). The preserved pyrender viewer is re-pointed at the rebuilt # scene graph in place. if self._pyrender_viewer is not None: self._pyrender_viewer.rebind(self.context, self._viewer_plugins) self.lock = ViewerLock(self._pyrender_viewer) self._is_built = True return # Try all candidate onscreen OpenGL "platforms" if none is specifically requested opengl_platform_orig = os.environ.get("PYOPENGL_PLATFORM") if opengl_platform_orig is None: if sys.platform == "win32": all_opengl_platforms = ("wgl",) # same as "native" elif sys.platform == "linux": # "native" is platform-specific ("egl" or "glx") all_opengl_platforms = ("native", "egl", "glx", "osmesa") else: all_opengl_platforms = ("native",) else: if opengl_platform_orig == "osmesa" and sys.platform != "linux": gs.raise_exception("PYOPENGL_PLATFORM='osmesa' is only supported on Linux OS for now.") all_opengl_platforms = (opengl_platform_orig,) for i, platform in enumerate(all_opengl_platforms): # Force re-import OpenGL platform os.environ["PYOPENGL_PLATFORM"] = platform importlib.reload(OpenGL.platform) try: gs.logger.debug(f"Trying to create OpenGL Context for PYOPENGL_PLATFORM='{platform}'...") with open(os.devnull, "w") as stderr, redirect_libc_stderr(stderr): self._pyrender_viewer = pyrender.Viewer( context=self.context, viewport_size=self._res, run_in_thread=self._run_in_thread, auto_start=False, view_center=self._camera_init_lookat, shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, enable_help_text=self._enable_help_text, plugins=self._viewer_plugins, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate, }, ) if not self._run_in_thread: self._pyrender_viewer.start(auto_refresh=False) self._pyrender_viewer.wait_until_initialized() break except (OpenGL.error.Error, RuntimeError) as e: # Invalid OpenGL context. Trying another platform if any... traceback = TracebackException.from_exception(e) gs.logger.debug("".join(traceback.format())) # Clear broken OpenGL context if it went this far if self._pyrender_viewer is not None: self._pyrender_viewer.close() self._pyrender_viewer = None if i == len(all_opengl_platforms) - 1: raise finally: # Restore original platform systematically del os.environ["PYOPENGL_PLATFORM"] if opengl_platform_orig is not None: os.environ["PYOPENGL_PLATFORM"] = opengl_platform_orig self.lock = ViewerLock(self._pyrender_viewer) gs.logger.info(f"Viewer created. Resolution: ~<{self._res[0]}×{self._res[1]}>~, max_FPS: ~<{self._max_FPS}>~.") self._is_built = True
[docs] def run(self): if self._pyrender_viewer is None: gs.raise_exception("Viewer must be built successfully before calling this method.") self._pyrender_viewer.run()
[docs] def stop(self): if self._pyrender_viewer is not None and self._pyrender_viewer.is_active: self._pyrender_viewer.close()
[docs] def is_alive(self): if self._pyrender_viewer is None: return False if self._pyrender_viewer._exception is not None: if self._pyrender_viewer.is_active: try: self._pyrender_viewer.close() except Exception: pass gs.raise_exception_from("Unexpected viewer error.", self._pyrender_viewer._exception) return self._pyrender_viewer.is_active
[docs] def setup_camera(self): yfov = self._camera_fov / 180.0 * np.pi pose = gu.pos_lookat_up_to_T(self._camera_init_pos, self._camera_init_lookat, self._camera_up) self._camera_up = pose[:3, 1].copy() self._camera_node = self.context.add_node(pyrender.PerspectiveCamera(yfov=yfov), pose=pose)
[docs] def update(self, auto_refresh=None, force=False): if not self.is_alive(): gs.raise_exception("Viewer closed.") if self._followed_entity is not None: self.update_following() self._pyrender_viewer.update_on_sim_step() with self.lock: # Update context self.context.update(force) # Refresh viewer by default if and if this is possible if auto_refresh is None: viewer_thread = self._pyrender_viewer._thread or threading.main_thread() auto_refresh = viewer_thread == threading.current_thread() if auto_refresh and not self._pyrender_viewer.run_in_thread: self._pyrender_viewer.refresh() # lock FPS if self._max_FPS is not None: self.rate.sleep()
[docs] def close_offscreen(self, render_target): return self._pyrender_viewer.close_offscreen(render_target)
[docs] def render_offscreen( self, camera_node, render_target, rgb=True, depth=False, seg=False, normal=False, skip_markers=False, env_separate_rigid=None, ): return self._pyrender_viewer.render_offscreen( camera_node, render_target, rgb, depth, seg, normal, skip_markers=skip_markers, env_separate_rigid=env_separate_rigid, )
[docs] def set_camera_pose(self, pose=None, pos=None, lookat=None): """ Set viewer camera pose. Parameters ---------- pose : [4,4] float, optional Camera-to-world pose. If provided, `pos` and `lookat` will be ignored. pos : (3,) float, optional Camera position. lookat : (3,) float, optional Camera lookat point. """ if pose is None: if pos is None: pos = self._camera_init_pos if lookat is None: lookat = self._camera_init_lookat up = self._camera_up pose = gu.pos_lookat_up_to_T(pos, lookat, up) self._camera_up = pose[:3, 1].copy() else: if np.array(pose).shape != (4, 4): gs.raise_exception("pose should be a 4x4 matrix.") self._pyrender_viewer._trackball.set_camera_pose(pose)
[docs] def follow_entity(self, entity, fixed_axis=(None, None, None), smoothing=None, fix_orientation=False): """ Set the viewer to follow a specified entity. Parameters ---------- entity : genesis.Entity The entity to follow. fixed_axis : (float, float, float), optional The fixed axis for the viewer's movement. For each axis, if None, the viewer will move freely. If a float, the viewer will be fixed on at that value. For example, [None, None, None] will allow the viewer to move freely while following, [None, None, 0.5] will fix the viewer's z-axis at 0.5. smoothing : float, optional The smoothing factor in ]0,1[ for the viewer's movement. If None, no smoothing will be applied. fix_orientation : bool, optional If True, the viewer will maintain its orientation relative to the world. If False, the viewer will look at the base link of the entity. """ self._followed_entity = entity self._follow_fixed_axis = fixed_axis self._follow_smoothing = smoothing self._follow_fix_orientation = fix_orientation self._follow_lookat = self._camera_init_lookat
[docs] def update_following(self): """ Update the viewer position to follow the specified entity. """ entity_pos = tensor_to_array(self._followed_entity.get_pos()) if entity_pos.ndim > 1: # check for multiple envs entity_pos = entity_pos[0] # numpy < 2.0 doesn't support the copy keyword argument in np.asarray() camera_transform = np.array(self._pyrender_viewer._trackball.pose, copy=True) camera_pos = np.array(self._pyrender_viewer._trackball.pose[:3, 3]) if self._follow_smoothing is not None: # Smooth viewer movement with a low-pass filter camera_pos = self._follow_smoothing * camera_pos + (1 - self._follow_smoothing) * ( entity_pos + self._camera_init_pos ) self._follow_lookat = ( self._follow_smoothing * self._follow_lookat + (1 - self._follow_smoothing) * entity_pos ) else: camera_pos = entity_pos + self._camera_init_pos self._follow_lookat = entity_pos for i, fixed_axis in enumerate(self._follow_fixed_axis): # Fix the camera's position along the specified axis if fixed_axis is not None: camera_pos[i] = fixed_axis if self._follow_fix_orientation: # Keep the camera orientation fixed by overriding the lookat point camera_transform[:3, 3] = camera_pos self.set_camera_pose(pose=camera_transform) else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat)
[docs] @gs.assert_built def register_keybinds(self, /, *keybinds: Keybind, overwrite: bool = False) -> None: """ Register a callback function to be called when a key is pressed. Parameters ---------- keybinds : Keybind One or more Keybind objects to register. See Keybind documentation for usage. """ self._pyrender_viewer.register_keybinds(*keybinds, overwrite=overwrite)
[docs] @gs.assert_built def remap_keybind( self, keybind_name: str, new_key: Key, new_key_mods: tuple[KeyMod] | None, new_key_action: KeyAction = KeyAction.PRESS, ) -> None: """ Remap an existing keybind by name to a new key combination. Parameters ---------- keybind_name : str The name of the keybind to remap. new_key : int The new key code from pyglet. new_key_mods : tuple[KeyMod] | None The new modifier keys pressed. new_key_action : KeyAction, optional The new type of key action. If not provided, the key action of the old keybind is used. """ self._pyrender_viewer.remap_keybind( keybind_name, new_key, new_key_mods, new_key_action, )
[docs] @gs.assert_built def remove_keybind(self, keybind_name: str) -> None: """ Remove an existing keybind by name. Parameters ---------- keybind_name : str The name of the keybind to remove. """ self._pyrender_viewer.remove_keybind(keybind_name)
[docs] def add_plugin(self, plugin: "ViewerPlugin") -> "ViewerPlugin": """ Add a viewer plugin to the viewer. Parameters ---------- plugin : ViewerPlugin The viewer plugin to add. """ self._viewer_plugins.append(plugin) if self.is_built: self._pyrender_viewer.register_plugin(plugin) return plugin
[docs] def consume_rebuild_requests(self) -> bool: """Let any plugin perform a pending scene rebuild on the calling (main) thread. Returns True if one did, in which case this viewer has just been torn down and the caller must stop using it.""" for plugin in tuple(self._viewer_plugins): if plugin.consume_rebuild_request(): return True return False
[docs] def should_advance_simulation(self) -> bool: """Whether the simulation may advance this frame, i.e. no plugin (e.g. the GUI play/pause control) vetoes it. Every plugin is polled so stateful single-step controls always observe the frame.""" return all([plugin.should_step() for plugin in self._viewer_plugins])
# ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ @property def is_built(self): return self._is_built @property def res(self): return self._res @property def refresh_rate(self): return self._refresh_rate @property def max_FPS(self): return self._max_FPS @property def camera_pos(self): """ Get the camera's current position. """ return np.array(self._pyrender_viewer._trackball._n_pose[:3, 3]) @property def camera_lookat(self): """ Get the camera's current lookat point. """ pos = np.array(self._pyrender_viewer._trackball._n_pose[:3, 3]) z = self._pyrender_viewer._trackball._n_pose[:3, 2] return pos - z @property def camera_pose(self): """ Get the camera's current pose represented by a 4x4 matrix. """ return np.array(self._pyrender_viewer._trackball._n_pose) @property def camera_up(self): return self._camera_up @property def camera_fov(self): return self._camera_fov