代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
dt = float(root.attrs.get('dt', DT))
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
action = root['/action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict
return qpos, qvel, action, image_dict, dt
def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
def save_videos(video, dt, video_path=None):
def save_videos(video, dt, video_path):
if isinstance(video, list):
cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names)
fps = int(1/dt)
fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for ts, image_dict in enumerate(video):
images = []
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]