Skip to content

Commit

Permalink
Merge pull request #6395 from hiyouga/hiyouga/fix_genkwargs
Browse files Browse the repository at this point in the history
[generate] fix generate kwargs
  • Loading branch information
hiyouga authored Dec 19, 2024
2 parents ffbb4db + d4c1fda commit c6e3c14
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,3 @@ saves/
output/
wandb/
generated_predictions.jsonl

# unittest
dummy_dir/
11 changes: 10 additions & 1 deletion src/llamafactory/hparams/generating_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional

from transformers import GenerationConfig


@dataclass
class GeneratingArguments:
Expand Down Expand Up @@ -69,10 +71,17 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)

if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)

return args
10 changes: 3 additions & 7 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "to
return padded_tensor.contiguous() # in contiguous memory

def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any]
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.
Expand Down Expand Up @@ -179,12 +179,8 @@ def save_predictions(
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)

decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(
preds, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)
decoded_labels = self.processing_class.batch_decode(
labels, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_sft(
)

# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
Expand Down Expand Up @@ -130,7 +130,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)

# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
4 changes: 2 additions & 2 deletions tests/e2e/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)


def test_export():
export_dir = os.path.join("output", "dummy_dir/llama3_export")
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)
6 changes: 5 additions & 1 deletion tests/train/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
{
"output_dir": os.path.join("output", f"shuffle{str(disable_shuffling).lower()}"),
"disable_shuffling": disable_shuffling,
**TRAIN_ARGS,
}
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
Expand Down

0 comments on commit c6e3c14

Please sign in to comment.