genesis.engine.entities.rigid_entity.rigid_joint 源代码

import taichi as ti
import torch

import genesis as gs
import genesis.utils.geom as gu
from genesis.repr_base import RBC


[文档]@ti.data_oriented class RigidJoint(RBC): """ Joint class for rigid body entities. Each RigidLink is connected to its parent link via a RigidJoint. """ def __init__( self, entity, name, idx, q_start, dof_start, n_qs, n_dofs, type, pos, quat, dofs_motion_ang, dofs_motion_vel, dofs_limit, dofs_invweight, dofs_stiffness, dofs_sol_params, dofs_damping, dofs_armature, dofs_kp, dofs_kv, dofs_force_range, init_qpos, ): self._name = name self._entity = entity self._solver = entity.solver self._uid = gs.UID() self._idx = idx self._q_start = q_start self._dof_start = dof_start self._n_qs = n_qs self._n_dofs = n_dofs self._type = type self._pos = pos self._quat = quat self._init_qpos = init_qpos self._dofs_motion_ang = dofs_motion_ang self._dofs_motion_vel = dofs_motion_vel self._dofs_limit = dofs_limit self._dofs_invweight = dofs_invweight self._dofs_stiffness = dofs_stiffness self._dofs_sol_params = dofs_sol_params self._dofs_damping = dofs_damping self._dofs_armature = dofs_armature self._dofs_kp = dofs_kp self._dofs_kv = dofs_kv self._dofs_force_range = dofs_force_range # NOTE: temp hack to use 0 damping/armature for drone if isinstance(self._entity, gs.engine.entities.DroneEntity) and self._type == gs.JOINT_TYPE.FREE: import numpy as np self._dofs_damping = np.zeros_like(self._dofs_damping) self._dofs_armature = np.zeros_like(self._dofs_armature) # ------------------------------------------------------------------------------------ # -------------------------------- real-time state ----------------------------------- # ------------------------------------------------------------------------------------
[文档] @gs.assert_built def get_pos(self): """ Get the position of the joint in the world frame. """ tensor = torch.empty(self._solver._batch_shape(3, True), dtype=gs.tc_float, device=gs.device) self._kernel_get_pos(tensor) if self._solver.n_envs == 0: tensor = tensor.squeeze(0) return tensor
@ti.kernel def _kernel_get_pos(self, tensor: ti.types.ndarray()): for i_b in range(self._solver._B): l_info = self._solver.links_info[self._idx, i_b] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) p_quat = gu.ti_identity_quat() if i_p != -1: p_pos = self._solver.links_state[i_p, i_b].pos p_quat = self._solver.links_state[i_p, i_b].quat tmp_pos, tmp_quat = gu.ti_transform_pos_quat_by_trans_quat(l_info.pos, l_info.quat, p_pos, p_quat) joint_pos, joint_quat = gu.ti_transform_pos_quat_by_trans_quat( l_info.joint_pos, l_info.joint_quat, tmp_pos, tmp_quat ) for i in ti.static(range(3)): tensor[i_b, i] = joint_pos[i]
[文档] @gs.assert_built def get_quat(self): """ Get the quaternion of the joint in the world frame. """ tensor = torch.empty(self._solver._batch_shape(4, True), dtype=gs.tc_float, device=gs.device) self._kernel_get_quat(tensor) if self._solver.n_envs == 0: tensor = tensor.squeeze(0) return tensor
@ti.kernel def _kernel_get_quat(self, tensor: ti.types.ndarray()): for i_b in range(self._solver._B): l_info = self._solver.links_info[self._idx, i_b] i_p = l_info.parent_idx p_pos = ti.Vector.zero(gs.ti_float, 3) p_quat = gu.ti_identity_quat() if i_p != -1: p_pos = self._solver.links_state[i_p, i_b].pos p_quat = self._solver.links_state[i_p, i_b].quat tmp_pos, tmp_quat = gu.ti_transform_pos_quat_by_trans_quat(l_info.pos, l_info.quat, p_pos, p_quat) joint_pos, joint_quat = gu.ti_transform_pos_quat_by_trans_quat( l_info.joint_pos, l_info.joint_quat, tmp_pos, tmp_quat ) for i in ti.static(range(4)): tensor[i_b, i] = joint_quat[i] # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ @property def uid(self): """ Returns the unique id of the joint. """ return self._uid @property def name(self): """ Returns the name of the joint. """ return self._name @property def entity(self): """ Returns the entity that the joint belongs to. """ return self._entity @property def solver(self): """ The RigidSolver object that the joint belongs to. """ return self._solver @property def link(self): """ Returns the child link that of the joint. """ return self._solver.links[self._idx] @property def idx(self): """ Returns the global index of the joint in the rigid solver. """ return self._idx @property def idx_local(self): """ Returns the local index of the joint in the entity. """ return self._idx - self._entity._joint_start @property def init_qpos(self): """ Returns the initial joint position. """ return self._init_qpos @property def n_qs(self): """ Returns the number of `q` (generalized coordinate) variables that the joint has. """ return self._n_qs @property def n_dofs(self): """ Returns the number of dofs that the joint has. """ return self._n_dofs @property def type(self): """ Returns the type of the joint. """ return self._type @property def pos(self): """ Returns the initial position of the joint in the world frame. """ return self._pos @property def quat(self): """ Returns the initial quaternion of the joint in the world frame. """ return self._quat @property def q_start(self): """ Returns the starting index of the `q` variables of the joint in the rigid solver. """ return self._q_start @property def dof_start(self): """ Returns the starting index of the dofs of the joint in the rigid solver. """ return self._dof_start @property def q_end(self): """ Returns the ending index of the `q` variables of the joint in the rigid solver. """ return self._n_qs + self.q_start @property def dof_end(self): """ Returns the ending index of the dofs of the joint in the rigid solver. """ return self._n_dofs + self.dof_start @property def dof_idx(self): """ Returns all the dof indices of the joint in the rigid solver. """ if self.n_dofs == 1: return self.dof_start elif self.n_dofs == 0: return None else: return list(range(self.dof_start, self.dof_end)) @property def dof_idx_local(self): """ Returns the local dof index of the joint in the entity. """ if self.n_dofs == 1: return self.dof_idx - self._entity._dof_start elif self.n_dofs == 0: return None else: return [dof_idx - self._entity._dof_start for dof_idx in self.dof_idx] @property def q_idx(self): """ Returns all the `q` indices of the joint in the rigid solver. """ if self.n_qs == 1: return self.q_start elif self.n_qs == 0: return None else: return list(range(self.q_start, self.q_end)) @property def q_idx_local(self): """ Returns all the local `q` indices of the joint in the entity. """ if self.n_qs == 1: return self.q_start - self._entity._q_start elif self.n_qs == 0: return None else: return [q_idx - self._entity._q_start for q_idx in self.q_idx] @property def dofs_motion_ang(self): return self._dofs_motion_ang @property def dofs_motion_vel(self): return self._dofs_motion_vel @property def dofs_limit(self): """ Returns the range limit of the dofs of the joint. """ return self._dofs_limit @property def dofs_invweight(self): """ Returns the invweight of the dofs of the joint. """ return self._dofs_invweight @property def dofs_stiffness(self): """ Returns the stiffness of the dofs of the joint. """ return self._dofs_stiffness @property def dofs_sol_params(self): """ Retruns the solver parameters of the dofs of the joint. """ return self._dofs_sol_params @property def dofs_damping(self): """ Returns the damping of the dofs of the joint. """ return self._dofs_damping @property def dofs_armature(self): """ Returns the armature of the dofs of the joint. """ return self._dofs_armature @property def dofs_kp(self): """ Returns the kp (positional gain) of the dofs of the joint. """ return self._dofs_kp @property def dofs_kv(self): """ Returns the kv (velocity gain) of the dofs of the joint. """ return self._dofs_kv @property def dofs_force_range(self): """ Returns the force range of the dofs of the joint. """ return self._dofs_force_range @property def is_built(self): """ Returns whether the entity the joint belongs to is built. """ return self.entity.is_built # ------------------------------------------------------------------------------------ # -------------------------------------- repr ---------------------------------------- # ------------------------------------------------------------------------------------ def _repr_brief(self): return f"{(self._repr_type())}: {self._uid}, name: '{self._name}', idx: {self._idx}, type: {self._type}"