#!/usr/bin/env python3 import argparse import csv import json import os import re import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import h5py import numpy as np from PIL import Image import matplotlib.pyplot as plt RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR") @dataclass class CropBox: x1: int y1: int x2: int y2: int @property def w(self) -> int: return self.x2 - self.x1 @property def h(self) -> int: return self.y2 - self.y1 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)." ) ) parser.add_argument( "--segment_dir", type=str, required=True, help="Path to one raw segment, e.g. data/follow_seg_001", ) parser.add_argument( "--output_dir", type=str, required=True, help="Output dir for episode_*.hdf5", ) parser.add_argument( "--episode_idx", type=int, default=0, help="Output episode index (default: 0)", ) parser.add_argument( "--max_frames", type=int, default=-1, help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)", ) parser.add_argument( "--camera_name", type=str, default="top", help="Camera name written to /observations/images/", ) parser.add_argument( "--crop", type=int, nargs=4, default=[733, 30, 1754, 1051], metavar=("X1", "Y1", "X2", "Y2"), help="Crop box in original image coordinates", ) parser.add_argument( "--resize", type=int, nargs=2, default=[224, 224], metavar=("W", "H"), help="Output image size", ) parser.add_argument( "--instruction_template", type=str, default="Move toward the {label} at {region}.", help="Template for per-frame instruction", ) parser.add_argument( "--instruction_empty", type=str, default="No target visible.", help="Instruction when no valid target after crop", ) parser.add_argument( "--stop_instruction", type=str, default="Stop move.", help="Instruction used for stationary head/tail frames", ) parser.add_argument( "--state_norm", choices=["minus1_1", "0_1", "raw"], default="minus1_1", help="Normalization for qpos (motor_pos_y, motor_pos_x)", ) parser.add_argument( "--action_norm", choices=["minus1_1", "0_1", "raw"], default="minus1_1", help="Normalization for action (motor_command_0, motor_command_1)", ) parser.add_argument( "--encode_text_features", action="store_true", help="Encode per-frame instruction into 768-dim DistilBERT features", ) parser.add_argument( "--text_model_name", type=str, default="distilbert-base-uncased", help="HuggingFace model name for DistilBERT", ) parser.add_argument( "--text_batch_size", type=int, default=32, help="Batch size for text feature extraction", ) parser.add_argument( "--motion_window", type=int, default=5, help="Sliding window size used for stationary detection at beginning/end", ) parser.add_argument( "--motion_threshold", type=float, default=0.002, help=( "Motion threshold in normalized delta space (0~1). " "Smaller value means stricter stationary detection" ), ) parser.add_argument( "--disable_stop_override", action="store_true", help="Disable head/tail stationary instruction override", ) return parser.parse_args() def sorted_frame_jsons(frames_dir: Path) -> List[Path]: json_files = list(frames_dir.glob("*.json")) def key_fn(p: Path) -> Tuple[int, str]: m = re.search(r"frame_(\d+)", p.name) idx = int(m.group(1)) if m else 10**9 return idx, p.name json_files.sort(key=key_fn) return json_files def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]: with csv_path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) return list(reader) def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray: if mode == "raw": return x.astype(np.float32) x01 = (x - min_v) / (max_v - min_v) if mode == "0_1": return x01.astype(np.float32) # minus1_1 return (x01 * 2.0 - 1.0).astype(np.float32) def clip_bbox_to_crop( x_min: float, y_min: float, x_max: float, y_max: float, crop: CropBox, ) -> Optional[Tuple[float, float, float, float]]: nx1 = max(x_min - crop.x1, 0.0) ny1 = max(y_min - crop.y1, 0.0) nx2 = min(x_max - crop.x1, float(crop.w - 1)) ny2 = min(y_max - crop.y1, float(crop.h - 1)) if nx2 <= nx1 or ny2 <= ny1: return None return nx1, ny1, nx2, ny2 def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]: x1, y1, x2, y2 = box return (x1 + x2) * 0.5, (y1 + y2) * 0.5 def region_3x3(cx: float, cy: float, w: int, h: int) -> str: x_bin = min(2, max(0, int(cx / (w / 3.0)))) y_bin = min(2, max(0, int(cy / (h / 3.0)))) xs = ["left", "center", "right"] ys = ["top", "middle", "bottom"] return f"{ys[y_bin]}-{xs[x_bin]}" def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]: points = shape.get("points", None) label = shape.get("label", "target") if not points or len(points) < 2: return None pts = np.array(points, dtype=np.float32) x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min()) x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max()) area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min) return label, x_min, y_min, x_max, y_max, area def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]: shapes = annotation.get("shapes", []) best = None for shape in shapes: parsed = read_shape_bbox(shape) if parsed is None: continue label, x1, y1, x2, y2, area = parsed clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop) if clipped is None: continue c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1]) if best is None or c_area > best[2]: best = (label, clipped, c_area) if best is None: return None return best[0], best[1] def instruction_from_annotation( annotation: Dict, crop: CropBox, template: str, empty_instruction: str, ) -> str: picked = select_target_box(annotation, crop) if picked is None: return empty_instruction label, box = picked cx, cy = bbox_center(box) region = region_3x3(cx, cy, crop.w, crop.h) return template.format(label=label, region=region) def extract_text_features( instructions: Sequence[str], model_name: str, batch_size: int = 32, ) -> np.ndarray: try: import torch from transformers import DistilBertTokenizerFast except ImportError as exc: raise ImportError( "Text feature encoding requires transformers. Please install: pip install transformers" ) from exc repo_root = Path(__file__).resolve().parents[1] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) from models.text_encoder import DistilBERTTextEncoder device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device) model.eval() feats: List[np.ndarray] = [] with torch.no_grad(): for i in range(0, len(instructions), batch_size): batch = list(instructions[i:i + batch_size]) tok = tokenizer( batch, padding=True, truncation=True, max_length=32, return_tensors="pt", ) input_ids = tok["input_ids"].to(device) attention_mask = tok["attention_mask"].to(device) cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32) feats.append(cls) return np.concatenate(feats, axis=0) def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray: scale = max_v - min_v if scale <= 0: return np.zeros_like(v, dtype=np.float32) return ((v - min_v) / scale).astype(np.float32) def override_stationary_edge_instructions( instructions: List[str], motor_pos_y: np.ndarray, motor_pos_x: np.ndarray, motor_cmd_0: np.ndarray, motor_cmd_1: np.ndarray, stop_instruction: str, motion_window: int, motion_threshold: float, ) -> Tuple[List[str], int, int]: """ Override instruction text at head/tail when deviation is below threshold. Start side: use the first frame as reference and expand forward until deviation exceeds threshold. End side: use the last frame as reference and expand backward until deviation exceeds threshold. """ num = len(instructions) if num == 0: return instructions, 0, 0 # normalize to comparable scales (0~1) py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0) px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0) c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0) c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0) if num == 1: return [stop_instruction], 1, 1 # keep argument for backward CLI compatibility _ = motion_window start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32) end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32) def deviation_to_ref(i: int, ref: np.ndarray) -> float: cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32) return float(np.max(np.abs(cur - ref))) start_count = 0 for i in range(num): if deviation_to_ref(i, start_ref) <= motion_threshold: start_count += 1 else: break end_count = 0 for i in range(num - 1, -1, -1): if deviation_to_ref(i, end_ref) <= motion_threshold: end_count += 1 else: break updated = list(instructions) for i in range(start_count): updated[i] = stop_instruction for i in range(num - end_count, num): updated[i] = stop_instruction return updated, start_count, end_count def find_segment_csv(segment_dir: Path) -> Path: csvs = sorted(segment_dir.glob("*.csv")) if not csvs: raise FileNotFoundError(f"No csv file found in {segment_dir}") return csvs[0] def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]: """Convert boolean mask to closed-open segments [start, end).""" segments: List[Tuple[int, int]] = [] if mask.size == 0: return segments in_seg = False start = 0 for i, v in enumerate(mask.tolist()): if v and not in_seg: in_seg = True start = i elif (not v) and in_seg: in_seg = False segments.append((start, i)) if in_seg: segments.append((start, int(mask.size))) return segments def save_episode_plot_with_stop_segments( qpos: np.ndarray, action: np.ndarray, instructions: Sequence[str], stop_instruction: str, plot_path: Path, ) -> None: """ Save a diagnostics plot (qpos/action) and highlight stop-instruction spans. """ qpos = np.asarray(qpos) action = np.asarray(action) if qpos.ndim != 2 or action.ndim != 2: return num_ts, num_dim = qpos.shape if num_ts == 0 or num_dim == 0: return stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool) stop_segments = _mask_to_segments(stop_mask) fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True) axs_list = np.atleast_1d(axs).reshape(-1).tolist() for dim_idx in range(num_dim): ax = axs_list[dim_idx] ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4) ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2) for seg_idx, (st, ed) in enumerate(stop_segments): ax.axvspan(st, ed - 1, color='orange', alpha=0.2, label='stop instruction' if seg_idx == 0 else None) ax.set_ylabel(f'dim {dim_idx}') ax.legend(loc='best') ax.grid(alpha=0.25, linestyle='--') axs_list[-1].set_xlabel('timestep') fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02) fig.tight_layout() fig.savefig(str(plot_path), dpi=140) plt.close(fig) def main() -> None: args = parse_args() segment_dir = Path(args.segment_dir).resolve() output_dir = Path(args.output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) frames_dir = segment_dir / "frames" if not frames_dir.exists(): raise FileNotFoundError(f"frames dir not found: {frames_dir}") csv_path = find_segment_csv(segment_dir) csv_rows = load_csv_rows(csv_path) if len(csv_rows) == 0: raise ValueError(f"CSV has no rows: {csv_path}") crop = CropBox(*args.crop) resize_w, resize_h = int(args.resize[0]), int(args.resize[1]) json_files = sorted_frame_jsons(frames_dir) if not json_files: raise FileNotFoundError(f"No frame json found in: {frames_dir}") max_aligned = min(len(json_files), len(csv_rows)) num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned) if num <= 0: raise ValueError("No aligned frames available.") images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8) qpos = np.zeros((num, 2), dtype=np.float32) # [y, x] action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)] instructions: List[str] = [] motor_pos_y_series = np.zeros((num,), dtype=np.float32) motor_pos_x_series = np.zeros((num,), dtype=np.float32) motor_cmd_0_series = np.zeros((num,), dtype=np.float32) motor_cmd_1_series = np.zeros((num,), dtype=np.float32) y_min, y_max = 8000.0, 18884.0 x_min, x_max = 7000.0, 17384.0 cmd_min, cmd_max = 0.0, 65535.0 for i in range(num): json_path = json_files[i] with json_path.open("r", encoding="utf-8") as f: ann = json.load(f) image_path = frames_dir / ann["imagePath"] if not image_path.exists(): alt = json_path.with_suffix(".jpg") if alt.exists(): image_path = alt else: raise FileNotFoundError(f"Image not found for {json_path.name}") img = Image.open(image_path).convert("RGB") img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2)) img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR) images[i] = np.asarray(img_resize, dtype=np.uint8) row = csv_rows[i] motor_pos_y = float(row["motor_pos_y"]) motor_pos_x = float(row["motor_pos_x"]) motor_cmd_0 = float(row["motor_command_0"]) motor_cmd_1 = float(row["motor_command_1"]) motor_pos_y_series[i] = motor_pos_y motor_pos_x_series[i] = motor_pos_x motor_cmd_0_series[i] = motor_cmd_0 motor_cmd_1_series[i] = motor_cmd_1 qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0] qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0] action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] ins = instruction_from_annotation( ann, crop, args.instruction_template, args.instruction_empty, ) instructions.append(ins) start_stop_count = 0 end_stop_count = 0 if not args.disable_stop_override: instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions( instructions=instructions, motor_pos_y=motor_pos_y_series, motor_pos_x=motor_pos_x_series, motor_cmd_0=motor_cmd_0_series, motor_cmd_1=motor_cmd_1_series, stop_instruction=args.stop_instruction, motion_window=args.motion_window, motion_threshold=args.motion_threshold, ) text_features = None if args.encode_text_features: text_features = extract_text_features( instructions, model_name=args.text_model_name, batch_size=args.text_batch_size, ) out_path = output_dir / f"episode_{args.episode_idx}.hdf5" dt = 1.0 / 30.0 with h5py.File(out_path, "w") as root: root.attrs["sim"] = False root.attrs["source_segment"] = str(segment_dir) root.attrs["frame_rate"] = 30 root.attrs["dt"] = dt root.attrs["state_norm_mode"] = args.state_norm root.attrs["action_norm_mode"] = args.action_norm root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]" root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]" root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32) obs = root.create_group("observations") obs.create_dataset("qpos", data=qpos, dtype=np.float32) images_group = obs.create_group("images") images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8) root.create_dataset("action", data=action, dtype=np.float32) str_dtype = h5py.string_dtype(encoding="utf-8") root.create_dataset( "instruction_timestep", shape=(num,), dtype=str_dtype, data=np.asarray(instructions, dtype=object), ) root.create_dataset( "instruction", shape=(), dtype=str_dtype, data=instructions[0] if len(instructions) > 0 else "", ) if text_features is not None: root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32) root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32) print(f"Saved: {out_path}") print(f"Frames used: {num}") print(f"Image shape: {images.shape}") print(f"qpos shape: {qpos.shape}, action shape: {action.shape}") if not args.disable_stop_override: print( f"stationary override: head={start_stop_count}, tail={end_stop_count}, " f"mode=endpoint_reference, threshold={args.motion_threshold}, " f"instruction='{args.stop_instruction}'" ) if text_features is not None: print(f"instruction_features_timestep shape: {text_features.shape}") # Save a same-basename plot next to the generated hdf5 plot_path = out_path.with_suffix('.png') save_episode_plot_with_stop_segments( qpos=qpos, action=action, instructions=instructions, stop_instruction=args.stop_instruction, plot_path=plot_path, ) print(f"Saved episode plot to: {plot_path}") if __name__ == "__main__": main()