暂时可以生成hdf5数据

This commit is contained in:
2026-02-17 22:20:25 +08:00
parent ba006e14c4
commit b701d939c2
3 changed files with 443 additions and 0 deletions

View File

@@ -0,0 +1,412 @@
#!/usr/bin/env python3
import argparse
import csv
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import h5py
import numpy as np
from PIL import Image
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
@dataclass
class CropBox:
x1: int
y1: int
x2: int
y2: int
@property
def w(self) -> int:
return self.x2 - self.x1
@property
def h(self) -> int:
return self.y2 - self.y1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
)
)
parser.add_argument(
"--segment_dir",
type=str,
required=True,
help="Path to one raw segment, e.g. data/follow_seg_001",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output dir for episode_*.hdf5",
)
parser.add_argument(
"--episode_idx",
type=int,
default=0,
help="Output episode index (default: 0)",
)
parser.add_argument(
"--max_frames",
type=int,
default=-1,
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
)
parser.add_argument(
"--camera_name",
type=str,
default="top",
help="Camera name written to /observations/images/<camera_name>",
)
parser.add_argument(
"--crop",
type=int,
nargs=4,
default=[733, 30, 1754, 1051],
metavar=("X1", "Y1", "X2", "Y2"),
help="Crop box in original image coordinates",
)
parser.add_argument(
"--resize",
type=int,
nargs=2,
default=[224, 224],
metavar=("W", "H"),
help="Output image size",
)
parser.add_argument(
"--instruction_template",
type=str,
default="Move toward the {label} at {region}.",
help="Template for per-frame instruction",
)
parser.add_argument(
"--instruction_empty",
type=str,
default="No target visible.",
help="Instruction when no valid target after crop",
)
parser.add_argument(
"--state_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
)
parser.add_argument(
"--action_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for action (motor_command_0, motor_command_1)",
)
parser.add_argument(
"--encode_text_features",
action="store_true",
help="Encode per-frame instruction into 768-dim DistilBERT features",
)
parser.add_argument(
"--text_model_name",
type=str,
default="distilbert-base-uncased",
help="HuggingFace model name for DistilBERT",
)
parser.add_argument(
"--text_batch_size",
type=int,
default=32,
help="Batch size for text feature extraction",
)
return parser.parse_args()
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
json_files = list(frames_dir.glob("*.json"))
def key_fn(p: Path) -> Tuple[int, str]:
m = re.search(r"frame_(\d+)", p.name)
idx = int(m.group(1)) if m else 10**9
return idx, p.name
json_files.sort(key=key_fn)
return json_files
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
return list(reader)
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
if mode == "raw":
return x.astype(np.float32)
x01 = (x - min_v) / (max_v - min_v)
if mode == "0_1":
return x01.astype(np.float32)
# minus1_1
return (x01 * 2.0 - 1.0).astype(np.float32)
def clip_bbox_to_crop(
x_min: float,
y_min: float,
x_max: float,
y_max: float,
crop: CropBox,
) -> Optional[Tuple[float, float, float, float]]:
nx1 = max(x_min - crop.x1, 0.0)
ny1 = max(y_min - crop.y1, 0.0)
nx2 = min(x_max - crop.x1, float(crop.w - 1))
ny2 = min(y_max - crop.y1, float(crop.h - 1))
if nx2 <= nx1 or ny2 <= ny1:
return None
return nx1, ny1, nx2, ny2
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
x1, y1, x2, y2 = box
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
x_bin = min(2, max(0, int(cx / (w / 3.0))))
y_bin = min(2, max(0, int(cy / (h / 3.0))))
xs = ["left", "center", "right"]
ys = ["top", "middle", "bottom"]
return f"{ys[y_bin]}-{xs[x_bin]}"
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
points = shape.get("points", None)
label = shape.get("label", "target")
if not points or len(points) < 2:
return None
pts = np.array(points, dtype=np.float32)
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
return label, x_min, y_min, x_max, y_max, area
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
shapes = annotation.get("shapes", [])
best = None
for shape in shapes:
parsed = read_shape_bbox(shape)
if parsed is None:
continue
label, x1, y1, x2, y2, area = parsed
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
if clipped is None:
continue
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
if best is None or c_area > best[2]:
best = (label, clipped, c_area)
if best is None:
return None
return best[0], best[1]
def instruction_from_annotation(
annotation: Dict,
crop: CropBox,
template: str,
empty_instruction: str,
) -> str:
picked = select_target_box(annotation, crop)
if picked is None:
return empty_instruction
label, box = picked
cx, cy = bbox_center(box)
region = region_3x3(cx, cy, crop.w, crop.h)
return template.format(label=label, region=region)
def extract_text_features(
instructions: Sequence[str],
model_name: str,
batch_size: int = 32,
) -> np.ndarray:
try:
import torch
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
"Text feature encoding requires transformers. Please install: pip install transformers"
) from exc
repo_root = Path(__file__).resolve().parents[1]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from models.text_encoder import DistilBERTTextEncoder
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
model.eval()
feats: List[np.ndarray] = []
with torch.no_grad():
for i in range(0, len(instructions), batch_size):
batch = list(instructions[i:i + batch_size])
tok = tokenizer(
batch,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt",
)
input_ids = tok["input_ids"].to(device)
attention_mask = tok["attention_mask"].to(device)
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
feats.append(cls)
return np.concatenate(feats, axis=0)
def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv"))
if not csvs:
raise FileNotFoundError(f"No csv file found in {segment_dir}")
return csvs[0]
def main() -> None:
args = parse_args()
segment_dir = Path(args.segment_dir).resolve()
output_dir = Path(args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
frames_dir = segment_dir / "frames"
if not frames_dir.exists():
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
csv_path = find_segment_csv(segment_dir)
csv_rows = load_csv_rows(csv_path)
if len(csv_rows) == 0:
raise ValueError(f"CSV has no rows: {csv_path}")
crop = CropBox(*args.crop)
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
json_files = sorted_frame_jsons(frames_dir)
if not json_files:
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
max_aligned = min(len(json_files), len(csv_rows))
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
if num <= 0:
raise ValueError("No aligned frames available.")
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
instructions: List[str] = []
y_min, y_max = 8000.0, 18884.0
x_min, x_max = 7000.0, 17384.0
cmd_min, cmd_max = 0.0, 65535.0
for i in range(num):
json_path = json_files[i]
with json_path.open("r", encoding="utf-8") as f:
ann = json.load(f)
image_path = frames_dir / ann["imagePath"]
if not image_path.exists():
alt = json_path.with_suffix(".jpg")
if alt.exists():
image_path = alt
else:
raise FileNotFoundError(f"Image not found for {json_path.name}")
img = Image.open(image_path).convert("RGB")
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
images[i] = np.asarray(img_resize, dtype=np.uint8)
row = csv_rows[i]
motor_pos_y = float(row["motor_pos_y"])
motor_pos_x = float(row["motor_pos_x"])
motor_cmd_0 = float(row["motor_command_0"])
motor_cmd_1 = float(row["motor_command_1"])
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
ins = instruction_from_annotation(
ann,
crop,
args.instruction_template,
args.instruction_empty,
)
instructions.append(ins)
text_features = None
if args.encode_text_features:
text_features = extract_text_features(
instructions,
model_name=args.text_model_name,
batch_size=args.text_batch_size,
)
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
dt = 1.0 / 30.0
with h5py.File(out_path, "w") as root:
root.attrs["sim"] = False
root.attrs["source_segment"] = str(segment_dir)
root.attrs["frame_rate"] = 30
root.attrs["dt"] = dt
root.attrs["state_norm_mode"] = args.state_norm
root.attrs["action_norm_mode"] = args.action_norm
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
obs = root.create_group("observations")
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
images_group = obs.create_group("images")
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
root.create_dataset("action", data=action, dtype=np.float32)
str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset(
"instruction_timestep",
shape=(num,),
dtype=str_dtype,
data=np.asarray(instructions, dtype=object),
)
root.create_dataset(
"instruction",
shape=(),
dtype=str_dtype,
data=instructions[0] if len(instructions) > 0 else "",
)
if text_features is not None:
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
print(f"Saved: {out_path}")
print(f"Frames used: {num}")
print(f"Image shape: {images.shape}")
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
if text_features is not None:
print(f"instruction_features_timestep shape: {text_features.shape}")
if __name__ == "__main__":
main()

0
models/__init__.py Normal file
View File

31
models/text_encoder.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
class DistilBERTTextEncoder(nn.Module):
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
super().__init__()
try:
from transformers import DistilBertModel
except ImportError as exc:
raise ImportError(
'transformers is required for DistilBERT text encoding. '
'Install it with: pip install transformers'
) from exc
self.encoder = DistilBertModel.from_pretrained(model_name)
self.output_dim = output_dim
self.freeze = freeze
if self.freeze:
for param in self.encoder.parameters():
param.requires_grad = False
self.encoder.eval()
def forward(self, input_ids, attention_mask):
if self.freeze:
self.encoder.eval()
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# DistilBERT has no pooled output; use [CLS] token embedding
cls_feature = outputs.last_hidden_state[:, 0, :]
return cls_feature