first commit
This commit is contained in:
90
README.md
Normal file
90
README.md
Normal file
@@ -0,0 +1,90 @@
|
||||
# ACT: Action Chunking with Transformers
|
||||
|
||||
### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||
|
||||
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim (tested) or real (ongoing).
|
||||
|
||||
|
||||
### Repo Structure
|
||||
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||
- ``policy.py`` An adaptor for ACT policy
|
||||
- ``detr`` Model definitions of ACT, modified from DETR
|
||||
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||
- ``constants.py`` Constants shared across files
|
||||
- ``utils.py`` Utils such as data loading and helper functions
|
||||
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
conda create -n aloha python=3.8
|
||||
conda activate aloha
|
||||
pip install torchvision
|
||||
pip install torch
|
||||
pip install pyquaternion
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco
|
||||
pip install dm_control
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
pip install packaging
|
||||
pip install h5py
|
||||
pip install h5py_cache
|
||||
cd act/detr && pip install -e .
|
||||
|
||||
### Example Usages
|
||||
|
||||
To set up a new terminal, run:
|
||||
|
||||
conda activate aloha
|
||||
cd <path to act repo>
|
||||
|
||||
### Simulated experiments
|
||||
|
||||
We use ``transfer_cube`` task in the examples below. Another option is ``insertion``.
|
||||
To generated 50 episodes of scripted data, run:
|
||||
|
||||
python3 record_sim_episodes.py \
|
||||
--task_name transfer_cube \
|
||||
--dataset_dir <data save dir> \
|
||||
--num_episodes 50
|
||||
|
||||
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||
To visualize the episode after it is collected, run
|
||||
|
||||
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||
|
||||
To train ACT:
|
||||
|
||||
# Transfer Cube task
|
||||
python3 imitate_episodes.py \
|
||||
--dataset_dir <data save dir> \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
|
||||
--task_name transfer_cube --seed 0 \
|
||||
--temporal_agg \
|
||||
--num_epochs 1000 --lr 1e-4
|
||||
|
||||
# Bimanual Insertion task
|
||||
python3 imitate_episodes.py \
|
||||
--dataset_dir <data save dir> \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
|
||||
--task_name insertion --seed 0 \
|
||||
--temporal_agg \
|
||||
--num_epochs 2000 --lr 1e-5
|
||||
|
||||
To evaluate the policy, run the same command but add ``--eval``. The success rate
|
||||
should be around 85% for transfer cube, and around 50% for insertion.
|
||||
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||
|
||||
|
||||
|
||||
|
||||
59
assets/bimanual_viperx_ee_insertion.xml
Normal file
59
assets/bimanual_viperx_ee_insertion.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
48
assets/bimanual_viperx_ee_transfer_cube.xml
Normal file
48
assets/bimanual_viperx_ee_transfer_cube.xml
Normal file
@@ -0,0 +1,48 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
53
assets/bimanual_viperx_insertion.xml
Normal file
53
assets/bimanual_viperx_insertion.xml
Normal file
@@ -0,0 +1,53 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
42
assets/bimanual_viperx_transfer_cube.xml
Normal file
42
assets/bimanual_viperx_transfer_cube.xml
Normal file
@@ -0,0 +1,42 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
39
assets/scene.xml
Normal file
39
assets/scene.xml
Normal file
@@ -0,0 +1,39 @@
|
||||
<mujocoinclude>
|
||||
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||
|
||||
<asset>
|
||||
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||
</asset>
|
||||
|
||||
<visual>
|
||||
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||
<quality shadowsize="4096" offsamples="4"/>
|
||||
<headlight ambient="0.4 0.4 0.4"/>
|
||||
</visual>
|
||||
|
||||
<worldbody>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||
dir='1 1 -1'/>
|
||||
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||
dir='0 -1 -1'/>
|
||||
|
||||
<body name="table" pos="0 .6 0">
|
||||
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||
</body>
|
||||
<body name="midair" pos="0 .6 0.2">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||
</body>
|
||||
|
||||
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="main" pos="0 -0.2 0.4" fovy="78" mode="targetbody" target="midair"/>
|
||||
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
|
||||
</worldbody>
|
||||
|
||||
|
||||
|
||||
</mujocoinclude>
|
||||
BIN
assets/tabletop.stl
Normal file
BIN
assets/tabletop.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_10_custom_finger_left.stl
Normal file
BIN
assets/vx300s_10_custom_finger_left.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_10_custom_finger_right.stl
Normal file
BIN
assets/vx300s_10_custom_finger_right.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_10_gripper_finger.stl
Normal file
BIN
assets/vx300s_10_gripper_finger.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_11_ar_tag.stl
Normal file
BIN
assets/vx300s_11_ar_tag.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_1_base.stl
Normal file
BIN
assets/vx300s_1_base.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_2_shoulder.stl
Normal file
BIN
assets/vx300s_2_shoulder.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_3_upper_arm.stl
Normal file
BIN
assets/vx300s_3_upper_arm.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_4_upper_forearm.stl
Normal file
BIN
assets/vx300s_4_upper_forearm.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_5_lower_forearm.stl
Normal file
BIN
assets/vx300s_5_lower_forearm.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_6_wrist.stl
Normal file
BIN
assets/vx300s_6_wrist.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_7_gripper.stl
Normal file
BIN
assets/vx300s_7_gripper.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_8_gripper_prop.stl
Normal file
BIN
assets/vx300s_8_gripper_prop.stl
Normal file
Binary file not shown.
BIN
assets/vx300s_9_gripper_bar.stl
Normal file
BIN
assets/vx300s_9_gripper_bar.stl
Normal file
Binary file not shown.
17
assets/vx300s_dependencies.xml
Normal file
17
assets/vx300s_dependencies.xml
Normal file
@@ -0,0 +1,17 @@
|
||||
<mujocoinclude>
|
||||
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||
<asset>
|
||||
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||
</asset>
|
||||
|
||||
</mujocoinclude>
|
||||
59
assets/vx300s_left.xml
Normal file
59
assets/vx300s_left.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
59
assets/vx300s_right.xml
Normal file
59
assets/vx300s_right.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
1038
commands.txt
Normal file
1038
commands.txt
Normal file
File diff suppressed because it is too large
Load Diff
53
constants.py
Normal file
53
constants.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import pathlib
|
||||
|
||||
### Parameters that changes across tasks
|
||||
EPISODE_LEN = 600
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.02
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
CAMERA_NAMES = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] # defines the number and ordering of cameras
|
||||
BOX_INIT_POSE = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
SIM_CAMERA_NAMES = ['main']
|
||||
|
||||
SIM_EPISODE_LEN_TRANSFER_CUBE = 400
|
||||
SIM_EPISODE_LEN_INSERTION = 400
|
||||
|
||||
XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
|
||||
PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
|
||||
201
detr/LICENSE
Normal file
201
detr/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 - present, Facebook, Inc
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
9
detr/README.md
Normal file
9
detr/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
|
||||
|
||||
@article{Carion2020EndtoEndOD,
|
||||
title={End-to-End Object Detection with Transformers},
|
||||
author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
|
||||
journal={ArXiv},
|
||||
year={2020},
|
||||
volume={abs/2005.12872}
|
||||
}
|
||||
115
detr/main.py
Normal file
115
detr/main.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from .models import build_ACT_model, build_CNNMLP_model
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
||||
parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
|
||||
parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
|
||||
parser.add_argument('--batch_size', default=2, type=int) # not used
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
||||
parser.add_argument('--epochs', default=300, type=int) # not used
|
||||
parser.add_argument('--lr_drop', default=200, type=int) # not used
|
||||
parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
|
||||
help='gradient clipping max norm')
|
||||
|
||||
# Model parameters
|
||||
# * Backbone
|
||||
parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
|
||||
help="Name of the convolutional backbone to use")
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
||||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
||||
help="Type of positional embedding to use on top of the image features")
|
||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||
help="A list of camera names")
|
||||
|
||||
# * Transformer
|
||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||
help="Number of encoding layers in the transformer")
|
||||
parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
|
||||
help="Number of decoding layers in the transformer")
|
||||
parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
|
||||
help="Intermediate size of the feedforward layers in the transformer blocks")
|
||||
parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
|
||||
help="Size of the embeddings (dimension of the transformer)")
|
||||
parser.add_argument('--dropout', default=0.1, type=float,
|
||||
help="Dropout applied in the transformer")
|
||||
parser.add_argument('--nheads', default=8, type=int, # will be overridden
|
||||
help="Number of attention heads inside the transformer's attentions")
|
||||
parser.add_argument('--num_queries', default=400, type=int, # will be overridden
|
||||
help="Number of query slots")
|
||||
parser.add_argument('--pre_norm', action='store_true')
|
||||
|
||||
# * Segmentation
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help="Train segmentation head if the flag is provided")
|
||||
|
||||
# repeat args in imitate_episodes just to avoid error. Will not be used
|
||||
parser.add_argument('--eval', action='store_true')
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
|
||||
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
||||
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_ACT_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
model = build_ACT_model(args)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def build_CNNMLP_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
model = build_CNNMLP_model(args)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
9
detr/models/__init__.py
Normal file
9
detr/models/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .detr_vae import build as build_vae
|
||||
from .detr_vae import build_cnnmlp as build_cnnmlp
|
||||
|
||||
def build_ACT_model(args):
|
||||
return build_vae(args)
|
||||
|
||||
def build_CNNMLP_model(args):
|
||||
return build_cnnmlp(args)
|
||||
122
detr/models/backbone.py
Normal file
122
detr/models/backbone.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from typing import Dict, List
|
||||
|
||||
from util.misc import NestedTensor, is_main_process
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {'layer4': "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
def __init__(self, name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
275
detr/models/detr_vae.py
Normal file
275
detr/models/detr_vae.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
""" This is the DETR module that performs object detection """
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_proj = nn.Linear(14, hidden_dim) # project state to embedding
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(num_queries+1, hidden_dim))
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_proj(actions) # (bs, seq, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat([cls_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_is_pad = torch.full((bs, 1), False).to(qpos.device) # False: not a padding
|
||||
is_pad = torch.cat([cls_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, :self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim:]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5)
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=state_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
|
||||
return model
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
|
||||
return model
|
||||
|
||||
93
detr/models/position_encoding.py
Normal file
93
detr/models/position_encoding.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from util.misc import NestedTensor
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
N_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif args.position_embedding in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
314
detr/models/transformer.py
Normal file
314
detr/models/transformer.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False,
|
||||
return_intermediate_dec=False):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
||||
dropout, activation, normalize_before)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
||||
return_intermediate=return_intermediate_dec)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
||||
pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, memory, tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos, query_pos=query_pos)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(self, src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu", normalize_before=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(self, tgt, memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
10
detr/setup.py
Normal file
10
detr/setup.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from distutils.core import setup
|
||||
from setuptools import find_packages
|
||||
|
||||
setup(
|
||||
name='detr',
|
||||
version='0.0.0',
|
||||
packages=find_packages(),
|
||||
license='MIT License',
|
||||
long_description=open('README.md').read(),
|
||||
)
|
||||
1
detr/util/__init__.py
Normal file
1
detr/util/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
88
detr/util/box_ops.py
Normal file
88
detr/util/box_ops.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
468
detr/util/misc.py
Normal file
468
detr/util/misc.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
107
detr/util/plot_utils.py
Normal file
107
detr/util/plot_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
264
ee_sim_env.py
Normal file
264
ee_sim_env.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import numpy as np
|
||||
import collections
|
||||
import os
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
from utils import sample_box_pose, sample_insertion_pose
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def make_ee_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_{task_name}.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
if task_name == 'transfer_cube':
|
||||
task = TransferCubeEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif task_name == 'insertion':
|
||||
task = InsertionEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXEETask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array([
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
])
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['main'] = physics.render(height=480, width=640, camera_id='main') # TODO hardcoded camera name
|
||||
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
436
imitate_episodes.py
Normal file
436
imitate_episodes.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN
|
||||
from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES
|
||||
from utils import load_data # data functions
|
||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||
from policy import ACTPolicy, CNNMLPPolicy
|
||||
from visualize_episodes import save_videos
|
||||
|
||||
from sim_env import BOX_POSE
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
def main(args):
|
||||
set_seed(1)
|
||||
# command line parameters
|
||||
is_eval = args['eval']
|
||||
ckpt_dir = args['ckpt_dir']
|
||||
dataset_dir = args['dataset_dir']
|
||||
policy_class = args['policy_class']
|
||||
onscreen_render = args['onscreen_render']
|
||||
task_name = args['task_name']
|
||||
batch_size_train = args['batch_size']
|
||||
batch_size_val = args['batch_size']
|
||||
num_epochs = args['num_epochs']
|
||||
|
||||
# fixed parameters
|
||||
num_episodes = 50
|
||||
state_dim = 14
|
||||
lr_backbone = 1e-5
|
||||
backbone = 'resnet18'
|
||||
if policy_class == 'ACT':
|
||||
enc_layers = 4
|
||||
dec_layers = 7
|
||||
nheads = 8
|
||||
policy_config = {'lr': args['lr'],
|
||||
'num_queries': args['chunk_size'],
|
||||
'kl_weight': args['kl_weight'],
|
||||
'hidden_dim': args['hidden_dim'],
|
||||
'dim_feedforward': args['dim_feedforward'],
|
||||
'lr_backbone': lr_backbone,
|
||||
'backbone': backbone,
|
||||
'enc_layers': enc_layers,
|
||||
'dec_layers': dec_layers,
|
||||
'nheads': nheads,
|
||||
}
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
config = {
|
||||
'num_epochs': num_epochs,
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'state_dim': state_dim,
|
||||
'lr': args['lr'],
|
||||
'real_robot': 'TBD',
|
||||
'policy_class': policy_class,
|
||||
'onscreen_render': onscreen_render,
|
||||
'policy_config': policy_config,
|
||||
'task_name': task_name,
|
||||
'seed': args['seed'],
|
||||
'temporal_agg': args['temporal_agg']
|
||||
}
|
||||
|
||||
train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val)
|
||||
|
||||
if is_sim:
|
||||
policy_config['camera_names'] = SIM_CAMERA_NAMES
|
||||
config['camera_names'] = SIM_CAMERA_NAMES
|
||||
config['real_robot'] = False
|
||||
if task_name == 'transfer_cube':
|
||||
config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
config['episode_len'] = SIM_EPISODE_LEN_INSERTION
|
||||
else:
|
||||
policy_config['camera_names'] = CAMERA_NAMES
|
||||
config['camera_names'] = CAMERA_NAMES
|
||||
config['real_robot'] = True
|
||||
config['episode_len'] = EPISODE_LEN
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
results = []
|
||||
for ckpt_name in ckpt_names:
|
||||
success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
|
||||
results.append([ckpt_name, success_rate, avg_return])
|
||||
|
||||
for ckpt_name, success_rate, avg_return in results:
|
||||
print(f'{ckpt_name}: {success_rate=} {avg_return=}')
|
||||
print()
|
||||
exit()
|
||||
|
||||
# save dataset stats
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
|
||||
with open(stats_path, 'wb') as f:
|
||||
pickle.dump(stats, f)
|
||||
|
||||
best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
|
||||
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
||||
|
||||
# save best checkpoint
|
||||
ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}')
|
||||
|
||||
|
||||
def make_policy(policy_class, policy_config):
|
||||
if policy_class == 'ACT':
|
||||
policy = ACTPolicy(policy_config)
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy = CNNMLPPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return policy
|
||||
|
||||
|
||||
def make_optimizer(policy_class, policy):
|
||||
if policy_class == 'ACT':
|
||||
optimizer = policy.configure_optimizers()
|
||||
elif policy_class == 'CNNMLP':
|
||||
optimizer = policy.configure_optimizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_image(ts, camera_names):
|
||||
curr_images = []
|
||||
for cam_name in camera_names:
|
||||
curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w')
|
||||
curr_images.append(curr_image)
|
||||
curr_image = np.stack(curr_images, axis=0)
|
||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
||||
return curr_image
|
||||
|
||||
|
||||
def eval_bc(config, ckpt_name, save_episode=True):
|
||||
set_seed(1000)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
state_dim = config['state_dim']
|
||||
real_robot = config['real_robot']
|
||||
policy_class = config['policy_class']
|
||||
onscreen_render = config['onscreen_render']
|
||||
policy_config = config['policy_config']
|
||||
camera_names = config['camera_names']
|
||||
max_timesteps = config['episode_len']
|
||||
task_name = config['task_name']
|
||||
temporal_agg = config['temporal_agg']
|
||||
onscreen_cam = 'main'
|
||||
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
||||
print(loading_status)
|
||||
policy.cuda()
|
||||
policy.eval()
|
||||
print(f'Loaded: {ckpt_path}')
|
||||
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
|
||||
with open(stats_path, 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
|
||||
# load environment
|
||||
if real_robot:
|
||||
from scripts.utils import move_grippers # requires aloha
|
||||
from scripts.real_env import make_real_env # requires aloha
|
||||
env = make_real_env(init_node=True)
|
||||
env_max_reward = 0
|
||||
else:
|
||||
from sim_env import make_sim_env
|
||||
env = make_sim_env(task_name)
|
||||
env_max_reward = env.task.max_reward
|
||||
|
||||
query_frequency = policy_config['num_queries']
|
||||
if temporal_agg:
|
||||
query_frequency = 1
|
||||
num_queries = policy_config['num_queries']
|
||||
|
||||
max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks
|
||||
|
||||
num_rollouts = 50
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
for rollout_id in range(num_rollouts):
|
||||
rollout_id += 0
|
||||
### set task
|
||||
if task_name == 'transfer_cube':
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif task_name == 'insertion':
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
else:
|
||||
raise NotImplementedError
|
||||
ts = env.reset()
|
||||
|
||||
### onscreen render
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
|
||||
plt.ion()
|
||||
|
||||
### evaluation loop
|
||||
if temporal_agg:
|
||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
|
||||
|
||||
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
||||
image_list = [] # for visualization
|
||||
qpos_list = []
|
||||
target_qpos_list = []
|
||||
rewards = []
|
||||
with torch.inference_mode():
|
||||
for t in range(max_timesteps):
|
||||
### update onscreen render and wait for DT
|
||||
if onscreen_render:
|
||||
image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
|
||||
plt_img.set_data(image)
|
||||
plt.pause(DT)
|
||||
|
||||
### process previous timestep to get qpos and image_list
|
||||
obs = ts.observation
|
||||
if 'images' in obs:
|
||||
image_list.append(obs['images'])
|
||||
else:
|
||||
image_list.append({'main': obs['image']})
|
||||
qpos_numpy = np.array(obs['qpos'])
|
||||
qpos = pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
qpos_history[:, t] = qpos
|
||||
curr_image = get_image(ts, camera_names)
|
||||
|
||||
### query policy
|
||||
if config['policy_class'] == "ACT":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = policy(qpos, curr_image)
|
||||
if temporal_agg:
|
||||
all_time_actions[[t], t:t+num_queries] = all_actions
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
else:
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif config['policy_class'] == "CNNMLP":
|
||||
raw_action = policy(qpos, curr_image)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
### post-process actions
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = post_process(raw_action)
|
||||
target_qpos = action
|
||||
|
||||
### step the environment
|
||||
ts = env.step(target_qpos)
|
||||
|
||||
### for visualization
|
||||
qpos_list.append(qpos_numpy)
|
||||
target_qpos_list.append(target_qpos)
|
||||
rewards.append(ts.reward)
|
||||
|
||||
plt.close()
|
||||
if real_robot:
|
||||
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open
|
||||
pass
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards!=None])
|
||||
episode_returns.append(episode_return)
|
||||
episode_highest_reward = np.max(rewards)
|
||||
highest_rewards.append(episode_highest_reward)
|
||||
print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward==env_max_reward}')
|
||||
|
||||
if save_episode:
|
||||
save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))
|
||||
|
||||
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
|
||||
avg_return = np.mean(episode_returns)
|
||||
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
|
||||
for r in range(env_max_reward+1):
|
||||
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||
more_or_equal_r_rate = more_or_equal_r / num_rollouts
|
||||
summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'
|
||||
|
||||
print(summary_str)
|
||||
|
||||
# save success rate to txt
|
||||
result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
|
||||
with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
|
||||
f.write(summary_str)
|
||||
f.write(repr(episode_returns))
|
||||
f.write('\n\n')
|
||||
f.write(repr(highest_rewards))
|
||||
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
def forward_pass(data, policy):
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
||||
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
||||
|
||||
|
||||
def train_bc(train_dataloader, val_dataloader, config):
|
||||
num_epochs = config['num_epochs']
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
seed = config['seed']
|
||||
policy_class = config['policy_class']
|
||||
policy_config = config['policy_config']
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
policy.cuda()
|
||||
optimizer = make_optimizer(policy_class, policy)
|
||||
|
||||
train_history = []
|
||||
validation_history = []
|
||||
min_val_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
for epoch in tqdm(range(num_epochs)):
|
||||
print(f'\nEpoch {epoch}')
|
||||
# validation
|
||||
with torch.inference_mode():
|
||||
policy.eval()
|
||||
epoch_dicts = []
|
||||
for batch_idx, data in enumerate(val_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
epoch_dicts.append(forward_dict)
|
||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||
validation_history.append(epoch_summary)
|
||||
|
||||
epoch_val_loss = epoch_summary['loss']
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
|
||||
print(f'Val loss: {epoch_val_loss:.5f}')
|
||||
summary_string = ''
|
||||
for k, v in epoch_summary.items():
|
||||
summary_string += f'{k}: {v.item():.3f} '
|
||||
print(summary_string)
|
||||
|
||||
# training
|
||||
policy.train()
|
||||
optimizer.zero_grad()
|
||||
for batch_idx, data in enumerate(train_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
# backward
|
||||
loss = forward_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
train_history.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
||||
epoch_train_loss = epoch_summary['loss']
|
||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||
summary_string = ''
|
||||
for k, v in epoch_summary.items():
|
||||
summary_string += f'{k}: {v.item():.3f} '
|
||||
print(summary_string)
|
||||
|
||||
if epoch % 100 == 0:
|
||||
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
|
||||
torch.save(policy.state_dict(), ckpt_path)
|
||||
plot_history(train_history, validation_history, epoch, ckpt_dir, seed)
|
||||
|
||||
ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt')
|
||||
torch.save(policy.state_dict(), ckpt_path)
|
||||
|
||||
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
||||
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt')
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
|
||||
|
||||
# save training curves
|
||||
plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
|
||||
# save training curves
|
||||
for key in train_history[0]:
|
||||
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
|
||||
plt.figure()
|
||||
train_values = [summary[key].item() for summary in train_history]
|
||||
val_values = [summary[key].item() for summary in validation_history]
|
||||
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
|
||||
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
|
||||
# plt.ylim([-0.1, 1])
|
||||
plt.tight_layout()
|
||||
plt.legend()
|
||||
plt.title(key)
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved plots to {ckpt_dir}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--eval', action='store_true')
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
|
||||
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
||||
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True)
|
||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||
|
||||
# for ACT
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
84
policy.py
Normal file
84
policy.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_ACT_model_and_optimizer(args_override)
|
||||
self.model = model # CVAE decoder
|
||||
self.optimizer = optimizer
|
||||
self.kl_weight = args_override['kl_weight']
|
||||
print(f'KL Weight {self.kl_weight}')
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = dict()
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict['l1'] = l1
|
||||
loss_dict['kl'] = total_kld[0]
|
||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
|
||||
self.model = model # decoder
|
||||
self.optimizer = optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None # TODO
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict['mse'] = mse
|
||||
loss_dict['loss'] = loss_dict['mse']
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
187
record_sim_episodes.py
Normal file
187
record_sim_episodes.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py_cache
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Generate demonstration data in simulation.
|
||||
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||
Replace the gripper joint positions with the commanded joint position.
|
||||
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||
Save this episode of data, and continue to next episode of data collection.
|
||||
"""
|
||||
|
||||
task_name = args['task_name']
|
||||
dataset_dir = args['dataset_dir']
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
if task_name == 'transfer_cube':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
policy_cls = InsertionPolicy
|
||||
episode_len = SIM_EPISODE_LEN_INSERTION
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
policy = policy_cls(inject_noise)
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||
# replace gripper pose with gripper control
|
||||
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||
joint[6] = left_ctrl
|
||||
joint[6+7] = right_ctrl
|
||||
|
||||
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||
|
||||
# clear unused variables
|
||||
del env
|
||||
del episode
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print(f'====== Start Replaying ======')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
|
||||
episode_replay = [ts]
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- main (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
action (14,) 'float64'
|
||||
"""
|
||||
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in SIM_CAMERA_NAMES:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
# truncate here to be consistent
|
||||
joint_traj = joint_traj[:-1]
|
||||
episode_replay = episode_replay[:-1]
|
||||
|
||||
# len(joint_traj) i.e. actions: max_timesteps
|
||||
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||
max_timesteps = len(joint_traj)
|
||||
while joint_traj:
|
||||
action = joint_traj.pop(0)
|
||||
ts = episode_replay.pop(0)
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in SIM_CAMERA_NAMES:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
t0 = time.time()
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
with h5py_cache.File(dataset_path + '.hdf5', 'w', chunk_cache_mem_size=1024 ** 2 * 2) as root:
|
||||
# with h5py.File(dataset_path + '.hdf5', 'w') as root:
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||
action = root.create_dataset('action', (max_timesteps, 14))
|
||||
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||
|
||||
print(f'Saved to {dataset_dir}')
|
||||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
195
scripted_policy.py
Normal file
195
scripted_policy.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
def __init__(self, inject_noise=False):
|
||||
self.inject_noise = inject_noise
|
||||
self.step_count = 0
|
||||
self.left_trajectory = None
|
||||
self.right_trajectory = None
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def interpolate(curr_waypoint, next_waypoint, t):
|
||||
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||
curr_xyz = curr_waypoint['xyz']
|
||||
curr_quat = curr_waypoint['quat']
|
||||
curr_grip = curr_waypoint['gripper']
|
||||
next_xyz = next_waypoint['xyz']
|
||||
next_quat = next_waypoint['quat']
|
||||
next_grip = next_waypoint['gripper']
|
||||
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||
return xyz, quat, gripper
|
||||
|
||||
def __call__(self, ts):
|
||||
# generate trajectory at first timestep, then open-loop execution
|
||||
if self.step_count == 0:
|
||||
self.generate_trajectory(ts)
|
||||
|
||||
# obtain left and right waypoints
|
||||
if self.left_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||
next_left_waypoint = self.left_trajectory[0]
|
||||
|
||||
if self.right_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||
next_right_waypoint = self.right_trajectory[0]
|
||||
|
||||
# interpolate between waypoints to obtain current pose and gripper command
|
||||
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||
|
||||
# Inject noise
|
||||
if self.inject_noise:
|
||||
scale = 0.01
|
||||
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||
|
||||
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||
|
||||
self.step_count += 1
|
||||
return np.concatenate([action_left, action_right])
|
||||
|
||||
|
||||
class PickAndTransferPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
box_info = np.array(ts_first.observation['env_state'])
|
||||
box_xyz = box_info[:3]
|
||||
box_quat = box_info[3:]
|
||||
# print(f"Generate trajectory for {box_xyz=}")
|
||||
|
||||
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.25])
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||
]
|
||||
|
||||
|
||||
class InsertionPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||
peg_xyz = peg_info[:3]
|
||||
peg_quat = peg_info[3:]
|
||||
|
||||
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||
socket_xyz = socket_info[:3]
|
||||
socket_quat = socket_info[3:]
|
||||
|
||||
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.15])
|
||||
lift_right = 0.00715
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
|
||||
]
|
||||
|
||||
|
||||
def test_policy(task_name):
|
||||
# example rolling out pick_and_transfer policy
|
||||
onscreen_render = True
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION
|
||||
if task_name == 'transfer_cube':
|
||||
env = make_ee_sim_env('transfer_cube')
|
||||
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
|
||||
elif task_name == 'insertion':
|
||||
env = make_ee_sim_env('insertion')
|
||||
episode_len = SIM_EPISODE_LEN_INSERTION
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for episode_idx in range(2):
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['main'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['main'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
if episode_return > 0:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'transfer_cube'
|
||||
test_policy(test_task_name)
|
||||
|
||||
274
sim_env.py
Normal file
274
sim_env.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import collections
|
||||
import matplotlib.pyplot as plt
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE, BOX_INIT_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
def make_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_{task_name}.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
if task_name == 'transfer_cube':
|
||||
task = TransferCubeTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif task_name == 'insertion':
|
||||
task = InsertionTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7:7+6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7+6]
|
||||
|
||||
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['main'] = physics.render(height=480, width=640, camera_id='top') # TODO hardcoded camera name
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') # TODO hardcoded camera name
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14)
|
||||
# arm action
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# gripper action
|
||||
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||
action[6] = normalized_left_pos
|
||||
action[7+6] = normalized_right_pos
|
||||
return action
|
||||
|
||||
def test_sim_teleop():
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
|
||||
BOX_POSE[0] = BOX_INIT_POSE
|
||||
|
||||
# source of data
|
||||
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_left', init_node=True)
|
||||
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_right', init_node=False)
|
||||
|
||||
# setup the environment
|
||||
env = make_sim_env()
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
# setup plotting
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['image'])
|
||||
plt.ion()
|
||||
|
||||
for t in range(1000):
|
||||
action = get_action(master_bot_left, master_bot_right)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
|
||||
plt_img.set_data(ts.observation['image'])
|
||||
plt.pause(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sim_teleop()
|
||||
|
||||
192
utils.py
Normal file
192
utils.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from constants import SIM_CAMERA_NAMES, CAMERA_NAMES
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, episode_ids, dataset_dir, norm_stats):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_dir = dataset_dir
|
||||
self.norm_stats = norm_stats
|
||||
self.is_sim = None
|
||||
self.__getitem__(0) # initialize self.is_sim
|
||||
|
||||
def __len__(self):
|
||||
return len(self.episode_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample_full_episode = False # hardcode
|
||||
|
||||
episode_id = self.episode_ids[index]
|
||||
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
if is_sim:
|
||||
camera_names = SIM_CAMERA_NAMES
|
||||
else:
|
||||
camera_names = CAMERA_NAMES
|
||||
original_action_shape = root['/action'].shape
|
||||
episode_len = original_action_shape[0]
|
||||
if sample_full_episode:
|
||||
start_ts = 0
|
||||
else:
|
||||
start_ts = np.random.choice(episode_len)
|
||||
# get observation at start_ts only
|
||||
qpos = root['/observations/qpos'][start_ts]
|
||||
qvel = root['/observations/qvel'][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in camera_names:
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = root['/action'][start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
||||
|
||||
self.is_sim = is_sim
|
||||
padded_action = np.zeros(original_action_shape, dtype=np.float32)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
qpos_data = torch.from_numpy(qpos).float()
|
||||
action_data = torch.from_numpy(padded_action).float()
|
||||
is_pad = torch.from_numpy(is_pad).bool()
|
||||
|
||||
# channel last
|
||||
image_data = torch.einsum('k h w c -> k c h w', image_data)
|
||||
|
||||
# normalize image and change dtype to float
|
||||
image_data = image_data / 255.0
|
||||
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
||||
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
|
||||
|
||||
def get_norm_stats(dataset_dir, num_episodes):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
for episode_idx in range(num_episodes):
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_qpos_data = torch.stack(all_qpos_data)
|
||||
all_action_data = torch.stack(all_action_data)
|
||||
all_action_data = all_action_data
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
|
||||
action_std = all_action_data.std(dim=[0, 1], keepdim=True)
|
||||
action_std = torch.clip(action_std, 1e-2, 10) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
|
||||
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, 10) # clipping
|
||||
|
||||
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
||||
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
||||
"example_qpos": qpos}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val):
|
||||
# obtain train test split
|
||||
train_ratio = 0.8 # TODO
|
||||
shuffled_indices = np.random.permutation(num_episodes)
|
||||
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
|
||||
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
norm_stats = get_norm_stats(dataset_dir, num_episodes)
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(train_indices, dataset_dir, norm_stats)
|
||||
val_dataset = EpisodicDataset(val_indices, dataset_dir, norm_stats)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
|
||||
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||
|
||||
|
||||
### env utils
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
|
||||
### helper functions
|
||||
|
||||
def compute_dict_mean(epoch_dicts):
|
||||
result = {k: None for k in epoch_dicts[0]}
|
||||
num_items = len(epoch_dicts)
|
||||
for k in result:
|
||||
value_sum = 0
|
||||
for epoch_dict in epoch_dicts:
|
||||
value_sum += epoch_dict[k]
|
||||
result[k] = value_sum / num_items
|
||||
return result
|
||||
|
||||
def detach_dict(d):
|
||||
new_d = dict()
|
||||
for k, v in d.items():
|
||||
new_d[k] = v.detach()
|
||||
return new_d
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
148
visualize_episodes.py
Normal file
148
visualize_episodes.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from constants import DT, CAMERA_NAMES, SIM_CAMERA_NAMES
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
camera_names = SIM_CAMERA_NAMES if is_sim else CAMERA_NAMES
|
||||
for cam_name in camera_names:
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
for cam_name in cam_names:
|
||||
image = image_dict[cam_name]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
images.append(image)
|
||||
images = np.concatenate(images, axis=1)
|
||||
out.write(images)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(t_list, dataset_path):
|
||||
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||
h, w = 4, 10
|
||||
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||
# process t_list
|
||||
t_float = []
|
||||
for secs, nsecs in t_list:
|
||||
t_float.append(secs + nsecs * 10E-10)
|
||||
t_float = np.array(t_float)
|
||||
|
||||
ax = axs[0]
|
||||
ax.plot(np.arange(len(t_float)), t_float)
|
||||
ax.set_title(f'Camera frame timestamps')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
ax = axs[1]
|
||||
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||
ax.set_title(f'dt')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||
main(vars(parser.parse_args()))
|
||||
Reference in New Issue
Block a user