暂时可以生成hdf5数据
This commit is contained in:
412
build_endoscope_act_dataset.py
Normal file
412
build_endoscope_act_dataset.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropBox:
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def w(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def h(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--segment_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to one raw segment, e.g. data/follow_seg_001",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output dir for episode_*.hdf5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode_idx",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Output episode index (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_frames",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera_name",
|
||||
type=str,
|
||||
default="top",
|
||||
help="Camera name written to /observations/images/<camera_name>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop",
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[733, 30, 1754, 1051],
|
||||
metavar=("X1", "Y1", "X2", "Y2"),
|
||||
help="Crop box in original image coordinates",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[224, 224],
|
||||
metavar=("W", "H"),
|
||||
help="Output image size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruction_template",
|
||||
type=str,
|
||||
default="Move toward the {label} at {region}.",
|
||||
help="Template for per-frame instruction",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruction_empty",
|
||||
type=str,
|
||||
default="No target visible.",
|
||||
help="Instruction when no valid target after crop",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state_norm",
|
||||
choices=["minus1_1", "0_1", "raw"],
|
||||
default="minus1_1",
|
||||
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_norm",
|
||||
choices=["minus1_1", "0_1", "raw"],
|
||||
default="minus1_1",
|
||||
help="Normalization for action (motor_command_0, motor_command_1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode_text_features",
|
||||
action="store_true",
|
||||
help="Encode per-frame instruction into 768-dim DistilBERT features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_name",
|
||||
type=str,
|
||||
default="distilbert-base-uncased",
|
||||
help="HuggingFace model name for DistilBERT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for text feature extraction",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
|
||||
json_files = list(frames_dir.glob("*.json"))
|
||||
|
||||
def key_fn(p: Path) -> Tuple[int, str]:
|
||||
m = re.search(r"frame_(\d+)", p.name)
|
||||
idx = int(m.group(1)) if m else 10**9
|
||||
return idx, p.name
|
||||
|
||||
json_files.sort(key=key_fn)
|
||||
return json_files
|
||||
|
||||
|
||||
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
return list(reader)
|
||||
|
||||
|
||||
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
|
||||
if mode == "raw":
|
||||
return x.astype(np.float32)
|
||||
x01 = (x - min_v) / (max_v - min_v)
|
||||
if mode == "0_1":
|
||||
return x01.astype(np.float32)
|
||||
# minus1_1
|
||||
return (x01 * 2.0 - 1.0).astype(np.float32)
|
||||
|
||||
|
||||
def clip_bbox_to_crop(
|
||||
x_min: float,
|
||||
y_min: float,
|
||||
x_max: float,
|
||||
y_max: float,
|
||||
crop: CropBox,
|
||||
) -> Optional[Tuple[float, float, float, float]]:
|
||||
nx1 = max(x_min - crop.x1, 0.0)
|
||||
ny1 = max(y_min - crop.y1, 0.0)
|
||||
nx2 = min(x_max - crop.x1, float(crop.w - 1))
|
||||
ny2 = min(y_max - crop.y1, float(crop.h - 1))
|
||||
if nx2 <= nx1 or ny2 <= ny1:
|
||||
return None
|
||||
return nx1, ny1, nx2, ny2
|
||||
|
||||
|
||||
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
|
||||
x1, y1, x2, y2 = box
|
||||
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
|
||||
|
||||
|
||||
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
|
||||
x_bin = min(2, max(0, int(cx / (w / 3.0))))
|
||||
y_bin = min(2, max(0, int(cy / (h / 3.0))))
|
||||
xs = ["left", "center", "right"]
|
||||
ys = ["top", "middle", "bottom"]
|
||||
return f"{ys[y_bin]}-{xs[x_bin]}"
|
||||
|
||||
|
||||
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
|
||||
points = shape.get("points", None)
|
||||
label = shape.get("label", "target")
|
||||
if not points or len(points) < 2:
|
||||
return None
|
||||
pts = np.array(points, dtype=np.float32)
|
||||
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
|
||||
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
|
||||
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
|
||||
return label, x_min, y_min, x_max, y_max, area
|
||||
|
||||
|
||||
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
|
||||
shapes = annotation.get("shapes", [])
|
||||
best = None
|
||||
for shape in shapes:
|
||||
parsed = read_shape_bbox(shape)
|
||||
if parsed is None:
|
||||
continue
|
||||
label, x1, y1, x2, y2, area = parsed
|
||||
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
|
||||
if clipped is None:
|
||||
continue
|
||||
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
|
||||
if best is None or c_area > best[2]:
|
||||
best = (label, clipped, c_area)
|
||||
if best is None:
|
||||
return None
|
||||
return best[0], best[1]
|
||||
|
||||
|
||||
def instruction_from_annotation(
|
||||
annotation: Dict,
|
||||
crop: CropBox,
|
||||
template: str,
|
||||
empty_instruction: str,
|
||||
) -> str:
|
||||
picked = select_target_box(annotation, crop)
|
||||
if picked is None:
|
||||
return empty_instruction
|
||||
label, box = picked
|
||||
cx, cy = bbox_center(box)
|
||||
region = region_3x3(cx, cy, crop.w, crop.h)
|
||||
return template.format(label=label, region=region)
|
||||
|
||||
|
||||
def extract_text_features(
|
||||
instructions: Sequence[str],
|
||||
model_name: str,
|
||||
batch_size: int = 32,
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
import torch
|
||||
from transformers import DistilBertTokenizerFast
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Text feature encoding requires transformers. Please install: pip install transformers"
|
||||
) from exc
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
from models.text_encoder import DistilBERTTextEncoder
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
||||
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
|
||||
model.eval()
|
||||
|
||||
feats: List[np.ndarray] = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(instructions), batch_size):
|
||||
batch = list(instructions[i:i + batch_size])
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=32,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tok["input_ids"].to(device)
|
||||
attention_mask = tok["attention_mask"].to(device)
|
||||
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
|
||||
feats.append(cls)
|
||||
return np.concatenate(feats, axis=0)
|
||||
|
||||
|
||||
def find_segment_csv(segment_dir: Path) -> Path:
|
||||
csvs = sorted(segment_dir.glob("*.csv"))
|
||||
if not csvs:
|
||||
raise FileNotFoundError(f"No csv file found in {segment_dir}")
|
||||
return csvs[0]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
segment_dir = Path(args.segment_dir).resolve()
|
||||
output_dir = Path(args.output_dir).resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
frames_dir = segment_dir / "frames"
|
||||
if not frames_dir.exists():
|
||||
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
|
||||
|
||||
csv_path = find_segment_csv(segment_dir)
|
||||
csv_rows = load_csv_rows(csv_path)
|
||||
if len(csv_rows) == 0:
|
||||
raise ValueError(f"CSV has no rows: {csv_path}")
|
||||
|
||||
crop = CropBox(*args.crop)
|
||||
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
|
||||
|
||||
json_files = sorted_frame_jsons(frames_dir)
|
||||
if not json_files:
|
||||
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
|
||||
|
||||
max_aligned = min(len(json_files), len(csv_rows))
|
||||
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
|
||||
if num <= 0:
|
||||
raise ValueError("No aligned frames available.")
|
||||
|
||||
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
|
||||
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||
instructions: List[str] = []
|
||||
|
||||
y_min, y_max = 8000.0, 18884.0
|
||||
x_min, x_max = 7000.0, 17384.0
|
||||
cmd_min, cmd_max = 0.0, 65535.0
|
||||
|
||||
for i in range(num):
|
||||
json_path = json_files[i]
|
||||
with json_path.open("r", encoding="utf-8") as f:
|
||||
ann = json.load(f)
|
||||
|
||||
image_path = frames_dir / ann["imagePath"]
|
||||
if not image_path.exists():
|
||||
alt = json_path.with_suffix(".jpg")
|
||||
if alt.exists():
|
||||
image_path = alt
|
||||
else:
|
||||
raise FileNotFoundError(f"Image not found for {json_path.name}")
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
|
||||
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
|
||||
images[i] = np.asarray(img_resize, dtype=np.uint8)
|
||||
|
||||
row = csv_rows[i]
|
||||
motor_pos_y = float(row["motor_pos_y"])
|
||||
motor_pos_x = float(row["motor_pos_x"])
|
||||
motor_cmd_0 = float(row["motor_command_0"])
|
||||
motor_cmd_1 = float(row["motor_command_1"])
|
||||
|
||||
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
|
||||
ins = instruction_from_annotation(
|
||||
ann,
|
||||
crop,
|
||||
args.instruction_template,
|
||||
args.instruction_empty,
|
||||
)
|
||||
instructions.append(ins)
|
||||
|
||||
text_features = None
|
||||
if args.encode_text_features:
|
||||
text_features = extract_text_features(
|
||||
instructions,
|
||||
model_name=args.text_model_name,
|
||||
batch_size=args.text_batch_size,
|
||||
)
|
||||
|
||||
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
|
||||
dt = 1.0 / 30.0
|
||||
|
||||
with h5py.File(out_path, "w") as root:
|
||||
root.attrs["sim"] = False
|
||||
root.attrs["source_segment"] = str(segment_dir)
|
||||
root.attrs["frame_rate"] = 30
|
||||
root.attrs["dt"] = dt
|
||||
root.attrs["state_norm_mode"] = args.state_norm
|
||||
root.attrs["action_norm_mode"] = args.action_norm
|
||||
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
|
||||
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
|
||||
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
|
||||
|
||||
obs = root.create_group("observations")
|
||||
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
|
||||
images_group = obs.create_group("images")
|
||||
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
|
||||
|
||||
root.create_dataset("action", data=action, dtype=np.float32)
|
||||
|
||||
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||
root.create_dataset(
|
||||
"instruction_timestep",
|
||||
shape=(num,),
|
||||
dtype=str_dtype,
|
||||
data=np.asarray(instructions, dtype=object),
|
||||
)
|
||||
root.create_dataset(
|
||||
"instruction",
|
||||
shape=(),
|
||||
dtype=str_dtype,
|
||||
data=instructions[0] if len(instructions) > 0 else "",
|
||||
)
|
||||
|
||||
if text_features is not None:
|
||||
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
|
||||
|
||||
print(f"Saved: {out_path}")
|
||||
print(f"Frames used: {num}")
|
||||
print(f"Image shape: {images.shape}")
|
||||
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||
if text_features is not None:
|
||||
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user