增加了stop命令

This commit is contained in:
2026-02-19 22:11:10 +08:00
parent 7023d5dde4
commit ee257bcb6c

View File

@@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import h5py import h5py
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR") RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
@@ -96,6 +97,12 @@ def parse_args() -> argparse.Namespace:
default="No target visible.", default="No target visible.",
help="Instruction when no valid target after crop", help="Instruction when no valid target after crop",
) )
parser.add_argument(
"--stop_instruction",
type=str,
default="Stop move.",
help="Instruction used for stationary head/tail frames",
)
parser.add_argument( parser.add_argument(
"--state_norm", "--state_norm",
choices=["minus1_1", "0_1", "raw"], choices=["minus1_1", "0_1", "raw"],
@@ -125,6 +132,26 @@ def parse_args() -> argparse.Namespace:
default=32, default=32,
help="Batch size for text feature extraction", help="Batch size for text feature extraction",
) )
parser.add_argument(
"--motion_window",
type=int,
default=5,
help="Sliding window size used for stationary detection at beginning/end",
)
parser.add_argument(
"--motion_threshold",
type=float,
default=0.002,
help=(
"Motion threshold in normalized delta space (0~1). "
"Smaller value means stricter stationary detection"
),
)
parser.add_argument(
"--disable_stop_override",
action="store_true",
help="Disable head/tail stationary instruction override",
)
return parser.parse_args() return parser.parse_args()
@@ -272,6 +299,76 @@ def extract_text_features(
return np.concatenate(feats, axis=0) return np.concatenate(feats, axis=0)
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
scale = max_v - min_v
if scale <= 0:
return np.zeros_like(v, dtype=np.float32)
return ((v - min_v) / scale).astype(np.float32)
def override_stationary_edge_instructions(
instructions: List[str],
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motor_cmd_0: np.ndarray,
motor_cmd_1: np.ndarray,
stop_instruction: str,
motion_window: int,
motion_threshold: float,
) -> Tuple[List[str], int, int]:
"""
Override instruction text at head/tail when deviation is below threshold.
Start side: use the first frame as reference and expand forward until
deviation exceeds threshold.
End side: use the last frame as reference and expand backward until
deviation exceeds threshold.
"""
num = len(instructions)
if num == 0:
return instructions, 0, 0
# normalize to comparable scales (0~1)
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
if num == 1:
return [stop_instruction], 1, 1
# keep argument for backward CLI compatibility
_ = motion_window
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
return float(np.max(np.abs(cur - ref)))
start_count = 0
for i in range(num):
if deviation_to_ref(i, start_ref) <= motion_threshold:
start_count += 1
else:
break
end_count = 0
for i in range(num - 1, -1, -1):
if deviation_to_ref(i, end_ref) <= motion_threshold:
end_count += 1
else:
break
updated = list(instructions)
for i in range(start_count):
updated[i] = stop_instruction
for i in range(num - end_count, num):
updated[i] = stop_instruction
return updated, start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path: def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv")) csvs = sorted(segment_dir.glob("*.csv"))
if not csvs: if not csvs:
@@ -279,6 +376,68 @@ def find_segment_csv(segment_dir: Path) -> Path:
return csvs[0] return csvs[0]
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
"""Convert boolean mask to closed-open segments [start, end)."""
segments: List[Tuple[int, int]] = []
if mask.size == 0:
return segments
in_seg = False
start = 0
for i, v in enumerate(mask.tolist()):
if v and not in_seg:
in_seg = True
start = i
elif (not v) and in_seg:
in_seg = False
segments.append((start, i))
if in_seg:
segments.append((start, int(mask.size)))
return segments
def save_episode_plot_with_stop_segments(
qpos: np.ndarray,
action: np.ndarray,
instructions: Sequence[str],
stop_instruction: str,
plot_path: Path,
) -> None:
"""
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
"""
qpos = np.asarray(qpos)
action = np.asarray(action)
if qpos.ndim != 2 or action.ndim != 2:
return
num_ts, num_dim = qpos.shape
if num_ts == 0 or num_dim == 0:
return
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
stop_segments = _mask_to_segments(stop_mask)
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
for dim_idx in range(num_dim):
ax = axs_list[dim_idx]
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
for seg_idx, (st, ed) in enumerate(stop_segments):
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
label='stop instruction' if seg_idx == 0 else None)
ax.set_ylabel(f'dim {dim_idx}')
ax.legend(loc='best')
ax.grid(alpha=0.25, linestyle='--')
axs_list[-1].set_xlabel('timestep')
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
fig.tight_layout()
fig.savefig(str(plot_path), dpi=140)
plt.close(fig)
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
@@ -311,6 +470,10 @@ def main() -> None:
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x] qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)] action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
instructions: List[str] = [] instructions: List[str] = []
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
y_min, y_max = 8000.0, 18884.0 y_min, y_max = 8000.0, 18884.0
x_min, x_max = 7000.0, 17384.0 x_min, x_max = 7000.0, 17384.0
@@ -340,6 +503,11 @@ def main() -> None:
motor_cmd_0 = float(row["motor_command_0"]) motor_cmd_0 = float(row["motor_command_0"])
motor_cmd_1 = float(row["motor_command_1"]) motor_cmd_1 = float(row["motor_command_1"])
motor_pos_y_series[i] = motor_pos_y
motor_pos_x_series[i] = motor_pos_x
motor_cmd_0_series[i] = motor_cmd_0
motor_cmd_1_series[i] = motor_cmd_1
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0] qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0] qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
@@ -353,6 +521,20 @@ def main() -> None:
) )
instructions.append(ins) instructions.append(ins)
start_stop_count = 0
end_stop_count = 0
if not args.disable_stop_override:
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions,
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motor_cmd_0=motor_cmd_0_series,
motor_cmd_1=motor_cmd_1_series,
stop_instruction=args.stop_instruction,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
text_features = None text_features = None
if args.encode_text_features: if args.encode_text_features:
text_features = extract_text_features( text_features = extract_text_features(
@@ -404,9 +586,26 @@ def main() -> None:
print(f"Frames used: {num}") print(f"Frames used: {num}")
print(f"Image shape: {images.shape}") print(f"Image shape: {images.shape}")
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}") print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
if not args.disable_stop_override:
print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'"
)
if text_features is not None: if text_features is not None:
print(f"instruction_features_timestep shape: {text_features.shape}") print(f"instruction_features_timestep shape: {text_features.shape}")
# Save a same-basename plot next to the generated hdf5
plot_path = out_path.with_suffix('.png')
save_episode_plot_with_stop_segments(
qpos=qpos,
action=action,
instructions=instructions,
stop_instruction=args.stop_instruction,
plot_path=plot_path,
)
print(f"Saved episode plot to: {plot_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()