Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to build the model execution plan using a model architecture file #2325

Open
Skyline-23 opened this issue Aug 27, 2024 · 9 comments
Open
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)

Comments

@Skyline-23
Copy link

Skyline-23 commented Aug 27, 2024

🐞Describing the bug

Hello. I'm trying to convert PyTorch model to Stateful CoreML Model

I wrote this code referred to WWDC 2024 session Mistral-7B model
The CoreML file is appear after run, but "Failed to build the model execution plan using a model architecture file" error appears when CoreML Class init

Stack Trace

/opt/homebrew/lib/python3.11/site-packages/transformers/modeling_utils.py:4779: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
  warnings.warn(
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
/Users/kimbuseong/Downloads/zenz-CoreML/convert-to-CoreML-Stateful.py:70: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if past_key.size(-2) > 0:
Torch var valueCache is added again.
Torch var keyCache is added again.
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1600/1600 [00:00<00:00, 2510.79 ops/s]
Running MIL frontend_pytorch pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 25.71 passes/s]
Running MIL default pipeline:  65%|█████████████████████████████████████████████████████████████████████████████                                          | 57/88 [00:03<00:02, 13.10 passes/s]
/opt/homebrew/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py:894: RuntimeWarning: overflow encountered in cast
  return input_var.val.astype(dtype=string_to_nptype(dtype_val))
Running MIL default pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 88/88 [00:06<00:00, 12.66 passes/s]
Running MIL backend_mlprogram pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 62.18 passes/s]
/opt/homebrew/lib/python3.11/site-packages/coremltools/models/model.py:489: RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: {
    NSLocalizedDescription = "Failed to build the model execution plan using a model architecture file '/private/var/folders/pz/rmstwmls5ls_0hrn5_jj01kh0000gn/T/tmppa7zpned.mlmodelc/model.mil' with error code: -14.";
}
  _warnings.warn(
Model successfully converted and saved as: zenz_v1_cached.mlpackage

To Reproduce

import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention, GPT2_ATTENTION_CLASSES
from transformers import AutoTokenizer
import coremltools as ct
from typing import Optional, Tuple
import numpy as np
from transformers.cache_utils import Cache
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class SliceUpdateKeyValueCache(Cache):
    def __init__(
        self,
        shape: Tuple[int, ...], 
        device="cpu",
        dtype=torch.float32
    ) -> None:
        super().__init__()
        self.past_seen_tokens: int = 0
        self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
        self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)

    def update(
        self,
        k_state: torch.Tensor, 
        v_state: torch.Tensor, 
        layer_idx: int, 
        slice_indices: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if len(slice_indices) != 2:
            raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
        
        begin, end = slice_indices
        self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
        self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
        k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
        v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
        return k_cache, v_cache

    def get_seq_length(self, _: int = 0) -> int:
        return self.past_seen_tokens
    
    def to_past_key_values(self):
        """Convert the internal cache to a format expected by GPT2."""
        return [(self.k_cache[layer], self.v_cache[layer]) for layer in range(self.k_cache.size(0))]

class SliceUpdateGPT2Attention(GPT2Attention):
    def __init__(self, config, layer_idx: Optional[int] = None):
        super().__init__(config=config, layer_idx=layer_idx)

    @torch.no_grad()
    def forward(
        self,
        hidden_states: torch.Tensor, 
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None, 
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: bool = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        # 기존 코드 유지
        query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            if past_key.size(-2) > 0:
                key = torch.cat([past_key, key], dim=-2)
                value = torch.cat([past_value, value], dim=-2)

        if attention_mask is not None:
            attention_mask = attention_mask[:, :, :, -key.size(-2):]

        # 어텐션 가중치를 반환받도록 수정
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)

        present = (key, value) if use_cache else None

        if output_attentions:
            return attn_output, present, attn_weights
        else:
            return attn_output, present

class StatefulZenz(torch.nn.Module):
    def __init__(self, model, max_context_size: int = 256, batch_size: int = 1):
        super(StatefulZenz, self).__init__()

        GPT2_ATTENTION_CLASSES["sdpa"] = SliceUpdateGPT2Attention

        self.model = model
        config = self.model.config
        self.kv_cache_shape: Tuple[int, ...] = (
            config.num_hidden_layers,
            batch_size,
            config.n_head,
            max_context_size,
            config.hidden_size // config.num_attention_heads,
        )
        self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
        self.register_buffer("keyCache", self.kv_cache.k_cache)
        self.register_buffer("valueCache", self.kv_cache.v_cache)

    def _extend_attention_mask(self, attention_mask, past_key_values):
        if past_key_values is not None:
            past_length = past_key_values[0][0].size(-2)
            new_length = past_length + attention_mask.size(-1)
            extended_attention_mask = torch.ones(
                (attention_mask.size(0), 1, 1, new_length),
                dtype=torch.float32,
                device=attention_mask.device
            )
            extended_attention_mask[:, :, :, -attention_mask.size(-1):] = attention_mask
            return extended_attention_mask
        return attention_mask

    @torch.no_grad()
    def forward(self, input_ids, attention_mask):
        self.kv_cache.past_seen_tokens = attention_mask.shape[-1] - input_ids.shape[-1]
        past_key_values = self.kv_cache.to_past_key_values()
        outputs = self.model(
            input_ids, 
            attention_mask=self._extend_attention_mask(attention_mask=attention_mask, past_key_values=past_key_values), 
            past_key_values=past_key_values, 
            use_cache=True,
            output_attentions=True  # 어텐션 가중치를 반환받도록 설정
        )
        return outputs.logits

def convert_model(model_name: str, output_path: str):
    # Set up model and tokenizer
    GPT2_ATTENTION_CLASSES["sdpa"] = SliceUpdateGPT2Attention
    model = GPT2LMHeadModel.from_pretrained(model_name).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Prepare example input
    text = "Example sentence"
    inputs = tokenizer(text, return_tensors="pt")
    
    # Create stateful model
    stateful_zenz = StatefulZenz(model).eval()
    
    # Trace the model with example inputs
    example_inputs = (inputs['input_ids'], inputs['attention_mask'])
    traced_model = torch.jit.trace(
        stateful_zenz,
        example_inputs,
        check_trace=False  # Disable trace checking to avoid minor numerical differences
    )

    # Convert to CoreML
    mlmodel = ct.convert(
        traced_model,
        inputs=[
            ct.TensorType(
                name="input_ids",
                shape=(1, ct.RangeDim(1, 256)),
                dtype=np.float32
            ),
            ct.TensorType(
                name="attention_mask",
                shape=(1, ct.RangeDim(1, 256)),
                dtype=np.float32
            )
        ],
        outputs=[
            ct.TensorType(
                name="output",
                dtype=np.float32
            )
        ],
        states=[
            ct.StateType(
                wrapped_type=ct.TensorType(
                    shape=stateful_zenz.kv_cache_shape,
                    dtype=np.float16
                ),
                name="keyCache",
            ),
            ct.StateType(
                wrapped_type=ct.TensorType(
                    shape=stateful_zenz.kv_cache_shape,
                    dtype=np.float16
                ),
                name="valueCache",
            ),
        ],
        minimum_deployment_target=ct.target.iOS18,
    )

    mlmodel.save(output_path)
    print(f"Model successfully converted and saved as: {output_path}")

# Usage
model_name = "Miwa-Keita/zenz-v1-checkpoints"
convert_model(model_name, "zenz_v1_cached.mlpackage")

System environment (please complete the following information):

  • coremltools version: 8.0b2
  • OS (e.g. MacOS version or Linux type): Mac OS Version 15.1 Beta (24B5024e)
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • python 3.11 with homebrew
    • torch-2.3.0
    • torchvision-0.18.0
    • transformers-4.41.0
@Skyline-23 Skyline-23 added the bug Unexpected behaviour that should be corrected (type) label Aug 27, 2024
@TobyRoseman
Copy link
Collaborator

@Skyline-23 that is a lot of code. Can you give us a more minimal example?

@Skyline-23
Copy link
Author

@TobyRoseman All of code is required to run stateful model based on GPT-2. Sorry 😢

@lithium0003
Copy link

Official document example says,

converted_model_kvcache = ct.convert(
    traced_model_kvcache,
    inputs=inputs,
    outputs=outputs,
    states=states,
    minimum_deployment_target=ct.target.iOS18,
    compute_units=ct.ComputeUnit.CPU_AND_GPU,
)

I got same error on compute_units=ct.ComputeUnit.ALL, but pass on compute_units=ct.ComputeUnit.CPU_AND_GPU

@Skyline-23
Copy link
Author

@lithium0003 It's not working....

@lithium0003
Copy link

compute_units=ct.ComputeUnit.CPU_AND_GPU and
ignore attention_mask with attention_mask = None just before self._attn(), it's maybe pass, but I don't know why it pass.

@adizhol-str
Copy link

adizhol-str commented Nov 14, 2024

I'm having the save error with Apple's checkpoint of DepthAnything. It worked a month ago.

@Skyline-23
Copy link
Author

@lithium0003 It works after adding attention_mask = None before attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
But, It produce another error

loc("/Users/kimbuseong/Library/Caches/org.python.python/com.apple.e5rt.e5bundlecache/24C5089c/482419A8FF15595378FF575F9BDC33548B8A4933527ED3D8F1364CA6FEF48A51/70D546C0C146A466AD586A6DE692334F73CB4C7D816C6FA63C4AF624CBCB818D.bundle/H13S.bundle/main/main_mps_graph/main_mps_graph.mpsgraphpackage/model_0.mpsgraph":0:0): error: attempting to parse a byte at the end of the bytecode

I think it's error of mps but I don't know how to resolve this error

@Skyline-23
Copy link
Author

huggingface/swift-chat#24 I found similar error in swift chat

@jsflax
Copy link

jsflax commented Dec 8, 2024

I am also encountering this– notably, if you try to run the mlpackage (from Swift) using .cpuAndNeuralEngine, that triggers it without fail. Also, attempting to run the model through coremltools.optimize.coreml.experimental.linear_quantize_activations during the quantization phase (if you choose to quantize it) will also trigger it.

Removing the state parameters entirely from ct.convert and using a simple torch.nn.Module allows you to use the NE again, but obviously this means you do not get to leverage the new stateful features.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)
Projects
None yet
Development

No branches or pull requests

6 participants
@jsflax @lithium0003 @TobyRoseman @Skyline-23 @adizhol-str and others