暂时可以生成hdf5数据
This commit is contained in:
412
build_endoscope_act_dataset.py
Normal file
412
build_endoscope_act_dataset.py
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CropBox:
|
||||||
|
x1: int
|
||||||
|
y1: int
|
||||||
|
x2: int
|
||||||
|
y2: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w(self) -> int:
|
||||||
|
return self.x2 - self.x1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> int:
|
||||||
|
return self.y2 - self.y1
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--segment_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to one raw segment, e.g. data/follow_seg_001",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output dir for episode_*.hdf5",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--episode_idx",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Output episode index (default: 0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_frames",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--camera_name",
|
||||||
|
type=str,
|
||||||
|
default="top",
|
||||||
|
help="Camera name written to /observations/images/<camera_name>",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop",
|
||||||
|
type=int,
|
||||||
|
nargs=4,
|
||||||
|
default=[733, 30, 1754, 1051],
|
||||||
|
metavar=("X1", "Y1", "X2", "Y2"),
|
||||||
|
help="Crop box in original image coordinates",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resize",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[224, 224],
|
||||||
|
metavar=("W", "H"),
|
||||||
|
help="Output image size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruction_template",
|
||||||
|
type=str,
|
||||||
|
default="Move toward the {label} at {region}.",
|
||||||
|
help="Template for per-frame instruction",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruction_empty",
|
||||||
|
type=str,
|
||||||
|
default="No target visible.",
|
||||||
|
help="Instruction when no valid target after crop",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--state_norm",
|
||||||
|
choices=["minus1_1", "0_1", "raw"],
|
||||||
|
default="minus1_1",
|
||||||
|
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_norm",
|
||||||
|
choices=["minus1_1", "0_1", "raw"],
|
||||||
|
default="minus1_1",
|
||||||
|
help="Normalization for action (motor_command_0, motor_command_1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encode_text_features",
|
||||||
|
action="store_true",
|
||||||
|
help="Encode per-frame instruction into 768-dim DistilBERT features",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_model_name",
|
||||||
|
type=str,
|
||||||
|
default="distilbert-base-uncased",
|
||||||
|
help="HuggingFace model name for DistilBERT",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size for text feature extraction",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
|
||||||
|
json_files = list(frames_dir.glob("*.json"))
|
||||||
|
|
||||||
|
def key_fn(p: Path) -> Tuple[int, str]:
|
||||||
|
m = re.search(r"frame_(\d+)", p.name)
|
||||||
|
idx = int(m.group(1)) if m else 10**9
|
||||||
|
return idx, p.name
|
||||||
|
|
||||||
|
json_files.sort(key=key_fn)
|
||||||
|
return json_files
|
||||||
|
|
||||||
|
|
||||||
|
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
|
||||||
|
if mode == "raw":
|
||||||
|
return x.astype(np.float32)
|
||||||
|
x01 = (x - min_v) / (max_v - min_v)
|
||||||
|
if mode == "0_1":
|
||||||
|
return x01.astype(np.float32)
|
||||||
|
# minus1_1
|
||||||
|
return (x01 * 2.0 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def clip_bbox_to_crop(
|
||||||
|
x_min: float,
|
||||||
|
y_min: float,
|
||||||
|
x_max: float,
|
||||||
|
y_max: float,
|
||||||
|
crop: CropBox,
|
||||||
|
) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
nx1 = max(x_min - crop.x1, 0.0)
|
||||||
|
ny1 = max(y_min - crop.y1, 0.0)
|
||||||
|
nx2 = min(x_max - crop.x1, float(crop.w - 1))
|
||||||
|
ny2 = min(y_max - crop.y1, float(crop.h - 1))
|
||||||
|
if nx2 <= nx1 or ny2 <= ny1:
|
||||||
|
return None
|
||||||
|
return nx1, ny1, nx2, ny2
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
|
||||||
|
x_bin = min(2, max(0, int(cx / (w / 3.0))))
|
||||||
|
y_bin = min(2, max(0, int(cy / (h / 3.0))))
|
||||||
|
xs = ["left", "center", "right"]
|
||||||
|
ys = ["top", "middle", "bottom"]
|
||||||
|
return f"{ys[y_bin]}-{xs[x_bin]}"
|
||||||
|
|
||||||
|
|
||||||
|
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
|
||||||
|
points = shape.get("points", None)
|
||||||
|
label = shape.get("label", "target")
|
||||||
|
if not points or len(points) < 2:
|
||||||
|
return None
|
||||||
|
pts = np.array(points, dtype=np.float32)
|
||||||
|
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
|
||||||
|
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
|
||||||
|
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
|
||||||
|
return label, x_min, y_min, x_max, y_max, area
|
||||||
|
|
||||||
|
|
||||||
|
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
|
||||||
|
shapes = annotation.get("shapes", [])
|
||||||
|
best = None
|
||||||
|
for shape in shapes:
|
||||||
|
parsed = read_shape_bbox(shape)
|
||||||
|
if parsed is None:
|
||||||
|
continue
|
||||||
|
label, x1, y1, x2, y2, area = parsed
|
||||||
|
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
|
||||||
|
if clipped is None:
|
||||||
|
continue
|
||||||
|
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
|
||||||
|
if best is None or c_area > best[2]:
|
||||||
|
best = (label, clipped, c_area)
|
||||||
|
if best is None:
|
||||||
|
return None
|
||||||
|
return best[0], best[1]
|
||||||
|
|
||||||
|
|
||||||
|
def instruction_from_annotation(
|
||||||
|
annotation: Dict,
|
||||||
|
crop: CropBox,
|
||||||
|
template: str,
|
||||||
|
empty_instruction: str,
|
||||||
|
) -> str:
|
||||||
|
picked = select_target_box(annotation, crop)
|
||||||
|
if picked is None:
|
||||||
|
return empty_instruction
|
||||||
|
label, box = picked
|
||||||
|
cx, cy = bbox_center(box)
|
||||||
|
region = region_3x3(cx, cy, crop.w, crop.h)
|
||||||
|
return template.format(label=label, region=region)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_features(
|
||||||
|
instructions: Sequence[str],
|
||||||
|
model_name: str,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import DistilBertTokenizerFast
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Text feature encoding requires transformers. Please install: pip install transformers"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
if str(repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(repo_root))
|
||||||
|
from models.text_encoder import DistilBERTTextEncoder
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
||||||
|
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
feats: List[np.ndarray] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(instructions), batch_size):
|
||||||
|
batch = list(instructions[i:i + batch_size])
|
||||||
|
tok = tokenizer(
|
||||||
|
batch,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=32,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = tok["input_ids"].to(device)
|
||||||
|
attention_mask = tok["attention_mask"].to(device)
|
||||||
|
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
|
||||||
|
feats.append(cls)
|
||||||
|
return np.concatenate(feats, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def find_segment_csv(segment_dir: Path) -> Path:
|
||||||
|
csvs = sorted(segment_dir.glob("*.csv"))
|
||||||
|
if not csvs:
|
||||||
|
raise FileNotFoundError(f"No csv file found in {segment_dir}")
|
||||||
|
return csvs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
segment_dir = Path(args.segment_dir).resolve()
|
||||||
|
output_dir = Path(args.output_dir).resolve()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
frames_dir = segment_dir / "frames"
|
||||||
|
if not frames_dir.exists():
|
||||||
|
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
|
||||||
|
|
||||||
|
csv_path = find_segment_csv(segment_dir)
|
||||||
|
csv_rows = load_csv_rows(csv_path)
|
||||||
|
if len(csv_rows) == 0:
|
||||||
|
raise ValueError(f"CSV has no rows: {csv_path}")
|
||||||
|
|
||||||
|
crop = CropBox(*args.crop)
|
||||||
|
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
|
||||||
|
|
||||||
|
json_files = sorted_frame_jsons(frames_dir)
|
||||||
|
if not json_files:
|
||||||
|
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
|
||||||
|
|
||||||
|
max_aligned = min(len(json_files), len(csv_rows))
|
||||||
|
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
|
||||||
|
if num <= 0:
|
||||||
|
raise ValueError("No aligned frames available.")
|
||||||
|
|
||||||
|
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
|
||||||
|
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||||
|
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||||
|
instructions: List[str] = []
|
||||||
|
|
||||||
|
y_min, y_max = 8000.0, 18884.0
|
||||||
|
x_min, x_max = 7000.0, 17384.0
|
||||||
|
cmd_min, cmd_max = 0.0, 65535.0
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
json_path = json_files[i]
|
||||||
|
with json_path.open("r", encoding="utf-8") as f:
|
||||||
|
ann = json.load(f)
|
||||||
|
|
||||||
|
image_path = frames_dir / ann["imagePath"]
|
||||||
|
if not image_path.exists():
|
||||||
|
alt = json_path.with_suffix(".jpg")
|
||||||
|
if alt.exists():
|
||||||
|
image_path = alt
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Image not found for {json_path.name}")
|
||||||
|
|
||||||
|
img = Image.open(image_path).convert("RGB")
|
||||||
|
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
|
||||||
|
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
|
||||||
|
images[i] = np.asarray(img_resize, dtype=np.uint8)
|
||||||
|
|
||||||
|
row = csv_rows[i]
|
||||||
|
motor_pos_y = float(row["motor_pos_y"])
|
||||||
|
motor_pos_x = float(row["motor_pos_x"])
|
||||||
|
motor_cmd_0 = float(row["motor_command_0"])
|
||||||
|
motor_cmd_1 = float(row["motor_command_1"])
|
||||||
|
|
||||||
|
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||||
|
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||||
|
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
|
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
|
|
||||||
|
ins = instruction_from_annotation(
|
||||||
|
ann,
|
||||||
|
crop,
|
||||||
|
args.instruction_template,
|
||||||
|
args.instruction_empty,
|
||||||
|
)
|
||||||
|
instructions.append(ins)
|
||||||
|
|
||||||
|
text_features = None
|
||||||
|
if args.encode_text_features:
|
||||||
|
text_features = extract_text_features(
|
||||||
|
instructions,
|
||||||
|
model_name=args.text_model_name,
|
||||||
|
batch_size=args.text_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
|
||||||
|
dt = 1.0 / 30.0
|
||||||
|
|
||||||
|
with h5py.File(out_path, "w") as root:
|
||||||
|
root.attrs["sim"] = False
|
||||||
|
root.attrs["source_segment"] = str(segment_dir)
|
||||||
|
root.attrs["frame_rate"] = 30
|
||||||
|
root.attrs["dt"] = dt
|
||||||
|
root.attrs["state_norm_mode"] = args.state_norm
|
||||||
|
root.attrs["action_norm_mode"] = args.action_norm
|
||||||
|
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
|
||||||
|
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
|
||||||
|
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
|
||||||
|
|
||||||
|
obs = root.create_group("observations")
|
||||||
|
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
|
||||||
|
images_group = obs.create_group("images")
|
||||||
|
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
|
||||||
|
|
||||||
|
root.create_dataset("action", data=action, dtype=np.float32)
|
||||||
|
|
||||||
|
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||||
|
root.create_dataset(
|
||||||
|
"instruction_timestep",
|
||||||
|
shape=(num,),
|
||||||
|
dtype=str_dtype,
|
||||||
|
data=np.asarray(instructions, dtype=object),
|
||||||
|
)
|
||||||
|
root.create_dataset(
|
||||||
|
"instruction",
|
||||||
|
shape=(),
|
||||||
|
dtype=str_dtype,
|
||||||
|
data=instructions[0] if len(instructions) > 0 else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_features is not None:
|
||||||
|
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||||
|
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
|
||||||
|
|
||||||
|
print(f"Saved: {out_path}")
|
||||||
|
print(f"Frames used: {num}")
|
||||||
|
print(f"Image shape: {images.shape}")
|
||||||
|
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||||
|
if text_features is not None:
|
||||||
|
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
31
models/text_encoder.py
Normal file
31
models/text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DistilBERTTextEncoder(nn.Module):
|
||||||
|
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
|
||||||
|
super().__init__()
|
||||||
|
try:
|
||||||
|
from transformers import DistilBertModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
'transformers is required for DistilBERT text encoding. '
|
||||||
|
'Install it with: pip install transformers'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self.encoder = DistilBertModel.from_pretrained(model_name)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.freeze = freeze
|
||||||
|
|
||||||
|
if self.freeze:
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.encoder.eval()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
if self.freeze:
|
||||||
|
self.encoder.eval()
|
||||||
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
# DistilBERT has no pooled output; use [CLS] token embedding
|
||||||
|
cls_feature = outputs.last_hidden_state[:, 0, :]
|
||||||
|
return cls_feature
|
||||||
Reference in New Issue
Block a user