# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import json
import os
import os.path
import subprocess
import sys
import tempfile
import textwrap
import threading
import unittest
import unittest.mock as mock
import uuid
import warnings
from pathlib import Path
from unittest.mock import patch

import httpx
import pytest
from huggingface_hub import HfApi, snapshot_download, split_torch_state_dict_into_shards
from parameterized import parameterized
from pytest import mark

from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForImageClassification,
    AutoModelForSequenceClassification,
    BartConfig,
    BartForConditionalGeneration,
    BartModel,
    CLIPTextModelWithProjection,
    DynamicCache,
    GPT2Config,
    GPT2LMHeadModel,
    LlavaConfig,
    LlavaForConditionalGeneration,
    MistralConfig,
    MistralForCausalLM,
    OPTConfig,
    OPTForCausalLM,
    OwlViTForObjectDetection,
    PreTrainedConfig,
    T5Config,
    T5ForConditionalGeneration,
    is_torch_available,
    logging,
)
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from transformers.models.mistral.modeling_mistral import MistralModel
from transformers.testing_utils import (
    TOKEN,
    CaptureLogger,
    LoggingLevel,
    TemporaryHubRepo,
    TestCasePlus,
    hub_retry,
    is_staging_test,
    require_accelerate,
    require_non_hpu,
    require_read_token,
    require_torch,
    require_torch_accelerator,
    require_torch_multi_accelerator,
    slow,
    torch_device,
)
from transformers.utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    check_torch_load_is_safe,
)
from transformers.utils.import_utils import (
    is_flash_attn_2_available,
    is_flash_attn_3_available,
    is_kernels_available,
    is_torch_npu_available,
)

from ..test_modeling_common import compare_state_dicts


sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))

from test_module.custom_configuration import CustomConfig


if is_torch_available():
    import torch
    from safetensors.torch import save_file as safe_save_file
    from test_module.custom_modeling import CustomModel
    from torch import nn

    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        BertConfig,
        BertModel,
        CLIPTextModel,
        GenerationMixin,
        MusicgenConfig,
        MusicgenForConditionalGeneration,
        PreTrainedModel,
        T5Config,
        T5ForConditionalGeneration,
    )
    from transformers.modeling_attn_mask_utils import (
        AttentionMaskConverter,
        _create_4d_causal_attention_mask,
        _prepare_4d_attention_mask,
        _prepare_4d_causal_attention_mask,
    )
    from transformers.modeling_utils import (
        _find_disjoint,
        _find_identical,
    )
    from transformers.pytorch_utils import isin_mps_friendly

    # Fake pretrained models for tests
    class BaseModel(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(5, 5)
            self.linear_2 = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class BaseModelWithUnexpectedKeys(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig
        _keys_to_ignore_on_load_unexpected = [r"^mtp.*"]

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(50, 50)
            self.linear_2 = nn.Linear(50, 50)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class BaseModelWithMissingKeys(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig
        _keys_to_ignore_on_load_missing = [r"^linear"]

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(50, 50)
            self.linear_2 = nn.Linear(50, 50)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class BaseModelWithTiedWeights(PreTrainedModel):
        config_class = PreTrainedConfig
        _tied_weights_keys = {"linear_2.weight": "linear.weight"}

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(5, 5)
            self.linear_2 = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class BaseModelWithMultipleTiedWeights(PreTrainedModel):
        config_class = PreTrainedConfig
        _tied_weights_keys = {"linear_2.weight": "linear.weight", "linear_3.weight": "linear.weight"}

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(5, 5)
            self.linear_2 = nn.Linear(5, 5)
            self.linear_3 = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class BaseModelWithMultipleMixedTiedWeights(PreTrainedModel):
        config_class = PreTrainedConfig
        # Here the tied keys both refer to `linear.weight`, but they are inconsistent in the mapping, i.e. they
        # are provided as a "circular" dependency
        _tied_weights_keys = {"linear_2.weight": "linear.weight", "linear_3.weight": "linear_2.weight"}

        def __init__(self, config):
            super().__init__(config)
            self.linear = nn.Linear(5, 5)
            self.linear_2 = nn.Linear(5, 5)
            self.linear_3 = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear_2(self.linear(x))

    class ModelWithHead(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig

        def _init_weights(self, module):
            pass

        def __init__(self, config):
            super().__init__(config)
            self.base = BaseModel(config)
            # linear is a common name between Base and Head on purpose.
            self.linear = nn.Linear(5, 5)
            self.linear2 = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear2(self.linear(self.base(x)))

    class ModelWithDirectParam(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig

        def _init_weights(self, module):
            pass

        def __init__(self, config):
            super().__init__(config)
            # direct params and submodules is helpful for testing offloading logic
            self.weight = nn.Parameter(torch.rand((5, 5)))
            self.base = BaseModel(config)
            self.post_init()

        def forward(self, x):
            return self.base(x @ self.weight.T)

    class ModelWithDirectParamSubmodule(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig

        def _init_weights(self, module):
            pass

        def __init__(self, config):
            super().__init__(config)
            self.submodule = ModelWithDirectParam(config)
            # needed so model can have at least one module on accelerator
            self.linear = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.linear(self.submodule(x))

    class ModelWithHeadAndTiedWeights(PreTrainedModel):
        base_model_prefix = "base"
        config_class = PreTrainedConfig
        _tied_weights_keys = {"decoder.weight": "base.linear.weight"}

        def _init_weights(self, module):
            pass

        def __init__(self, config):
            super().__init__(config)
            self.base = BaseModel(config)
            self.decoder = nn.Linear(5, 5)
            self.post_init()

        def forward(self, x):
            return self.decoder(self.base(x))

    class Prepare4dCausalAttentionMaskModel(nn.Module):
        def forward(self, inputs_embeds):
            batch_size, seq_length, _ = inputs_embeds.shape
            past_key_values_length = 4
            attention_mask = _prepare_4d_causal_attention_mask(
                None, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )
            return attention_mask

    class Create4dCausalAttentionMaskModel(nn.Module):
        def forward(self, inputs_embeds):
            batch_size, seq_length, _ = inputs_embeds.shape
            past_key_values_length = 4
            attention_mask = _create_4d_causal_attention_mask(
                (batch_size, seq_length),
                dtype=inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )
            return attention_mask

    class Prepare4dAttentionMaskModel(nn.Module):
        def forward(self, mask, inputs_embeds):
            attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
            return attention_mask

    class TestOffline(unittest.TestCase):
        def test_offline(self):
            with tempfile.TemporaryDirectory() as tmpdir:
                # First offline load should fail
                with patch("transformers.utils.hub.is_offline_mode", return_value=True):
                    with pytest.raises(OSError):
                        AutoModelForImageClassification.from_pretrained(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)

                # Enable online mode for download
                with patch("transformers.utils.hub.is_offline_mode", return_value=False):
                    snapshot_download(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)

                # Load again in offline mode - should work now
                with patch("transformers.utils.hub.is_offline_mode", return_value=True):
                    AutoModelForImageClassification.from_pretrained(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)

        def test_local_files_only(self):
            with tempfile.TemporaryDirectory() as tmpdir:
                # Empty cache => fail to load from cache
                with pytest.raises(OSError):
                    AutoModelForImageClassification.from_pretrained(
                        TINY_IMAGE_CLASSIF, cache_dir=tmpdir, local_files_only=True
                    )

                # Populate cache
                snapshot_download(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)

                # Load again from cache => success
                AutoModelForImageClassification.from_pretrained(
                    TINY_IMAGE_CLASSIF, cache_dir=tmpdir, local_files_only=True
                )


# Need to be serializable, which means they cannot be in a test class method
class TestGammaBetaNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gamma = torch.nn.Parameter(torch.ones(1))
        self.beta = torch.nn.Parameter(torch.zeros(1))

    def forward(self):
        return self.gamma.sum() + self.beta.sum()


class TestModelGammaBeta(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.LayerNorm = TestGammaBetaNorm()
        self.post_init()

    def forward(self):
        return self.LayerNorm()


TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"

LOG = logging.get_logger(__name__)


def check_models_equal(model1, model2):
    models_are_equal = True
    for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
        if model1_p.data.ne(model2_p.data).sum() > 0:
            models_are_equal = False

    return models_are_equal


@require_torch
class ModelUtilsTest(TestCasePlus):
    def setUp(self):
        self.old_dtype = torch.get_default_dtype()
        super().setUp()

    def tearDown(self):
        torch.set_default_dtype(self.old_dtype)
        super().tearDown()

    def test_hub_retry(self):
        @hub_retry(max_attempts=2)
        def test_func():
            # First attempt will fail with a connection error
            if not hasattr(test_func, "attempt"):
                test_func.attempt = 1
                raise httpx.ConnectError("Connection failed")
            # Second attempt will succeed
            return True

        self.assertTrue(test_func())

    @slow
    def test_model_from_pretrained(self):
        model_name = "google-bert/bert-base-uncased"
        config = BertConfig.from_pretrained(model_name)
        self.assertIsNotNone(config)
        self.assertIsInstance(config, PreTrainedConfig)

        model = BertModel.from_pretrained(model_name)
        model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertIsInstance(model, PreTrainedModel)

        self.assertEqual(len(loading_info["missing_keys"]), 0)
        self.assertEqual(len(loading_info["unexpected_keys"]), 8)
        self.assertEqual(len(loading_info["mismatched_keys"]), 0)
        self.assertEqual(len(loading_info["error_msgs"]), 0)

        config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)

        # Not sure this is the intended behavior. TODO fix Lysandre & Thom
        config.name_or_path = model_name

        model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
        self.assertEqual(model.config.output_hidden_states, True)
        self.assertEqual(model.config, config)

    def test_model_from_pretrained_subfolder(self):
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
        model = BertModel(config)

        subfolder = "bert"
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(os.path.join(tmp_dir, subfolder))

            with self.assertRaises(OSError):
                _ = BertModel.from_pretrained(tmp_dir)

            model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)

        self.assertTrue(check_models_equal(model, model_loaded))

    def test_model_manually_shared_disjointed_tensors_optimum(self):
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
        model = BertModel(config)

        # Let's fuse qkv
        attn = model.encoder.layer[0].attention.self
        q = attn.query.weight
        k = attn.key.weight
        v = attn.value.weight
        # Force some shared storage
        qkv = torch.stack([q, k, v], dim=0)
        attn.query.weight = torch.nn.Parameter(qkv[0])
        attn.key.weight = torch.nn.Parameter(qkv[1])
        attn.value.weight = torch.nn.Parameter(qkv[2])
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model_loaded = BertModel.from_pretrained(tmp_dir)

        self.assertTrue(check_models_equal(model, model_loaded))

    def test_model_from_pretrained_subfolder_sharded(self):
        config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
        model = BertModel(config)

        subfolder = "bert"
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")

            with self.assertRaises(OSError):
                _ = BertModel.from_pretrained(tmp_dir)

            model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)

        self.assertTrue(check_models_equal(model, model_loaded))

    def test_model_from_pretrained_hub_subfolder(self):
        subfolder = "bert"
        model_id = "hf-internal-testing/tiny-random-bert-subfolder"
        with self.assertRaises(OSError):
            _ = BertModel.from_pretrained(model_id)

        model = BertModel.from_pretrained(model_id, subfolder=subfolder)

        self.assertIsNotNone(model)

    def test_model_from_pretrained_with_different_pretrained_model_name(self):
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
        self.assertIsNotNone(model)

        logger = logging.get_logger("transformers.configuration_utils")
        with LoggingLevel(logging.WARNING):
            with CaptureLogger(logger) as cl:
                BertModel.from_pretrained(TINY_T5)
        self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)

    @require_accelerate
    def test_model_from_pretrained_with_none_quantization_config(self):
        # Needs a device_map for to enter the low_cpu_mem branch. We also load AutoModelForSequenceClassification
        # deliberately to enter the missing keys branch.
        model = AutoModelForSequenceClassification.from_pretrained(
            TINY_MISTRAL, device_map="auto", quantization_config=None
        )
        self.assertIsNotNone(model)

    def test_model_from_config_dtype(self):
        # test that the model can be instantiated with dtype of user's choice - as long as it's a
        # float dtype. To make it happen config.dtype needs to be set before instantiating the
        # model from the config object.

        config = T5Config.from_pretrained(TINY_T5)
        model = AutoModel.from_config(config)
        # XXX: isn't supported
        # model = T5ForConditionalGeneration.from_config(config)
        self.assertEqual(model.dtype, torch.float32)

        model = AutoModel.from_config(config, dtype=torch.float16)
        self.assertEqual(model.dtype, torch.float16)

        # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
        with self.assertRaises(ValueError):
            model = AutoModel.from_config(config, dtype=torch.int64)

    def test_model_from_config_dtype_str(self):
        # test that from_pretrained works with dtype being strings like "float32" for PyTorch backend
        model = AutoModel.from_pretrained(TINY_T5, dtype="float32")
        self.assertEqual(model.dtype, torch.float32)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        model = AutoModel.from_pretrained(TINY_T5, dtype="float16")
        self.assertEqual(model.dtype, torch.float16)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
        with self.assertRaises(ValueError):
            model = AutoModel.from_pretrained(TINY_T5, dtype="int64")

    def test_model_from_config_dtype_composite(self):
        """
        Test that from_pretrained works with dtype being as a dict per each sub-config in composite config
        Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
        """
        # Load without dtype specified
        model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
        self.assertEqual(model.model.language_model.dtype, torch.float32)
        self.assertEqual(model.model.vision_tower.dtype, torch.float32)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # should be able to set dtype as a simple string and the model loads it correctly
        model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, dtype="float32")
        self.assertEqual(model.model.language_model.dtype, torch.float32)
        self.assertEqual(model.model.vision_tower.dtype, torch.float32)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, dtype=torch.float16)
        self.assertEqual(model.model.language_model.dtype, torch.float16)
        self.assertEqual(model.model.vision_tower.dtype, torch.float16)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # should be able to set dtype as a dict for each sub-config
        model = LlavaForConditionalGeneration.from_pretrained(
            TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
        )
        self.assertEqual(model.model.language_model.dtype, torch.float32)
        self.assertEqual(model.model.vision_tower.dtype, torch.float16)
        self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # should be able to set the values as torch.dtype (not str)
        model = LlavaForConditionalGeneration.from_pretrained(
            TINY_LLAVA, dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
        )
        self.assertEqual(model.model.language_model.dtype, torch.float32)
        self.assertEqual(model.model.vision_tower.dtype, torch.float16)
        self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # should be able to set the values in configs directly and pass it to `from_pretrained`
        config = copy.deepcopy(model.config)
        config.text_config.dtype = torch.float32
        config.vision_config.dtype = torch.bfloat16
        config.dtype = torch.float16
        model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
        self.assertEqual(model.model.language_model.dtype, torch.float32)
        self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
        self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
        LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
        model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto")
        self.assertEqual(
            model.model.language_model.dtype, torch.float32
        )  # remember config says float32 for text_config
        self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16)
        self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
        self.assertIsInstance(model.config.dtype, torch.dtype)

        # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
        with self.assertRaises(ValueError):
            model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, dtype="int64")
            model = LlavaForConditionalGeneration.from_pretrained(
                TINY_LLAVA, dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
            )

    def test_model_from_pretrained_dtype(self):
        # test that the model can be instantiated with dtype of either
        # 1. explicit from_pretrained's dtype argument
        # 2. via autodiscovery by looking at model weights (dtype="auto")
        # so if a model.half() was saved, we want it to be instantiated as such.
        #
        # test an explicit model class, but also AutoModel separately as the latter goes through a different code path
        model_path = self.get_auto_remove_tmp_dir()

        # baseline - we know TINY_T5 is fp32 model
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
        self.assertEqual(model.dtype, torch.float32)

        def remove_dtype(model_path):
            file = f"{model_path}/config.json"
            with open(file, encoding="utf-8") as f:
                s = json.load(f)
            s.pop("dtype")
            with open(file, "w", encoding="utf-8") as f:
                json.dump(s, f)

        # test the default fp32 save_pretrained => from_pretrained cycle
        model.save_pretrained(model_path)
        model = T5ForConditionalGeneration.from_pretrained(model_path)
        self.assertEqual(model.dtype, torch.float32)
        # 1. test dtype="auto" via `config.dtype`
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype="auto")
        self.assertEqual(model.dtype, torch.float32)
        # 2. test dtype="auto" via auto-derivation
        # now remove the dtype entry from config.json and try "auto" again which should
        # perform auto-derivation from weights
        remove_dtype(model_path)
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype="auto")
        self.assertEqual(model.dtype, torch.float32)

        # test forced loading in fp16 (even though the weights are in fp32)
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype=torch.float16)
        self.assertEqual(model.dtype, torch.float16)

        # test fp16 save_pretrained, loaded with auto-detection
        model = model.half()
        model.save_pretrained(model_path)
        # 1. test dtype="auto" via `config.dtype`
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype="auto")
        self.assertEqual(model.config.dtype, torch.float16)
        self.assertEqual(model.dtype, torch.float16)
        # tests `config.dtype` saving
        with open(f"{model_path}/config.json") as f:
            config_dict = json.load(f)
        self.assertEqual(config_dict["dtype"], "float16")
        # 2. test dtype="auto" via auto-derivation
        # now same with using config info
        remove_dtype(model_path)
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype="auto")
        self.assertEqual(model.dtype, torch.float16)

        # 3. now retest that AutoModel behaves the same wrt dtype="auto" as T5ForConditionalGeneration
        model = AutoModel.from_pretrained(model_path, dtype="auto")
        self.assertEqual(model.dtype, torch.float16)

        # test fp16 save_pretrained, loaded with the explicit fp16
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype=torch.float16)
        self.assertEqual(model.dtype, torch.float16)

        # test AutoModel separately as it goes through a different path
        # test auto-detection - as currently TINY_T5 doesn't have dtype entry
        model = AutoModel.from_pretrained(TINY_T5, dtype="auto")
        # test that the config object didn't get polluted with dtype="auto"
        # there was a bug that after this call we ended up with config.dtype=="auto"
        self.assertNotEqual(model.config.dtype, "auto")
        # now test the outcome
        self.assertEqual(model.dtype, torch.float32)
        model = AutoModel.from_pretrained(TINY_T5, dtype=torch.float16)
        self.assertEqual(model.dtype, torch.float16)

        # test model whose first param is not of a floating type, but int
        model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, dtype="auto")
        self.assertEqual(model.dtype, torch.float32)

        # test model that init the model with _from_config
        model = CLIPTextModelWithProjection.from_pretrained(
            "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
            subfolder="text_encoder",
            dtype=torch.bfloat16,
        )
        self.assertEqual(model.dtype, torch.bfloat16)

    def test_model_from_pretrained_attn_implementation(self):
        # test that the model can be instantiated with attn_implementation of either
        # 1. explicit from_pretrained's attn_implementation argument
        # 2. explicit from_pretrained's attn_implementation argument with a config argument
        attn_implementation_available = ["eager", "sdpa"]

        if is_flash_attn_available():
            attn_implementation_available.append("flash_attention_2")

        if is_flash_attn_3_available():
            attn_implementation_available.append("flash_attention_3")

        for requested_attn_implementation in attn_implementation_available:
            model = AutoModelForCausalLM.from_pretrained(
                TINY_MISTRAL, attn_implementation=requested_attn_implementation
            )
            self.assertEqual(model.config._attn_implementation, requested_attn_implementation)

            config = AutoConfig.from_pretrained(TINY_MISTRAL)
            model = AutoModelForCausalLM.from_pretrained(
                TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
            )
            self.assertEqual(model.config._attn_implementation, requested_attn_implementation)

    def test_model_from_config_attn_implementation(self):
        # test that the model can be instantiated with attn_implementation of either
        # 1. config created with explicit attn_implementation and from_config
        # 2. explicit from_config's attn_implementation argument with a config argument
        # 3. config created with explicit attn_implementation and from_config overriding with explicit attn_implementation argument
        attn_implementation_available = ["eager", "sdpa"]

        if is_flash_attn_available():
            attn_implementation_available.append("flash_attention_2")

        if is_flash_attn_3_available():
            attn_implementation_available.append("flash_attention_3")

        for requested_attn_implementation in attn_implementation_available:
            config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
            # Ensure the config was set correctly
            self.assertEqual(config._attn_implementation, requested_attn_implementation)
            model = AutoModelForCausalLM.from_config(config)
            self.assertEqual(model.config._attn_implementation, requested_attn_implementation)

            config = AutoConfig.from_pretrained(TINY_MISTRAL)
            # When the config is not set, the default is "eager"
            self.assertEqual(config._attn_implementation, None)
            model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
            self.assertEqual(model.config._attn_implementation, requested_attn_implementation)

            # Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
            config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
            self.assertEqual(config._attn_implementation, "foo-bar-baz")
            model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
            self.assertEqual(model.config._attn_implementation, requested_attn_implementation)

    def test_checkpoint_sharding_local_bin(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
            for max_size in ["50kB", "100kB", "200kB"]:
                model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)

                # Get each shard file and its size
                shard_to_size = {}
                for shard in os.listdir(tmp_dir):
                    if shard.endswith(".bin"):
                        shard_file = os.path.join(tmp_dir, shard)
                        shard_to_size[shard_file] = os.path.getsize(shard_file)

                index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
                # Check there is an index but no regular weight file
                self.assertTrue(os.path.isfile(index_file))
                self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))

                # Check a file is bigger than max_size only when it has a single weight
                for shard_file, size in shard_to_size.items():
                    max_size_int = int(max_size[:-2]) * 10**3
                    # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
                    # the size asked for (since we count parameters)
                    if size >= max_size_int + 50000:
                        check_torch_load_is_safe()
                        state_dict = torch.load(shard_file, weights_only=True)
                        self.assertEqual(len(state_dict), 1)

                # Check the index and the shard files found match
                with open(index_file, encoding="utf-8") as f:
                    index = json.loads(f.read())

                all_shards = set(index["weight_map"].values())
                shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
                self.assertSetEqual(all_shards, shards_found)

                # Finally, check the model can be reloaded
                new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
                for p1, p2 in zip(model.parameters(), new_model.parameters()):
                    torch.testing.assert_close(p1, p2)

    def test_checkpoint_sharding_from_hub(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
        # the model above is the same as the model below, just a sharded version.
        ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        for p1, p2 in zip(model.parameters(), ref_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_checkpoint_variant_local_bin(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)

            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])

            weights_file = os.path.join(tmp_dir, weights_name)
            self.assertTrue(os.path.isfile(weights_file))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))

            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(tmp_dir)

            new_model = BertModel.from_pretrained(tmp_dir, variant="v2", use_safetensors=False)

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            torch.testing.assert_close(p1, p2)

    @unittest.skip("Skipping it for now, not sure how critial but does not look hard to fix.")
    def test_checkpoint_variant_local_sharded_bin(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)

            weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
            self.assertTrue(os.path.isfile(weights_index_file))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))

            for i in range(1, 5):
                weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
                weights_name_file = os.path.join(tmp_dir, weights_name)
                self.assertTrue(os.path.isfile(weights_name_file))

            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(tmp_dir)

            new_model = BertModel.from_pretrained(tmp_dir, variant="v2", use_safe_tensors=False)

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_checkpoint_variant_local_safe(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)

            weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])

            weights_file = os.path.join(tmp_dir, weights_name)
            self.assertTrue(os.path.isfile(weights_file))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))

            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(tmp_dir)

            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_checkpoint_variant_local_sharded_safe(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)

            weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
            self.assertTrue(os.path.isfile(weights_index_file))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))

            for i in range(1, 5):
                weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
                weights_name_file = os.path.join(tmp_dir, weights_name)
                self.assertTrue(os.path.isfile(weights_name_file))

            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(tmp_dir)

            new_model = BertModel.from_pretrained(tmp_dir, variant="v2")

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_checkpoint_loading_only_safetensors_available(self):
        # Test that the loading behaviour is as expected when only safetensor checkpoints are available
        # - We can load the model with use_safetensors=True
        # - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
        #   preferring safetensors
        # - We cannot load the model with use_safetensors=False
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)

            weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
            self.assertTrue(os.path.isfile(weights_index_file))

            for i in range(1, 5):
                weights_name = f"model-0000{i}-of-00005" + ".safetensors"
                weights_name_file = os.path.join(tmp_dir, weights_name)
                self.assertTrue(os.path.isfile(weights_name_file))

            # Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
            with self.assertRaises(OSError):
                _ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

            # We can load the model with use_safetensors=True
            new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

            # We can load the model without specifying use_safetensors
            new_model = BertModel.from_pretrained(tmp_dir)

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_checkpoint_loading_only_pytorch_bin_available(self):
        # Test that the loading behaviour is as expected when only pytorch checkpoints are available
        # - We can load the model with use_safetensors=False
        # - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
        #   preferring safetensors but falling back to pytorch
        # - We cannot load the model with use_safetensors=True
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)

            weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
            weights_index_file = os.path.join(tmp_dir, weights_index_name)
            self.assertTrue(os.path.isfile(weights_index_file))

            for i in range(1, 5):
                weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
                weights_name_file = os.path.join(tmp_dir, weights_name)
                self.assertTrue(os.path.isfile(weights_name_file))

            # Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
            with self.assertRaises(OSError):
                _ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

            # We can load the model with use_safetensors=False
            _ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

            # We can load the model without specifying use_safetensors
            with self.assertRaises(OSError):
                BertModel.from_pretrained(tmp_dir)

    def test_checkpoint_variant_hub(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
            model = BertModel.from_pretrained(
                "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2", use_safetensors=False
            )
        self.assertIsNotNone(model)

    def test_checkpoint_variant_hub_sharded(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(
                    "hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
                )
            model = BertModel.from_pretrained(
                "hf-internal-testing/tiny-random-bert-variant-sharded",
                cache_dir=tmp_dir,
                variant="v2",
                use_safetensors=False,
            )
        self.assertIsNotNone(model)

    def test_checkpoint_variant_hub_safe(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
            model = BertModel.from_pretrained(
                "hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
            )
        self.assertIsNotNone(model)

    def test_checkpoint_variant_hub_sharded_safe(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            with self.assertRaises(EnvironmentError):
                _ = BertModel.from_pretrained(
                    "hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
                )
            model = BertModel.from_pretrained(
                "hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
            )
        self.assertIsNotNone(model)

    def test_checkpoint_variant_save_load_bin(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = BertModel.from_pretrained(
                "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2", use_safetensors=False
            )
            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])

            model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
            # saving will create a variant checkpoint
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))

            model.save_pretrained(tmp_dir, safe_serialization=False)
            # saving shouldn't delete variant checkpoints
            weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))

            # there should be a normal checkpoint
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))

        self.assertIsNotNone(model)

    @require_non_hpu
    @require_accelerate
    @mark.accelerate_tests
    @require_torch_multi_accelerator
    @slow
    def test_model_parallelism_gpt2(self):
        device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
        for i in range(12):
            device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1

        model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", device_map=device_map)

        tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
        inputs = tokenizer("Hello, my name is", return_tensors="pt")
        output = model.generate(inputs["input_ids"].to(f"{torch_device}:0"))

        text_output = tokenizer.decode(output[0].tolist())
        self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")

    @require_accelerate
    @mark.accelerate_tests
    @require_torch_accelerator
    def test_from_pretrained_disk_offload_task_model(self):
        model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
        device_map = {
            "transformer.wte": f"{torch_device}:0",
            "transformer.wpe": f"{torch_device}:0",
            "transformer.h.0": "cpu",
            "transformer.h.1": "cpu",
            "transformer.h.2": "cpu",
            "transformer.h.3": "disk",
            "transformer.h.4": "disk",
            "transformer.ln_f": f"{torch_device}:0",
            "lm_head": f"{torch_device}:0",
        }
        with tempfile.TemporaryDirectory() as tmp_dir:
            inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")

            model.save_pretrained(tmp_dir)
            new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(f"{torch_device}:0")
            outputs1 = new_model.to(f"{torch_device}:0")(inputs)

            offload_folder = os.path.join(tmp_dir, "offload")
            new_model_with_offload = AutoModelForCausalLM.from_pretrained(
                tmp_dir, device_map=device_map, offload_folder=offload_folder
            )
            outputs2 = new_model_with_offload(inputs)

            torch.testing.assert_close(outputs1.logits.cpu(), outputs2.logits.cpu())

            # With state dict temp offload
            new_model_with_offload = AutoModelForCausalLM.from_pretrained(
                tmp_dir,
                device_map=device_map,
                offload_folder=offload_folder,
                offload_state_dict=True,
            )
            outputs2 = new_model_with_offload(inputs)
            torch.testing.assert_close(outputs1.logits.cpu(), outputs2.logits.cpu())

    @require_accelerate
    @mark.accelerate_tests
    @require_torch_accelerator
    def test_from_pretrained_disk_offload_derived_to_base_model(self):
        derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")

        device_map = {
            "wte": f"{torch_device}:0",
            "wpe": f"{torch_device}:0",
            "h.0": "cpu",
            "h.1": "cpu",
            "h.2": "cpu",
            "h.3": "disk",
            "h.4": "disk",
            "ln_f": f"{torch_device}:0",
        }
        with tempfile.TemporaryDirectory() as tmp_dir:
            inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
            derived_model.save_pretrained(tmp_dir, use_safetensors=True)
            base_model = AutoModel.from_pretrained(tmp_dir)
            outputs1 = base_model.to(f"{torch_device}:0")(inputs)

            # with disk offload
            offload_folder = os.path.join(tmp_dir, "offload")
            base_model_with_offload = AutoModel.from_pretrained(
                tmp_dir, device_map=device_map, offload_folder=offload_folder
            )
            outputs2 = base_model_with_offload(inputs)
            torch.testing.assert_close(outputs1[0].cpu(), outputs2[0].cpu())

            # With state dict temp offload
            new_model_with_offload = AutoModel.from_pretrained(
                tmp_dir,
                device_map=device_map,
                offload_folder=offload_folder,
                offload_state_dict=True,
            )
            outputs2 = new_model_with_offload(inputs)
            torch.testing.assert_close(outputs1[0].cpu(), outputs2[0].cpu())

    @slow
    @require_torch
    def test_from_pretrained_non_contiguous_checkpoint(self):
        # See: https://github.com/huggingface/transformers/pull/28414
        # Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
        model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
        self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())

        model = OwlViTForObjectDetection.from_pretrained(
            "fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
        )
        self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=False)
            model.save_pretrained(tmp_dir, safe_serialization=True)

    def test_cached_files_are_used_when_internet_is_down(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
            "failed", request=mock.Mock(), response=mock.Mock()
        )
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        # Under the mock environment we get a 500 error when trying to reach the model.
        with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
            _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
            # This check we did call the fake head request
            mock_head.assert_called()

    @require_accelerate
    @mark.accelerate_tests
    def test_save_model_with_device_map_cpu(self):
        model_id = "hf-internal-testing/tiny-random-gpt2"
        inputs = torch.tensor([[1, 2, 3]])

        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
            output = model(inputs)[0]
            model.save_pretrained(
                tmp_dir, max_shard_size="200KB"
            )  # model is 1.6MB, max shard size is allocated to cpu by default
            saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map="cpu")
            saved_model_output = saved_model(inputs)[0]

        torch.testing.assert_close(output, saved_model_output)

    @require_accelerate
    @mark.accelerate_tests
    @require_torch_accelerator
    def test_save_offloaded_model(self):
        device_map = {
            "transformer.wte": f"{torch_device}:0",
            "transformer.wpe": f"{torch_device}:0",
            "transformer.h.0": "cpu",
            "transformer.h.1": "cpu",
            "transformer.h.2": "cpu",
            "transformer.h.3": "disk",
            "transformer.h.4": "disk",
            "transformer.ln_f": f"{torch_device}:0",
            "lm_head": f"{torch_device}:0",
        }

        # check_models_equal requires onloaded tensors
        model_id = "hf-internal-testing/tiny-random-gpt2"
        onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu").to(f"{torch_device}:0")
        inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
        output = onloaded_model(inputs)[0]

        with tempfile.TemporaryDirectory() as tmp_dir:
            offload_folder = os.path.join(tmp_dir, "offload")
            offloaded_model = AutoModelForCausalLM.from_pretrained(
                model_id, device_map=device_map, offload_folder=offload_folder
            )
            presaved_output = offloaded_model(inputs)[0]
            offloaded_model.save_pretrained(
                tmp_dir, max_shard_size="200KB"
            )  # model is 1.6MB, max shard size is allocated to cpu by default
            saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
            postsaved_output = saved_model(inputs)[0]

        torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
        torch.testing.assert_close(presaved_output, postsaved_output)

    @require_accelerate
    @mark.accelerate_tests
    @require_torch_accelerator
    def test_save_offloaded_model_with_direct_params(self):
        from accelerate import dispatch_model

        device_map = {"submodule": "cpu", "linear": f"{torch_device}:0"}
        model = ModelWithDirectParamSubmodule(PreTrainedConfig())
        dispatch_model(model, device_map)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)

    @require_accelerate
    @mark.accelerate_tests
    @require_torch_accelerator
    @unittest.skip("TODO @cyrilvallez when saving")
    def test_save_offloaded_model_dynamic_tied_weights_keys(self):
        from accelerate import dispatch_model

        device_map = {"base": f"{torch_device}:0", "linear": "cpu", "linear2": "cpu"}
        model = ModelWithHead(PreTrainedConfig())
        dispatch_model(model, device_map)

        transform_a = torch.nn.Linear(1, 1, bias=False)
        transform_a._dynamic_tied_weights_keys = ["weight"]
        transform_b = torch.nn.Linear(1, 1, bias=False)
        transform_b._dynamic_tied_weights_keys = ["weight"]

        model.linear.register_module("transform_a", transform_a)
        model.linear.register_module("transform_b", transform_b)
        model.linear2.register_module("transform_a", transform_a)
        model.linear2.register_module("transform_b", transform_b)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)

    def test_use_safetensors(self):
        # Should not raise anymore
        AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)

        # test that error if only safetensors is available
        with self.assertRaises(OSError) as env_error:
            BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)

        self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))

        # test that only safetensors if both available and use_safetensors=False
        with tempfile.TemporaryDirectory() as tmp_dir:
            CLIPTextModel.from_pretrained(
                "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
                subfolder="text_encoder",
                use_safetensors=False,
                cache_dir=tmp_dir,
            )

            all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
            self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
            self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))

        # test that no safetensors if both available and use_safetensors=True
        with tempfile.TemporaryDirectory() as tmp_dir:
            CLIPTextModel.from_pretrained(
                "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
                subfolder="text_encoder",
                use_safetensors=True,
                cache_dir=tmp_dir,
            )

            all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
            self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
            self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))

        # test no model file found when use_safetensors=None (default when safetensors package available)
        with self.assertRaises(OSError) as missing_model_file_error:
            BertModel.from_pretrained("hf-internal-testing/config-no-model")

        with self.assertRaises(OSError) as missing_model_file_error:
            with tempfile.TemporaryDirectory() as tmp_dir:
                with open(os.path.join(tmp_dir, "config.json"), "w") as f:
                    f.write("{}")
                f.close()
                BertModel.from_pretrained(tmp_dir)

        self.assertTrue(
            "Error no file named model.safetensors found in directory" in str(missing_model_file_error.exception),
            msg=missing_model_file_error.exception,
        )

    def test_safetensors_save_and_load(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=True)
            # No pytorch_model.bin file, only a model.safetensors
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))

            new_model = BertModel.from_pretrained(tmp_dir)

            # Check models are equal
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                torch.testing.assert_close(p1, p2)

    def test_safetensors_load_from_hub(self):
        safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
        pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

        # Check models are equal
        for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
            torch.testing.assert_close(p1, p2)

    def test_safetensors_save_and_load_sharded(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
            # No pytorch_model.bin index file, only a model.safetensors index
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
            self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
            # No regular weights file
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
            self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))

            new_model = BertModel.from_pretrained(tmp_dir)

            # Check models are equal
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                torch.testing.assert_close(p1, p2)

    def test_safetensors_load_from_hub_sharded(self):
        safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
        pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")

        # Check models are equal
        for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
            torch.testing.assert_close(p1, p2)

    @unittest.skip("This now just works by defaults :) no complicated load from task blah blah")
    def test_base_model_to_head_model_load(self):
        base_model = BaseModel(PreTrainedConfig())
        with tempfile.TemporaryDirectory() as tmp_dir:
            base_model.save_pretrained(tmp_dir)

            # Can load a base model in a model with head
            model = ModelWithHead.from_pretrained(tmp_dir)
            for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
                torch.testing.assert_close(p1, p2)

            # It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
            base_state_dict = base_model.state_dict()
            head_state_dict = model.state_dict()
            base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
            base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
            safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})

            with self.assertRaisesRegex(
                ValueError, "The state dictionary of the model you are trying to load is corrupted."
            ):
                _ = ModelWithHead.from_pretrained(tmp_dir)

    def test_tied_weights_reload(self):
        # Base
        model = BaseModelWithTiedWeights(PreTrainedConfig())
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)

            new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir)
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight)

            state_dict = model.state_dict()
            # Remove tied weight from state_dict -> model should load with no complain of missing keys
            del state_dict["linear_2.weight"]
            torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
            new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
            self.assertSetEqual(load_info["missing_keys"], set())
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight)

            # With head
            model = BaseModel(PreTrainedConfig())
            model.save_pretrained(tmp_dir)
            new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
            self.assertIs(new_model.base.linear.weight, new_model.decoder.weight)
            # Should only complain about the missing bias
            self.assertSetEqual(load_info["missing_keys"], {"decoder.bias"})

    def test_tied_weights_can_load_symmetrically(self):
        """Test that we can correctly load and tie weights even though the wrong key was saved."""
        model = BaseModelWithTiedWeights(PreTrainedConfig())
        # Just to be sure it's actually tied
        self.assertIs(model.linear.weight, model.linear_2.weight, msg="Weights are not tied!")
        with tempfile.TemporaryDirectory() as tmp_dir:
            # Save the config
            with open(os.path.join(tmp_dir, "config.json"), "w") as f:
                f.write(json.dumps(model.config.to_dict()))

            state_dict = model.state_dict()
            # Save using the wrong key
            state_dict.pop("linear.weight")
            safe_save_file(state_dict, os.path.join(tmp_dir, "model.safetensors"))

            new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
            # Assert no missing keys
            self.assertSetEqual(load_info["missing_keys"], set(), msg=f"{load_info['missing_keys']} are missing!")
            # It's still the same weight
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight, msg="Weights are not tied!")

            # Make sure both state dict are the same
            compare_state_dicts(model.state_dict(), new_model.state_dict())

    def test_tied_weights_can_load_symmetrically_multiple_keys(self):
        """Test that we can correctly load and tie weights even though the wrong key was saved, when we
        have more than 1 target to the same source."""
        # First class is consistent in how they provide the source, second is not -> make sure it works in both cases
        for model_class in [BaseModelWithMultipleTiedWeights, BaseModelWithMultipleMixedTiedWeights]:
            with self.subTest(model_class.__name__):
                model = model_class(PreTrainedConfig())
                # Just to be sure it's actually tied
                self.assertIs(model.linear.weight, model.linear_2.weight, msg="Weights are not tied!")
                self.assertIs(model.linear.weight, model.linear_3.weight, msg="Weights are not tied!")
                with tempfile.TemporaryDirectory() as tmp_dir:
                    # Save the config
                    with open(os.path.join(tmp_dir, "config.json"), "w") as f:
                        f.write(json.dumps(model.config.to_dict()))

                    state_dict = model.state_dict()
                    # Keep only 1 of the 3 tied keys, but not the source (which is `linear.weight`)
                    state_dict.pop("linear.weight")
                    state_dict.pop("linear_3.weight")
                    safe_save_file(state_dict, os.path.join(tmp_dir, "model.safetensors"))

                    new_model, load_info = BaseModelWithMultipleTiedWeights.from_pretrained(
                        tmp_dir, output_loading_info=True
                    )
                    # Assert no missing keys
                    self.assertSetEqual(
                        load_info["missing_keys"], set(), msg=f"{load_info['missing_keys']} are missing!"
                    )
                    # It's still the same weight
                    self.assertIs(new_model.linear.weight, new_model.linear_2.weight, msg="Weights are not tied!")
                    self.assertIs(new_model.linear.weight, new_model.linear_3.weight, msg="Weights are not tied!")

                    # Make sure both state dict are the same
                    compare_state_dicts(model.state_dict(), new_model.state_dict())

                    # Now, do the same but try to keep `linear_2.weight` in the saved key instead of `linear_3.weight`
                    # to make sure it does not matter
                    state_dict = model.state_dict()
                    # Keep only 1 of the 3 tied keys, but not the source (which is `linear.weight`)
                    state_dict.pop("linear.weight")
                    state_dict.pop("linear_2.weight")
                    safe_save_file(state_dict, os.path.join(tmp_dir, "model.safetensors"))

                    new_model, load_info = BaseModelWithMultipleTiedWeights.from_pretrained(
                        tmp_dir, output_loading_info=True
                    )
                    # Assert no missing keys
                    self.assertSetEqual(
                        load_info["missing_keys"], set(), msg=f"{load_info['missing_keys']} are missing!"
                    )
                    # It's still the same weight
                    self.assertIs(new_model.linear.weight, new_model.linear_2.weight, msg="Weights are not tied!")
                    self.assertIs(new_model.linear.weight, new_model.linear_3.weight, msg="Weights are not tied!")

                    # Make sure both state dict are the same
                    compare_state_dicts(model.state_dict(), new_model.state_dict())

    def test_tied_weights_are_not_tied_if_both_present(self):
        """Test that if both the source and target of tied weights are present, we do NOT tie them, and instead
        raise a warning"""
        model = BaseModelWithTiedWeights(PreTrainedConfig())
        # Just to be sure it's actually tied
        self.assertIs(model.linear.weight, model.linear_2.weight, msg="Weights are not tied!")
        with tempfile.TemporaryDirectory() as tmp_dir:
            # Save the config
            with open(os.path.join(tmp_dir, "config.json"), "w") as f:
                f.write(json.dumps(model.config.to_dict()))

            state_dict = model.state_dict()
            # Clone every param to make sure nothing is tied -> we save everything
            state_dict = {k: v.clone() for k, v in state_dict.items()}
            safe_save_file(state_dict, os.path.join(tmp_dir, "model.safetensors"))

            logger = logging.get_logger("transformers.modeling_utils")
            with CaptureLogger(logger) as cl:
                new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)

            # We should have raised a warning here saying that we will NOT tie the weights
            self.assertIn("both are present in the checkpoints, so we will NOT tie them.", cl.out)
            # Assert no missing keys
            self.assertSetEqual(load_info["missing_keys"], set(), msg=f"{load_info['missing_keys']} are missing!")
            # It should not be the same weight anymore
            self.assertIsNot(
                new_model.linear.weight, new_model.linear_2.weight, msg="Weights are tied but they should not!"
            )

            # Make sure both state dict are the same (the values are still the same, it's just not tied)
            compare_state_dicts(model.state_dict(), new_model.state_dict())

    def test_tied_weights_are_missing_if_both_absent(self):
        """Test that if both the source and target of tied weights are absent, we do tie them, but they are missing"""
        model = BaseModelWithTiedWeights(PreTrainedConfig())
        # Just to be sure it's actually tied
        self.assertIs(model.linear.weight, model.linear_2.weight, msg="Weights are not tied!")
        with tempfile.TemporaryDirectory() as tmp_dir:
            # Save the config
            with open(os.path.join(tmp_dir, "config.json"), "w") as f:
                f.write(json.dumps(model.config.to_dict()))

            state_dict = model.state_dict()
            # Remove both from the state dict
            state_dict.pop("linear.weight")
            state_dict.pop("linear_2.weight")
            safe_save_file(state_dict, os.path.join(tmp_dir, "model.safetensors"))

            logger = logging.get_logger("transformers.modeling_utils")
            with CaptureLogger(logger) as cl:
                new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)

            # We should have raised a warning here saying that we will NOT tie the weights
            self.assertIn(
                "This checkpoint seem corrupted. The tied weights mapping for this model specifies to tie", cl.out
            )
            # Assert both are in the missing keys
            self.assertSetEqual(load_info["missing_keys"], {"linear.weight", "linear_2.weight"})
            # They should still be tied though
            self.assertIs(new_model.linear.weight, new_model.linear_2.weight, msg="Weights are not tied!")

    def test_unexpected_keys_warnings(self):
        model = ModelWithHead(PreTrainedConfig())
        logger = logging.get_logger("transformers.modeling_utils")
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)

            # Loading the model with a new class, we don't get a warning for unexpected weights, just an info
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    _, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
            self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
            self.assertEqual(
                set(loading_info["unexpected_keys"]),
                {"linear2.weight", "linear2.bias"},
            )

            # Loading the model with the same class, we do get a warning for unexpected weights
            state_dict = model.state_dict()
            state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
            safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
            self.assertIn("added_key | UNEXPECTED", cl.out)
            self.assertEqual(loading_info["unexpected_keys"], {"added_key"})

    def test_warn_if_padding_and_no_attention_mask(self):
        logger = logging.get_logger("transformers.modeling_utils")

        with self.subTest("Ensure no warnings when pad_token_id is None."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config_no_pad_token = PreTrainedConfig()
                    config_no_pad_token.pad_token_id = None
                    model = ModelWithHead(config_no_pad_token)
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)

        with self.subTest("Ensure no warnings when there is an attention_mask."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
                    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)

        with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out)

        with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)

        with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)

        with self.subTest("Ensure that the warning is shown at most once."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1)

        with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."):
            logger.warning_once.cache_clear()
            with LoggingLevel(logging.WARNING):
                with CaptureLogger(logger) as cl:
                    config = PreTrainedConfig()
                    config.pad_token_id = 0
                    config.bos_token_id = config.pad_token_id
                    model = ModelWithHead(config)
                    input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]])
                    model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
            self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)

        with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
            logger.warning_once.cache_clear()
            from torch._dynamo import config, testing

            config = PreTrainedConfig()
            config.pad_token_id = 0
            model = ModelWithHead(config)
            input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])

            def f(input_ids):
                model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)

            compile_counter = testing.CompileCounter()
            opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
            opt_fn(input_ids)
            self.assertEqual(compile_counter.frame_count, 0)

    @require_torch_accelerator
    @slow
    def test_pretrained_low_mem_new_config(self):
        # Checking for 1 model(the same one which was described in the issue) .
        model_ids = ["openai-community/gpt2"]

        for model_id in model_ids:
            model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
            model_config.n_layer = 48
            model_config.n_head = 25
            model_config.n_embd = 1600
            model = AutoModelForCausalLM.from_pretrained(
                pretrained_model_name_or_path=model_id,
                config=model_config,
                ignore_mismatched_sizes=True,
                dtype=torch.float16,
            )
            model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)

            self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)

    def test_generation_config_is_loaded_with_model(self):
        # Note: `hf-internal-testing/tiny-random-MistralForCausalLM` has a `generation_config.json`
        # containing `bos_token_id: 1`

        # 1. Load without further parameters
        model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
        self.assertEqual(model.generation_config.bos_token_id, 1)

        # 2. Load with `device_map`
        model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto")
        self.assertEqual(model.generation_config.bos_token_id, 1)

    def test_safetensors_torch_from_torch(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=True)
            new_model = BertModel.from_pretrained(tmp_dir)

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

    def test_safetensors_torch_from_torch_sharded(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
            new_model = BertModel.from_pretrained(tmp_dir)

        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

    def test_saving_model_config_with_generation_params(self):
        """
        Calling `model.save_pretrained` with generation parameters should raise a `ValueError`
        """
        model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        self.assertTrue(model.generation_config.repetition_penalty == 1.0)
        self.assertFalse(hasattr(model.config, "repetition_penalty"))

        # If the user attempts to save a custom generation parameter, we raise an Error
        model.config.repetition_penalty = 3.0
        with self.assertRaises(ValueError):
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.save_pretrained(tmp_dir)

    def test_model_from_pretrained_from_mlx(self):
        from safetensors import safe_open

        model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
        self.assertIsNotNone(model)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, safe_serialization=True)
            with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
                metadata = f.metadata()
                self.assertEqual(metadata.get("format"), "pt")
            new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)

        input_ids = torch.randint(100, 1000, (1, 10))
        with torch.no_grad():
            outputs = model(input_ids)
            outputs_from_saved = new_model(input_ids)
            torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])

    def test_warning_for_beta_gamma_parameters(self):
        config = PreTrainedConfig()
        model = TestModelGammaBeta(config)

        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            with LoggingLevel(logging.INFO):
                _, loading_info = TestModelGammaBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

        missing_keys = loading_info["missing_keys"]
        unexpected_keys = loading_info["unexpected_keys"]
        self.assertIn("LayerNorm.gamma", missing_keys)
        self.assertIn("LayerNorm.weight", unexpected_keys)
        self.assertIn("LayerNorm.beta", missing_keys)
        self.assertIn("LayerNorm.bias", unexpected_keys)

    def test_isin_mps_friendly(self):
        """tests that our custom `isin_mps_friendly` matches `torch.isin`"""
        random_ids = torch.randint(0, 100, (100,))
        # We can match against an integer
        random_test_integer = torch.randint(0, 100, (1,)).item()
        self.assertTrue(
            torch.equal(
                torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
            )
        )
        # We can match against an 0D tensor
        random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
        self.assertTrue(
            torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
        )
        # We can match against an 1D tensor (with many items)
        random_test_tensor = torch.randint(0, 100, (10,))
        self.assertTrue(
            torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
        )

    def test_can_generate(self):
        """Tests the behavior of `PreTrainedModel.can_generate` method."""
        logger = logging.get_logger("transformers.modeling_utils")
        logger.warning_once.cache_clear()

        # 1 - By default, a model CAN'T generate
        can_generate = BertModel.can_generate()
        self.assertFalse(can_generate)

        # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
        class DummyBertWithMixin(BertModel, GenerationMixin):
            pass

        with CaptureLogger(logger) as cl:
            can_generate = DummyBertWithMixin.can_generate()
        self.assertTrue("" == cl.out)
        self.assertTrue(can_generate)

        # 3 - Finally, it can inherit from a model that can generate
        class DummyBertWithParent(DummyBertWithMixin):
            pass

        with CaptureLogger(logger) as cl:
            can_generate = DummyBertWithParent.can_generate()
        self.assertTrue("" == cl.out)
        self.assertTrue(can_generate)

        # 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
        # they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
        class DummyBertWithPrepareInputs(BertModel):
            def prepare_inputs_for_generation(self):
                pass

        with CaptureLogger(logger) as cl:
            can_generate = DummyBertWithPrepareInputs.can_generate()
        self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
        self.assertFalse(can_generate)

    def test_save_and_load_config_with_custom_generation(self):
        """
        Tests that saving and loading a config with a custom generation kwarg is not possible
        """
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)

        # The default for `num_beams` is 1 and `early_stopping` is False
        # NOTE: accessible only from generation config, EVEN IF they are saved
        # in `config.json` file in the hub
        self.assertTrue(model.generation_config.num_beams == 1)
        self.assertTrue(model.generation_config.early_stopping is False)
        self.assertFalse(hasattr(model.config, "num_beams"))
        self.assertFalse(hasattr(model.config, "early_stopping"))

        # Sanity check: We can run `generate` with the model without any warnings
        random_ids = torch.randint(0, 100, (1, 5))
        with warnings.catch_warnings(record=True) as w:
            model.generate(random_ids, max_new_tokens=3)
        self.assertTrue(len(w) == 0)

        # When we save the model and config has generation-related parameter,
        # we will throw an error, nudging user to save attributes in the generation_config
        model.config.num_beams = 5
        model.config.early_stopping = True
        self.assertTrue(model.generation_config.num_beams == 1)  # unmodified generation config
        with tempfile.TemporaryDirectory() as tmp_dir:
            with self.assertRaises(ValueError):
                model.save_pretrained(tmp_dir)

    def test_load_model_with_state_dict_only(self):
        model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
        state_dict = model.state_dict()
        config = model.config

        model_loaded = BertModel.from_pretrained(
            pretrained_model_name_or_path=None, config=config, state_dict=state_dict
        )
        self.assertTrue(check_models_equal(model, model_loaded))

    @unittest.skip("Skipping flaky test")
    def test_cache_when_needed_at_train_time(self):
        """
        Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
        is at train time used if we request it. Related issue: #35648
        """
        model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
        tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
        model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

        # By default it is not training, we have to set it
        self.assertFalse(model.training)
        model.train()

        # If we set `use_cache=True` while training, then a cache is returned
        model_outputs = model(**model_inputs, use_cache=True)
        self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
        self.assertTrue(model.training)

        # simulate injecting virtual tokens like in prefix tuning
        num_virtual_tokens = 3
        past_key_values = [
            (torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
            (torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
        ]
        past_key_values = DynamicCache(past_key_values)
        model_inputs["attention_mask"] = torch.cat(
            (
                model_inputs["attention_mask"],
                torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
            ),
            dim=1,
        )
        model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
        self.assertTrue(model.training)

        # We can also disable the cache to skip a few operations, if the training loop doesn't need cache
        # NOTE: after #41900, we need to pass the correct attention mask size
        model_inputs["attention_mask"] = model_inputs["attention_mask"][:, :-num_virtual_tokens]
        model_outputs = model(**model_inputs, use_cache=False)
        self.assertIsNone(model_outputs.past_key_values)
        self.assertTrue(model.training)

    def test_restore_default_dtype_from_pretrained(self):
        """
        Tests that the default torch dtype is restored
        when an error happens during the loading of a model.
        """
        old_dtype = torch.get_default_dtype()
        # set default type to float32
        torch.set_default_dtype(torch.float32)

        # Mock injection point which is right after the call to `_set_default_dtype`
        original_set_default_dtype = MistralForCausalLM._set_default_dtype

        def debug(*args, **kwargs):
            # call the method as usual, than raise a RuntimeError
            original_set_default_dtype(*args, **kwargs)
            raise RuntimeError

        with mock.patch(
            "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_dtype",
            side_effect=debug,
        ):
            with self.assertRaises(RuntimeError):
                _ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", dtype=torch.float16)
        # default should still be float32
        assert torch.get_default_dtype() == torch.float32
        torch.set_default_dtype(old_dtype)

    def test_restore_default_dtype_from_config(self):
        """
        Tests that the default torch dtype is restored
        when an error happens during the loading of a model.
        """
        old_dtype = torch.get_default_dtype()
        # set default type to float32
        torch.set_default_dtype(torch.float32)

        config = AutoConfig.from_pretrained(
            TINY_MISTRAL,
        )

        # Mock injection point which is right after the call to `_set_default_dtype`
        original_set_default_dtype = MistralForCausalLM._set_default_dtype

        def debug(*args, **kwargs):
            # call the method as usual, than raise a RuntimeError
            original_set_default_dtype(*args, **kwargs)
            raise RuntimeError

        with mock.patch(
            "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_dtype",
            side_effect=debug,
        ):
            with self.assertRaises(RuntimeError):
                config.dtype = torch.float16
                _ = AutoModelForCausalLM.from_config(
                    config,
                )
        # default should still be float32
        assert torch.get_default_dtype() == torch.float32
        torch.set_default_dtype(old_dtype)

    def test_unknown_quantization_config(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            config.quantization_config = {"quant_method": "unknown"}
            model.save_pretrained(tmpdir)
            with self.assertLogs("transformers", level="WARNING") as cm:
                BertModel.from_pretrained(tmpdir)
            self.assertEqual(len(cm.records), 1)
            self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))

    @parameterized.expand([("Qwen/Qwen2.5-3B-Instruct", 10), ("meta-llama/Llama-2-7b-chat-hf", 10)])
    @slow
    @require_read_token
    @require_torch_accelerator
    def test_loading_is_fast_on_gpu(self, model_id: str, max_loading_time: float):
        """
        This test is used to avoid regression on https://github.com/huggingface/transformers/pull/36380.
        10s should be more than enough for both models, and allows for some margin as loading time are quite
        unstable. Before #36380, it used to take more than 40s, so 10s is still reasonable.
        Note that we run this test in a subprocess, to ensure that cuda is not already initialized/warmed-up.
        """
        # First download the weights if not already on disk
        _ = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16)

        script_to_run = textwrap.dedent(
            """
            import torch
            import time
            import argparse
            from transformers import AutoModelForCausalLM
            from transformers.utils import is_torch_accelerator_available

            parser = argparse.ArgumentParser()
            parser.add_argument("model_id", type=str)
            parser.add_argument("max_loading_time", type=float)
            args = parser.parse_args()

            device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
            device = torch.device(f"{device_type}:0")

            torch_accelerator_module = getattr(torch, device_type, torch.cuda)
            torch_accelerator_module.synchronize(device)
            t0 = time.time()
            model = AutoModelForCausalLM.from_pretrained(args.model_id, dtype=torch.float16, device_map=device)
            torch_accelerator_module.synchronize(device)
            dt = time.time() - t0

            # Assert loading is faster (it should be more than enough in both cases)
            if dt > args.max_loading_time:
                raise ValueError(f"Loading took {dt:.2f}s! It should not take more than {args.max_loading_time}s")
            # Ensure everything is correctly loaded on accelerator
            bad_device_params = {k for k, v in model.named_parameters() if v.device != device}
            if len(bad_device_params) > 0:
                raise ValueError(f"The following parameters are not on accelerator: {bad_device_params}")
            """
        )

        with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
            tmp.write(script_to_run)
            tmp.flush()
            tmp.seek(0)
            cmd = f"python {tmp.name} {model_id} {max_loading_time}".split()
            try:
                # We cannot use a timeout of `max_loading_time` as cuda initialization can take up to 15-20s
                _ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True, timeout=60)
            except subprocess.CalledProcessError as e:
                raise Exception(f"The following error was captured: {e.stderr}")

    def test_explicit_transformers_weights(self):
        """
        Transformers supports loading from repos where the weights file is explicitly set in the config.
        When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
        If so, it will load from that file.

        Here, we ensure that the correct file is loaded.
        """
        model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
        self.assertEqual(model.num_parameters(), 87929)

    def test_explicit_transformers_weights_index(self):
        """
        Transformers supports loading from repos where the weights file is explicitly set in the config.
        When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
        If so, it will load from that file.

        Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
        """
        model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
        self.assertEqual(model.num_parameters(), 87929)

    def test_explicit_transformers_weights_save_and_reload(self):
        """
        Transformers supports loading from repos where the weights file is explicitly set in the config.
        When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
        If so, it will load from that file.

        When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
        otherwise, transformers will try to load from that file whereas it should simply load from the default file.

        We test that for a non-sharded repo.
        """
        model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
        explicit_transformers_weights = model.config.transformers_weights

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

            # The config should not have a mention of transformers_weights
            with open(os.path.join(tmpdirname, "config.json")) as f:
                config = json.loads(f.read())
                self.assertFalse("transformers_weights" in config)

            # The serialized weights should be in model.safetensors and not the transformers_weights
            self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
            self.assertTrue("model.safetensors" in os.listdir(tmpdirname))

    def test_explicit_transformers_weights_index_save_and_reload(self):
        """
        Transformers supports loading from repos where the weights file is explicitly set in the config.
        When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
        If so, it will load from that file.

        When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
        otherwise, transformers will try to load from that file whereas it should simply load from the default file.

        We test that for a sharded repo.
        """
        model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
        explicit_transformers_weights = model.config.transformers_weights

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname, max_shard_size="100kb")

            # The config should not have a mention of transformers_weights
            with open(os.path.join(tmpdirname, "config.json")) as f:
                config = json.loads(f.read())
                self.assertFalse("transformers_weights" in config)

            # The serialized weights should be in model.safetensors and not the transformers_weights
            self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
            self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))

    def test_config_class_attribute(self):
        # custom configs
        class MyConfigA(PreTrainedConfig):
            pass

        class MyConfigB(PreTrainedConfig):
            pass

        class MyConfigC(PreTrainedConfig):
            pass

        # custom models
        class MyModelA(PreTrainedModel):
            config: dict
            config_class = MyConfigA

        class MyModelB(MyModelA):
            config: MyConfigB

        class MyModelC(MyModelA):
            config_class = MyConfigC

        class MyModelD(MyModelA):
            pass

        # child config_class > child 'config:' > parent config_class > parent 'config:'
        self.assertIs(MyModelA.config_class, MyConfigA)
        self.assertIs(MyModelB.config_class, MyConfigB)
        self.assertIs(MyModelC.config_class, MyConfigC)
        self.assertIs(MyModelD.config_class, MyConfigA)

    def test_ignore_missing_key_works(self):
        """Test that if a parameter (not buffer) is specified in `_keys_to_ignore_on_load_missing` and is actually
        missing from the checkpoint, it will still be moved to cpu and initialized"""
        temp = tempfile.TemporaryDirectory()
        # Create dummy model
        model = BaseModelWithMissingKeys(PreTrainedConfig())

        # Save the config
        model.config.save_pretrained(temp.name)
        # Get the state dict to save
        state_dict = model.state_dict()
        # Remove the layer that we should ignore if missing
        del state_dict["linear.weight"], state_dict["linear.bias"]
        # Save the state dict as a single shard
        safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"})

        # Try loading back, with the missing key not present in the state_dict
        model = BaseModelWithMissingKeys.from_pretrained(temp.name)

        # Make sure the skipped missing key is not still on meta device!
        for k, v in model.state_dict().items():
            self.assertTrue(v.device.type == "cpu", f"{k} is not on cpu!")

    def test_device_map_works_with_unexpected_keys(self):
        """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually
        present in the checkpoint, it will correctly be removed from the weights we load, especially those
        we use if the device map has offloading"""
        temp = tempfile.TemporaryDirectory()

        # Create dummy model
        model = BaseModelWithUnexpectedKeys(PreTrainedConfig())

        # Save the config
        model.config.save_pretrained(temp.name)

        # Get the state dict to save
        state_dict = model.state_dict()
        # Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore
        state_dict["mtp"] = torch.randn(12, 12)
        # Save the state dict as a single shard
        safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"})

        # Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`.
        # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out.
        BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"})

    def test_device_map_works_with_unexpected_keys_sharded(self):
        """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually
        present in the checkpoint, it will correctly be removed from the weights we load, especially those
        we use if the device map has offloading"""
        temp = tempfile.TemporaryDirectory()

        # Create dummy model
        model = BaseModelWithUnexpectedKeys(PreTrainedConfig())

        # Save the config
        model.config.save_pretrained(temp.name)

        # Get the state dict to save
        state_dict = model.state_dict()

        # Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore
        state_dict["mtp"] = torch.randn(50, 50)

        # Split the state dict in shards, save the index and the shards
        shards = split_torch_state_dict_into_shards(state_dict, max_shard_size="1kb")
        index = {
            "metadata": {"total_parameters": model.num_parameters(), **shards.metadata},
            "weight_map": shards.tensor_to_filename,
        }
        with open(Path(temp.name) / SAFE_WEIGHTS_INDEX_NAME, "w", encoding="utf-8") as f:
            content = json.dumps(index, indent=2, sort_keys=True) + "\n"
            f.write(content)

        # Save each shard
        filename_to_tensors = shards.filename_to_tensors.items()
        for shard_file, tensors in filename_to_tensors:
            shard = {}
            for tensor in tensors:
                shard[tensor] = state_dict[tensor].contiguous()
            safe_save_file(shard, Path(temp.name) / shard_file, metadata={"format": "pt"})

        # Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`.
        # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out.
        BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"})

    def test_loading_respect_env_variable_for_threading(self):
        """Test that we can correctly control threading during loading"""
        model = BaseModel(PreTrainedConfig())

        # Monkey patch Thread.__init__ to add a counter of launched threads
        original_init = threading.Thread.__init__
        counter = 0

        def tracking_init(self, *args, **kwargs):
            nonlocal counter
            counter += 1
            original_init(self, *args, **kwargs)

        threading.Thread.__init__ = tracking_init

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

            # Use threading
            os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "0"
            before = counter
            _ = BaseModel.from_pretrained(tmpdirname)
            after = counter
            self.assertTrue(after - before > 0, "Loading should have spawned new threads!")

            # Deactivate threading
            os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
            before = counter
            _ = BaseModel.from_pretrained(tmpdirname)
            after = counter
            self.assertTrue(after == before, "It looks like loading did spawn new threads, but it should not have!")

        # Reverse monkey patch
        threading.Thread.__init__ = original_init


@slow
@require_torch
class ModelOnTheFlyConversionTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.user = "huggingface-hub-ci"
        cls.token = os.getenv("HUGGINGFACE_PRODUCTION_USER_TOKEN", None)

        if cls.token is None:
            raise ValueError("Cannot run tests as secret isn't setup.")

        cls.api = HfApi(token=cls.token)

    def setUp(self) -> None:
        self.repo_name = f"{self.user}/test-model-on-the-fly-{uuid.uuid4()}"

    def tearDown(self) -> None:
        self.api.delete_repo(self.repo_name)

    def test_safetensors_on_the_fly_conversion(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, "SFconvertbot")
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_conversion_private(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name, token=self.token)
            discussion = next(discussions)
            self.assertEqual(discussion.author, self.user)
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_conversion_gated(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
        self.api.update_repo_settings(self.repo_name, gated="auto")
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, "SFconvertbot")
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_sharded_conversion(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb")
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, "SFconvertbot")
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_sharded_conversion_private(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(
            self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb", private=True
        )
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, self.user)
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_sharded_conversion_gated(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, max_shard_size="200kb", safe_serialization=False)
        headers = {"Authorization": f"Bearer {self.token}"}
        httpx.put(
            f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
        )
        converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        with self.subTest("Initial and converted models are equal"):
            for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, "SFconvertbot")
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

    @unittest.skip(reason="Edge case, should work once the Space is updated`")
    def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
        BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        # This should have opened a PR with the user's account
        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)
            discussion = next(discussions)
            self.assertEqual(discussion.author, self.user)
            self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")

        # We now switch the repo visibility to public
        self.api.update_repo_settings(self.repo_name, private=False)

        # We once again call from_pretrained, which should call the bot to open a PR
        BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)

        with self.subTest("PR was open with the safetensors account"):
            discussions = self.api.get_repo_discussions(self.repo_name)

            bot_opened_pr = None
            bot_opened_pr_title = None

            for discussion in discussions:
                if discussion.author == "SFconvertbot":
                    bot_opened_pr = True
                    bot_opened_pr_title = discussion.title

            self.assertTrue(bot_opened_pr)
            self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")

    def test_safetensors_on_the_fly_specific_revision(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        # Push a model on `main`
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

        # Push a model on a given revision
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, revision="new-branch")

        # Try to convert the model on that revision should raise
        with self.assertRaises(EnvironmentError):
            BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")

    def test_absence_of_safetensors_triggers_conversion(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        # Push a model on `main`
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

        # Download the model that doesn't have safetensors
        BertModel.from_pretrained(self.repo_name, token=self.token)

        for thread in threading.enumerate():
            if thread.name == "Thread-autoconversion":
                thread.join(timeout=10)

        discussions = self.api.get_repo_discussions(self.repo_name)

        bot_opened_pr = None
        bot_opened_pr_title = None

        for discussion in discussions:
            if discussion.author == "SFconvertbot":
                bot_opened_pr = True
                bot_opened_pr_title = discussion.title

        self.assertTrue(bot_opened_pr)
        self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")

    @mock.patch("transformers.safetensors_conversion.spawn_conversion")
    def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):
        spawn_conversion_mock.side_effect = httpx.HTTPError("failed")

        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        initial_model = BertModel(config)

        # Push a model on `main`
        initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

        # The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread
        BertModel.from_pretrained(self.repo_name, token=self.token)


@require_torch
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._token = TOKEN

    @unittest.skip(reason="This test is flaky")
    def test_push_to_hub(self):
        with TemporaryHubRepo(token=self._token) as tmp_repo:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            model.push_to_hub(tmp_repo.repo_id, token=self._token)

            new_model = BertModel.from_pretrained(tmp_repo.repo_id)
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

    @unittest.skip(reason="This test is flaky")
    def test_push_to_hub_via_save_pretrained(self):
        with TemporaryHubRepo(token=self._token) as tmp_repo:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            # Push to hub via save_pretrained
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)

            new_model = BertModel.from_pretrained(tmp_repo.repo_id)
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

    def test_push_to_hub_with_description(self):
        with TemporaryHubRepo(token=self._token) as tmp_repo:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            COMMIT_DESCRIPTION = """
The commit description supports markdown synthax see:
```python
>>> form transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
```
"""
            commit_details = model.push_to_hub(
                tmp_repo.repo_id, create_pr=True, token=self._token, commit_description=COMMIT_DESCRIPTION
            )
            self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)

    @unittest.skip(reason="This test is flaky")
    def test_push_to_hub_in_organization(self):
        with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            model.push_to_hub(tmp_repo.repo_id, token=self._token)

            new_model = BertModel.from_pretrained(tmp_repo.repo_id)
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

    @unittest.skip(reason="This test is flaky")
    def test_push_to_hub_in_organization_via_save_pretrained(self):
        with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
            config = BertConfig(
                vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
            )
            model = BertModel(config)
            # Push to hub via save_pretrained
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo.repo_id)

            new_model = BertModel.from_pretrained(tmp_repo.repo_id)
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

    def test_push_to_hub_dynamic_model(self):
        with TemporaryHubRepo(token=self._token) as tmp_repo:
            CustomConfig.register_for_auto_class()
            CustomModel.register_for_auto_class()

            config = CustomConfig(hidden_size=32)
            model = CustomModel(config)

            model.push_to_hub(tmp_repo.repo_id, token=self._token)
            # checks
            self.assertDictEqual(
                config.auto_map,
                {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
            )

            new_model = AutoModel.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
            # Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
            self.assertEqual(new_model.__class__.__name__, "CustomModel")
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

            config = AutoConfig.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
            new_model = AutoModel.from_config(config, trust_remote_code=True)
            self.assertEqual(new_model.__class__.__name__, "CustomModel")

    def test_push_to_hub_with_tags(self):
        with TemporaryHubRepo(token=self._token) as tmp_repo:
            from huggingface_hub import ModelCard

            new_tags = ["tag-1", "tag-2"]

            CustomConfig.register_for_auto_class()
            CustomModel.register_for_auto_class()

            config = CustomConfig(hidden_size=32)
            model = CustomModel(config)

            self.assertTrue(model.model_tags is None)

            model.add_model_tags(new_tags)

            self.assertTrue(model.model_tags == new_tags)

            model.push_to_hub(tmp_repo.repo_id, token=self._token)

            loaded_model_card = ModelCard.load(tmp_repo.repo_id)
            self.assertEqual(loaded_model_card.data.tags, new_tags)


@require_torch
class AttentionMaskTester(unittest.TestCase):
    def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
        mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
        mask_4d_values = mask_4d[:, 0][mask_indices]
        is_inf = mask_4d_values == -float("inf")
        is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
        assert torch.logical_or(is_inf, is_min).all()

    def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
        mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)

        if additional_mask is not None:
            for bsz_idx, seq_idx in additional_mask:
                mask_2d[bsz_idx, seq_idx] = 0

        mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len, dtype=torch.float32)

        assert mask_4d.shape == (bsz, 1, q_len, kv_len)

        # make sure there are no overflows
        assert mask_4d.min() != float("-inf")

        context = mask_converter.sliding_window
        if mask_converter.is_causal and context is None:
            # k * (k+1) / 2 tokens are masked in triangualar masks
            num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)

            if 0 not in mask_2d:
                assert (mask_4d != 0).sum().item() == num_tokens_masked
            if 0 in mask_2d:
                # at least causal mask + maybe more
                assert (mask_4d != 0).sum().item() >= num_tokens_masked
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
        elif not mask_converter.is_causal and context is None:
            if 0 not in mask_2d:
                assert (mask_4d != 0).sum().item() == 0
            if 0 in mask_2d:
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
        elif mask_converter.is_causal and context is not None:
            # k * (k+1) / 2 tokens are masked in triangualar masks
            num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
            num_tokens_masked = bsz * num_tokens_masked

            if 0 not in mask_2d:
                assert (mask_4d != 0).sum().item() == num_tokens_masked
            if 0 in mask_2d:
                # at least causal mask + maybe more
                assert (mask_4d != 0).sum().item() >= num_tokens_masked
                self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)

    def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
        mask_4d = mask_converter.to_causal_4d(
            bsz, query_length=q_len, key_value_length=kv_len, device=torch_device, dtype=torch.float32
        )

        if q_len == 1 and mask_converter.sliding_window is None:
            # no causal mask if q_len is 1
            assert mask_4d is None
            return

        context = mask_converter.sliding_window
        if mask_converter.is_causal and context is None:
            # k * (k+1) / 2 tokens are masked in triangualar masks
            num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)

            assert (mask_4d != 0).sum().item() == num_tokens_masked
        elif not mask_converter.is_causal and context is None:
            assert (mask_4d != 0).sum().item() == 0
        elif mask_converter.is_causal and context is not None:
            # k * (k+1) / 2 tokens are masked in triangualar masks
            num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
            num_tokens_masked = bsz * num_tokens_masked

            assert (mask_4d != 0).sum().item() == num_tokens_masked

    def compute_num_context_mask(self, kv_len, context, q_len):
        # This function computes the # of attention tokens that are added for
        # the sliding window
        c_mask_len = kv_len - context - 1
        num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
        cut_mask_len = max(c_mask_len - q_len, 0)
        num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
        return num_mask_triangle - num_cut_mask

    def test_2d_to_4d_causal(self):
        mask_converter = AttentionMaskConverter(is_causal=True)

        # auto-regressive use case
        self.check_to_4d(mask_converter, q_len=1, kv_len=7)
        # special auto-regressive case
        self.check_to_4d(mask_converter, q_len=3, kv_len=7)
        # non auto-regressive case
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)

        # same with extra attention masks
        self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
        self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])

        # check that the mask does not overflow on causal masked tokens
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 0), (1, 0), (1, 1)])

    def test_2d_to_4d(self):
        mask_converter = AttentionMaskConverter(is_causal=False)

        # non auto-regressive case
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)

        # same with extra attention masks
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])

    def test_2d_to_4d_causal_sliding(self):
        mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=5)

        # auto-regressive use case
        self.check_to_4d(mask_converter, q_len=1, kv_len=7)
        # special auto-regressive case
        self.check_to_4d(mask_converter, q_len=3, kv_len=7)
        # non auto-regressive case
        self.check_to_4d(mask_converter, q_len=7, kv_len=7)

        # same with extra attention masks
        self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
        self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
        self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])

    def test_causal_mask(self):
        mask_converter = AttentionMaskConverter(is_causal=True)

        # auto-regressive use case
        self.check_to_causal(mask_converter, q_len=1, kv_len=7)
        # special auto-regressive case
        self.check_to_causal(mask_converter, q_len=3, kv_len=7)
        # non auto-regressive case
        self.check_to_causal(mask_converter, q_len=7, kv_len=7)

    def test_causal_mask_sliding(self):
        mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=3)

        # auto-regressive use case
        self.check_to_causal(mask_converter, q_len=1, kv_len=7)
        # special auto-regressive case
        self.check_to_causal(mask_converter, q_len=3, kv_len=7)
        # non auto-regressive case
        self.check_to_causal(mask_converter, q_len=7, kv_len=7)

    @pytest.mark.torch_compile_test
    def test_torch_compile_fullgraph(self):
        model = Prepare4dCausalAttentionMaskModel()

        inputs_embeds = torch.rand([1, 3, 32])
        res_non_compiled = model(inputs_embeds)

        compiled_model = torch.compile(model, fullgraph=True)

        res_compiled = compiled_model(inputs_embeds)

        self.assertTrue(torch.equal(res_non_compiled, res_compiled))

        model = Create4dCausalAttentionMaskModel()

        inputs_embeds = torch.rand(2, 4, 16)
        res_non_compiled = model(inputs_embeds)

        compiled_model = torch.compile(model, fullgraph=True)
        res_compiled = compiled_model(inputs_embeds)

        self.assertTrue(torch.equal(res_non_compiled, res_compiled))

        model = Prepare4dAttentionMaskModel()

        mask = torch.ones(2, 4)
        mask[0, :2] = 0
        inputs_embeds = torch.rand(2, 4, 16)

        res_non_compiled = model(mask, inputs_embeds)

        compiled_model = torch.compile(model, fullgraph=True)
        res_compiled = compiled_model(mask, inputs_embeds)

        self.assertTrue(torch.equal(res_non_compiled, res_compiled))

    @require_torch
    @slow
    def test_unmask_unattended_left_padding(self):
        attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)

        expanded_mask = torch.Tensor(
            [
                [[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
                [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
                [[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
            ]
        ).to(torch.int64)

        reference_output = torch.Tensor(
            [
                [[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
                [[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
                [[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
            ]
        ).to(torch.int64)

        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)

        self.assertTrue(torch.equal(result, reference_output))

        attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)

        attn_mask_converter = AttentionMaskConverter(is_causal=True)
        past_key_values_length = 0
        key_value_length = attention_mask.shape[-1] + past_key_values_length

        expanded_mask = attn_mask_converter.to_4d(
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
        )

        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
        min_inf = torch.finfo(torch.float32).min
        reference_output = torch.Tensor(
            [
                [
                    [
                        [0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0],
                        [min_inf, min_inf, 0, min_inf, min_inf],
                        [min_inf, min_inf, 0, 0, min_inf],
                        [min_inf, min_inf, 0, 0, 0],
                    ]
                ],
                [
                    [
                        [0, min_inf, min_inf, min_inf, min_inf],
                        [0, 0, min_inf, min_inf, min_inf],
                        [0, 0, 0, min_inf, min_inf],
                        [0, 0, 0, 0, min_inf],
                        [0, 0, 0, 0, 0],
                    ]
                ],
                [
                    [
                        [0, 0, 0, 0, 0],
                        [min_inf, 0, min_inf, min_inf, min_inf],
                        [min_inf, 0, 0, min_inf, min_inf],
                        [min_inf, 0, 0, 0, min_inf],
                        [min_inf, 0, 0, 0, 0],
                    ]
                ],
            ]
        )

        self.assertTrue(torch.equal(reference_output, result))

    @require_torch
    @slow
    def test_unmask_unattended_right_padding(self):
        attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)

        attn_mask_converter = AttentionMaskConverter(is_causal=True)
        past_key_values_length = 0
        key_value_length = attention_mask.shape[-1] + past_key_values_length

        expanded_mask = attn_mask_converter.to_4d(
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
        )

        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)

        self.assertTrue(torch.equal(expanded_mask, result))

    @require_torch
    @slow
    def test_unmask_unattended_random_mask(self):
        attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)

        attn_mask_converter = AttentionMaskConverter(is_causal=True)
        past_key_values_length = 0
        key_value_length = attention_mask.shape[-1] + past_key_values_length

        expanded_mask = attn_mask_converter.to_4d(
            attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
        )

        result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)

        self.assertTrue(torch.equal(expanded_mask, result))


@require_torch
class TestAttentionImplementation(unittest.TestCase):
    @unittest.skip("Just a bit annoying")
    def test_error_no_sdpa_available(self):
        with self.assertRaises(ValueError) as cm:
            _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")

        self.assertTrue(
            "does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
            in str(cm.exception)
        )

        _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

    # TODO (ydshieh): use another model
    @unittest.skip("model deleted")
    def test_error_no_flash_available(self):
        with self.assertRaises(ValueError) as cm:
            _ = AutoModel.from_pretrained(
                "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
            )

        self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

    # TODO (ydshieh): use another model
    @unittest.skip("model deleted")
    def test_error_no_flash_available_with_config(self):
        with self.assertRaises(ValueError) as cm:
            config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

            _ = AutoModel.from_pretrained(
                "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
            )

        self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

    # TODO (ydshieh): use another model
    @unittest.skip("model deleted")
    def test_error_wrong_attn_implementation(self):
        with self.assertRaises(ValueError) as cm:
            _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")

        self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))

    def test_not_available_flash(self):
        if is_flash_attn_2_available():
            self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")

        if is_torch_npu_available():
            self.skipTest(
                reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
            )

        if is_kernels_available():
            self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash`")

        with self.assertRaises(ImportError) as cm:
            _ = AutoModel.from_pretrained(
                "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
            )
        self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

    def test_not_available_flash_with_config(self):
        if is_flash_attn_2_available():
            self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")

        if is_torch_npu_available():
            self.skipTest(
                reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
            )

        if is_kernels_available():
            self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_flash_with_config`")

        config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")

        with self.assertRaises(ImportError) as cm:
            _ = AutoModel.from_pretrained(
                "hf-internal-testing/tiny-random-GPTBigCodeModel",
                config=config,
                attn_implementation="flash_attention_2",
            )

        self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

    def test_kernels_fallback(self):
        if not is_kernels_available():
            self.skipTest(reason="Please install `kernels` package to run `test_kernels_fallback`")

        if is_flash_attn_2_available():
            self.skipTest(reason="Please uninstall flash-attn package to run test_kernels_fallback")

        if is_torch_npu_available():
            self.skipTest(
                reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case."
            )

        logger = logging.get_logger("transformers.modeling_utils")
        with LoggingLevel(logging.WARNING):
            with CaptureLogger(logger) as cl:
                _ = AutoModel.from_pretrained(
                    "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
                )

        self.assertTrue(
            "You do not have `flash_attn` installed, using `kernels-community/flash-attn2` from the `kernels` library instead!"
            in cl.out
        )

    # TODO (ydshieh): use another model
    @unittest.skip("model deleted")
    def test_not_available_kernels(self):
        if is_kernels_available():
            self.skipTest(reason="Please uninstall `kernels` package to run `test_not_available_kernels`")

        with self.assertRaises(ImportError) as cm:
            _ = AutoModel.from_pretrained(
                "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn2"
            )

        self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))


@require_torch
class TestTensorSharing(TestCasePlus):
    def test_disjoint(self):
        main = torch.zeros(10)
        a = main[:5]
        b = main[5:]
        state_dict = {"a": a, "b": b}

        shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
        self.assertEqual(shared_names, [])
        self.assertEqual(disjoint_names, ["a", "b"])

        a = main[::2]
        b = main[1::2]
        state_dict = {"a": a, "b": b}

        shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
        self.assertEqual(shared_names, [{"a", "b"}])
        self.assertEqual(disjoint_names, [])

    def test_identical(self):
        a = torch.zeros(10)
        b = a
        state_dict = {"a": a, "b": b}

        shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
        self.assertEqual(shared_names, [])
        self.assertEqual(identical_names, [{"a", "b"}])

        b = a[:5]
        state_dict = {"a": a, "b": b}

        shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
        self.assertEqual(shared_names, [{"a", "b"}])
        self.assertEqual(identical_names, [])


@require_torch
@unittest.skip(
    "These tests are currently failing and need to be fixed, but not sure we want to support this/not sure its even used! Fix this line:https://github.com/huggingface/transformers/blob/b750e6b9eeed5fb9adc2f8c7adb46639c8e41963/src/transformers/core_model_loading.py#L512"
)
class TestSaveAndLoadModelWithExtraState(TestCasePlus):
    """
    This test checks that a model can be saved and loaded that uses the torch extra state API.
    https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state.

    Currently, only tensor-valued extra_states are supported.
    """

    def test_save_and_load_model_with_tensor_extra_state(self):
        class MyConfig(PreTrainedConfig):
            def __init__(self, **kwargs):
                super().__init__(**kwargs)

        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.some_counter = 0
                self.linear = torch.nn.Linear(320, 320)

            def get_extra_state(self):
                return torch.tensor(self.some_counter)

            def set_extra_state(self, state):
                self.some_counter = state.item()

        class MyModel(PreTrainedModel):
            config_class = MyConfig

            def __init__(self, config: MyConfig):
                super().__init__(config)
                self.my_layer = MyModule()

            def forward(self, hidden_states, attention_mask):
                return self.my_layer(hidden_states, attention_mask)

        config = MyConfig()
        model = MyModel(config)
        model.my_layer.some_counter = 42

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
            model = MyModel.from_pretrained(tmpdirname)
            self.assertEqual(model.my_layer.some_counter, 42)

    @mark.xfail(reason="save and from_pretrained currently only supports tensor extra_state")
    def test_save_and_load_model_with_dict_extra_state(self):
        class MyConfig(PreTrainedConfig):
            def __init__(self, **kwargs):
                super().__init__(**kwargs)

        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.some_counter = 0
                self.linear = torch.nn.Linear(320, 320)

            def get_extra_state(self):
                return {"some_counter": self.some_counter}

            def set_extra_state(self, state):
                self.some_counter = state["some_counter"]

        class MyModel(PreTrainedModel):
            config_class = MyConfig

            def __init__(self, config: MyConfig):
                super().__init__(config)
                self.my_layer = MyModule()

            def forward(self, hidden_states, attention_mask):
                return self.my_layer(hidden_states, attention_mask)

        config = MyConfig()
        model = MyModel(config)
        model.my_layer.some_counter = 42

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
            model = MyModel.from_pretrained(tmpdirname)
            self.assertEqual(model.my_layer.some_counter, 42)


class TestGetDecoder(unittest.TestCase):
    def test_causal_lm_get_decoder_returns_underlying_model(self):
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        model = MistralForCausalLM(cfg)
        dec = model.get_decoder()

        assert dec is model.model, f"Expected get_decoder() to return model.model, got {type(dec)}"

    def test_seq2seq_get_decoder_still_returns_decoder_module(self):
        cfg = BartConfig(
            vocab_size=128,
            d_model=32,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=4,
            decoder_attention_heads=4,
            encoder_ffn_dim=64,
            decoder_ffn_dim=64,
        )
        model = BartForConditionalGeneration(cfg)
        dec = model.get_decoder()

        assert dec is model.model.decoder, "Seq2seq get_decoder() should return the decoder submodule"

    def test_base_model_returns_self(self):
        """Test that base transformer models (no decoder/model attributes) return self."""
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        base_model = MistralModel(cfg)
        dec = base_model.get_decoder()

        assert dec is base_model, f"Base model get_decoder() should return self, got {type(dec)}"

    def test_explicit_decoder_attribute_opt(self):
        """Test models with explicit decoder attribute (OPT style)."""
        cfg = OPTConfig(
            vocab_size=128,
            hidden_size=32,
            ffn_dim=64,
            num_hidden_layers=2,
            num_attention_heads=4,
            max_position_embeddings=512,
        )
        model = OPTForCausalLM(cfg)
        dec = model.get_decoder()

        assert dec is model.model.decoder, f"OPT get_decoder() should return model.decoder, got {type(dec)}"

    def test_explicit_decoder_attribute_t5(self):
        """Test encoder-decoder models with explicit decoder attribute."""
        cfg = T5Config(
            vocab_size=128,
            d_model=32,
            d_ff=64,
            num_layers=2,
            num_heads=4,
        )
        model = T5ForConditionalGeneration(cfg)
        dec = model.get_decoder()

        assert dec is model.decoder, f"T5 get_decoder() should return decoder attribute, got {type(dec)}"

    def test_same_type_recursion_prevention(self):
        """Test that same-type recursion is prevented (see issue #40815)."""
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        model = MistralForCausalLM(cfg)

        assert type(model) is not type(model.model), "Types should be different to prevent recursion"

        dec = model.get_decoder()
        assert dec is model.model, f"Should return model.model without infinite recursion, got {type(dec)}"

        inner_dec = model.model.get_decoder()
        assert inner_dec is model.model, f"Inner model should return itself, got {type(inner_dec)}"

    def test_nested_wrapper_recursion(self):
        """Test models that don't have model/decoder attributes return self."""
        cfg = GPT2Config(
            vocab_size=128,
            n_embd=32,
            n_layer=2,
            n_head=4,
            n_positions=512,
        )
        model = GPT2LMHeadModel(cfg)
        dec = model.get_decoder()

        assert dec is model.transformer, f"GPT2 get_decoder() should return self (fallback), got {type(dec)}"

    def test_model_without_get_decoder(self):
        """Test edge case where model has model attribute but no get_decoder method."""

        class MockInnerModel:
            """Mock model without get_decoder method."""

            pass

        class MockWrapperModel:
            """Mock wrapper with model attribute but inner has no get_decoder."""

            def __init__(self):
                self.model = MockInnerModel()

            def get_decoder(self):
                if hasattr(self, "decoder"):
                    return self.decoder
                if hasattr(self, "model"):
                    inner = self.model
                    if hasattr(inner, "get_decoder") and type(inner) is not type(self):
                        return inner.get_decoder()
                    return inner
                return self

        wrapper = MockWrapperModel()
        dec = wrapper.get_decoder()

        assert dec is wrapper.model, f"Should return inner model when no get_decoder, got {type(dec)}"

    def test_vision_language_model(self):
        """Test vision-language models like LLaVA that delegate to language_model."""
        text_config = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )

        vision_config = {
            "hidden_size": 32,
            "intermediate_size": 64,
            "num_hidden_layers": 2,
            "num_attention_heads": 4,
            "num_channels": 3,
            "image_size": 224,
            "patch_size": 16,
        }

        cfg = LlavaConfig(
            text_config=text_config.to_dict(),
            vision_config=vision_config,
            vocab_size=128,
        )

        model = LlavaForConditionalGeneration(cfg)
        dec = model.get_decoder()

        assert dec is model.model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}"


class TestGetEncoder(unittest.TestCase):
    def test_seq2seq_lm_get_encoder_returns_encoder(self):
        cfg = BartConfig(
            vocab_size=128,
            d_model=32,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=4,
            decoder_attention_heads=4,
            encoder_ffn_dim=64,
            decoder_ffn_dim=64,
        )
        model = BartForConditionalGeneration(cfg)
        encoder = model.get_encoder()

        assert encoder is model.model.encoder, (
            f"Expected get_encoder() to return model.model.encoder, got {type(encoder)}"
        )

    def test_base_model_returns_encoder(self):
        cfg = BartConfig(
            vocab_size=128,
            d_model=32,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=4,
            decoder_attention_heads=4,
            encoder_ffn_dim=64,
            decoder_ffn_dim=64,
        )
        model = BartModel(cfg)
        encoder = model.get_encoder()

        assert encoder is model.encoder, f"Expected get_encoder() to return  model.encoder, got {type(encoder)}"

    def test_decoder_only_model_returns_self(self):
        """Test that decoder-only models (no encoder) return self."""
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        model = MistralForCausalLM(cfg)
        encoder = model.get_encoder()

        assert encoder is model, f"Base model get_encoder() should return self, got {type(encoder)}"

    def test_when_encoder_has_different_name(self):
        """Test models with non-standard name for encoder modular (Musicgen has `self.model.text_encoder`)."""
        cfg = MusicgenConfig(
            text_encoder={
                "model_type": "t5",
                "vocab_size": 99,
                "d_model": 32,
                "d_ff": 37,
                "num_layers": 2,
                "num_heads": 2,
            },
            audio_encoder={
                "model_type": "encodec",
                "hidden_size": 99,
                "compress": 1,
                "num_filters": 2,
                "codebook_size": 32,
                "codebook_dim": 32,
            },
            decoder={
                "vocab_size": 99,
                "ffn_dim": 32,
                "num_attention_heads": 2,
                "hidden_size": 32,
                "num_hidden_layers": 2,
            },
        )
        model = MusicgenForConditionalGeneration(cfg)
        encoder = model.get_encoder()

        assert encoder is model.text_encoder, (
            f"MusicgenForConditionalGeneration get_encoder() should return model.model.text_encoder, got {type(encoder)}"
        )

    def test_audio_encoder(self):
        """Test models with multiple modality encoders (Musicgen has `self.model.audio_encoder`)."""
        cfg = MusicgenConfig(
            text_encoder={
                "model_type": "t5",
                "vocab_size": 99,
                "d_model": 32,
                "d_ff": 37,
                "num_layers": 2,
                "num_heads": 2,
            },
            audio_encoder={
                "model_type": "encodec",
                "hidden_size": 99,
                "compress": 1,
                "num_filters": 2,
                "codebook_size": 32,
                "codebook_dim": 32,
            },
            decoder={
                "vocab_size": 99,
                "ffn_dim": 32,
                "num_attention_heads": 2,
                "hidden_size": 32,
                "num_hidden_layers": 2,
            },
        )
        model = MusicgenForConditionalGeneration(cfg)
        encoder = model.get_encoder(modality="audio")

        assert encoder is model.audio_encoder, (
            f"MusicgenForConditionalGeneration get_encoder(modality='audio') should return model.model.audio_encoder, got {type(encoder)}"
        )

    def test_non_existant_modality_throws_error(self):
        """Test that an error is thrown when a rquested modality does not exist."""
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        model = MistralModel(cfg)
        with self.assertRaises(ValueError):
            _ = model.get_encoder(modality="3d")

    def test_encoder_return_self_when_modality_not_found(self):
        """Test that `self` is returned if the model has no encoder for requested modality."""
        cfg = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )
        model = MistralModel(cfg)
        encoder = model.get_encoder(modality="image")

        assert encoder is model, f"Mistral get_encoder(modality='image') should return self, got {type(encoder)}"

    def test_model_without_get_encoder(self):
        """Test edge case where model has model attribute but no get_encoder method."""

        class MockInnerModel:
            """Mock model without get_encoder method."""

            pass

        class MockWrapperModel:
            """Mock wrapper with model attribute but inner has no get_encoder."""

            def __init__(self):
                self.model = MockInnerModel()

            def get_encoder(self):
                if hasattr(self, "encoder"):
                    return self.encoder
                if hasattr(self, "model"):
                    inner = self.model
                    if hasattr(inner, "get_encoder") and type(inner) is not type(self):
                        return inner.get_encoder()
                    return inner
                return self

        wrapper = MockWrapperModel()
        encoder = wrapper.get_encoder()

        assert encoder is wrapper.model, f"Should return inner model when no get_encoder, got {type(encoder)}"

    def test_vision_language_model(self):
        """Test vision-language models like LLaVA can find the modality encoder ("image")."""
        text_config = MistralConfig(
            vocab_size=128,
            hidden_size=32,
            intermediate_size=64,
            num_hidden_layers=2,
            num_attention_heads=4,
        )

        vision_config = {
            "hidden_size": 32,
            "intermediate_size": 64,
            "num_hidden_layers": 2,
            "num_attention_heads": 4,
            "num_channels": 3,
            "image_size": 224,
            "patch_size": 16,
        }

        cfg = LlavaConfig(
            text_config=text_config.to_dict(),
            vision_config=vision_config,
            vocab_size=128,
        )

        model = LlavaForConditionalGeneration(cfg)
        image_encoder = model.get_encoder(modality="image")

        assert image_encoder is model.model.vision_tower, (
            f"LLaVA get_encoder(modality='image') should return vision_tower, got {type(image_encoder)}"
        )
