Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MiniLLM and Data selection #285

Merged
merged 8 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
19 changes: 15 additions & 4 deletions data_selection/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Data Selection via Optimal Control

[paper]() | [huggingface]()
[paper](https://arxiv.org/abs/2410.07064) | [huggingface](https://huggingface.co/Data-Selection)

<div>Theory Overview:</div>

<img src="./figures/theory.png" width="70%"/>
<br>
<div>Training Framwork PDS:</div>

<img src="./figures/method.png" width="70%"/>

## Overview of the Training Framework
Expand All @@ -14,7 +17,11 @@
4. Filter CC with the scores.
5. Pre-train the model.

## Selected Data and Pre-Trained Models
## Pre-Trained Models
+ [Models](https://huggingface.co/collections/Data-Selection/baseline-models-670550972a59015f6c8870ab) Trained on Redpajama CC (Conventional Pre-Training, Baselines)
+ [Models](https://huggingface.co/collections/Data-Selection/pds-models-6705504096a78d10a30837c0) Trained PDS-Selected Data

## Selected Data
TODO

## Details of the Pipeline & How to run
Expand Down Expand Up @@ -109,6 +116,10 @@ bash $BASE_PATH/scripts/eval_offline/lm/${model_size}_pds.sh $BASE_PATH
done
```


## 9 Citation
TODO
@article{gu2024data,
title={Data Selection via Optimal Control for Language Models},
author={Gu, Yuxian and Dong, Li and Wang, Hongning and Hao, Yaru and Dong, Qingxiu and Wei, Furu and Huang, Minlie},
journal={arXiv preprint arXiv:2410.07064},
year={2024}
}
9 changes: 1 addition & 8 deletions data_selection/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import argparse
import os
import deepspeed
import numpy as np
from numerize.numerize import numerize


Expand Down Expand Up @@ -94,8 +93,6 @@ def add_data_args(parser: argparse.ArgumentParser):
group.add_argument("--test-data-dir", type=str, default=None)
group.add_argument("--proxy-data-dir", type=str, default=None)
group.add_argument("--processed-data-dir", type=str, default=None)
group.add_argument("--force-process", action="store_true")
group.add_argument("--force-process-demo", action="store_true")
group.add_argument("--data-process-workers", type=int, default=-1)
group.add_argument("--precompute-data-order", action="store_true")
group.add_argument("--train-num", type=int, default=None)
Expand All @@ -107,7 +104,6 @@ def add_data_args(parser: argparse.ArgumentParser):
group.add_argument("--gen-num", type=int, default=None)
group.add_argument("--infer-num", type=int, default=None)
group.add_argument("--data-name", type=str, default=None)
group.add_argument("--prompt-type", type=str, default=None)
group.add_argument("--num-workers", type=int, default=1)
group.add_argument("--max-prompt-length", type=int, default=512)
group.add_argument("--min-prompt-length", type=int, default=128)
Expand All @@ -120,17 +116,15 @@ def add_data_args(parser: argparse.ArgumentParser):
group.add_argument("--min-offset", type=int, default=0)
group.add_argument("--data-split", type=str, default=None)
group.add_argument("--no-shuffle", action="store_true")
group.add_argument("--trunc-data", action="store_true")

group.add_argument("--prompt-data-dir", type=str)
group.add_argument("--lm-data-dir", type=str)
group.add_argument("--eval-ppl", action="store_true")
group.add_argument("--eval-gen", action="store_true")

group.add_argument("--only-prompt", action="store_true")
group.add_argument("--prompt-data-full-loss", action="store_true",
help="Compute loss on the entire sentence in prompt data type.")
group.add_argument("--remove-bos-in-training", action="store_true",
help="Remove bos token during training. This ensures the first token is bos token.")
group.add_argument("--chunk-num-per-shard", type=int, default=None)
group.add_argument("--max-shard-num", type=int, default=10000000)
group.add_argument("--max-sample-num", type=int, default=None)
Expand Down Expand Up @@ -214,7 +208,6 @@ def add_pmp_solver_args(parser: argparse.ArgumentParser):
group.add_argument("--dev-grad-batch-size", type=int, default=None)

group.add_argument("--compute-ct-interval", type=int, default=1)
group.add_argument("--trunc-data", action="store_true")

group.add_argument("--dataset-type", type=str, default="lm")

Expand Down
6 changes: 2 additions & 4 deletions data_selection/data_scorer/modeling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import os
import json

from utils import get_model, print_rank, all_gather
from transformers import AutoModel, AutoConfig, AutoModelForCausalLM
from utils import get_model, print_rank
from transformers import AutoModel, AutoConfig


class BertBaseModel(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions data_selection/data_scorer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import json
import torch
import random
import shutil
from time import time
from tqdm import tqdm
from numerize.numerize import numerize

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from scipy.stats import spearmanr
from sklearn.metrics import f1_score, accuracy_score

from utils import print_rank, all_gather
Expand Down
6 changes: 0 additions & 6 deletions data_selection/data_utils/data_scorer_datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import random
import torch
import os
from torch.utils.data import Dataset
from torch.distributed import get_rank, get_world_size
from utils import print_rank
from tqdm import tqdm
import json
import numpy as np
import h5py

from .base_datasets import BaseDataset
Expand Down
2 changes: 0 additions & 2 deletions data_selection/data_utils/distributed_indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import os
import struct

from itertools import accumulate

import numpy as np
import torch
import torch.distributed as dist
Expand Down
10 changes: 0 additions & 10 deletions data_selection/data_utils/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import numpy as np
import torch

from utils import naive_copy_to_blob


def best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
Expand Down Expand Up @@ -167,10 +164,6 @@ def add_np_item(self, item):
self.builder.finalize(self.idx_file)
self._chunks = []

if self.tmp_output_path is not None:
naive_copy_to_blob(self.base_path, self.tmp_bin_file, self.bin_file.replace(self.base_path, ""), rm_source=True)
naive_copy_to_blob(self.base_path, self.tmp_idx_file, self.idx_file.replace(self.base_path, ""), rm_source=True)

self.ofid += 1
self.bin_file = os.path.join(self.output_path, f"{self.split}_{self.ofid}.bin")
self.idx_file = os.path.join(self.output_path, f"{self.split}_{self.ofid}.idx")
Expand All @@ -195,9 +188,6 @@ def finalize(self):
else:
self.builder.finalize(self.idx_file)
self._chunks = []
if self.tmp_output_path is not None:
naive_copy_to_blob(self.base_path, self.tmp_bin_file, self.bin_file.replace(self.base_path, ""), rm_source=True)
naive_copy_to_blob(self.base_path, self.tmp_idx_file, self.idx_file.replace(self.base_path, ""), rm_source=True)


class IndexedDataset(torch.utils.data.Dataset):
Expand Down
14 changes: 0 additions & 14 deletions data_selection/data_utils/lm_datasets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
import random
import torch
import os
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset

from torch.distributed import get_rank, get_world_size
from utils import print_rank
from tqdm import tqdm
import json
import numpy as np
from .base_datasets import BaseDataset


Expand All @@ -27,10 +17,6 @@ def __getitem__(self, index: int):
index = int(self.order[self.epoch, index])

data = self.data[index].astype(int)

if self.args.remove_bos_in_training:
assert data[0] == self.tokenizer.bos_token_id
data = data[1:]

return index, data

Expand Down
13 changes: 0 additions & 13 deletions data_selection/data_utils/prompt_datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
import random
import torch
import os
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset

from torch.distributed import get_rank, get_world_size
from utils import print_rank
from tqdm import tqdm
import json
import numpy as np
from .base_datasets import BaseDataset

Expand Down Expand Up @@ -50,10 +41,6 @@ def __getitem__(self, index: int):
assert len(prompt_ids) + len(response_ids) <= self.args.max_length + 1, \
f"Prompt and response too long: {len(prompt_ids)} + {len(response_ids)} > {self.args.max_length + 1}"

if self.args.remove_bos_in_training:
assert prompt_ids[0] == self.tokenizer.bos_token_id
prompt_ids = prompt_ids[1:]

return index, prompt_ids, response_ids

def collate(self, samples):
Expand Down
5 changes: 1 addition & 4 deletions data_selection/eval_offline/lm/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import torch
import json
import time
import datasets

from utils import print_rank, save_rank, get_tokenizer, BOS_MODELS
from utils import save_rank, get_tokenizer
from train_eval_utils.base_evaluator import BaseEvaluator
from pretrain.trainer import PreTrainer

Expand Down
1 change: 0 additions & 1 deletion data_selection/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ pip3 install rich
pip3 install accelerate
pip3 install datasets
pip3 install sentencepiece
pip3 install peft
pip3 install matplotlib
pip3 install wandb
pip3 install cvxpy
Expand Down
2 changes: 1 addition & 1 deletion data_selection/pmp_solver/grad_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.func import grad, jvp, vmap, grad_and_value
from torch.func import grad, jvp, vmap
from .model_wrapper import TransformerWrapper


Expand Down
1 change: 0 additions & 1 deletion data_selection/pmp_solver/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from torch.func import functional_call
from utils import print_rank


class TransformerWrapper(nn.Module):
Expand Down
10 changes: 3 additions & 7 deletions data_selection/pmp_solver/trainer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import os
import math
import random
import numpy as np
from tqdm import tqdm
from time import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler
from torch.func import grad, jvp, vmap, grad_and_value
from torch.func import grad_and_value

from transformers import AutoConfig, AutoModelForCausalLM

from utils import all_gather, print_rank, save_rank, get_model
from utils import all_gather, print_rank, get_model
from train_eval_utils import BaseTrainer
from data_utils import PromptDataset, LMDataset

from .model_wrapper import TransformerWrapper
from .checkpointing import Checkpointing
from .grad_utils import jvp_single, jvp_batch, hvp_fwdrev
from .grad_utils import jvp_batch, hvp_fwdrev


class GammaTrainer(BaseTrainer):
Expand Down
10 changes: 1 addition & 9 deletions data_selection/pretrain/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
import re
import wandb
from collections import defaultdict
from train_eval_utils.base_trainer import BaseTrainer
from utils import print_rank, save_rank
from torch.distributed import get_rank
from data_utils.prompt_datasets import PromptDataset
from data_utils.lm_datasets import LMDataset


Expand All @@ -26,16 +23,11 @@ def set_datasets(self, args=None, do_train=True):
data_split = args.data_split or "data"
if do_train:
if args.dev_data_dir is None or os.path.samefile(args.dev_data_dir, args.data_dir):
print_rank("### Spliting dev data from training data ###")
args.dev_data_dir = args.data_dir
min_train_offset = 100000
assert 0 == 1, "dangerous!"
raise ValueError("dev_data_dir should be different from data_dir")
else:
min_train_offset = 0
self.train_dataset = LMDataset(args, self.tokenizer, data_split, args.data_dir, args.train_num, data_name="lm", min_offset=min_train_offset+self.args.min_offset, min_state=self.args.min_state)
self.print_and_save(f"### Training Data Number: {len(self.train_dataset)}")
# self.train_dataset = LMDataset(args, self.tokenizer, "data", args.data_dir, args.dev_num, max_offset=10000)
# print_rank("train num", len(self.train_dataset))
self.eval_dataset = LMDataset(args, self.tokenizer, data_split, args.dev_data_dir, args.dev_num, data_name="lm_dev", max_offset=100000)
self.print_and_save(f"### Dev Data Number: {len(self.eval_dataset)}")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ DISTRIBUTED_ARGS="--num_gpus $GPUS_PER_NODE \

# model
BASE_PATH=${1-"/home/MiniLLM"}
CKPT="${BASE_PATH}/results/data_scorer/cc-sgd100-160M-10k-lima-163840/fairseq_125M/e5-w10-bs16-lr0.0001cosine1e-07-G2-N16-NN2/mean-bias-linear/best"
CKPT_NAME="cc-sgd100-160M-10k-lima"
CKPT="${BASE_PATH}/results/data_scorer/"
CKPT_NAME="cc-160M-lima"
# data
DATA_DIR="${BASE_PATH}/processed_data/data_scorer_infer/cc/mistral-fairseq-1024"
# hp
Expand Down
2 changes: 1 addition & 1 deletion data_selection/scripts/pmp_solver/160M.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ OPTS=""
OPTS+=" --type pmp_solver"
# model
OPTS+=" --model-type mistral"
OPTS+=" --model-path ${BASE_PATH}/results/pretrain/cc/mistral_160M/t100K-w2K-bs8-lr0.0006cosine6e-05-G4-N16-NN2-scr/10000"
OPTS+=" --model-path ${BASE_PATH}/results/pretrain/mistral_160M-10K/"
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --ckpt-name 160M-10k"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
Expand Down
46 changes: 0 additions & 46 deletions data_selection/sft/lm_trainer.py

This file was deleted.

Loading