#!/usr/bin/env python3 import argparse import csv import json import os import re import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import h5py import numpy as np from PIL import Image RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR") @dataclass class CropBox: x1: int y1: int x2: int y2: int @property def w(self) -> int: return self.x2 - self.x1 @property def h(self) -> int: return self.y2 - self.y1 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)." ) ) parser.add_argument( "--segment_dir", type=str, required=True, help="Path to one raw segment, e.g. data/follow_seg_001", ) parser.add_argument( "--output_dir", type=str, required=True, help="Output dir for episode_*.hdf5", ) parser.add_argument( "--episode_idx", type=int, default=0, help="Output episode index (default: 0)", ) parser.add_argument( "--max_frames", type=int, default=-1, help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)", ) parser.add_argument( "--camera_name", type=str, default="top", help="Camera name written to /observations/images/", ) parser.add_argument( "--crop", type=int, nargs=4, default=[733, 30, 1754, 1051], metavar=("X1", "Y1", "X2", "Y2"), help="Crop box in original image coordinates", ) parser.add_argument( "--resize", type=int, nargs=2, default=[224, 224], metavar=("W", "H"), help="Output image size", ) parser.add_argument( "--instruction_template", type=str, default="Move toward the {label} at {region}.", help="Template for per-frame instruction", ) parser.add_argument( "--instruction_empty", type=str, default="No target visible.", help="Instruction when no valid target after crop", ) parser.add_argument( "--state_norm", choices=["minus1_1", "0_1", "raw"], default="minus1_1", help="Normalization for qpos (motor_pos_y, motor_pos_x)", ) parser.add_argument( "--action_norm", choices=["minus1_1", "0_1", "raw"], default="minus1_1", help="Normalization for action (motor_command_0, motor_command_1)", ) parser.add_argument( "--encode_text_features", action="store_true", help="Encode per-frame instruction into 768-dim DistilBERT features", ) parser.add_argument( "--text_model_name", type=str, default="distilbert-base-uncased", help="HuggingFace model name for DistilBERT", ) parser.add_argument( "--text_batch_size", type=int, default=32, help="Batch size for text feature extraction", ) return parser.parse_args() def sorted_frame_jsons(frames_dir: Path) -> List[Path]: json_files = list(frames_dir.glob("*.json")) def key_fn(p: Path) -> Tuple[int, str]: m = re.search(r"frame_(\d+)", p.name) idx = int(m.group(1)) if m else 10**9 return idx, p.name json_files.sort(key=key_fn) return json_files def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]: with csv_path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) return list(reader) def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray: if mode == "raw": return x.astype(np.float32) x01 = (x - min_v) / (max_v - min_v) if mode == "0_1": return x01.astype(np.float32) # minus1_1 return (x01 * 2.0 - 1.0).astype(np.float32) def clip_bbox_to_crop( x_min: float, y_min: float, x_max: float, y_max: float, crop: CropBox, ) -> Optional[Tuple[float, float, float, float]]: nx1 = max(x_min - crop.x1, 0.0) ny1 = max(y_min - crop.y1, 0.0) nx2 = min(x_max - crop.x1, float(crop.w - 1)) ny2 = min(y_max - crop.y1, float(crop.h - 1)) if nx2 <= nx1 or ny2 <= ny1: return None return nx1, ny1, nx2, ny2 def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]: x1, y1, x2, y2 = box return (x1 + x2) * 0.5, (y1 + y2) * 0.5 def region_3x3(cx: float, cy: float, w: int, h: int) -> str: x_bin = min(2, max(0, int(cx / (w / 3.0)))) y_bin = min(2, max(0, int(cy / (h / 3.0)))) xs = ["left", "center", "right"] ys = ["top", "middle", "bottom"] return f"{ys[y_bin]}-{xs[x_bin]}" def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]: points = shape.get("points", None) label = shape.get("label", "target") if not points or len(points) < 2: return None pts = np.array(points, dtype=np.float32) x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min()) x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max()) area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min) return label, x_min, y_min, x_max, y_max, area def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]: shapes = annotation.get("shapes", []) best = None for shape in shapes: parsed = read_shape_bbox(shape) if parsed is None: continue label, x1, y1, x2, y2, area = parsed clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop) if clipped is None: continue c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1]) if best is None or c_area > best[2]: best = (label, clipped, c_area) if best is None: return None return best[0], best[1] def instruction_from_annotation( annotation: Dict, crop: CropBox, template: str, empty_instruction: str, ) -> str: picked = select_target_box(annotation, crop) if picked is None: return empty_instruction label, box = picked cx, cy = bbox_center(box) region = region_3x3(cx, cy, crop.w, crop.h) return template.format(label=label, region=region) def extract_text_features( instructions: Sequence[str], model_name: str, batch_size: int = 32, ) -> np.ndarray: try: import torch from transformers import DistilBertTokenizerFast except ImportError as exc: raise ImportError( "Text feature encoding requires transformers. Please install: pip install transformers" ) from exc repo_root = Path(__file__).resolve().parents[1] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) from models.text_encoder import DistilBERTTextEncoder device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device) model.eval() feats: List[np.ndarray] = [] with torch.no_grad(): for i in range(0, len(instructions), batch_size): batch = list(instructions[i:i + batch_size]) tok = tokenizer( batch, padding=True, truncation=True, max_length=32, return_tensors="pt", ) input_ids = tok["input_ids"].to(device) attention_mask = tok["attention_mask"].to(device) cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32) feats.append(cls) return np.concatenate(feats, axis=0) def find_segment_csv(segment_dir: Path) -> Path: csvs = sorted(segment_dir.glob("*.csv")) if not csvs: raise FileNotFoundError(f"No csv file found in {segment_dir}") return csvs[0] def main() -> None: args = parse_args() segment_dir = Path(args.segment_dir).resolve() output_dir = Path(args.output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) frames_dir = segment_dir / "frames" if not frames_dir.exists(): raise FileNotFoundError(f"frames dir not found: {frames_dir}") csv_path = find_segment_csv(segment_dir) csv_rows = load_csv_rows(csv_path) if len(csv_rows) == 0: raise ValueError(f"CSV has no rows: {csv_path}") crop = CropBox(*args.crop) resize_w, resize_h = int(args.resize[0]), int(args.resize[1]) json_files = sorted_frame_jsons(frames_dir) if not json_files: raise FileNotFoundError(f"No frame json found in: {frames_dir}") max_aligned = min(len(json_files), len(csv_rows)) num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned) if num <= 0: raise ValueError("No aligned frames available.") images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8) qpos = np.zeros((num, 2), dtype=np.float32) # [y, x] action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)] instructions: List[str] = [] y_min, y_max = 8000.0, 18884.0 x_min, x_max = 7000.0, 17384.0 cmd_min, cmd_max = 0.0, 65535.0 for i in range(num): json_path = json_files[i] with json_path.open("r", encoding="utf-8") as f: ann = json.load(f) image_path = frames_dir / ann["imagePath"] if not image_path.exists(): alt = json_path.with_suffix(".jpg") if alt.exists(): image_path = alt else: raise FileNotFoundError(f"Image not found for {json_path.name}") img = Image.open(image_path).convert("RGB") img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2)) img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR) images[i] = np.asarray(img_resize, dtype=np.uint8) row = csv_rows[i] motor_pos_y = float(row["motor_pos_y"]) motor_pos_x = float(row["motor_pos_x"]) motor_cmd_0 = float(row["motor_command_0"]) motor_cmd_1 = float(row["motor_command_1"]) qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0] qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0] action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] ins = instruction_from_annotation( ann, crop, args.instruction_template, args.instruction_empty, ) instructions.append(ins) text_features = None if args.encode_text_features: text_features = extract_text_features( instructions, model_name=args.text_model_name, batch_size=args.text_batch_size, ) out_path = output_dir / f"episode_{args.episode_idx}.hdf5" dt = 1.0 / 30.0 with h5py.File(out_path, "w") as root: root.attrs["sim"] = False root.attrs["source_segment"] = str(segment_dir) root.attrs["frame_rate"] = 30 root.attrs["dt"] = dt root.attrs["state_norm_mode"] = args.state_norm root.attrs["action_norm_mode"] = args.action_norm root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]" root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]" root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32) obs = root.create_group("observations") obs.create_dataset("qpos", data=qpos, dtype=np.float32) images_group = obs.create_group("images") images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8) root.create_dataset("action", data=action, dtype=np.float32) str_dtype = h5py.string_dtype(encoding="utf-8") root.create_dataset( "instruction_timestep", shape=(num,), dtype=str_dtype, data=np.asarray(instructions, dtype=object), ) root.create_dataset( "instruction", shape=(), dtype=str_dtype, data=instructions[0] if len(instructions) > 0 else "", ) if text_features is not None: root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32) root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32) print(f"Saved: {out_path}") print(f"Frames used: {num}") print(f"Image shape: {images.shape}") print(f"qpos shape: {qpos.shape}, action shape: {action.shape}") if text_features is not None: print(f"instruction_features_timestep shape: {text_features.shape}") if __name__ == "__main__": main()