Skip to content

Commit

Permalink
convert : fix set_vocab_sentencepiece (#6866)
Browse files Browse the repository at this point in the history
* convert : fix set_vocab_sentencepiece

* Update convert-hf-to-gguf.py
  • Loading branch information
ggerganov authored May 18, 2024
1 parent 0583484 commit b49a13d
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ def _set_vocab_sentencepiece(self):

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size

for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
Expand All @@ -588,21 +592,23 @@ def _set_vocab_sentencepiece(self):
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens.append(text)
scores.append(score)
toktypes.append(toktype)
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)

for key in added_tokens_json:
key = key.encode("utf-8")
if key not in tokens:
tokens.append(key)
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
Expand All @@ -612,8 +618,6 @@ def _set_vocab_sentencepiece(self):
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)

assert len(tokens) == vocab_size

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
Expand Down

0 comments on commit b49a13d

Please sign in to comment.