612 lines
20 KiB
Python
612 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
import h5py
|
|
import numpy as np
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
|
|
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
|
|
|
|
|
@dataclass
|
|
class CropBox:
|
|
x1: int
|
|
y1: int
|
|
x2: int
|
|
y2: int
|
|
|
|
@property
|
|
def w(self) -> int:
|
|
return self.x2 - self.x1
|
|
|
|
@property
|
|
def h(self) -> int:
|
|
return self.y2 - self.y1
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
"--segment_dir",
|
|
type=str,
|
|
required=True,
|
|
help="Path to one raw segment, e.g. data/follow_seg_001",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
required=True,
|
|
help="Output dir for episode_*.hdf5",
|
|
)
|
|
parser.add_argument(
|
|
"--episode_idx",
|
|
type=int,
|
|
default=0,
|
|
help="Output episode index (default: 0)",
|
|
)
|
|
parser.add_argument(
|
|
"--max_frames",
|
|
type=int,
|
|
default=-1,
|
|
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
|
|
)
|
|
parser.add_argument(
|
|
"--camera_name",
|
|
type=str,
|
|
default="top",
|
|
help="Camera name written to /observations/images/<camera_name>",
|
|
)
|
|
parser.add_argument(
|
|
"--crop",
|
|
type=int,
|
|
nargs=4,
|
|
default=[733, 30, 1754, 1051],
|
|
metavar=("X1", "Y1", "X2", "Y2"),
|
|
help="Crop box in original image coordinates",
|
|
)
|
|
parser.add_argument(
|
|
"--resize",
|
|
type=int,
|
|
nargs=2,
|
|
default=[224, 224],
|
|
metavar=("W", "H"),
|
|
help="Output image size",
|
|
)
|
|
parser.add_argument(
|
|
"--instruction_template",
|
|
type=str,
|
|
default="Move toward the {label} at {region}.",
|
|
help="Template for per-frame instruction",
|
|
)
|
|
parser.add_argument(
|
|
"--instruction_empty",
|
|
type=str,
|
|
default="No target visible.",
|
|
help="Instruction when no valid target after crop",
|
|
)
|
|
parser.add_argument(
|
|
"--stop_instruction",
|
|
type=str,
|
|
default="Stop move.",
|
|
help="Instruction used for stationary head/tail frames",
|
|
)
|
|
parser.add_argument(
|
|
"--state_norm",
|
|
choices=["minus1_1", "0_1", "raw"],
|
|
default="minus1_1",
|
|
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
|
|
)
|
|
parser.add_argument(
|
|
"--action_norm",
|
|
choices=["minus1_1", "0_1", "raw"],
|
|
default="minus1_1",
|
|
help="Normalization for action (motor_command_0, motor_command_1)",
|
|
)
|
|
parser.add_argument(
|
|
"--encode_text_features",
|
|
action="store_true",
|
|
help="Encode per-frame instruction into 768-dim DistilBERT features",
|
|
)
|
|
parser.add_argument(
|
|
"--text_model_name",
|
|
type=str,
|
|
default="distilbert-base-uncased",
|
|
help="HuggingFace model name for DistilBERT",
|
|
)
|
|
parser.add_argument(
|
|
"--text_batch_size",
|
|
type=int,
|
|
default=32,
|
|
help="Batch size for text feature extraction",
|
|
)
|
|
parser.add_argument(
|
|
"--motion_window",
|
|
type=int,
|
|
default=5,
|
|
help="Sliding window size used for stationary detection at beginning/end",
|
|
)
|
|
parser.add_argument(
|
|
"--motion_threshold",
|
|
type=float,
|
|
default=0.002,
|
|
help=(
|
|
"Motion threshold in normalized delta space (0~1). "
|
|
"Smaller value means stricter stationary detection"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--disable_stop_override",
|
|
action="store_true",
|
|
help="Disable head/tail stationary instruction override",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
|
|
json_files = list(frames_dir.glob("*.json"))
|
|
|
|
def key_fn(p: Path) -> Tuple[int, str]:
|
|
m = re.search(r"frame_(\d+)", p.name)
|
|
idx = int(m.group(1)) if m else 10**9
|
|
return idx, p.name
|
|
|
|
json_files.sort(key=key_fn)
|
|
return json_files
|
|
|
|
|
|
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
|
|
with csv_path.open("r", encoding="utf-8") as f:
|
|
reader = csv.DictReader(f)
|
|
return list(reader)
|
|
|
|
|
|
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
|
|
if mode == "raw":
|
|
return x.astype(np.float32)
|
|
x01 = (x - min_v) / (max_v - min_v)
|
|
if mode == "0_1":
|
|
return x01.astype(np.float32)
|
|
# minus1_1
|
|
return (x01 * 2.0 - 1.0).astype(np.float32)
|
|
|
|
|
|
def clip_bbox_to_crop(
|
|
x_min: float,
|
|
y_min: float,
|
|
x_max: float,
|
|
y_max: float,
|
|
crop: CropBox,
|
|
) -> Optional[Tuple[float, float, float, float]]:
|
|
nx1 = max(x_min - crop.x1, 0.0)
|
|
ny1 = max(y_min - crop.y1, 0.0)
|
|
nx2 = min(x_max - crop.x1, float(crop.w - 1))
|
|
ny2 = min(y_max - crop.y1, float(crop.h - 1))
|
|
if nx2 <= nx1 or ny2 <= ny1:
|
|
return None
|
|
return nx1, ny1, nx2, ny2
|
|
|
|
|
|
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
|
|
x1, y1, x2, y2 = box
|
|
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
|
|
|
|
|
|
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
|
|
x_bin = min(2, max(0, int(cx / (w / 3.0))))
|
|
y_bin = min(2, max(0, int(cy / (h / 3.0))))
|
|
xs = ["left", "center", "right"]
|
|
ys = ["top", "middle", "bottom"]
|
|
return f"{ys[y_bin]}-{xs[x_bin]}"
|
|
|
|
|
|
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
|
|
points = shape.get("points", None)
|
|
label = shape.get("label", "target")
|
|
if not points or len(points) < 2:
|
|
return None
|
|
pts = np.array(points, dtype=np.float32)
|
|
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
|
|
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
|
|
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
|
|
return label, x_min, y_min, x_max, y_max, area
|
|
|
|
|
|
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
|
|
shapes = annotation.get("shapes", [])
|
|
best = None
|
|
for shape in shapes:
|
|
parsed = read_shape_bbox(shape)
|
|
if parsed is None:
|
|
continue
|
|
label, x1, y1, x2, y2, area = parsed
|
|
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
|
|
if clipped is None:
|
|
continue
|
|
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
|
|
if best is None or c_area > best[2]:
|
|
best = (label, clipped, c_area)
|
|
if best is None:
|
|
return None
|
|
return best[0], best[1]
|
|
|
|
|
|
def instruction_from_annotation(
|
|
annotation: Dict,
|
|
crop: CropBox,
|
|
template: str,
|
|
empty_instruction: str,
|
|
) -> str:
|
|
picked = select_target_box(annotation, crop)
|
|
if picked is None:
|
|
return empty_instruction
|
|
label, box = picked
|
|
cx, cy = bbox_center(box)
|
|
region = region_3x3(cx, cy, crop.w, crop.h)
|
|
return template.format(label=label, region=region)
|
|
|
|
|
|
def extract_text_features(
|
|
instructions: Sequence[str],
|
|
model_name: str,
|
|
batch_size: int = 32,
|
|
) -> np.ndarray:
|
|
try:
|
|
import torch
|
|
from transformers import DistilBertTokenizerFast
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"Text feature encoding requires transformers. Please install: pip install transformers"
|
|
) from exc
|
|
|
|
repo_root = Path(__file__).resolve().parents[1]
|
|
if str(repo_root) not in sys.path:
|
|
sys.path.insert(0, str(repo_root))
|
|
from models.text_encoder import DistilBERTTextEncoder
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
|
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
|
|
model.eval()
|
|
|
|
feats: List[np.ndarray] = []
|
|
with torch.no_grad():
|
|
for i in range(0, len(instructions), batch_size):
|
|
batch = list(instructions[i:i + batch_size])
|
|
tok = tokenizer(
|
|
batch,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=32,
|
|
return_tensors="pt",
|
|
)
|
|
input_ids = tok["input_ids"].to(device)
|
|
attention_mask = tok["attention_mask"].to(device)
|
|
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
|
|
feats.append(cls)
|
|
return np.concatenate(feats, axis=0)
|
|
|
|
|
|
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
|
|
scale = max_v - min_v
|
|
if scale <= 0:
|
|
return np.zeros_like(v, dtype=np.float32)
|
|
return ((v - min_v) / scale).astype(np.float32)
|
|
|
|
|
|
def override_stationary_edge_instructions(
|
|
instructions: List[str],
|
|
motor_pos_y: np.ndarray,
|
|
motor_pos_x: np.ndarray,
|
|
motor_cmd_0: np.ndarray,
|
|
motor_cmd_1: np.ndarray,
|
|
stop_instruction: str,
|
|
motion_window: int,
|
|
motion_threshold: float,
|
|
) -> Tuple[List[str], int, int]:
|
|
"""
|
|
Override instruction text at head/tail when deviation is below threshold.
|
|
Start side: use the first frame as reference and expand forward until
|
|
deviation exceeds threshold.
|
|
End side: use the last frame as reference and expand backward until
|
|
deviation exceeds threshold.
|
|
"""
|
|
num = len(instructions)
|
|
if num == 0:
|
|
return instructions, 0, 0
|
|
|
|
# normalize to comparable scales (0~1)
|
|
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
|
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
|
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
|
|
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
|
|
|
|
if num == 1:
|
|
return [stop_instruction], 1, 1
|
|
|
|
# keep argument for backward CLI compatibility
|
|
_ = motion_window
|
|
|
|
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
|
|
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
|
|
|
|
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
|
|
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
|
|
return float(np.max(np.abs(cur - ref)))
|
|
|
|
start_count = 0
|
|
for i in range(num):
|
|
if deviation_to_ref(i, start_ref) <= motion_threshold:
|
|
start_count += 1
|
|
else:
|
|
break
|
|
|
|
end_count = 0
|
|
for i in range(num - 1, -1, -1):
|
|
if deviation_to_ref(i, end_ref) <= motion_threshold:
|
|
end_count += 1
|
|
else:
|
|
break
|
|
|
|
updated = list(instructions)
|
|
for i in range(start_count):
|
|
updated[i] = stop_instruction
|
|
for i in range(num - end_count, num):
|
|
updated[i] = stop_instruction
|
|
|
|
return updated, start_count, end_count
|
|
|
|
|
|
def find_segment_csv(segment_dir: Path) -> Path:
|
|
csvs = sorted(segment_dir.glob("*.csv"))
|
|
if not csvs:
|
|
raise FileNotFoundError(f"No csv file found in {segment_dir}")
|
|
return csvs[0]
|
|
|
|
|
|
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
|
|
"""Convert boolean mask to closed-open segments [start, end)."""
|
|
segments: List[Tuple[int, int]] = []
|
|
if mask.size == 0:
|
|
return segments
|
|
in_seg = False
|
|
start = 0
|
|
for i, v in enumerate(mask.tolist()):
|
|
if v and not in_seg:
|
|
in_seg = True
|
|
start = i
|
|
elif (not v) and in_seg:
|
|
in_seg = False
|
|
segments.append((start, i))
|
|
if in_seg:
|
|
segments.append((start, int(mask.size)))
|
|
return segments
|
|
|
|
|
|
def save_episode_plot_with_stop_segments(
|
|
qpos: np.ndarray,
|
|
action: np.ndarray,
|
|
instructions: Sequence[str],
|
|
stop_instruction: str,
|
|
plot_path: Path,
|
|
) -> None:
|
|
"""
|
|
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
|
|
"""
|
|
qpos = np.asarray(qpos)
|
|
action = np.asarray(action)
|
|
if qpos.ndim != 2 or action.ndim != 2:
|
|
return
|
|
|
|
num_ts, num_dim = qpos.shape
|
|
if num_ts == 0 or num_dim == 0:
|
|
return
|
|
|
|
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
|
|
stop_segments = _mask_to_segments(stop_mask)
|
|
|
|
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
|
|
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
|
|
|
|
for dim_idx in range(num_dim):
|
|
ax = axs_list[dim_idx]
|
|
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
|
|
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
|
|
for seg_idx, (st, ed) in enumerate(stop_segments):
|
|
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
|
|
label='stop instruction' if seg_idx == 0 else None)
|
|
ax.set_ylabel(f'dim {dim_idx}')
|
|
ax.legend(loc='best')
|
|
ax.grid(alpha=0.25, linestyle='--')
|
|
|
|
axs_list[-1].set_xlabel('timestep')
|
|
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
|
|
fig.tight_layout()
|
|
fig.savefig(str(plot_path), dpi=140)
|
|
plt.close(fig)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
segment_dir = Path(args.segment_dir).resolve()
|
|
output_dir = Path(args.output_dir).resolve()
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
frames_dir = segment_dir / "frames"
|
|
if not frames_dir.exists():
|
|
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
|
|
|
|
csv_path = find_segment_csv(segment_dir)
|
|
csv_rows = load_csv_rows(csv_path)
|
|
if len(csv_rows) == 0:
|
|
raise ValueError(f"CSV has no rows: {csv_path}")
|
|
|
|
crop = CropBox(*args.crop)
|
|
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
|
|
|
|
json_files = sorted_frame_jsons(frames_dir)
|
|
if not json_files:
|
|
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
|
|
|
|
max_aligned = min(len(json_files), len(csv_rows))
|
|
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
|
|
if num <= 0:
|
|
raise ValueError("No aligned frames available.")
|
|
|
|
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
|
|
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
|
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
|
instructions: List[str] = []
|
|
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
|
|
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
|
|
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
|
|
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
|
|
|
|
y_min, y_max = 8000.0, 18884.0
|
|
x_min, x_max = 7000.0, 17384.0
|
|
cmd_min, cmd_max = 0.0, 65535.0
|
|
|
|
for i in range(num):
|
|
json_path = json_files[i]
|
|
with json_path.open("r", encoding="utf-8") as f:
|
|
ann = json.load(f)
|
|
|
|
image_path = frames_dir / ann["imagePath"]
|
|
if not image_path.exists():
|
|
alt = json_path.with_suffix(".jpg")
|
|
if alt.exists():
|
|
image_path = alt
|
|
else:
|
|
raise FileNotFoundError(f"Image not found for {json_path.name}")
|
|
|
|
img = Image.open(image_path).convert("RGB")
|
|
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
|
|
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
|
|
images[i] = np.asarray(img_resize, dtype=np.uint8)
|
|
|
|
row = csv_rows[i]
|
|
motor_pos_y = float(row["motor_pos_y"])
|
|
motor_pos_x = float(row["motor_pos_x"])
|
|
motor_cmd_0 = float(row["motor_command_0"])
|
|
motor_cmd_1 = float(row["motor_command_1"])
|
|
|
|
motor_pos_y_series[i] = motor_pos_y
|
|
motor_pos_x_series[i] = motor_pos_x
|
|
motor_cmd_0_series[i] = motor_cmd_0
|
|
motor_cmd_1_series[i] = motor_cmd_1
|
|
|
|
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
|
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
|
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
|
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
|
|
|
ins = instruction_from_annotation(
|
|
ann,
|
|
crop,
|
|
args.instruction_template,
|
|
args.instruction_empty,
|
|
)
|
|
instructions.append(ins)
|
|
|
|
start_stop_count = 0
|
|
end_stop_count = 0
|
|
if not args.disable_stop_override:
|
|
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
|
instructions=instructions,
|
|
motor_pos_y=motor_pos_y_series,
|
|
motor_pos_x=motor_pos_x_series,
|
|
motor_cmd_0=motor_cmd_0_series,
|
|
motor_cmd_1=motor_cmd_1_series,
|
|
stop_instruction=args.stop_instruction,
|
|
motion_window=args.motion_window,
|
|
motion_threshold=args.motion_threshold,
|
|
)
|
|
|
|
text_features = None
|
|
if args.encode_text_features:
|
|
text_features = extract_text_features(
|
|
instructions,
|
|
model_name=args.text_model_name,
|
|
batch_size=args.text_batch_size,
|
|
)
|
|
|
|
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
|
|
dt = 1.0 / 30.0
|
|
|
|
with h5py.File(out_path, "w") as root:
|
|
root.attrs["sim"] = False
|
|
root.attrs["source_segment"] = str(segment_dir)
|
|
root.attrs["frame_rate"] = 30
|
|
root.attrs["dt"] = dt
|
|
root.attrs["state_norm_mode"] = args.state_norm
|
|
root.attrs["action_norm_mode"] = args.action_norm
|
|
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
|
|
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
|
|
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
|
|
|
|
obs = root.create_group("observations")
|
|
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
|
|
images_group = obs.create_group("images")
|
|
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
|
|
|
|
root.create_dataset("action", data=action, dtype=np.float32)
|
|
|
|
str_dtype = h5py.string_dtype(encoding="utf-8")
|
|
root.create_dataset(
|
|
"instruction_timestep",
|
|
shape=(num,),
|
|
dtype=str_dtype,
|
|
data=np.asarray(instructions, dtype=object),
|
|
)
|
|
root.create_dataset(
|
|
"instruction",
|
|
shape=(),
|
|
dtype=str_dtype,
|
|
data=instructions[0] if len(instructions) > 0 else "",
|
|
)
|
|
|
|
if text_features is not None:
|
|
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
|
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
|
|
|
|
print(f"Saved: {out_path}")
|
|
print(f"Frames used: {num}")
|
|
print(f"Image shape: {images.shape}")
|
|
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
|
if not args.disable_stop_override:
|
|
print(
|
|
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
|
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
|
|
f"instruction='{args.stop_instruction}'"
|
|
)
|
|
if text_features is not None:
|
|
print(f"instruction_features_timestep shape: {text_features.shape}")
|
|
|
|
# Save a same-basename plot next to the generated hdf5
|
|
plot_path = out_path.with_suffix('.png')
|
|
save_episode_plot_with_stop_segments(
|
|
qpos=qpos,
|
|
action=action,
|
|
instructions=instructions,
|
|
stop_instruction=args.stop_instruction,
|
|
plot_path=plot_path,
|
|
)
|
|
print(f"Saved episode plot to: {plot_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|