import igl
import numpy as np
import gstaichi as ti
import torch
import genesis as gs
import genesis.utils.element as eu
import genesis.utils.geom as gu
import genesis.utils.mesh as mu
from genesis.engine.entities.rigid_entity import RigidLink
from genesis.engine.couplers import SAPCoupler
from genesis.engine.states.cache import QueriedStates
from genesis.engine.states.entities import FEMEntityState
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, to_gs_tensor, tensor_to_array
from .base_entity import Entity
[docs]@ti.data_oriented
class FEMEntity(Entity):
"""
A finite element method (FEM)-based entity for deformable simulation.
This class represents a deformable object using tetrahedral elements. It interfaces with
the physics solver to handle state updates, checkpointing, gradients, and actuation
for physics-based simulation in batched environments.
Parameters
----------
scene : Scene
The simulation scene that this entity belongs to.
solver : Solver
The physics solver instance used for simulation.
material : Material
The material properties defining elasticity, density, etc.
morph : Morph
The morph specification that defines the entity's shape.
surface : Surface
The surface mesh associated with the entity (for rendering or collision).
idx : int
Unique identifier of the entity within the scene.
v_start : int, optional
Starting index of this entity's vertices in the global vertex array (default is 0).
el_start : int, optional
Starting index of this entity's elements in the global element array (default is 0).
s_start : int, optional
Starting index of this entity's surface triangles in the global surface array (default is 0).
"""
def __init__(self, scene, solver, material, morph, surface, idx, v_start=0, el_start=0, s_start=0):
super().__init__(idx, scene, morph, solver, material, surface)
self._v_start = v_start # offset for vertex index of elements
self._el_start = el_start # offset for element index
self._s_start = s_start # offset for surface triangles
self._step_global_added = None
self._surface.update_texture()
self.sample()
el2tri = np.array(
[ # follow the order with correct normal
[[v[0], v[2], v[1]], [v[1], v[2], v[3]], [v[0], v[1], v[3]], [v[0], v[3], v[2]]] for v in self.elems
],
dtype=gs.np_int,
)
all_tri = el2tri.reshape((-1, 3))
all_tri_sorted = np.sort(all_tri, axis=1)
_, unique_idcs, cnt = np.unique(all_tri_sorted, axis=0, return_counts=True, return_index=True)
unique_tri = all_tri[unique_idcs]
surface_tri = unique_tri[cnt == 1]
self._surface_tri_np = surface_tri
self._n_surfaces = len(self._surface_tri_np)
if self._n_surfaces > 0:
self._n_surface_vertices = len(np.unique(self._surface_tri_np))
else:
self._n_surface_vertices = 0
tri2el = np.repeat(np.arange(self.elems.shape[0], dtype=gs.np_int)[:, np.newaxis], 4, axis=1)
unique_el = tri2el.flat[unique_idcs]
self._surface_el_np = unique_el[cnt == 1]
if isinstance(self.sim.coupler, SAPCoupler):
self.compute_pressure_field()
self.init_tgt_vars()
self.init_ckpt()
self._queried_states = QueriedStates()
self.active = False # This attribute is only used in forward pass. It should NOT be used during backward pass.
# ------------------------------------------------------------------------------------
# ----------------------------------- basic entity ops -------------------------------
# ------------------------------------------------------------------------------------
[docs] def set_position(self, pos):
"""
Set the target position(s) for the FEM entity.
Parameters
----------
pos : torch.Tensor or array-like
The desired position(s). Can be:
- (3,): a single COM offset vector.
- (n_vertices, 3): per-vertex positions for all vertices.
- (n_envs, 3): per-environment COM offsets.
- (n_envs, n_vertices, 3): full batched per-vertex positions.
Raises
------
Exception
If the tensor shape is not supported.
"""
self._assert_active()
gs.logger.warning("Manally setting element positions. This is not recommended and could break gradient flow.")
pos = to_gs_tensor(pos)
is_valid = False
if pos.ndim == 1:
if pos.shape == (3,):
pos = self.init_positions_COM_offset + pos
self._tgt["pos"] = pos.unsqueeze(0).tile((self._sim._B, 1, 1))
is_valid = True
elif pos.ndim == 2:
if pos.shape == (self.n_vertices, 3):
self._tgt["pos"] = pos.unsqueeze(0).tile((self._sim._B, 1, 1))
is_valid = True
elif pos.shape == (self._sim._B, 3):
pos = self.init_positions_COM_offset.unsqueeze(0) + pos.unsqueeze(1)
self._tgt["pos"] = pos
is_valid = True
elif pos.ndim == 3:
if pos.shape == (self._sim._B, self.n_vertices, 3):
self._tgt["pos"] = pos
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] def set_velocity(self, vel):
"""
Set the target velocity(ies) for the FEM entity.
Parameters
----------
vel : torch.Tensor or array-like
The desired velocity(ies). Can be:
- (3,): a global velocity vector for all vertices.
- (n_vertices, 3): per-vertex velocities.
- (n_envs, 3): per-environment velocities broadcast to all vertices.
- (n_envs, n_vertices, 3): full batched per-vertex velocities.
Raises
------
Exception
If the tensor shape is not supported.
"""
self._assert_active()
gs.logger.warning("Manally setting element velocities. This is not recommended and could break gradient flow.")
vel = to_gs_tensor(vel)
is_valid = False
if vel.ndim == 1:
if vel.shape == (3,):
self._tgt["vel"] = vel.tile((self._sim._B, self.n_vertices, 1))
is_valid = True
elif vel.ndim == 2:
if vel.shape == (self.n_vertices, 3):
self._tgt["vel"] = vel.unsqueeze(0).tile((self._sim._B, 1, 1))
is_valid = True
elif vel.shape == (self._sim._B, 3):
self._tgt["vel"] = vel.unsqueeze(1).tile((1, self.n_vertices, 1))
is_valid = True
elif vel.ndim == 3:
if vel.shape == (self._sim._B, self.n_vertices, 3):
self._tgt["vel"] = vel
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] def set_actuation(self, actu):
"""
Set the actuation signal for the FEM entity.
Parameters
----------
actu : torch.Tensor or array-like
The actuation tensor. Can be:
- (): a single scalar for all groups.
- (n_groups,): group-level actuation.
- (n_envs, n_groups): batch of group-level actuation signals.
Raises
------
Exception
If the tensor shape is not supported or per-element actuation is attempted.
"""
self._assert_active()
actu = to_gs_tensor(actu)
n_groups = getattr(self.material, "n_groups", 1)
is_valid = False
if actu.ndim == 0:
self._tgt["actu"] = actu.tile((self._sim._B, n_groups))
is_valid = True
elif actu.ndim == 1:
if actu.shape == (n_groups,):
self._tgt["actu"] = actu.unsqueeze(0).tile((self._sim._B, 1))
is_valid = True
elif actu.shape == (self.n_elements,):
gs.raise_exception("Cannot set per-element actuation.")
elif actu.ndim == 2:
if actu.shape == (self._sim._B, n_groups):
self._tgt["actu"] = actu
is_valid = True
if not is_valid:
gs.raise_exception("Tensor shape not supported.")
[docs] def set_muscle(self, muscle_group=None, muscle_direction=None):
"""
Set the muscle group and/or muscle direction for the FEM entity.
Parameters
----------
muscle_group : torch.Tensor or array-like, optional
Tensor of shape (n_elements,) specifying the muscle group ID for each element.
muscle_direction : torch.Tensor or array-like, optional
Tensor of shape (n_elements, 3) specifying unit direction vectors for muscle forces.
Raises
------
AssertionError
If tensor shapes are incorrect or normalization fails.
"""
self._assert_active()
if muscle_group is not None:
n_groups = getattr(self.material, "n_groups", 1)
max_group_id = muscle_group.max().item()
muscle_group = to_gs_tensor(muscle_group)
assert muscle_group.shape == (self.n_elements,)
assert isinstance(max_group_id, int) and max_group_id < n_groups
self.set_muscle_group(muscle_group)
if muscle_direction is not None:
muscle_direction = to_gs_tensor(muscle_direction)
assert muscle_direction.shape == (self.n_elements, 3)
assert ((1.0 - muscle_direction.norm(dim=-1)).abs() < gs.EPS).all()
self.set_muscle_direction(muscle_direction)
[docs] def get_state(self):
state = FEMEntityState(self, self._sim.cur_step_global)
self.get_frame(
self._sim.cur_substep_local,
state.pos,
state.vel,
state.active,
)
# we store all queried states to track gradient flow
self._queried_states.append(state)
return state
[docs] def deactivate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> deactivated.")
self._tgt["act"] = gs.INACTIVE
self.active = False
[docs] def activate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> activated.")
self._tgt["act"] = gs.ACTIVE
self.active = True
# ------------------------------------------------------------------------------------
# ----------------------------------- instantiation ----------------------------------
# ------------------------------------------------------------------------------------
[docs] def instantiate(self, verts, elems):
"""
Initialize FEM entity with given vertices and elements.
Parameters
----------
verts : np.ndarray
Array of vertex positions with shape (n_vertices, 3).
elems : np.ndarray
Array of tetrahedral elements with shape (n_elements, 4), indexing into verts.
Raises
------
Exception
If no vertices are provided.
"""
verts = verts.astype(gs.np_float, copy=False)
elems = elems.astype(gs.np_int, copy=False)
# rotate
R = gu.quat_to_R(np.array(self.morph.quat, dtype=gs.np_float))
verts_COM = verts.mean(axis=0)
init_positions = (verts - verts_COM) @ R.T + verts_COM
if not init_positions.shape[0] > 0:
gs.raise_exception(f"Entity has zero vertices.")
self.init_positions = gs.tensor(init_positions)
self.init_positions_COM_offset = self.init_positions - gs.tensor(verts_COM)
self.elems = elems
[docs] def sample(self):
"""
Sample mesh and elements based on the entity's morph type.
Raises
------
Exception
If the morph type is unsupported.
"""
if isinstance(self.morph, gs.options.morphs.Sphere):
verts, elems = eu.sphere_to_elements(
pos=self._morph.pos,
radius=self._morph.radius,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Box):
verts, elems = eu.box_to_elements(
pos=self._morph.pos,
size=self._morph.size,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Cylinder):
verts, elems = eu.cylinder_to_elements()
elif isinstance(self.morph, gs.options.morphs.Mesh):
verts, elems = eu.mesh_to_elements(
file=self._morph.file,
pos=self._morph.pos,
scale=self._morph.scale,
tet_cfg=self.tet_cfg,
)
else:
gs.raise_exception(f"Unsupported morph: {self.morph}.")
self.instantiate(*eu.split_all_surface_tets(verts, elems))
def _add_to_solver(self, in_backward=False):
if not in_backward:
self._step_global_added = self._sim.cur_step_global
gs.logger.info(
f"Entity {self.uid} added. class: {self.__class__.__name__}, morph: {self.morph.__class__.__name__}, size: ({self.n_elements}, {self.n_vertices}), material: {self.material}."
)
# Convert to appropriate numpy array types
elems_np = self.elems.astype(gs.np_int, copy=False)
verts_numpy = tensor_to_array(self.init_positions, dtype=gs.np_float)
self._solver._kernel_add_elements(
f=self._sim.cur_substep_local,
mat_idx=self._material.idx,
mat_mu=self._material.mu,
mat_lam=self._material.lam,
mat_rho=self._material.rho,
mat_friction_mu=self._material.friction_mu,
n_surfaces=self._n_surfaces,
v_start=self._v_start,
el_start=self._el_start,
s_start=self._s_start,
verts=verts_numpy,
elems=elems_np,
tri2v=self._surface_tri_np,
tri2el=self._surface_el_np,
)
self.active = True
[docs] def compute_pressure_field(self):
"""
Compute the pressure field for the FEM entity based on its tetrahedral elements.
For hydroelastic contact: https://drake.mit.edu/doxygen_cxx/group__hydroelastic__user__guide.html
Notes
-----
https://github.com/RobotLocomotion/drake/blob/master/geometry/proximity/make_mesh_field.cc
TODO: Add margin support
Drake's implementation of margin seems buggy.
"""
init_positions = tensor_to_array(self.init_positions)
signed_distance, *_ = igl.signed_distance(init_positions, init_positions, self._surface_tri_np)
signed_distance = signed_distance.astype(gs.np_float, copy=False)
unsigned_distance = np.abs(signed_distance)
max_distance = np.max(unsigned_distance)
if max_distance < gs.EPS:
gs.raise_exception(
f"Pressure field max distance is too small: {max_distance}. "
"This might be due to a mesh having no internal vertices."
)
self.pressure_field_np = unsigned_distance / max_distance * self.material._hydroelastic_modulus # normalize
# ------------------------------------------------------------------------------------
# ---------------------------- checkpoint and buffer ---------------------------------
# ------------------------------------------------------------------------------------
[docs] def init_tgt_keys(self):
"""
Initialize the keys used in target state management.
This defines which physical properties (e.g., position, velocity) will be tracked for checkpointing and buffering.
"""
self._tgt_keys = ["vel", "pos", "act", "actu"]
[docs] def init_tgt_vars(self):
"""
Initialize the target state variables and their buffers.
This sets up internal dictionaries to store per-step target values for properties like velocity, position, actuation, and activation.
"""
# temp variable to store targets for next step
self._tgt = dict()
self._tgt_buffer = dict()
self.init_tgt_keys()
for key in self._tgt_keys:
self._tgt[key] = None
self._tgt_buffer[key] = list()
[docs] def init_ckpt(self):
"""
Initialize the checkpoint storage dictionary.
Creates an empty container for storing simulation checkpoints.
"""
self._ckpt = dict()
[docs] def save_ckpt(self, ckpt_name):
"""
Save the current target state buffers to a named checkpoint.
Parameters
----------
ckpt_name : str
The name to identify the checkpoint.
Notes
-----
After saving, the internal target buffers are cleared to prepare for new input.
"""
if not ckpt_name in self._ckpt:
self._ckpt[ckpt_name] = {
"_tgt_buffer": dict(),
}
for key in self._tgt_keys:
self._ckpt[ckpt_name]["_tgt_buffer"][key] = list(self._tgt_buffer[key])
self._tgt_buffer[key].clear()
[docs] def load_ckpt(self, ckpt_name):
"""
Load a previously saved target state buffer from a named checkpoint.
Parameters
----------
ckpt_name : str
The name of the checkpoint to load.
Raises
------
KeyError
If the checkpoint name is not found.
"""
for key in self._tgt_keys:
self._tgt_buffer[key] = list(self._ckpt[ckpt_name]["_tgt_buffer"][key])
[docs] def reset_grad(self):
"""
Clear all stored gradient-related buffers.
This resets the target buffer and clears any queried states used for gradient tracking.
"""
for key in self._tgt_keys:
self._tgt_buffer[key].clear()
self._queried_states.clear()
def _assert_active(self):
if not self.active:
gs.raise_exception(f"{self.__class__.__name__} is inactive. Call `entity.activate()` first.")
# ------------------------------------------------------------------------------------
# ---------------------------- interfacing with solver -------------------------------
# ------------------------------------------------------------------------------------
[docs] def set_pos(self, f, pos):
"""
Set element positions in the solver.
Parameters
----------
f : int
Current substep/frame index.
pos : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing new positions.
"""
self._solver._kernel_set_elements_pos(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos=pos,
)
[docs] def set_pos_grad(self, f, pos_grad):
"""
Set gradient of element positions in the solver.
Parameters
----------
f : int
Current substep/frame index.
pos_grad : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing gradients of positions.
"""
self._solver._kernel_set_elements_pos_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos_grad=pos_grad,
)
[docs] def set_vel(self, f, vel):
"""
Set element velocities in the solver.
Parameters
----------
f : int
Current substep/frame index.
vel : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing velocities.
"""
self._solver._kernel_set_elements_vel(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel=vel,
)
[docs] def set_vel_grad(self, f, vel_grad):
"""
Set gradient of element velocities in the solver.
Parameters
----------
f : int
Current substep/frame index.
vel_grad : gs.Tensor
Tensor of shape (n_envs, n_vertices, 3) containing gradients of velocities.
"""
self._solver._kernel_set_elements_vel_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel_grad=vel_grad,
)
[docs] def set_actu(self, f, actu):
"""
Set actuation values for elements in the solver.
Parameters
----------
f : int
Current substep/frame index.
actu : gs.Tensor
Tensor of shape (n_envs, n_groups) specifying actuation values.
"""
self._solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
n_groups=self.material.n_groups,
actu=actu,
)
[docs] def set_actu_grad(self, f, actu_grad):
"""
Set gradient of actuation values in the solver.
Parameters
----------
f : int
Current substep/frame index.
actu_grad : gs.Tensor
Tensor of shape (n_envs, n_groups) specifying gradients of actuation.
"""
self._solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
actu_grad=actu_grad,
)
[docs] def set_active(self, f, active):
"""
Set the active status of each element.
Parameters
----------
f : int
Current substep/frame index.
active : int
Activity flag (gs.ACTIVE or gs.INACTIVE).
"""
self._solver._kernel_set_active(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
active=active,
)
[docs] def set_muscle_group(self, muscle_group):
"""
Set muscle group index for each element.
Parameters
----------
muscle_group : torch.Tensor
Tensor of shape (n_elements,) specifying muscle group IDs.
"""
self._solver._kernel_set_muscle_group(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_group=muscle_group,
)
[docs] def set_muscle_direction(self, muscle_direction):
"""
Set muscle force direction for each element.
Parameters
----------
muscle_direction : torch.Tensor
Tensor of shape (n_elements, 3) with unit direction vectors.
"""
self._solver._kernel_set_muscle_direction(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_direction=muscle_direction,
)
def _sanitize_input_tensor(self, tensor, dtype, unbatched_ndim=1):
_tensor = torch.as_tensor(tensor, dtype=dtype, device=gs.device)
if _tensor.ndim < unbatched_ndim + 1:
_tensor = _tensor.repeat((self._sim._B, *((1,) * max(1, _tensor.ndim))))
if self._sim._B > 1:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
else:
_tensor = _tensor.contiguous()
if _tensor is not tensor:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
if len(_tensor) != self._sim._B:
gs.raise_exception("Input tensor batch size must match the number of environments.")
if _tensor.ndim != unbatched_ndim + 1:
gs.raise_exception(f"Input tensor ndim is {_tensor.ndim}, should be {unbatched_ndim + 1}.")
return _tensor
def _sanitize_input_verts_idx(self, verts_idx_local):
verts_idx = self._sanitize_input_tensor(verts_idx_local, dtype=gs.tc_int, unbatched_ndim=1) + self._v_start
assert ((verts_idx >= 0) & (verts_idx < self._solver.n_vertices)).all(), "Vertex indices out of bounds."
return verts_idx
def _sanitize_input_poss(self, poss):
poss = self._sanitize_input_tensor(poss, dtype=gs.tc_float, unbatched_ndim=2)
assert poss.ndim == 3 and poss.shape[2] == 3, "Position tensor must have shape (B, num_verts, 3)."
return poss
[docs] def set_vertex_constraints(
self, verts_idx, target_poss=None, link=None, is_soft_constraint=False, stiffness=0.0, envs_idx=None
):
"""
Set vertex constraints for specified vertices.
Parameters
----------
verts_idx : array_like
List of vertex indices to constrain.
target_poss : array_like, shape (len(verts_idx), 3), optional
List of target positions [x, y, z] for each vertex. If not provided, the initial positions are used.
link : RigidLink
Optional rigid link for the vertices to follow, maintaining relative position.
is_soft_constraint: bool
By default, use a hard constraint directly sets position and zero velocity.
A soft constraint uses a spring force to pull the vertex towards the target position.
stiffness : float
Specify a spring stiffness for a soft constraint. Critical damping is applied.
envs_idx : array_like, optional
List of environment indices to apply the constraints to. If None, applies to all environments.
"""
if self._solver._use_implicit_solver:
if not self._solver._enable_vertex_constraints:
gs.logger.warning("Ignoring vertex constraint; FEM implicit solver needs to enable vertex constraints.")
return
if not self._solver._constraints_initialized:
self._solver.init_constraints()
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
verts_idx = self._sanitize_input_verts_idx(verts_idx)
if target_poss is None:
target_poss = torch.zeros(
(verts_idx.shape[0], verts_idx.shape[1], 3), dtype=gs.tc_float, device=gs.device, requires_grad=False
)
self._kernel_get_verts_pos(self._sim.cur_substep_local, target_poss, verts_idx)
target_poss = self._sanitize_input_poss(target_poss)
assert (
len(envs_idx) == len(target_poss) == len(verts_idx)
), "First dimension should match number of environments."
assert target_poss.shape[1] == verts_idx.shape[1], "Target position should be provided for each vertex."
if link is None:
link_init_pos = torch.zeros((self._sim._B, 3), dtype=gs.tc_float, device=gs.device)
link_init_quat = torch.zeros((self._sim._B, 4), dtype=gs.tc_float, device=gs.device)
link_idx = -1
else:
assert isinstance(link, RigidLink), "Only RigidLink is supported for vertex constraints."
link_init_pos = self._sanitize_input_tensor(link.get_pos(), dtype=gs.tc_float)
link_init_quat = self._sanitize_input_tensor(link.get_quat(), dtype=gs.tc_float)
link_idx = link.idx
self._solver._kernel_set_vertex_constraints(
self._sim.cur_substep_local,
verts_idx,
target_poss,
is_soft_constraint,
stiffness,
link_idx,
link_init_pos,
link_init_quat,
envs_idx,
)
[docs] def update_constraint_targets(self, verts_idx, target_poss, envs_idx=None):
"""Update target positions for existing constraints."""
if not self._solver._constraints_initialized:
gs.logger.warning("Ignoring update_constraint_targets; constraints have not been initialized.")
return
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
verts_idx = self._sanitize_input_verts_idx(verts_idx)
target_poss = self._sanitize_input_poss(target_poss)
assert target_poss.shape[1] == verts_idx.shape[1], "Target position should be provided for each vertex."
self._solver._kernel_update_constraint_targets(verts_idx, target_poss, envs_idx)
[docs] def remove_vertex_constraints(self, verts_idx=None, envs_idx=None):
"""Remove constraints from specified vertices, or all if None."""
if not self._solver._constraints_initialized:
gs.logger.warning("Ignoring remove_vertex_constraints; constraints have not been initialized.")
return
if verts_idx is None:
self._solver.vertex_constraints.is_constrained.fill(0)
else:
verts_idx = self._sanitize_input_verts_idx(verts_idx)
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
self._solver._kernel_remove_specific_constraints(verts_idx, envs_idx)
@ti.kernel
def _kernel_get_verts_pos(self, f: ti.i32, pos: ti.types.ndarray(), verts_idx: ti.types.ndarray()):
# get current position of vertices
for i_v, i_b in ti.ndrange(verts_idx.shape[0], verts_idx.shape[1]):
i_global = verts_idx[i_v, i_b] + self.v_start
for j in ti.static(range(3)):
pos[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].pos[j]
[docs] def get_el2v(self):
"""
Retrieve the element-to-vertex mapping.
Returns
-------
el2v : gs.Tensor
Tensor of shape (n_elements, 4) mapping each element to its vertex indices.
"""
el2v = gs.zeros((self.n_elements, 4), dtype=int, requires_grad=False, scene=self.scene)
self._solver._kernel_get_el2v(
element_el_start=self._el_start,
n_elements=self.n_elements,
el2v=el2v,
)
return el2v
[docs] @ti.kernel
def get_frame(self, f: ti.i32, pos: ti.types.ndarray(), vel: ti.types.ndarray(), active: ti.types.ndarray()):
"""
Fetch the position, velocity, and activation state of the FEM entity at a specific substep.
Parameters
----------
f : int
The substep/frame index to fetch the state from.
pos : np.ndarray
Output array of shape (n_envs, n_vertices, 3) to store positions.
vel : np.ndarray
Output array of shape (n_envs, n_vertices, 3) to store velocities.
active : np.ndarray
Output array of shape (n_envs, n_elements) to store active flags.
"""
for i_v, i_b in ti.ndrange(self.n_vertices, self._sim._B):
i_global = i_v + self.v_start
for j in ti.static(range(3)):
pos[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].pos[j]
vel[i_b, i_v, j] = self._solver.elements_v[f, i_global, i_b].vel[j]
for i_v, i_b in ti.ndrange(self.n_elements, self._sim._B):
i_global = i_v + self.el_start
active[i_b, i_v] = self._solver.elements_el_ng[f, i_global, i_b].active
[docs] @ti.kernel
def clear_grad(self, f: ti.i32):
"""
Zero out the gradients of position, velocity, and actuation for the current substep.
Parameters
----------
f : int
The substep/frame index for which to clear gradients.
Notes
-----
This method is primarily used during backward passes to manually reset gradients
that may be corrupted by explicit state setting.
"""
# TODO: not well-tested
for i_v, i_b in ti.ndrange(self.n_vertices, self._sim._B):
i_global = i_v + self.v_start
self._solver.elements_v.grad[f, i_global, i_b].pos = 0
self._solver.elements_v.grad[f, i_global, i_b].vel = 0
for i_v, i_b in ti.ndrange(self.n_elements, self._sim._B):
i_global = i_v + self.el_start
self._solver.elements_el.grad[f, i_global, i_b].actu = 0
# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
@property
def n_vertices(self):
"""Number of vertices in the FEM entity."""
return len(self.init_positions)
@property
def n_elements(self):
"""Number of tetrahedral elements in the FEM entity."""
return len(self.elems)
@property
def n_surfaces(self):
"""Number of surface triangles extracted from the FEM mesh."""
return self._n_surfaces
@property
def v_start(self):
"""Global vertex index offset for this entity."""
return self._v_start
@property
def el_start(self):
"""Global element index offset for this entity."""
return self._el_start
@property
def s_start(self):
"""Global surface triangle index offset for this entity."""
return self._s_start
@property
def morph(self):
"""Morph specification used to generate the FEM mesh."""
return self._morph
@property
def material(self):
"""Material properties of the FEM entity."""
return self._material
@property
def surface(self):
"""Surface for rendering."""
return self._surface
@property
def n_surface_vertices(self):
"""Number of unique vertices involved in surface triangles."""
return self._n_surface_vertices
@property
def surface_triangles(self):
"""Surface triangles of the FEM mesh."""
return self._surface_tri_np
@property
def tet_cfg(self):
"""Configuration of tetrahedralization."""
tet_cfg = mu.generate_tetgen_config_from_morph(self.morph)
return tet_cfg