Source code for genesis.recorders.file_writers

import csv
import logging
import os
import tempfile
from collections import defaultdict
from functools import lru_cache
from pathlib import Path

import numpy as np
import torch

import genesis as gs
from genesis.options.recorders import (
    VideoFile as VideoFileWriterOptions,
    CSVFile as CSVFileWriterOptions,
    NPZFile as NPZFileWriterOptions,
)
from genesis.utils import tensor_to_array

from .base_recorder import Recorder
from .recorder_manager import register_recording

try:
    import av
except ImportError:
    pass


LOGGER = logging.getLogger(__name__)


@lru_cache(maxsize=None)
def _probe_h264_codec(codec: str, width: int, height: int) -> bool:
    """Test whether a codec can actually encode a single frame at the given resolution.

    Some hardware encoders (e.g. NVENC) reject small resolutions, so the actual target resolution must be tested.
    """
    # Use mkstemp instead of NamedTemporaryFile because Windows cannot open a NamedTemporaryFile from another handle
    path = None
    try:
        fd, path = tempfile.mkstemp(suffix=".mp4")
        os.close(fd)
        container = av.open(path, mode="w")
        stream = container.add_stream(codec, rate=30)
        stream.width = width
        stream.height = height
        stream.pix_fmt = "yuv420p"
        frame = av.VideoFrame(width, height, "yuv420p")
        for packet in stream.encode(frame):
            container.mux(packet)
        for packet in stream.encode(None):
            container.mux(packet)
        container.close()
        return True
    except (av.error.FFmpegError, ValueError):  # FFmpegError: codec/permission failures, ValueError: invalid parameters
        return False
    finally:
        if path is not None:
            os.remove(path)


[docs]class BaseFileWriter(Recorder): """ Base class for file writers. Handles filename counter when save_on_reset is True. """
[docs] def build(self): super().build() self.counter = 0 os.makedirs(os.path.abspath(os.path.dirname(self._options.filename)), exist_ok=True) self._initialize_writer()
[docs] def reset(self, envs_idx=None): super().reset(envs_idx) # no envs specific saving supported if self._options.save_on_reset: self.cleanup() self.counter += 1 self._initialize_writer()
def _get_filename(self): if self._options.save_on_reset: path, ext = os.path.splitext(self._options.filename) return f"{path}_{self.counter}{ext}" return self._options.filename def _initialize_writer(self): pass
[docs]@register_recording(VideoFileWriterOptions) class VideoFileWriter(BaseFileWriter): video_container: "av.container.OutputContainer | None" video_stream: "av.video.stream.VideoStream | None" video_frame: "av.video.frame.VideoFrame | None" video_buffer: "np.ndarray | None"
[docs] def build(self): self.video_container = None self.video_stream = None self.video_frame = None self.video_buffer = None self.fps = int( round( 1.0 / (self._steps_per_sample * self._manager._step_dt) if self._options.fps is None else self._options.fps ) ) super().build()
def _initialize_writer(self): video_path = self._get_filename() video_name = self._options.name or Path(video_path).stem # Create ffmpeg video container self.video_container = av.open(video_path, mode="w") self.video_container.metadata["title"] = video_name def _initialize_data(self, data): assert isinstance(data, (np.ndarray, torch.Tensor)) is_color = data.ndim == 3 and data.shape[-1] == 3 if isinstance(data, np.ndarray): is_dtype_int = np.issubdtype(data.dtype, np.integer) else: is_dtype_int = not torch.is_floating_point(data) if data.ndim != 2 + is_color or not is_dtype_int: gs.raise_exception(f"[{type(self).__name__}] Data must be either grayscale [H, W] or color [H, W, RGB]") height, width, *_ = data.shape # Auto-select best available codec at the actual recording resolution. # Deferred to here because hardware encoders (e.g. NVENC) have minimum resolution # requirements that can only be validated with the real frame size. codec = self._options.codec if not codec: for candidate in ( "h264_videotoolbox", # macOS hardware "h264_nvenc", # NVIDIA hardware "h264_vaapi", # Linux VA-API hardware "h264_qsv", # Intel Quick Sync "libx264", # Software fallback ): if candidate in av.codecs_available and _probe_h264_codec(candidate, width, height): codec = candidate break else: gs.raise_exception( "No supported H.264 codec found. Please install libx264 or specify a codec explicitly." ) # Apply sensible defaults per codec when no explicit options are provided codec_options = self._options.codec_options if not codec_options: codec_options = { "h264_videotoolbox": {"realtime": "1"}, "h264_nvenc": {"preset": "p1", "tune": "ull"}, "h264_vaapi": {}, "h264_qsv": {"preset": "veryfast"}, "libx264": {"preset": "veryfast", "tune": "zerolatency"}, }.get(codec, {}) gs.logger.debug( f"Starting video recording using codec '{codec}' ({codec_options}) at {width}x{height} {self.fps}fps." ) # Create ffmpeg video stream self.video_stream = self.video_container.add_stream(codec, rate=self.fps) assert isinstance(self.video_stream, av.video.stream.VideoStream) self.video_stream.width, self.video_stream.height = (width, height) self.video_stream.pix_fmt = "yuv420p" self.video_stream.bit_rate = int(self._options.bitrate * (8 * 1024**2)) self.video_stream.codec_context.options = codec_options # Create frame storage once for efficiency if is_color: self.video_frame = av.VideoFrame(width, height, "rgb24") frame_plane = self.video_frame.planes[0] self.video_buffer = np.asarray(memoryview(frame_plane)).reshape((-1, frame_plane.line_size // 3, 3)) else: self.video_frame = av.VideoFrame(width, height, "gray8") frame_plane = self.video_frame.planes[0] self.video_buffer = np.asarray(memoryview(frame_plane)).reshape((-1, frame_plane.line_size))
[docs] def process(self, data, cur_time): if self.video_buffer is None: self._initialize_data(data) if isinstance(data, torch.Tensor): data = tensor_to_array(data) data = data.astype(np.uint8) # Write frame self.video_buffer[: data.shape[0], : data.shape[1]] = data for packet in self.video_stream.encode(self.video_frame): self.video_container.mux(packet)
[docs] def cleanup(self): if self.video_container is not None: # Finalize video recording. # Note that 'video_stream' may be None if 'process' what never called. if self.video_stream is not None: for packet in self.video_stream.encode(None): self.video_container.mux(packet) self.video_container.close() (gs.logger or LOGGER).info(f'Video saved to "~<{self._options.filename}>~".') self.video_container = None self.video_stream = None self.video_frame = None self.video_buffer = None
@property def run_in_thread(self) -> bool: return False
[docs]@register_recording(CSVFileWriterOptions) class CSVFileWriter(BaseFileWriter): def _initialize_writer(self): self.wrote_data = False self.file_handle = open(self._get_filename(), "w", encoding="utf-8", newline="") self.csv_writer = csv.writer(self.file_handle) def _sanitize_to_list(self, value): if isinstance(value, (torch.Tensor, np.ndarray)): return value.reshape((-1,)).tolist() elif isinstance(value, (int, float, bool)): return [value] elif isinstance(value, (list, tuple)): return value else: gs.raise_exception(f"[{type(self).__name__}] Unsupported data type: {type(value)}")
[docs] def process(self, data, cur_time): row_data = [cur_time] if isinstance(data, dict): for value in data.values(): row_data.extend(self._sanitize_to_list(value)) else: row_data.extend(self._sanitize_to_list(data)) if not self.wrote_data: # write header header = ["timestamp"] if self._options.header: header.extend(self._options.header) else: if isinstance(data, dict): for key, val in data.items(): if hasattr(val, "__len__"): header.extend([f"{key}_{i}" for i in range(len(val))]) else: header.append(key) else: header.extend([f"data_{i}" for i in range(1, len(row_data))]) if len(header) != len(row_data): gs.raise_exception(f"[{type(self).__name__}] header length does not match data length.") self.csv_writer.writerow(header) self.wrote_data = True self.csv_writer.writerow(row_data) if self._options.save_every_write: self.file_handle.flush()
[docs] def cleanup(self): if self.file_handle: if self.wrote_data: self.file_handle.close() (gs.logger or LOGGER).info(f'[CSVFileWriter] Saved to ~<"{self._get_filename()}">~.') else: self.file_handle.close() os.remove(self._get_filename()) # delete empty file
@property def run_in_thread(self) -> bool: return True
[docs]@register_recording(NPZFileWriterOptions) class NPZFileWriter(BaseFileWriter):
[docs] def build(self): self.all_data: dict[str, list] = defaultdict(list) super().build()
[docs] def process(self, data, cur_time): self.all_data["timestamp"].append(cur_time) if isinstance(data, dict): for key, value in data.items(): if isinstance(value, torch.Tensor): value = tensor_to_array(value) assert isinstance(value, (int, float, bool, list, tuple, np.ndarray)) self.all_data[key].append(value) else: self.all_data["data"].append(tensor_to_array(data))
[docs] def cleanup(self): filename = self._get_filename() if self.all_data["timestamp"]: # at least one data point was collected try: np.savez_compressed(filename, **self.all_data) except ValueError as error: (gs.logger or LOGGER).warning(f"NPZFileWriter: saving as dtype=object due to ValueError: {error}") np.savez_compressed(filename, **{k: np.array(v, dtype=object) for k, v in self.all_data.items()}) (gs.logger or LOGGER).info( f'[NPZFileWriter] Saved data with keys {list(self.all_data.keys())} to ~<"{filename}">~.' ) self.all_data.clear()
@property def run_in_thread(self) -> bool: return True