Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretraining LLaVA with SigLIP #1603

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llava/model/multimodal_encoder/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2

from .siglip_encoder import SiglipVisionTower

def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
Expand All @@ -11,5 +11,7 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif 'siglip' in vision_tower:
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)

raise ValueError(f'Unknown vision tower: {vision_tower}')
87 changes: 87 additions & 0 deletions llava/model/multimodal_encoder/siglip_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn as nn

from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig

class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()

self.is_loaded = False

self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')

if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)

def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return

self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)

self.is_loaded = True

def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features

@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)

return image_features

@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

@property
def dtype(self):
return self.vision_tower.dtype

@property
def device(self):
return self.vision_tower.device

@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only

@property
def hidden_size(self):
return self.config.hidden_size

@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size

@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
28 changes: 27 additions & 1 deletion llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

from PIL import Image

import functools


local_rank = None

Expand All @@ -46,6 +48,21 @@ def rank0_print(*args):
print(*args)


'''
This function sets interpolate_pos_encoding to True. If the image size is different than the default siglip
image size then positional embeddings are interpolated to account for the new size.
'''
def wrap_siglip_forward_method(siglip_object):
original_forward = siglip_object.forward

@functools.wraps(original_forward)
def wrapped_forward(pixel_values, interpolate_pos_encoding=True):
return original_forward(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

siglip_object.forward = wrapped_forward
return siglip_object


from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')

Expand Down Expand Up @@ -734,7 +751,10 @@ def expand2square(pil_img, background_color):
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
if 'siglip' in self.data_args.image_processor.image_processor_type.lower():
crop_size = self.data_args.image_processor.size
else:
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict

Expand Down Expand Up @@ -916,6 +936,12 @@ def make_inputs_require_grad(module, input, output):
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

if vision_tower.__class__.__name__ == 'SiglipVisionTower':
#Enforcing interpolate_pos_encoding = True by default for Siglip embeddings
siglip_embedding = vision_tower.vision_tower.vision_model.embeddings
siglip_embedding = wrap_siglip_forward_method(siglip_embedding)
vision_tower.vision_tower.vision_model.embeddings = siglip_embedding

data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True

Expand Down
35 changes: 35 additions & 0 deletions scripts/v1_5/pretrain_siglip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--version plain \
--data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder ./playground/data/LLaVA-Pretrain/images \
--vision_tower google/siglip-so400m-patch14-384 \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b-pretrain \
--num_train_epochs 1 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 24000 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb