Files
aloha/build_endoscope_act_dataset.py
2026-02-20 14:13:25 +08:00

681 lines
22 KiB
Python

#!/usr/bin/env python3
import argparse
import csv
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
@dataclass
class CropBox:
x1: int
y1: int
x2: int
y2: int
@property
def w(self) -> int:
return self.x2 - self.x1
@property
def h(self) -> int:
return self.y2 - self.y1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
)
)
parser.add_argument(
"--segment_dir",
type=str,
required=True,
help="Path to one raw segment, e.g. data/follow_seg_001",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output dir for episode_*.hdf5",
)
parser.add_argument(
"--episode_idx",
type=int,
default=0,
help="Output episode index (default: 0)",
)
parser.add_argument(
"--max_frames",
type=int,
default=-1,
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
)
parser.add_argument(
"--camera_name",
type=str,
default="top",
help="Camera name written to /observations/images/<camera_name>",
)
parser.add_argument(
"--crop",
type=int,
nargs=4,
default=[733, 30, 1754, 1051],
metavar=("X1", "Y1", "X2", "Y2"),
help="Crop box in original image coordinates",
)
parser.add_argument(
"--resize",
type=int,
nargs=2,
default=[224, 224],
metavar=("W", "H"),
help="Output image size",
)
parser.add_argument(
"--instruction_template",
type=str,
default="Move toward the {label} at {region}.",
help="Template for per-frame instruction",
)
parser.add_argument(
"--instruction_empty",
type=str,
default="No target visible.",
help="Instruction when no valid target after crop",
)
parser.add_argument(
"--stop_instruction",
type=str,
default="Stop move.",
help="Instruction used for stationary head/tail frames",
)
parser.add_argument(
"--state_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
)
parser.add_argument(
"--action_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for action (motor_command_0, motor_command_1)",
)
parser.add_argument(
"--encode_text_features",
action="store_true",
help="Encode per-frame instruction into 768-dim DistilBERT features",
)
parser.add_argument(
"--text_model_name",
type=str,
default="distilbert-base-uncased",
help="HuggingFace model name for DistilBERT",
)
parser.add_argument(
"--text_batch_size",
type=int,
default=32,
help="Batch size for text feature extraction",
)
parser.add_argument(
"--motion_window",
type=int,
default=5,
help="Sliding window size used for stationary detection at beginning/end",
)
parser.add_argument(
"--motion_threshold",
type=float,
default=0.002,
help=(
"Motion threshold in normalized delta space (0~1). "
"Smaller value means stricter stationary detection"
),
)
parser.add_argument(
"--disable_stop_override",
action="store_true",
help="Disable head/tail stationary instruction override",
)
parser.add_argument(
"--trim_stationary_edges",
action="store_true",
help="Trim stationary head/tail segments and keep only the middle moving segment",
)
parser.add_argument(
"--no_text_instruction",
action="store_true",
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
)
return parser.parse_args()
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
json_files = list(frames_dir.glob("*.json"))
def key_fn(p: Path) -> Tuple[int, str]:
m = re.search(r"frame_(\d+)", p.name)
idx = int(m.group(1)) if m else 10**9
return idx, p.name
json_files.sort(key=key_fn)
return json_files
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
return list(reader)
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
if mode == "raw":
return x.astype(np.float32)
x01 = (x - min_v) / (max_v - min_v)
if mode == "0_1":
return x01.astype(np.float32)
# minus1_1
return (x01 * 2.0 - 1.0).astype(np.float32)
def clip_bbox_to_crop(
x_min: float,
y_min: float,
x_max: float,
y_max: float,
crop: CropBox,
) -> Optional[Tuple[float, float, float, float]]:
nx1 = max(x_min - crop.x1, 0.0)
ny1 = max(y_min - crop.y1, 0.0)
nx2 = min(x_max - crop.x1, float(crop.w - 1))
ny2 = min(y_max - crop.y1, float(crop.h - 1))
if nx2 <= nx1 or ny2 <= ny1:
return None
return nx1, ny1, nx2, ny2
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
x1, y1, x2, y2 = box
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
x_bin = min(2, max(0, int(cx / (w / 3.0))))
y_bin = min(2, max(0, int(cy / (h / 3.0))))
xs = ["left", "center", "right"]
ys = ["top", "middle", "bottom"]
return f"{ys[y_bin]}-{xs[x_bin]}"
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
points = shape.get("points", None)
label = shape.get("label", "target")
if not points or len(points) < 2:
return None
pts = np.array(points, dtype=np.float32)
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
return label, x_min, y_min, x_max, y_max, area
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
shapes = annotation.get("shapes", [])
best = None
for shape in shapes:
parsed = read_shape_bbox(shape)
if parsed is None:
continue
label, x1, y1, x2, y2, area = parsed
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
if clipped is None:
continue
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
if best is None or c_area > best[2]:
best = (label, clipped, c_area)
if best is None:
return None
return best[0], best[1]
def instruction_from_annotation(
annotation: Dict,
crop: CropBox,
template: str,
empty_instruction: str,
) -> str:
picked = select_target_box(annotation, crop)
if picked is None:
return empty_instruction
label, box = picked
cx, cy = bbox_center(box)
region = region_3x3(cx, cy, crop.w, crop.h)
return template.format(label=label, region=region)
def extract_text_features(
instructions: Sequence[str],
model_name: str,
batch_size: int = 32,
) -> np.ndarray:
try:
import torch
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
"Text feature encoding requires transformers. Please install: pip install transformers"
) from exc
repo_root = Path(__file__).resolve().parents[1]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from models.text_encoder import DistilBERTTextEncoder
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
model.eval()
feats: List[np.ndarray] = []
with torch.no_grad():
for i in range(0, len(instructions), batch_size):
batch = list(instructions[i:i + batch_size])
tok = tokenizer(
batch,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt",
)
input_ids = tok["input_ids"].to(device)
attention_mask = tok["attention_mask"].to(device)
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
feats.append(cls)
return np.concatenate(feats, axis=0)
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
scale = max_v - min_v
if scale <= 0:
return np.zeros_like(v, dtype=np.float32)
return ((v - min_v) / scale).astype(np.float32)
def override_stationary_edge_instructions(
instructions: List[str],
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motor_cmd_0: np.ndarray,
motor_cmd_1: np.ndarray,
stop_instruction: str,
motion_window: int,
motion_threshold: float,
) -> Tuple[List[str], int, int]:
"""
Override instruction text at head/tail using qpos velocity.
Start side: once qpos speed is above threshold for consecutive frames,
stop applying stop_instruction from that point onward.
End side: similarly scan backward from the end.
"""
num = len(instructions)
if num == 0:
return instructions, 0, 0
start_count, end_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y,
motor_pos_x=motor_pos_x,
motion_window=motion_window,
motion_threshold=motion_threshold,
)
updated = list(instructions)
for i in range(start_count):
updated[i] = stop_instruction
for i in range(num - end_count, num):
updated[i] = stop_instruction
return updated, start_count, end_count
def detect_stationary_edge_counts_from_qpos(
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motion_window: int,
motion_threshold: float,
) -> Tuple[int, int]:
"""Return stationary frame counts on head and tail using qpos velocity rule."""
num = int(len(motor_pos_y))
if num == 0:
return 0, 0
if num == 1:
return 1, 1
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
consecutive = max(1, int(motion_window))
dt = 1.0 / 30.0
frame_speed = np.zeros((num,), dtype=np.float32)
dy = np.abs(np.diff(py)) / dt
dx = np.abs(np.diff(px)) / dt
frame_speed[1:] = np.maximum(dy, dx)
high_run = 0
start_count = num
for i in range(1, num):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
start_count = i - consecutive + 1
break
else:
high_run = 0
high_run = 0
end_count = num
for i in range(num - 1, 0, -1):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
tail_start = i + consecutive
end_count = max(0, num - tail_start)
break
else:
high_run = 0
return start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv"))
if not csvs:
raise FileNotFoundError(f"No csv file found in {segment_dir}")
return csvs[0]
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
"""Convert boolean mask to closed-open segments [start, end)."""
segments: List[Tuple[int, int]] = []
if mask.size == 0:
return segments
in_seg = False
start = 0
for i, v in enumerate(mask.tolist()):
if v and not in_seg:
in_seg = True
start = i
elif (not v) and in_seg:
in_seg = False
segments.append((start, i))
if in_seg:
segments.append((start, int(mask.size)))
return segments
def save_episode_plot_with_stop_segments(
qpos: np.ndarray,
action: np.ndarray,
instructions: Sequence[str],
stop_instruction: str,
plot_path: Path,
) -> None:
"""
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
"""
qpos = np.asarray(qpos)
action = np.asarray(action)
if qpos.ndim != 2 or action.ndim != 2:
return
num_ts, num_dim = qpos.shape
if num_ts == 0 or num_dim == 0:
return
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
stop_segments = _mask_to_segments(stop_mask)
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
for dim_idx in range(num_dim):
ax = axs_list[dim_idx]
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
for seg_idx, (st, ed) in enumerate(stop_segments):
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
label='stop instruction' if seg_idx == 0 else None)
ax.set_ylabel(f'dim {dim_idx}')
ax.legend(loc='best')
ax.grid(alpha=0.25, linestyle='--')
axs_list[-1].set_xlabel('timestep')
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
fig.tight_layout()
fig.savefig(str(plot_path), dpi=140)
plt.close(fig)
def main() -> None:
args = parse_args()
if args.no_text_instruction and args.encode_text_features:
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
segment_dir = Path(args.segment_dir).resolve()
output_dir = Path(args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
frames_dir = segment_dir / "frames"
if not frames_dir.exists():
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
csv_path = find_segment_csv(segment_dir)
csv_rows = load_csv_rows(csv_path)
if len(csv_rows) == 0:
raise ValueError(f"CSV has no rows: {csv_path}")
crop = CropBox(*args.crop)
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
json_files = sorted_frame_jsons(frames_dir)
if not json_files:
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
max_aligned = min(len(json_files), len(csv_rows))
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
if num <= 0:
raise ValueError("No aligned frames available.")
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
instructions: List[str] = []
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
y_min, y_max = 8000.0, 18884.0
x_min, x_max = 7000.0, 17384.0
cmd_min, cmd_max = 0.0, 65535.0
for i in range(num):
json_path = json_files[i]
with json_path.open("r", encoding="utf-8") as f:
ann = json.load(f)
image_path = frames_dir / ann["imagePath"]
if not image_path.exists():
alt = json_path.with_suffix(".jpg")
if alt.exists():
image_path = alt
else:
raise FileNotFoundError(f"Image not found for {json_path.name}")
img = Image.open(image_path).convert("RGB")
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
images[i] = np.asarray(img_resize, dtype=np.uint8)
row = csv_rows[i]
motor_pos_y = float(row["motor_pos_y"])
motor_pos_x = float(row["motor_pos_x"])
motor_cmd_0 = float(row["motor_command_0"])
motor_cmd_1 = float(row["motor_command_1"])
motor_pos_y_series[i] = motor_pos_y
motor_pos_x_series[i] = motor_pos_x
motor_cmd_0_series[i] = motor_cmd_0
motor_cmd_1_series[i] = motor_cmd_1
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
if not args.no_text_instruction:
ins = instruction_from_annotation(
ann,
crop,
args.instruction_template,
args.instruction_empty,
)
instructions.append(ins)
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
if args.trim_stationary_edges:
keep_start = int(start_stop_count)
keep_end = int(num - end_stop_count)
if keep_end <= keep_start:
raise ValueError(
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
f'Consider lowering --motion_threshold or --motion_window.'
)
images = images[keep_start:keep_end]
qpos = qpos[keep_start:keep_end]
action = action[keep_start:keep_end]
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
if not args.no_text_instruction:
instructions = instructions[keep_start:keep_end]
print(
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
f'kept={keep_end - keep_start}'
)
# After trimming, full kept segment is the moving region.
start_stop_count, end_stop_count = 0, 0
if (not args.disable_stop_override) and (not args.no_text_instruction):
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions,
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motor_cmd_0=motor_cmd_0_series,
motor_cmd_1=motor_cmd_1_series,
stop_instruction=args.stop_instruction,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
text_features = None
if args.encode_text_features:
text_features = extract_text_features(
instructions,
model_name=args.text_model_name,
batch_size=args.text_batch_size,
)
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
dt = 1.0 / 30.0
with h5py.File(out_path, "w") as root:
root.attrs["sim"] = False
root.attrs["source_segment"] = str(segment_dir)
root.attrs["frame_rate"] = 30
root.attrs["dt"] = dt
root.attrs["state_norm_mode"] = args.state_norm
root.attrs["action_norm_mode"] = args.action_norm
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
obs = root.create_group("observations")
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
images_group = obs.create_group("images")
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
root.create_dataset("action", data=action, dtype=np.float32)
if not args.no_text_instruction:
str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset(
"instruction_timestep",
shape=(len(instructions),),
dtype=str_dtype,
data=np.asarray(instructions, dtype=object),
)
root.create_dataset(
"instruction",
shape=(),
dtype=str_dtype,
data=instructions[0] if len(instructions) > 0 else "",
)
if text_features is not None:
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
print(f"Saved: {out_path}")
print(f"Frames used: {num}")
print(f"Image shape: {images.shape}")
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
if not args.disable_stop_override:
print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
f"threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'"
)
if text_features is not None:
print(f"instruction_features_timestep shape: {text_features.shape}")
# Save a same-basename plot next to the generated hdf5
plot_path = out_path.with_suffix('.png')
save_episode_plot_with_stop_segments(
qpos=qpos,
action=action,
instructions=instructions,
stop_instruction=args.stop_instruction,
plot_path=plot_path,
)
print(f"Saved episode plot to: {plot_path}")
if __name__ == "__main__":
main()