Source code for genesis.engine.entities.rigid_entity.rigid_link

from itertools import starmap
from typing import TYPE_CHECKING, Sequence

import numpy as np
import torch
import trimesh

import genesis as gs
from genesis.repr_base import RBC
from genesis.typing import LaxPositiveFArrayType, Matrix3x3Type, UnitVec4FType, Vec3FType
from genesis.utils import geom as gu
from genesis.utils.misc import DeprecationError, qd_to_torch, tensor_to_array

from .rigid_geom import RigidGeom, RigidVisGeom

if TYPE_CHECKING:
    from genesis.engine.solvers.rigid.rigid_solver import RigidSolver

    from .rigid_entity import KinematicEntity, RigidEntity
    from .rigid_joint import RigidJoint


RHO_OBJECT = 600.0
RHO_ROBOT = 1500.0
RHO_MUJOCO = 1000.0

# If mass is too small, we do not care much about spatial inertia discrepancy
MASS_EPS = 0.005
AABB_EPS = 0.002
INERTIA_RATIO_MAX = 100.0


def get_local_inertial_from_geom(geom: RigidGeom | RigidVisGeom, rho: float) -> tuple[float, Vec3FType, Matrix3x3Type]:
    """
    Extract the local inertial properties (mass, center of mass, inertia tensor) of a given rigid geometry.
    """
    geom_type = gs.GEOM_TYPE.MESH if isinstance(geom, RigidVisGeom) else geom.type

    geom_com_local = np.zeros(3)
    if geom_type == gs.GEOM_TYPE.PLANE:
        geom_mass = 0.0
        geom_inertia_local = np.zeros(3, dtype=gs.np_float)
    elif geom_type == gs.GEOM_TYPE.SPHERE:
        radius = geom.data[0]
        geom_mass = (4.0 / 3.0) * np.pi * radius**3 * rho
        I = (2.0 / 5.0) * geom_mass * radius**2
        geom_inertia_local = np.diag([I, I, I])
    elif geom_type == gs.GEOM_TYPE.ELLIPSOID:
        hx, hy, hz = geom.data[:3]
        geom_mass = (4.0 / 3.0) * np.pi * hx * hy * hz * rho
        geom_inertia_local = (geom_mass / 5.0) * np.diag([hy**2 + hz**2, hx**2 + hz**2, hx**2 + hy**2])
    elif geom_type == gs.GEOM_TYPE.CYLINDER:
        radius, height = geom.data[:2]
        geom_mass = np.pi * radius**2 * height * rho
        I_r = (geom_mass / 12.0) * (3.0 * radius**2 + height**2)
        I_z = 0.5 * geom_mass * radius**2
        geom_inertia_local = np.diag([I_r, I_r, I_z])
    elif geom_type == gs.GEOM_TYPE.CAPSULE:
        radius, height = geom.data[:2]
        m_cyl = np.pi * radius**2 * height * rho
        m_sph = (4.0 / 3.0) * np.pi * radius**3 * rho
        geom_mass = m_cyl + m_sph
        I_r = (m_cyl * radius**2 / 12.0 * (3.0 + height**2 / radius**2)) + (
            m_sph * radius**2 / 4.0 * (83.0 / 80.0 + (height / radius + 3.0 / 4.0) ** 2)
        )
        I_h = 0.5 * m_cyl * radius**2 + (2.0 / 5.0) * m_sph * radius**2
        geom_inertia_local = np.diag([I_r, I_r, I_h])
    elif geom_type == gs.GEOM_TYPE.BOX:
        hx, hy, hz = geom.data[:3]
        geom_mass = (hx * hy * hz) * rho
        geom_inertia_local = (geom_mass / 12.0) * np.diag([hy**2 + hz**2, hx**2 + hz**2, hx**2 + hy**2])
    else:
        # MESH type
        if isinstance(geom, RigidVisGeom):
            inertia_mesh = trimesh.Trimesh(geom.init_vverts, geom.init_vfaces, process=False)
        else:
            inertia_mesh = trimesh.Trimesh(geom.init_verts, geom.init_faces, process=False)

        if not inertia_mesh.is_watertight:
            inertia_mesh = inertia_mesh.convex_hull

        # FIXME: without this check, some geom will have negative volume even after the above convex
        # hull operation, e.g. 'tests/test_examples.py::test_example[rigid/terrain_from_mesh.py-None]'
        if inertia_mesh.volume < 0.0:
            inertia_mesh.invert()

        inertia_mesh.density = rho
        geom_mass = inertia_mesh.mass
        geom_com_local = inertia_mesh.center_mass
        geom_inertia_local = inertia_mesh.moment_inertia

    return geom_mass, geom_com_local, geom_inertia_local


def compose_inertial_properties(
    geoms_inertial_info: Sequence[tuple[float, Vec3FType, Matrix3x3Type, Vec3FType, UnitVec4FType]],
) -> tuple[float, Vec3FType, Matrix3x3Type]:
    """
    Compose mass, center of mass, and inertia tensor from multiple geometries.
    """
    global_mass = 0.0
    if geoms_inertial_info:
        geoms_mass, geoms_com_local, geoms_I_local, geoms_pos, geoms_quat = zip(*geoms_inertial_info)
        geoms_mass = np.asarray(geoms_mass)
        global_mass = geoms_mass.sum()

    if global_mass == 0.0:
        return 0.0, np.zeros(3, dtype=np.float64), np.zeros((3, 3), dtype=np.float64)

    # Compute world COMs of each geom
    geoms_com_world = np.stack(
        tuple(starmap(gu.transform_by_trans_quat, zip(geoms_com_local, geoms_pos, geoms_quat))), axis=0
    )

    # Compute total COM
    global_com = (geoms_mass[:, None] * geoms_com_world).sum(axis=0) / global_mass

    # Accumulate inertia about global COM
    global_inertia = np.zeros((3, 3), dtype=np.float64)

    # Transform local inertias directly to global COM frame using parallel axis theorem
    for geom_mass, geom_I_local, geom_quat, geom_com_world in zip(
        geoms_mass, geoms_I_local, geoms_quat, geoms_com_world
    ):
        T_offset = gu.trans_quat_to_T(global_com - geom_com_world, geom_quat)
        geom_I_world = gu.transform_inertia_by_T(geom_I_local, T_offset, geom_mass)
        global_inertia += geom_I_world

    return global_mass, global_com, global_inertia


def compute_inertial_from_geoms(
    geoms: Sequence[RigidGeom | RigidVisGeom], rho: float
) -> tuple[float, Vec3FType, Matrix3x3Type]:
    """
    Compose inertial properties (mass, center of mass, inertia tensor) from multiple rigid geometries.

    Handles all primitive collision geometry types analytically (SPHERE, ELLIPSOID, CYLINDER, CAPSULE, BOX) and falls
    back to trimesh for MESH type.

    Parameters
    ----------
    geoms : list[RigidGeom] or list[RigidVisGeom]
        List of geometry objects to compute inertial from.
    rho : float
        Material density (kg/m^3).

    Returns
    -------
    tuple[float, np.ndarray, np.ndarray]
        (total_mass, center_of_mass, inertia_tensor)
    """
    # Extract inertia information
    geoms_inertial_info = tuple(
        (*get_local_inertial_from_geom(geom, rho), geom._init_pos, geom._init_quat) for geom in geoms
    )

    # Compose all inertia of all geometries in parent link frame
    return compose_inertial_properties(geoms_inertial_info)


class KinematicLink(RBC):
    """
    Kinematic class. One KinematicEntity consists of multiple KinematicLinks, each of which is a rigid body and could
    consist of multiple RigidVisGeoms (`link.vgeoms` for visualization).
    """

    def __init__(
        self,
        entity: "KinematicEntity",
        name: str,
        idx: int,
        joint_start: int,
        n_joints: int,
        vgeom_start: int,
        vvert_start: int,
        vface_start: int,
        pos: "np.typing.ArrayLike",
        quat: "np.typing.ArrayLike",
        parent_idx: int,
        root_idx: int | None,
    ):
        self._name: str = name
        self._entity: "KinematicEntity" = entity
        self._solver: "RigidSolver" = entity.solver
        self._entity_idx_in_solver = entity._idx_in_solver

        self._uid = gs.UID()
        self._idx: int = idx
        self._parent_idx: int = parent_idx  # -1 if no parent

        # 'is_fixed' attribute specifies whether the link is free to move.
        # In practice, this attributes determines whether the geometry vertices associated with the entity are stored
        # per batch-element and updated at every simulation step, or computed once at build time and shared among the
        # entire batch. This affects correct processing of collision detection and sensor raycasting as a side-effect.
        is_fixed = True
        link = self
        while True:
            is_fixed &= all(joint.type is gs.JOINT_TYPE.FIXED for joint in link.joints)
            if link.parent_idx == -1:
                break
            link = self.entity.links[link.parent_idx - self.entity.link_start]
        if root_idx is None:
            root_idx = link.idx
        self._root_idx: int = root_idx
        self._is_fixed: bool = is_fixed

        self._joint_start: int = joint_start
        self._n_joints: int = n_joints

        self._vgeom_start: int = vgeom_start
        self._vvert_start: int = vvert_start
        self._vface_start: int = vface_start

        # Link position & rotation at creation time:
        self._pos: "np.typing.ArrayLike" = pos
        self._quat: "np.typing.ArrayLike" = quat

        self._vgeoms: list[RigidVisGeom] = gs.List()

        # Heterogeneous variant tracking (None = not heterogeneous)
        self._variant_vgeom_ranges: list[tuple[int, int]] | None = None

    def _init_variant_tracking(self):
        """Start tracking heterogeneous variants. Records first variant from current state."""
        self._variant_vgeom_ranges = [(self._vgeom_start, self._vgeom_start + self.n_vgeoms)]

    def _record_variant_vgeom_range(self, n_new_vgeoms):
        """Record a new variant's vgeom range."""
        prev_end = self._variant_vgeom_ranges[-1][1]
        self._variant_vgeom_ranges.append((prev_end, prev_end + n_new_vgeoms))

    def _build(self):
        for vgeom in self._vgeoms:
            vgeom._build()

    def _add_vgeom(self, vmesh, init_pos, init_quat):
        vgeom = RigidVisGeom(
            link=self,
            idx=self.n_vgeoms + self._vgeom_start,
            vvert_start=self.n_vverts + self._vvert_start,
            vface_start=self.n_vfaces + self._vface_start,
            vmesh=vmesh,
            init_pos=init_pos,
            init_quat=init_quat,
        )
        self._vgeoms.append(vgeom)

    # ------------------------------------------------------------------------------------
    # -------------------------------- real-time state -----------------------------------
    # ------------------------------------------------------------------------------------

    @gs.assert_built
    def get_pos(self, envs_idx=None):
        """
        Get the position of the link in the world frame.

        Parameters
        ----------
        envs_idx : int or array of int, optional
            The indices of the environments to get the position. If None, get the position of all environments. Default is None.
        """
        return self._solver.get_links_pos(self._idx, envs_idx)[..., 0, :]

    @gs.assert_built
    def get_quat(self, envs_idx=None):
        """
        Get the quaternion of the link in the world frame.

        Parameters
        ----------
        envs_idx : int or array of int, optional
            The indices of the environments to get the quaternion. If None, get the quaternion of all environments. Default is None.
        """
        return self._solver.get_links_quat(self._idx, envs_idx)[..., 0, :]

    @gs.assert_built
    def get_vel(self, envs_idx=None) -> torch.Tensor:
        """
        Get the linear velocity of the link in the world frame.

        Parameters
        ----------
        envs_idx : int or array of int, optional
            The indices of the environments to get the linear velocity. If None, get the linear velocity of all environments. Default is None.
        """
        return self._solver.get_links_vel(self._idx, envs_idx)[..., 0, :]

    @gs.assert_built
    def get_ang(self, envs_idx=None) -> torch.Tensor:
        """
        Get the angular velocity of the link in the world frame.

        Parameters
        ----------
        envs_idx : int or array of int, optional
            The indices of the environments to get the angular velocity. If None, get the angular velocity of all environments. Default is None.
        """
        return self._solver.get_links_ang(self._idx, envs_idx)[..., 0, :]

    @gs.assert_built
    def get_vAABB(self, envs_idx=None):
        """
        Get the axis-aligned bounding box (AABB) of the link's visual body in the world frame by aggregating all
        the visual geometries associated with this link (`link.vgeoms`).
        """
        if self.n_vgeoms == 0:
            gs.raise_exception("Link has no visual geometries.")

        # For heterogeneous entities, compute AABB per-environment respecting active_envs_idx
        if self.entity._enable_heterogeneous:
            envs_idx = self._scene._sanitize_envs_idx(envs_idx)
            n_envs = len(envs_idx)
            aabb_min = torch.full((n_envs, 3), float("inf"), dtype=gs.tc_float, device=gs.device)
            aabb_max = torch.full((n_envs, 3), float("-inf"), dtype=gs.tc_float, device=gs.device)
            for vgeom in self.vgeoms:
                vgeom_aabb = vgeom.get_vAABB(envs_idx)
                active_mask = vgeom.active_envs_mask[envs_idx] if vgeom.active_envs_mask is not None else ()
                aabb_min[active_mask] = torch.minimum(aabb_min[active_mask], vgeom_aabb[active_mask, 0])
                aabb_max[active_mask] = torch.maximum(aabb_max[active_mask], vgeom_aabb[active_mask, 1])
            return torch.stack((aabb_min, aabb_max), dim=-2)

        aabbs = torch.stack([vgeom.get_vAABB(envs_idx) for vgeom in self._vgeoms], dim=-3)
        return torch.stack((aabbs[..., 0, :].min(dim=-2).values, aabbs[..., 1, :].max(dim=-2).values), dim=-2)

    # ------------------------------------------------------------------------------------
    # ----------------------------------- properties -------------------------------------
    # ------------------------------------------------------------------------------------

    @property
    def uid(self):
        """
        The unique ID of the link.
        """
        return self._uid

    @property
    def name(self) -> str:
        """
        The name of the link.
        """
        return self._name

    @property
    def entity(self) -> "KinematicEntity":
        """
        The entity that the link belongs to.
        """
        return self._entity

    @property
    def solver(self) -> "RigidSolver":
        """
        The solver that the link belongs to.
        """
        return self._solver

    @property
    def joints(self) -> list["RigidJoint"]:
        """
        The sequence of joints that connects the link to its parent link.
        """
        return self.entity.joints_by_links[self.idx_local]

    @property
    def n_joints(self):
        """
        Number of the joints that connects the link to its parent link.
        """
        return self._n_joints

    @property
    def joint_start(self):
        """
        The start index of the link's joints in the RigidSolver.
        """
        return self._joint_start

    @property
    def joint_end(self):
        """
        The end index of the link's joints in the RigidSolver.
        """
        return self._joint_start + self.n_joints

    @property
    def n_dofs(self):
        """The number of degrees of freedom (DOFs) of the entity."""
        return sum(joint.n_dofs for joint in self.joints)

    @property
    def dof_start(self):
        """The index of the link's first degree of freedom (DOF) in the scene."""
        if not self.joints:
            return -1
        return self.joints[0].dof_start

    @property
    def dof_end(self):
        """The index of the link's last degree of freedom (DOF) in the scene *plus one*."""
        if not self.joints:
            return -1
        return self.joints[-1].dof_end

    @property
    def n_qs(self):
        """Returns the number of `q` variables of the link."""
        return sum(joint.n_qs for joint in self.joints)

    @property
    def q_start(self):
        """Returns the starting index of the `q` variables of the link in the rigid solver."""
        if not self.joints:
            return -1
        return self.joints[0].q_start

    @property
    def q_end(self):
        """Returns the last index of the `q` variables of the link in the rigid solver *plus one*."""
        if not self.joints:
            return -1
        return self.joints[-1].q_end

    @property
    def idx(self):
        """
        The global index of the link in the RigidSolver.
        """
        return self._idx

    @property
    def parent_idx(self):
        """
        The global index of the link's parent link in the RigidSolver. If the link is the root link, return -1.
        """
        return self._parent_idx

    @property
    def root_idx(self):
        """
        The global index of the link's root link in the RigidSolver.
        """
        return self._root_idx

    @property
    def idx_local(self):
        """
        The local index of the link in the entity.
        """
        return self._idx - self._entity.link_start

    @property
    def is_fixed(self):
        """
        Whether the link is fixed wrt the world.
        """
        return self._is_fixed

    @property
    def invweight(self):
        """Inverse weight of the link. Always zero for KinematicLink (infinite mass)."""
        return np.zeros(2, dtype=gs.np_float)

    @property
    def pos(self) -> "np.typing.ArrayLike":
        """
        The initial position of the link. For real-time position, use `link.get_pos()`.
        """
        return self._pos

    @property
    def quat(self) -> "np.typing.ArrayLike":
        """
        The initial quaternion of the link. For real-time quaternion, use `link.get_quat()`.
        """
        return self._quat

    @property
    def inertial_pos(self):
        """Initial position of the link's inertial frame. Zero for KinematicLink."""
        return np.zeros(3, dtype=gs.np_float)

    @property
    def inertial_quat(self):
        """Initial quaternion of the link's inertial frame. Identity for KinematicLink."""
        return np.array([1.0, 0.0, 0.0, 0.0], dtype=gs.np_float)

    @property
    def inertial_mass(self):
        """Mass of the link. Always 0.0 for KinematicLink."""
        return 0.0

    @property
    def inertial_i(self):
        """Inertia matrix of the link. Zero for KinematicLink."""
        return np.zeros((3, 3), dtype=gs.np_float)

    @property
    def vgeoms(self) -> list[RigidVisGeom]:
        """
        The list of the link's visualization geometries (`RigidVisGeom`).
        """
        return self._vgeoms

    @property
    def geom_start(self) -> int:
        """Start index of collision geoms. Always 0 for KinematicLink."""
        return 0

    @property
    def geom_end(self) -> int:
        """End index of collision geoms. Always 0 for KinematicLink."""
        return 0

    @property
    def n_vgeoms(self) -> int:
        """
        Number of the link's visualization geometries (`vgeom`).
        """
        return len(self._vgeoms)

    @property
    def vgeom_start(self) -> int:
        """
        The start index of the link's vgeom in the RigidSolver.
        """
        return self._vgeom_start

    @property
    def vgeom_end(self) -> int:
        """
        The end index of the link's vgeom in the RigidSolver.
        """
        return self._vgeom_start + self.n_vgeoms

    @property
    def n_verts(self) -> int:
        """Number of collision vertices. Always 0 for KinematicLink."""
        return 0

    @property
    def n_vverts(self) -> int:
        """
        Number of vertices of all the link's vgeoms.
        """
        return sum([vgeom.n_vverts for vgeom in self._vgeoms])

    @property
    def n_vfaces(self) -> int:
        """
        Number of faces of all the link's vgeoms.
        """
        return sum([vgeom.n_vfaces for vgeom in self._vgeoms])

    @property
    def is_built(self) -> bool:
        """
        Whether the entity the link belongs to is built.
        """
        return self.entity.is_built

    # ------------------------------------------------------------------------------------
    # -------------------------------------- repr ----------------------------------------
    # ------------------------------------------------------------------------------------

    def _repr_brief(self):
        return f"{(self.__repr_name__())}: {self._uid}, name: '{self._name}', idx: {self._idx}"