增加了stop命令

This commit is contained in:
2026-02-19 22:11:10 +08:00
parent 7023d5dde4
commit ee257bcb6c

View File

@@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
@@ -96,6 +97,12 @@ def parse_args() -> argparse.Namespace:
default="No target visible.",
help="Instruction when no valid target after crop",
)
parser.add_argument(
"--stop_instruction",
type=str,
default="Stop move.",
help="Instruction used for stationary head/tail frames",
)
parser.add_argument(
"--state_norm",
choices=["minus1_1", "0_1", "raw"],
@@ -125,6 +132,26 @@ def parse_args() -> argparse.Namespace:
default=32,
help="Batch size for text feature extraction",
)
parser.add_argument(
"--motion_window",
type=int,
default=5,
help="Sliding window size used for stationary detection at beginning/end",
)
parser.add_argument(
"--motion_threshold",
type=float,
default=0.002,
help=(
"Motion threshold in normalized delta space (0~1). "
"Smaller value means stricter stationary detection"
),
)
parser.add_argument(
"--disable_stop_override",
action="store_true",
help="Disable head/tail stationary instruction override",
)
return parser.parse_args()
@@ -272,6 +299,76 @@ def extract_text_features(
return np.concatenate(feats, axis=0)
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
scale = max_v - min_v
if scale <= 0:
return np.zeros_like(v, dtype=np.float32)
return ((v - min_v) / scale).astype(np.float32)
def override_stationary_edge_instructions(
instructions: List[str],
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motor_cmd_0: np.ndarray,
motor_cmd_1: np.ndarray,
stop_instruction: str,
motion_window: int,
motion_threshold: float,
) -> Tuple[List[str], int, int]:
"""
Override instruction text at head/tail when deviation is below threshold.
Start side: use the first frame as reference and expand forward until
deviation exceeds threshold.
End side: use the last frame as reference and expand backward until
deviation exceeds threshold.
"""
num = len(instructions)
if num == 0:
return instructions, 0, 0
# normalize to comparable scales (0~1)
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
if num == 1:
return [stop_instruction], 1, 1
# keep argument for backward CLI compatibility
_ = motion_window
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
return float(np.max(np.abs(cur - ref)))
start_count = 0
for i in range(num):
if deviation_to_ref(i, start_ref) <= motion_threshold:
start_count += 1
else:
break
end_count = 0
for i in range(num - 1, -1, -1):
if deviation_to_ref(i, end_ref) <= motion_threshold:
end_count += 1
else:
break
updated = list(instructions)
for i in range(start_count):
updated[i] = stop_instruction
for i in range(num - end_count, num):
updated[i] = stop_instruction
return updated, start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv"))
if not csvs:
@@ -279,6 +376,68 @@ def find_segment_csv(segment_dir: Path) -> Path:
return csvs[0]
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
"""Convert boolean mask to closed-open segments [start, end)."""
segments: List[Tuple[int, int]] = []
if mask.size == 0:
return segments
in_seg = False
start = 0
for i, v in enumerate(mask.tolist()):
if v and not in_seg:
in_seg = True
start = i
elif (not v) and in_seg:
in_seg = False
segments.append((start, i))
if in_seg:
segments.append((start, int(mask.size)))
return segments
def save_episode_plot_with_stop_segments(
qpos: np.ndarray,
action: np.ndarray,
instructions: Sequence[str],
stop_instruction: str,
plot_path: Path,
) -> None:
"""
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
"""
qpos = np.asarray(qpos)
action = np.asarray(action)
if qpos.ndim != 2 or action.ndim != 2:
return
num_ts, num_dim = qpos.shape
if num_ts == 0 or num_dim == 0:
return
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
stop_segments = _mask_to_segments(stop_mask)
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
for dim_idx in range(num_dim):
ax = axs_list[dim_idx]
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
for seg_idx, (st, ed) in enumerate(stop_segments):
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
label='stop instruction' if seg_idx == 0 else None)
ax.set_ylabel(f'dim {dim_idx}')
ax.legend(loc='best')
ax.grid(alpha=0.25, linestyle='--')
axs_list[-1].set_xlabel('timestep')
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
fig.tight_layout()
fig.savefig(str(plot_path), dpi=140)
plt.close(fig)
def main() -> None:
args = parse_args()
@@ -311,6 +470,10 @@ def main() -> None:
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
instructions: List[str] = []
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
y_min, y_max = 8000.0, 18884.0
x_min, x_max = 7000.0, 17384.0
@@ -340,6 +503,11 @@ def main() -> None:
motor_cmd_0 = float(row["motor_command_0"])
motor_cmd_1 = float(row["motor_command_1"])
motor_pos_y_series[i] = motor_pos_y
motor_pos_x_series[i] = motor_pos_x
motor_cmd_0_series[i] = motor_cmd_0
motor_cmd_1_series[i] = motor_cmd_1
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
@@ -353,6 +521,20 @@ def main() -> None:
)
instructions.append(ins)
start_stop_count = 0
end_stop_count = 0
if not args.disable_stop_override:
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions,
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motor_cmd_0=motor_cmd_0_series,
motor_cmd_1=motor_cmd_1_series,
stop_instruction=args.stop_instruction,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
text_features = None
if args.encode_text_features:
text_features = extract_text_features(
@@ -404,9 +586,26 @@ def main() -> None:
print(f"Frames used: {num}")
print(f"Image shape: {images.shape}")
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
if not args.disable_stop_override:
print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'"
)
if text_features is not None:
print(f"instruction_features_timestep shape: {text_features.shape}")
# Save a same-basename plot next to the generated hdf5
plot_path = out_path.with_suffix('.png')
save_episode_plot_with_stop_segments(
qpos=qpos,
action=action,
instructions=instructions,
stop_instruction=args.stop_instruction,
plot_path=plot_path,
)
print(f"Saved episode plot to: {plot_path}")
if __name__ == "__main__":
main()