import numpy as np
import taichi as ti
import torch
import genesis as gs
import genesis.utils.element as eu
import genesis.utils.geom as gu
from genesis.engine.states.cache import QueriedStates
from genesis.engine.states.entities import FEMEntityState
from genesis.utils.misc import to_gs_tensor
from .base_entity import Entity
[文档]@ti.data_oriented
class FEMEntity(Entity):
"""
FEM-based entity.
"""
def __init__(self, scene, solver, material, morph, surface, idx, v_start=0, el_start=0, s_start=0):
super().__init__(idx, scene, morph, solver, material, surface)
self._v_start = v_start # offset for vertex index of elements
self._el_start = el_start # offset for element index
self._s_start = s_start # offset for surface triangles
self._step_global_added = None
self._surface.update_texture()
self.sample()
self.init_tgt_vars()
self.init_ckpt()
self._queried_states = QueriedStates()
self.active = False # This attribute is only used in forward pass. It should NOT be used during backward pass.
# ------------------------------------------------------------------------------------
# ----------------------------------- basic entity ops -------------------------------
# ------------------------------------------------------------------------------------
[文档] def set_position(self, pos):
self._assert_active()
gs.logger.warning("Manally setting element positions. This is not recommended and could break gradient flow.")
pos = to_gs_tensor(pos)
if len(pos.shape) == 1:
assert pos.shape == (3,)
self._tgt["pos"] = self.init_positions_COM_offset + pos
elif len(pos.shape) == 2:
assert pos.shape == (self.n_vertices, 3)
self._tgt["pos"] = pos
else:
gs.raise_exception("Tensor shape not supported.")
[文档] def set_velocity(self, vel):
self._assert_active()
gs.logger.warning("Manally setting element velocities. This is not recommended and could break gradient flow.")
vel = to_gs_tensor(vel)
if len(vel.shape) == 1:
assert vel.shape == (3,)
self._tgt["vel"] = torch.tile(vel, [self.n_vertices, 1])
elif len(vel.shape) == 2:
assert vel.shape == (self.n_vertices, 3)
self._tgt["vel"] = vel
else:
gs.raise_exception("Tensor shape not supported.")
[文档] def set_actuation(self, actu):
self._assert_active()
actu = to_gs_tensor(actu)
n_groups = getattr(self.material, "n_groups", 1)
if len(actu.shape) == 0:
assert actu.shape == ()
self._tgt["actu"] = torch.tile(actu, [self.n_elements, n_groups])
elif len(actu.shape) == 1:
if actu.shape[0] == n_groups:
assert self.n_elements != n_groups # ambiguous
actu = actu.tile([self.n_elements, 1])
else:
assert actu.shape == (self.n_elements,)
gs.raise_exception("Cannot set per-element actuation")
self._tgt["actu"] = actu
else:
gs.raise_exception("Tensor shape not supported.")
[文档] def set_muscle(self, muscle_group=None, muscle_direction=None):
self._assert_active()
if muscle_group is not None:
n_groups = getattr(self.material, "n_groups", 1)
max_group_id = muscle_group.max().item()
muscle_group = to_gs_tensor(muscle_group)
assert muscle_group.shape == (self.n_elements,)
assert isinstance(max_group_id, int) and max_group_id < n_groups
self.set_muscle_group(muscle_group)
if muscle_direction is not None:
muscle_direction = to_gs_tensor(muscle_direction)
assert muscle_direction.shape == (self.n_elements, 3)
assert torch.allclose(muscle_direction.norm(dim=-1), torch.Tensor([1.0]).to(muscle_direction))
self.set_muscle_direction(muscle_direction)
[文档] def get_state(self):
state = FEMEntityState(self, self._sim.cur_step_global)
self.get_frame(
self._sim.cur_substep_local,
state.pos,
state.vel,
state.active,
)
# we store all queried states to track gradient flow
self._queried_states.append(state)
return state
[文档] def deactivate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> deactivated.")
self._tgt["act"] = gs.INACTIVE
self.active = False
[文档] def activate(self):
gs.logger.info(f"{self.__class__.__name__} <{self.id}> activated.")
self._tgt["act"] = gs.ACTIVE
self.active = True
# ------------------------------------------------------------------------------------
# ----------------------------------- instantiation ----------------------------------
# ------------------------------------------------------------------------------------
[文档] def instantiate(self, verts, elems):
# rotate
R = gu.quat_to_R(np.array(self.morph.quat))
verts_COM = verts.mean(0)
init_positions = (R @ (verts - verts_COM).T).T + verts_COM
if not init_positions.shape[0] > 0:
gs.raise_exception(f"Entity has zero vertices.")
self.init_positions = gs.tensor(init_positions).contiguous()
self.init_positions_COM_offset = (self.init_positions - gs.tensor(verts_COM)).contiguous()
self.elems = elems
[文档] def sample(self):
if isinstance(self.morph, gs.options.morphs.Sphere):
verts, elems = eu.sphere_to_elements(
pos=self._morph.pos,
radius=self._morph.radius,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Box):
verts, elems = eu.box_to_elements(
pos=self._morph.pos,
size=self._morph.size,
tet_cfg=self.tet_cfg,
)
elif isinstance(self.morph, gs.options.morphs.Cylinder):
verts, elems = eu.cylinder_to_elements()
elif isinstance(self.morph, gs.options.morphs.Mesh):
verts, elems = eu.mesh_to_elements(
file=self._morph.file,
pos=self._morph.pos,
scale=self._morph.scale,
tet_cfg=self.tet_cfg,
)
else:
gs.raise_exception(f"Unsupported morph: {self.morph}.")
self.instantiate(verts, elems)
def _add_to_solver(self, in_backward=False):
if not in_backward:
self._step_global_added = self._sim.cur_step_global
gs.logger.info(
f"Entity {self.uid} added. class: {self.__class__.__name__}, morph: {self.morph.__class__.__name__}, size: ({self.n_elements}, {self.n_vertices}), material: {self.material}."
)
el2tri = np.array(
[ # follow the order with correct normal
[[v[0], v[2], v[1]], [v[1], v[2], v[3]], [v[0], v[1], v[3]], [v[0], v[3], v[2]]] for v in self.elems
]
)
all_tri = el2tri.reshape(-1, 3)
all_tri_sorted = np.sort(all_tri, axis=1)
_, unique_idcs, cnt = np.unique(all_tri_sorted, axis=0, return_counts=True, return_index=True)
unique_tri = all_tri[unique_idcs]
surface_tri = unique_tri[cnt == 1]
surface_tri = surface_tri.astype(gs.np_int)
self._n_surfaces = len(surface_tri)
self._n_surface_vertices = len(np.unique(surface_tri))
tri2el = np.repeat(np.arange(self.elems.shape[0])[:, None], 4, axis=-1)
all_el = tri2el.reshape(
-1,
)
unique_el = all_el[unique_idcs]
surface_el = unique_el[cnt == 1].astype(gs.np_int)
self._solver._kernel_add_elements(
f=self._sim.cur_substep_local,
mat_idx=self._material.idx,
mat_mu=self._material.mu,
mat_lam=self._material.lam,
mat_rho=self._material.rho,
n_surfaces=self._n_surfaces,
v_start=self._v_start,
el_start=self._el_start,
s_start=self._s_start,
verts=self.init_positions,
elems=self.elems,
tri2v=surface_tri,
tri2el=surface_el,
)
self.active = True
# ------------------------------------------------------------------------------------
# ---------------------------- checkpoint and buffer ---------------------------------
# ------------------------------------------------------------------------------------
[文档] def init_tgt_keys(self):
self._tgt_keys = ["vel", "pos", "act", "actu"]
[文档] def init_tgt_vars(self):
# temp variable to store targets for next step
self._tgt = dict()
self._tgt_buffer = dict()
self.init_tgt_keys()
for key in self._tgt_keys:
self._tgt[key] = None
self._tgt_buffer[key] = list()
[文档] def init_ckpt(self):
self._ckpt = dict()
[文档] def save_ckpt(self, ckpt_name):
if not ckpt_name in self._ckpt:
self._ckpt[ckpt_name] = {
"_tgt_buffer": dict(),
}
for key in self._tgt_keys:
self._ckpt[ckpt_name]["_tgt_buffer"][key] = list(self._tgt_buffer[key])
self._tgt_buffer[key].clear()
[文档] def load_ckpt(self, ckpt_name):
for key in self._tgt_keys:
self._tgt_buffer[key] = list(self._ckpt[ckpt_name]["_tgt_buffer"][key])
[文档] def reset_grad(self):
for key in self._tgt_keys:
self._tgt_buffer[key].clear()
self._queried_states.clear()
def _assert_active(self):
if not self.active:
gs.raise_exception(f"{self.__class__.__name__} is inactive. Call `entity.activate()` first.")
# ------------------------------------------------------------------------------------
# ---------------------------- interfacing with solver -------------------------------
# ------------------------------------------------------------------------------------
[文档] def set_pos(self, f, pos):
self.solver._kernel_set_elements_pos(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos=pos,
)
[文档] def set_pos_grad(self, f, pos_grad):
self.solver._kernel_set_elements_pos_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
pos_grad=pos_grad,
)
[文档] def set_vel(self, f, vel):
self.solver._kernel_set_elements_vel(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel=vel,
)
[文档] def set_vel_grad(self, f, vel_grad):
self.solver._kernel_set_elements_vel_grad(
f=f,
element_v_start=self._v_start,
n_vertices=self.n_vertices,
vel_grad=vel_grad,
)
[文档] def set_actu(self, f, actu):
self.solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
n_groups=self.material.n_groups,
actu=actu,
)
[文档] def set_actu_grad(self, f, actu_grad):
self.solver._kernel_set_elements_actu(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
actu_grad=actu_grad,
)
[文档] def set_active(self, f, active):
self.solver._kernel_set_active(
f=f,
element_el_start=self._el_start,
n_elements=self.n_elements,
active=active,
)
[文档] def set_muscle_group(self, muscle_group):
self.solver._kernel_set_muscle_group(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_group=muscle_group,
)
[文档] def set_muscle_direction(self, muscle_direction):
self.solver._kernel_set_muscle_direction(
element_el_start=self._el_start,
n_elements=self.n_elements,
muscle_direction=muscle_direction,
)
[文档] def get_el2v(self):
el2v = gs.zeros((self.n_elements, 4), dtype=int, requires_grad=False, scene=self.scene)
self.solver._kernel_get_el2v(
element_el_start=self._el_start,
n_elements=self.n_elements,
el2v=el2v,
)
return el2v
[文档] @ti.kernel
def get_frame(self, f: ti.i32, pos: ti.types.ndarray(), vel: ti.types.ndarray(), active: ti.types.ndarray()):
for i in range(self.n_vertices):
i_global = i + self.v_start
for j in ti.static(range(3)):
pos[i, j] = self._solver.elements_v[f, i_global].pos[j]
vel[i, j] = self._solver.elements_v[f, i_global].vel[j]
for i in range(self.n_elements):
i_global = i + self.el_start
active[i] = self._solver.elements_el_ng[f, i_global].active
[文档] @ti.kernel
def clear_grad(self, f: ti.i32):
# TODO: not well-tested
for i in range(self.n_vertices):
i_global = i + self.v_start
self._solver.elements_v.grad[f, i_global].pos = 0
self._solver.elements_v.grad[f, i_global].vel = 0
for i in range(self.n_elements):
i_global = i + self.el_start
self._solver.elements_el.grad[f, i_global].actu = 0
# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
@property
def n_vertices(self):
return len(self.init_positions)
@property
def n_elements(self):
return len(self.elems)
@property
def n_surfaces(self):
return self._n_surfaces
@property
def v_start(self):
return self._v_start
@property
def el_start(self):
return self._el_start
@property
def s_start(self):
return self._s_start
@property
def morph(self):
return self._morph
@property
def material(self):
return self._material
@property
def surface(self):
return self._surface
@property
def n_surface_vertices(self):
return self._n_surface_vertices
@property
def tet_cfg(self):
tet_cfg = dict(
order=getattr(self.morph, "order", 1),
mindihedral=getattr(self.morph, "mindihedral", 10),
minratio=getattr(self.morph, "minratio", 1.1),
nobisect=getattr(self.morph, "nobisect", True),
quality=getattr(self.morph, "quality", True),
verbose=getattr(self.morph, "verbose", 0),
)
return tet_cfg