From b701d939c2ef4737bfba59902e3b1b5c3ae99055 Mon Sep 17 00:00:00 2001 From: JC6123 Date: Tue, 17 Feb 2026 22:20:25 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=82=E6=97=B6=E5=8F=AF=E4=BB=A5=E7=94=9F?= =?UTF-8?q?=E6=88=90hdf5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- build_endoscope_act_dataset.py | 412 +++++++++++++++++++++++++++++++++ models/__init__.py | 0 models/text_encoder.py | 31 +++ 3 files changed, 443 insertions(+) create mode 100644 build_endoscope_act_dataset.py create mode 100644 models/__init__.py create mode 100644 models/text_encoder.py diff --git a/build_endoscope_act_dataset.py b/build_endoscope_act_dataset.py new file mode 100644 index 0000000..12cc0c6 --- /dev/null +++ b/build_endoscope_act_dataset.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +import argparse +import csv +import json +import os +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import h5py +import numpy as np +from PIL import Image + +RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR") + + +@dataclass +class CropBox: + x1: int + y1: int + x2: int + y2: int + + @property + def w(self) -> int: + return self.x2 - self.x1 + + @property + def h(self) -> int: + return self.y2 - self.y1 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)." + ) + ) + parser.add_argument( + "--segment_dir", + type=str, + required=True, + help="Path to one raw segment, e.g. data/follow_seg_001", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output dir for episode_*.hdf5", + ) + parser.add_argument( + "--episode_idx", + type=int, + default=0, + help="Output episode index (default: 0)", + ) + parser.add_argument( + "--max_frames", + type=int, + default=-1, + help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)", + ) + parser.add_argument( + "--camera_name", + type=str, + default="top", + help="Camera name written to /observations/images/", + ) + parser.add_argument( + "--crop", + type=int, + nargs=4, + default=[733, 30, 1754, 1051], + metavar=("X1", "Y1", "X2", "Y2"), + help="Crop box in original image coordinates", + ) + parser.add_argument( + "--resize", + type=int, + nargs=2, + default=[224, 224], + metavar=("W", "H"), + help="Output image size", + ) + parser.add_argument( + "--instruction_template", + type=str, + default="Move toward the {label} at {region}.", + help="Template for per-frame instruction", + ) + parser.add_argument( + "--instruction_empty", + type=str, + default="No target visible.", + help="Instruction when no valid target after crop", + ) + parser.add_argument( + "--state_norm", + choices=["minus1_1", "0_1", "raw"], + default="minus1_1", + help="Normalization for qpos (motor_pos_y, motor_pos_x)", + ) + parser.add_argument( + "--action_norm", + choices=["minus1_1", "0_1", "raw"], + default="minus1_1", + help="Normalization for action (motor_command_0, motor_command_1)", + ) + parser.add_argument( + "--encode_text_features", + action="store_true", + help="Encode per-frame instruction into 768-dim DistilBERT features", + ) + parser.add_argument( + "--text_model_name", + type=str, + default="distilbert-base-uncased", + help="HuggingFace model name for DistilBERT", + ) + parser.add_argument( + "--text_batch_size", + type=int, + default=32, + help="Batch size for text feature extraction", + ) + return parser.parse_args() + + +def sorted_frame_jsons(frames_dir: Path) -> List[Path]: + json_files = list(frames_dir.glob("*.json")) + + def key_fn(p: Path) -> Tuple[int, str]: + m = re.search(r"frame_(\d+)", p.name) + idx = int(m.group(1)) if m else 10**9 + return idx, p.name + + json_files.sort(key=key_fn) + return json_files + + +def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]: + with csv_path.open("r", encoding="utf-8") as f: + reader = csv.DictReader(f) + return list(reader) + + +def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray: + if mode == "raw": + return x.astype(np.float32) + x01 = (x - min_v) / (max_v - min_v) + if mode == "0_1": + return x01.astype(np.float32) + # minus1_1 + return (x01 * 2.0 - 1.0).astype(np.float32) + + +def clip_bbox_to_crop( + x_min: float, + y_min: float, + x_max: float, + y_max: float, + crop: CropBox, +) -> Optional[Tuple[float, float, float, float]]: + nx1 = max(x_min - crop.x1, 0.0) + ny1 = max(y_min - crop.y1, 0.0) + nx2 = min(x_max - crop.x1, float(crop.w - 1)) + ny2 = min(y_max - crop.y1, float(crop.h - 1)) + if nx2 <= nx1 or ny2 <= ny1: + return None + return nx1, ny1, nx2, ny2 + + +def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]: + x1, y1, x2, y2 = box + return (x1 + x2) * 0.5, (y1 + y2) * 0.5 + + +def region_3x3(cx: float, cy: float, w: int, h: int) -> str: + x_bin = min(2, max(0, int(cx / (w / 3.0)))) + y_bin = min(2, max(0, int(cy / (h / 3.0)))) + xs = ["left", "center", "right"] + ys = ["top", "middle", "bottom"] + return f"{ys[y_bin]}-{xs[x_bin]}" + + +def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]: + points = shape.get("points", None) + label = shape.get("label", "target") + if not points or len(points) < 2: + return None + pts = np.array(points, dtype=np.float32) + x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min()) + x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max()) + area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min) + return label, x_min, y_min, x_max, y_max, area + + +def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]: + shapes = annotation.get("shapes", []) + best = None + for shape in shapes: + parsed = read_shape_bbox(shape) + if parsed is None: + continue + label, x1, y1, x2, y2, area = parsed + clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop) + if clipped is None: + continue + c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1]) + if best is None or c_area > best[2]: + best = (label, clipped, c_area) + if best is None: + return None + return best[0], best[1] + + +def instruction_from_annotation( + annotation: Dict, + crop: CropBox, + template: str, + empty_instruction: str, +) -> str: + picked = select_target_box(annotation, crop) + if picked is None: + return empty_instruction + label, box = picked + cx, cy = bbox_center(box) + region = region_3x3(cx, cy, crop.w, crop.h) + return template.format(label=label, region=region) + + +def extract_text_features( + instructions: Sequence[str], + model_name: str, + batch_size: int = 32, +) -> np.ndarray: + try: + import torch + from transformers import DistilBertTokenizerFast + except ImportError as exc: + raise ImportError( + "Text feature encoding requires transformers. Please install: pip install transformers" + ) from exc + + repo_root = Path(__file__).resolve().parents[1] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + from models.text_encoder import DistilBERTTextEncoder + + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) + model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device) + model.eval() + + feats: List[np.ndarray] = [] + with torch.no_grad(): + for i in range(0, len(instructions), batch_size): + batch = list(instructions[i:i + batch_size]) + tok = tokenizer( + batch, + padding=True, + truncation=True, + max_length=32, + return_tensors="pt", + ) + input_ids = tok["input_ids"].to(device) + attention_mask = tok["attention_mask"].to(device) + cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32) + feats.append(cls) + return np.concatenate(feats, axis=0) + + +def find_segment_csv(segment_dir: Path) -> Path: + csvs = sorted(segment_dir.glob("*.csv")) + if not csvs: + raise FileNotFoundError(f"No csv file found in {segment_dir}") + return csvs[0] + + +def main() -> None: + args = parse_args() + + segment_dir = Path(args.segment_dir).resolve() + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + frames_dir = segment_dir / "frames" + if not frames_dir.exists(): + raise FileNotFoundError(f"frames dir not found: {frames_dir}") + + csv_path = find_segment_csv(segment_dir) + csv_rows = load_csv_rows(csv_path) + if len(csv_rows) == 0: + raise ValueError(f"CSV has no rows: {csv_path}") + + crop = CropBox(*args.crop) + resize_w, resize_h = int(args.resize[0]), int(args.resize[1]) + + json_files = sorted_frame_jsons(frames_dir) + if not json_files: + raise FileNotFoundError(f"No frame json found in: {frames_dir}") + + max_aligned = min(len(json_files), len(csv_rows)) + num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned) + if num <= 0: + raise ValueError("No aligned frames available.") + + images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8) + qpos = np.zeros((num, 2), dtype=np.float32) # [y, x] + action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)] + instructions: List[str] = [] + + y_min, y_max = 8000.0, 18884.0 + x_min, x_max = 7000.0, 17384.0 + cmd_min, cmd_max = 0.0, 65535.0 + + for i in range(num): + json_path = json_files[i] + with json_path.open("r", encoding="utf-8") as f: + ann = json.load(f) + + image_path = frames_dir / ann["imagePath"] + if not image_path.exists(): + alt = json_path.with_suffix(".jpg") + if alt.exists(): + image_path = alt + else: + raise FileNotFoundError(f"Image not found for {json_path.name}") + + img = Image.open(image_path).convert("RGB") + img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2)) + img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR) + images[i] = np.asarray(img_resize, dtype=np.uint8) + + row = csv_rows[i] + motor_pos_y = float(row["motor_pos_y"]) + motor_pos_x = float(row["motor_pos_x"]) + motor_cmd_0 = float(row["motor_command_0"]) + motor_cmd_1 = float(row["motor_command_1"]) + + qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0] + qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0] + action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] + action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] + + ins = instruction_from_annotation( + ann, + crop, + args.instruction_template, + args.instruction_empty, + ) + instructions.append(ins) + + text_features = None + if args.encode_text_features: + text_features = extract_text_features( + instructions, + model_name=args.text_model_name, + batch_size=args.text_batch_size, + ) + + out_path = output_dir / f"episode_{args.episode_idx}.hdf5" + dt = 1.0 / 30.0 + + with h5py.File(out_path, "w") as root: + root.attrs["sim"] = False + root.attrs["source_segment"] = str(segment_dir) + root.attrs["frame_rate"] = 30 + root.attrs["dt"] = dt + root.attrs["state_norm_mode"] = args.state_norm + root.attrs["action_norm_mode"] = args.action_norm + root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]" + root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]" + root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32) + + obs = root.create_group("observations") + obs.create_dataset("qpos", data=qpos, dtype=np.float32) + images_group = obs.create_group("images") + images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8) + + root.create_dataset("action", data=action, dtype=np.float32) + + str_dtype = h5py.string_dtype(encoding="utf-8") + root.create_dataset( + "instruction_timestep", + shape=(num,), + dtype=str_dtype, + data=np.asarray(instructions, dtype=object), + ) + root.create_dataset( + "instruction", + shape=(), + dtype=str_dtype, + data=instructions[0] if len(instructions) > 0 else "", + ) + + if text_features is not None: + root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32) + root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32) + + print(f"Saved: {out_path}") + print(f"Frames used: {num}") + print(f"Image shape: {images.shape}") + print(f"qpos shape: {qpos.shape}, action shape: {action.shape}") + if text_features is not None: + print(f"instruction_features_timestep shape: {text_features.shape}") + + +if __name__ == "__main__": + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/text_encoder.py b/models/text_encoder.py new file mode 100644 index 0000000..fabacab --- /dev/null +++ b/models/text_encoder.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +class DistilBERTTextEncoder(nn.Module): + def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True): + super().__init__() + try: + from transformers import DistilBertModel + except ImportError as exc: + raise ImportError( + 'transformers is required for DistilBERT text encoding. ' + 'Install it with: pip install transformers' + ) from exc + + self.encoder = DistilBertModel.from_pretrained(model_name) + self.output_dim = output_dim + self.freeze = freeze + + if self.freeze: + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.eval() + + def forward(self, input_ids, attention_mask): + if self.freeze: + self.encoder.eval() + outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + # DistilBERT has no pooled output; use [CLS] token embedding + cls_feature = outputs.last_hidden_state[:, 0, :] + return cls_feature