Skip to content

CAGE: Causal Attention Enables Data-Efficient Generalizable Robotic Manipulation

Notifications You must be signed in to change notification settings

cage-policy/CAGE

Repository files navigation

CAGE: Causal Attention Enables Data-Efficient Generalizable Robotic Manipulation

[Paper] [Project Page] [Sample Data]

CAGE is a data-efficient generalizable robotic manipulation policy. Extensive eperiments demonstrate that CAGE can effectively complete the task in test environments with different levels of distribution shifts.

🛫 Getting Started

💻 Installation

To set up the conda environment for CAGE, we provide a minimal package requirement list for training & inference (See CAGE.yaml), as well as a full environment export (CAGE_full.yaml) for reproducibility reference.

conda env create -f CAGE.yaml

🛢️ Data Collection

We apply the same data collection process as described in RH20T: The robotic arm is equipped with an in-hand camera and faces directly to a fixed camera. A haptic device is then used to collect expert demonstrations. Since CAGE utilizes only RGB images, no calibration is required.

We provide the sample data for each task on Google Drive and Baidu Netdisk (code: vwxq).

🧑🏻‍💻 Training

We adopt accelerate and deepspeed for configurable single-GPU / multi-GPU training. The main training script is train_cage.py. Its config file is located in configs with cage prefix. (e.g. cage-rh20t.yaml is used for training CAGE on the RH20T dataset)

An example of the training command is as follows:

accelerate launch \
  --num_machines=1 \
  --num_processes=4 \
  --gpu_ids=0,1,2,3 \
  -m train_cage \
  --tag [real|rh20t] \
  [--test] \
  [--resume PATH] \
  [other parameters defined in the config file (e.g. batch_size=16 dataset.obs_horizon=2)]

To train on single GPU, you can simply set --num_processes=1 and --gpu_ids=0.

With a batch size of 16, CAGE takes around 60 GB of VRAM on each GPU. To reduce the memory usage, you can use a smaller batch size combined with gradient accumulation (e.g., batch_size=8 with gradient_accumulation_steps=2) to fit the model on your GPUs.

📁 Checkpoint

Checkpoints are automatically saved each checkpoint_steps and after each epoch. Once the training is done, a final_ckpt will be saved for evaluation.

If you want to evaluate the policy from an intermediate checkpoint, you should first extract model.bin from DeepSpeed states using the provided script under each checkpoint directory:

cd [path_to_checkpoint] && python zero_to_fp32.py . model.bin

Since model.bin is directly exported from the training states, frozen parameters of DINOv2 is also embedded. To further reduce the file size, you can use the following command to store only trainable parameters:

python merge_weight.py --unmerge --ckpt [path_to_model.bin] --config [path_to_config.yaml]

🤖 Inference

Running inference on CAGE is as simple as:

import os
from agent import CAGEAgent
from omegaconf import OmegaConf

conf = OmegaConf.load(os.path.join('configs', 'eval-cage.yaml'))
# create the inference agent from config file
agent = CAGEAgent(conf)

# each list is a sequence of observations in order of time
#   (the last element is the most recent one)
# 'proprio' is a list of end-effector pose of the robot
# If the model is trained without proprioception, 
#   the current pose of the robot is still required
#   for relative action calculation. (length should be 1)
obs_dict = {
  'global_cam': [...],  # [PIL.Image] * obs_horizon
  'wrist_cam': [...],   # [PIL.Image] * obs_horizon
  'proprio': [...],     # [np.ndarray] * obs_horizon or 1
}

# get next 8 actions based on current observations
xyz, rot, w = agent(obs_dict, act_horizon=8)

For advanced usage, please refer to agent/cage.py, agent/generic.py and eval.py for more details.

🤖 Evaluation

We provide a sample script eval.py for evaluation on our platform (Flexiv Rizon 4 robotic arm + Dahuan AG-95 gripper) with the following command:

python eval.py \
  --config [path_to_config_file] \
  --ckpt [path_to_model_ckpt] \
  --ctrl_freq 10 \
  --pred_interval 4 \
  --t_ensemble

To evaluate on the platform with the same setup as ours, extra python libraries are required:

pip install pyserial modbus_tk pyrealsense2

And you also need to install the Flexiv RDK for robot control. Specifically, download FlexivRDK v0.9 and copy lib_py/flexivrdk.cpython-310-[arch].so to hardware/robot directory.

For other platforms, you should modify the codes in hardware/ and the evaluation script to adapt to your own configuration.

📈 Results

As the level of distribution shift increases, the performance of selected 2D/3D baselines drops significantly, while CAGE maintains a stable performance, even when evaluating in a completely new environment.

In similar environments, CAGE offers an average of 42% increase in task completion rate. While all baselines fail to execute the task in unseen environments, CAGE manages to obtain a 43% completion rate and a 51% success rate in average.

L0 Evaluation Results.

L1 Evaluation Results.

L2 Evaluation Results.

🙏 Acknowledgement

  • Our diffusion module is adapted from Diffusion Policy. This part is under MIT License.
  • Our real-world evaluation code is adapted from RISE. This part is under CC-BY-NC-SA 4.0 License.

✍️ Citation

@article{
  xia2024cage,
  title   = {CAGE: Causal Attention Enables Data-Efficient Generalizable Robotic Manipulation},
  author  = {Xia, Shangning and Fang, Hongjie and Lu, Cewu and Fang, Hao-Shu},
  journal = {arXiv preprint arXiv:2410.14974},
  year    = {2024}
}

📃 License

CAGE (including data and codebase) by Shangning Xia, Hongjie Fang, Cewu Lu, Hao-Shu Fang is licensed under CC BY-NC-SA 4.0

About

CAGE: Causal Attention Enables Data-Efficient Generalizable Robotic Manipulation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published