# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Moshi model."""

import copy
import tempfile
import unittest
from functools import cached_property

import numpy as np
import pytest
from datasets import Audio, load_dataset
from parameterized import parameterized

from transformers import (
    MoshiConfig,
    PreTrainedConfig,
)
from transformers.integrations.deepspeed import (
    is_deepspeed_available,
    is_deepspeed_zero3_enabled,
)
from transformers.testing_utils import (
    is_flaky,
    is_torch_available,
    require_torch,
    require_torch_fp16,
    slow,
    torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
    TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
    ModelTesterMixin,
    floats_tensor,
    ids_tensor,
)
from ...test_pipeline_mixin import PipelineTesterMixin


if is_deepspeed_available():
    import deepspeed

if is_torch_available():
    import torch

    from transformers import (
        AutoFeatureExtractor,
        AutoTokenizer,
        MoshiForCausalLM,
        MoshiForConditionalGeneration,
        MoshiModel,
    )


def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__:
        if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
            setattr(configs_no_init, key, 1e-10)
        if isinstance(getattr(configs_no_init, key, None), PreTrainedConfig):
            no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
            setattr(configs_no_init, key, no_init_subconfig)
    return configs_no_init


class MoshiDecoderTester:
    def __init__(
        self,
        parent,
        batch_size=4,  # need batch_size != num_hidden_layers
        seq_length=7,
        is_training=True,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=2,
        num_attention_heads=4,
        intermediate_size=4,
        hidden_act="silu",
        rms_norm_eps=0.001,
        ffn_dim=32,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=100,
        pad_token_id=25,
        num_codebooks=4,
        audio_encoder_type="mimi",
        attn_implementation="eager",
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.rms_norm_eps = rms_norm_eps
        self.ffn_dim = ffn_dim
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.pad_token_id = pad_token_id
        self.num_codebooks = num_codebooks
        self.audio_encoder_type = audio_encoder_type
        self.attn_implementation = attn_implementation

    def prepare_config_and_inputs(self, batch_size=None):
        batch_size = self.batch_size if batch_size is None else batch_size
        input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size)
        config = self.get_config()

        attention_mask = input_ids.ne(self.pad_token_id)

        inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
        return config, inputs_dict

    def get_config(self):
        config = MoshiConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            d_ff=self.intermediate_size,
            num_codebooks=self.num_codebooks,
            rms_norm_eps=self.rms_norm_eps,
            tie_word_embeddings=False,
            pad_token_id=self.pad_token_id,
            ffn_dim=self.ffn_dim,
            audio_encoder_config={"model_type": self.audio_encoder_type},
            attn_implementation=self.attn_implementation,
        )
        return config

    def prepare_config_and_inputs_for_common(self, batch_size=None):
        config, inputs_dict = self.prepare_config_and_inputs(batch_size)
        return config, inputs_dict


@require_torch
class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
    all_model_classes = (MoshiModel, MoshiForCausalLM) if is_torch_available() else ()

    test_resize_embeddings = True
    pipeline_model_mapping = (
        {
            "feature-extraction": MoshiModel,
            "text-generation": MoshiForCausalLM,
        }
        if is_torch_available()
        else {}
    )

    def setUp(self):
        self.model_tester = MoshiDecoderTester(self)
        self.config_tester = ConfigTester(
            self,
            config_class=MoshiConfig,
            hidden_size=16,
            audio_encoder_config={"model_type": self.model_tester.audio_encoder_type},
        )

    @unittest.skip(reason="The MoshiModel does not have support dynamic compile yet")
    @pytest.mark.torch_compile_test
    def test_sdpa_can_compile_dynamic(self):
        pass

    def _get_input_ids_and_config(self, batch_size=1):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size)
        input_ids = inputs_dict.pop("input_ids").to(torch_device)
        attention_mask = inputs_dict.pop("attention_mask").to(torch_device)

        return config, input_ids, attention_mask, inputs_dict

    def _get_logits_processor_kwargs(self, do_sample=False, config=None):
        logits_processor_kwargs = {}
        return logits_processor_kwargs

    @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
    def test_eager_matches_sdpa_inference(
        self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
    ):
        if use_attention_mask or (not use_attention_mask and dtype == "fp32" and not output_attentions):
            self.skipTest("Test is failing, fix me :) ")
        parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
        parent_parameterized_test(self)

    # Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings
    def test_resize_tokens_embeddings(self):
        if not self.test_resize_embeddings:
            self.skipTest(reason="test_resize_embeddings is set to `False`")

        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            if is_deepspeed_zero3_enabled():
                with deepspeed.zero.Init():
                    model = model_class(config)
            else:
                model = model_class(config)
                model.to(torch_device)

            model_embed_pre_resize = model.get_input_embeddings()
            type_model_embed_pre_resize = type(model_embed_pre_resize)

            if self.model_tester.is_training is False:
                model.eval()

            model_vocab_size = config.get_text_config().vocab_size
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
            new_model_vocab_size = model.config.get_text_config().vocab_size
            self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
            # Check to make sure the type of embeddings returned post resizing is same as type of input
            type_model_embed_post_resize = type(model_embed)
            self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
            # Check that added embeddings mean is close to the old embeddings mean
            if is_deepspeed_zero3_enabled():
                with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
                    old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
                    new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
            else:
                old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
                new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
            torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3)

            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            if not is_deepspeed_zero3_enabled():
                # A distriputed launcher is needed for the forward pass when deepspeed is enabled
                model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
            new_model_vocab_size = model.config.get_text_config().vocab_size
            self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)

            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)

            # make sure that decoder_input_ids are resized as well
            if not is_deepspeed_zero3_enabled():
                # A distriputed launcher is needed for the forward pass when deepspeed is enabled
                if "decoder_input_ids" in inputs_dict:
                    inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
                model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False

            self.assertTrue(models_equal)

            del model
            if is_deepspeed_zero3_enabled():
                with deepspeed.zero.Init():
                    model = model_class(config)
            else:
                model = model_class(config)
                model.to(torch_device)

            model_vocab_size = config.get_text_config().vocab_size
            model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
            new_model_vocab_size = model.config.get_text_config().vocab_size
            self.assertTrue(new_model_vocab_size + 10, model_vocab_size)

            model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
            new_model_vocab_size = model.config.get_text_config().vocab_size
            self.assertTrue(model_embed.weight.shape[0] // 64, 0)

            self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
            self.assertTrue(new_model_vocab_size, model.vocab_size)

            model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
            self.assertTrue(model_embed.weight.shape[0] // 64, 0)

            # Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
            target_dimension = 128
            model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
            self.assertTrue(model_embed.weight.shape[0], target_dimension)

            with self.assertRaisesRegex(
                ValueError,
                "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
            ):
                model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)

            # Test when `vocab_size` is smaller than `hidden_size`.
            del model
            config.vocab_size = 4
            config.pad_token_id = 4  # Ignore copy
            if is_deepspeed_zero3_enabled():
                with deepspeed.zero.Init():
                    model = model_class(config)
            else:
                model = model_class(config)
                model.to(torch_device)

            model_vocab_size = config.get_text_config().vocab_size
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
            new_model_vocab_size = model.config.get_text_config().vocab_size
            self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
            # Check to make sure the type of embeddings returned post resizing is same as type of input
            type_model_embed_post_resize = type(model_embed)
            self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
            # Check that added embeddings mean is close to the old embeddings mean
            if is_deepspeed_zero3_enabled():
                with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
                    old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
                    new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
            else:
                old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
                new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
            torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3)

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_cpu_offload(self):
        pass

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_disk_offload_bin(self):
        pass

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_disk_offload_safetensors(self):
        pass

    @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
    def test_generate_continue_from_inputs_embeds(self):
        pass

    @is_flaky(max_attempts=5, description="flaky on some models.")
    def test_save_load(self):
        super().test_save_load()


class MoshiTester:
    def __init__(
        self,
        parent,
        batch_size=4,  # need batch_size != num_hidden_layers
        seq_length=7,
        is_training=True,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=2,
        num_attention_heads=8,
        intermediate_size=4,
        hidden_act="silu",
        rms_norm_eps=0.001,
        ffn_dim=32,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=100,
        pad_token_id=25,
        bos_token_id=25,
        num_codebooks=4,
        audio_encoder_type="mimi",
        attn_implementation="eager",
        depth_hidden_size=16,
        depth_num_hidden_layers=2,
        depth_max_position_embeddings=5,
        depth_num_attention_heads=8,
        depth_ffn_dim=16,
        depth_sliding_window=4,
        mimi_intermediate_size=40,
        mimi_hidden_size=32,
        mimi_num_filters=8,
        mimi_num_residual_layers=1,
        mimi_upsampling_ratios=[8, 4],
        mimi_codebook_size=64,
        mimi_vector_quantization_hidden_dimension=64,
        mimi_codebook_dim=64,
        mimi_upsample_groups=32,
        mimi_num_hidden_layers=2,
        mimi_num_attention_heads=2,
        mimi_num_key_value_heads=2,
        mimi_sliding_window=3,
        sampling_rate=800,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.rms_norm_eps = rms_norm_eps
        self.ffn_dim = ffn_dim
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.num_codebooks = num_codebooks
        self.attn_implementation = attn_implementation
        self.depth_hidden_size = depth_hidden_size
        self.depth_num_hidden_layers = depth_num_hidden_layers
        self.depth_max_position_embeddings = depth_max_position_embeddings
        self.depth_num_attention_heads = depth_num_attention_heads
        self.depth_ffn_dim = depth_ffn_dim
        self.depth_sliding_window = depth_sliding_window

        self.audio_encoder_type = audio_encoder_type
        self.mimi_intermediate_size = mimi_intermediate_size
        self.mimi_hidden_size = mimi_hidden_size
        self.mimi_num_filters = mimi_num_filters
        self.mimi_num_residual_layers = mimi_num_residual_layers
        self.mimi_upsampling_ratios = mimi_upsampling_ratios
        self.mimi_codebook_size = mimi_codebook_size
        self.mimi_vector_quantization_hidden_dimension = mimi_vector_quantization_hidden_dimension
        self.mimi_codebook_dim = mimi_codebook_dim
        self.mimi_upsample_groups = mimi_upsample_groups
        self.mimi_num_hidden_layers = mimi_num_hidden_layers
        self.mimi_num_attention_heads = mimi_num_attention_heads
        self.mimi_num_key_value_heads = mimi_num_key_value_heads
        self.mimi_sliding_window = mimi_sliding_window
        self.sampling_rate = sampling_rate

        self.num_hidden_states_types = 2

    def prepare_config_and_inputs(self, batch_size=None):
        batch_size = self.batch_size if batch_size is None else batch_size

        input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size)

        moshi_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size)
        user_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size)
        attention_mask = input_ids.ne(self.pad_token_id)

        config = self.get_config()
        inputs_dict = {
            "input_ids": input_ids,
            "moshi_audio_codes": moshi_audio_codes,
            "user_audio_codes": user_audio_codes,
            "attention_mask": attention_mask,
        }
        return config, inputs_dict

    def get_config(self):
        mimi_dict_config = {
            "model_type": self.audio_encoder_type,
            "audio_channels": 1,
            "hidden_size": self.mimi_hidden_size,
            "num_filters": self.mimi_num_filters,
            "num_residual_layers": self.mimi_num_residual_layers,
            "upsampling_ratios": self.mimi_upsampling_ratios,
            "codebook_size": self.mimi_codebook_size,
            "vector_quantization_hidden_dimension": self.mimi_vector_quantization_hidden_dimension,
            "upsample_groups": self.mimi_upsample_groups,
            "num_hidden_layers": self.mimi_num_hidden_layers,
            "num_attention_heads": self.mimi_num_attention_heads,
            "num_key_value_heads": self.mimi_num_key_value_heads,
            "sliding_window": self.mimi_sliding_window,
            "codebook_dim": self.mimi_codebook_dim,
            "use_cache": False,
            "sampling_rate": self.sampling_rate,
        }

        depth_dict_config = {
            "hidden_size": self.depth_hidden_size,
            "num_hidden_layers": self.depth_num_hidden_layers,
            "max_position_embeddings": self.depth_max_position_embeddings,
            "num_attention_heads": self.depth_num_attention_heads,
            "ffn_dim": self.depth_ffn_dim,
            "sliding_window": self.depth_sliding_window,
        }

        config = MoshiConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            d_ff=self.intermediate_size,
            num_codebooks=self.num_codebooks,
            rms_norm_eps=self.rms_norm_eps,
            tie_word_embeddings=False,
            pad_token_id=self.pad_token_id,
            bos_token_id=self.bos_token_id,
            ffn_dim=self.ffn_dim,
            audio_encoder_config=mimi_dict_config,
            depth_decoder_config=depth_dict_config,
            attn_implementation=self.attn_implementation,
        )
        return config

    def prepare_config_and_inputs_for_common(self, batch_size=None):
        config, inputs_dict = self.prepare_config_and_inputs(batch_size)
        return config, inputs_dict


@require_torch
class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
    all_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else ()
    # training is not supported yet for Moshi
    test_resize_embeddings = False

    def setUp(self):
        self.model_tester = MoshiTester(self)

    # special case for labels
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

        if return_labels:
            inputs_dict["text_labels"] = torch.zeros(
                (self.model_tester.batch_size, self.model_tester.seq_length),
                dtype=torch.long,
                device=torch_device,
            )
        return inputs_dict

    def _get_input_ids_and_config(self, batch_size=2):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size)
        input_ids = inputs_dict.pop("input_ids").to(torch_device)
        attention_mask = inputs_dict.pop("attention_mask").to(torch_device)

        # Make sure we only return `input_ids`.
        # Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
        # There are further tests to test that audio waveforms and codes are well generated.
        inputs_dict["return_audio_waveforms"] = False
        inputs_dict["return_audio_codes"] = False
        inputs_dict["concat_unconditional_inputs"] = False

        return config, input_ids, attention_mask, inputs_dict

    def prepare_config_and_inputs_for_generate(self, batch_size=2):
        config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)

        # Make sure we only return `input_ids`.
        # Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
        # There are further tests to test that audio waveforms and codes are well generated.
        filtered_inputs_dict["return_audio_waveforms"] = False
        filtered_inputs_dict["return_audio_codes"] = False
        filtered_inputs_dict["concat_unconditional_inputs"] = False

        return config, filtered_inputs_dict

    def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
        # Overwrite because the generate method actually always uses `inputs_embeds` so `use_cache` is always `True`
        super()._check_generate_outputs(
            output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
        )

    @unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
    def test_generate_continue_from_past_key_values(self):
        pass

    @unittest.skip(
        "Moshi either needs default generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
    )
    def test_greedy_generate_dict_outputs_use_cache(self):
        pass

    @unittest.skip(
        "Moshi either needs default generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
    )
    def test_beam_search_generate_dict_outputs_use_cache(self):
        pass

    @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
    @unittest.skip(reason="Unimplemented. Relies on `test_eager_matches_sdpa_generate` to check correctness.")
    def test_eager_matches_sdpa_inference(
        self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
    ):
        pass

    @unittest.skip(reason="The Moshi model does not have support dynamic compile yet")
    @pytest.mark.torch_compile_test
    def test_sdpa_can_compile_dynamic(self):
        pass

    @pytest.mark.generate
    def test_left_padding_compatibility(self):
        # Overwrite -- Moshi needs to prepare the audio codes, and they must be padded accordingly
        config, inputs_dict = self.prepare_config_and_inputs_for_generate()
        input_ids = inputs_dict["input_ids"]
        moshi_audio_codes = inputs_dict["moshi_audio_codes"]
        user_audio_codes = inputs_dict["user_audio_codes"]

        pad_size = (input_ids.shape[0], 32)
        padding = (
            torch.ones((pad_size[0], self.model_tester.num_codebooks, 32), dtype=input_ids.dtype, device=torch_device)
            * config.audio_vocab_size
        )
        padded_moshi_audio_codes = torch.cat((padding, moshi_audio_codes), dim=2)
        padded_user_audio_codes = torch.cat((padding, user_audio_codes), dim=2)

        # the audio codes are randomly generated in `prepare_config_and_inputs_for_generate`, and they must match
        # their padded version for the test to be valid -- we need to pass both
        unpadded_custom_inputs = {"moshi_audio_codes": moshi_audio_codes, "user_audio_codes": user_audio_codes}
        padded_custom_inputs = {
            "moshi_audio_codes": padded_moshi_audio_codes,
            "user_audio_codes": padded_user_audio_codes,
        }
        super().test_left_padding_compatibility(
            unpadded_custom_inputs=unpadded_custom_inputs, padded_custom_inputs=padded_custom_inputs
        )

    @slow
    @is_flaky(max_attempts=5, description="flaky on some models.")
    def test_eager_matches_sdpa_generate(self):
        """Overwritten -- mochi has custom inputs and custom output checks"""

        max_new_tokens = 5

        for model_class in self.all_generative_model_classes:
            if not model_class._supports_sdpa:
                self.skipTest(f"{model_class.__name__} does not support SDPA")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            dummy_input = inputs_dict[model_class.main_input_name]
            if dummy_input.dtype in [torch.float32, torch.bfloat16]:
                dummy_input = dummy_input.to(torch.float16)

            inputs_dict[model_class.main_input_name] = dummy_input

            # make sure that all models have enough positions for generation
            if hasattr(config, "max_position_embeddings"):
                config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)

                model_sdpa = model_class.from_pretrained(
                    tmpdirname,
                    dtype=torch.float16,
                ).to(torch_device)

                self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

                model_eager = model_class.from_pretrained(
                    tmpdirname,
                    dtype=torch.float16,
                    attn_implementation="eager",
                ).to(torch_device)

                self.assertTrue(model_eager.config._attn_implementation == "eager")

                for name, submodule in model_eager.named_modules():
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
                        raise ValueError("The eager model should not have SDPA attention layers")

                has_sdpa = False
                for name, submodule in model_sdpa.named_modules():
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
                        has_sdpa = True
                        break
                if not has_sdpa:
                    raise ValueError("The SDPA model should have SDPA attention layers")

                # Just test that a large cache works as expected
                res_eager = model_eager.generate(
                    **inputs_dict,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    depth_decoder_do_sample=False,
                )

                res_sdpa = model_sdpa.generate(
                    **inputs_dict,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    depth_decoder_do_sample=False,
                )

                torch.testing.assert_close(res_eager.sequences, res_sdpa.sequences)
                torch.testing.assert_close(res_eager.audio_sequences, res_sdpa.audio_sequences)

    @pytest.mark.generate
    def test_generate_without_input_ids(self):
        config, _, _, _ = self._get_input_ids_and_config()

        for model_class in self.all_generative_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()

            output_ids_generate = model.generate(
                do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
            )
            print(output_ids_generate)
            self.assertIsNotNone(output_ids_generate)

    @unittest.skip(reason="The audio encoder has no gradients.")
    def test_training_gradient_checkpointing(self):
        pass

    @unittest.skip(reason="The audio encoder has no gradients.")
    def test_training_gradient_checkpointing_use_reentrant(self):
        pass

    @unittest.skip(reason="The audio encoder has no gradients.")
    def test_training_gradient_checkpointing_use_reentrant_false(self):
        pass

    def test_generate_from_input_values(self):
        for model_class in self.all_generative_model_classes:
            config, input_ids, _, _ = self._get_input_ids_and_config()

            model = model_class(config).to(torch_device).eval()

            input_values_length = int(
                self.model_tester.seq_length * config.sampling_rate / config.audio_encoder_config.frame_rate
            )

            user_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length))
            moshi_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length))

            user_audio_codes = model.audio_encoder.encode(user_input_values, num_quantizers=model.num_codebooks)[0]
            moshi_audio_codes = model.audio_encoder.encode(moshi_input_values, num_quantizers=model.num_codebooks)[0]

            outputs_from_audio_codes = model.generate(
                input_ids, max_new_tokens=5, user_audio_codes=user_audio_codes, moshi_audio_codes=moshi_audio_codes
            )

            outputs_from_audio_values = model.generate(
                input_ids, max_new_tokens=5, user_input_values=user_input_values, moshi_input_values=moshi_input_values
            )

            self.assertTrue((outputs_from_audio_values.sequences == outputs_from_audio_codes.sequences).all())
            self.assertTrue(
                torch.allclose(outputs_from_audio_codes.audio_sequences, outputs_from_audio_values.audio_sequences)
            )

    def test_generate_depth_decoder_kwargs(self):
        # test sampling and beam search
        for model_class in self.all_generative_model_classes:
            config, input_ids, _, input_dict = self._get_input_ids_and_config()

            model = model_class(config).to(torch_device).eval()

            model.generate(input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True)

            model.generate(
                input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True, depth_decoder_num_beams=5
            )

    def test_generate_from_unconditional(self):
        # test sampling and beam search
        for model_class in self.all_generative_model_classes:
            config, input_ids, _, input_dict = self._get_input_ids_and_config()

            model = model_class(config).to(torch_device).eval()

            # check bs>1
            model.generate(
                **model.get_unconditional_inputs(num_samples=4), max_new_tokens=5, concat_unconditional_inputs=False
            )

            # check same results from unconditional or no inputs
            outputs_from_unconditional = model.generate(
                **model.get_unconditional_inputs(num_samples=1), max_new_tokens=5, concat_unconditional_inputs=False
            )
            outputs_from_none = model.generate(max_new_tokens=5)

            self.assertTrue((outputs_from_unconditional.sequences == outputs_from_none.sequences).all())
            self.assertTrue(
                torch.allclose(outputs_from_unconditional.audio_sequences, outputs_from_none.audio_sequences)
            )

    @unittest.skip(reason="Compile not yet supported because in Moshi models")
    def test_sdpa_can_dispatch_on_flash(self):
        pass

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_cpu_offload(self):
        pass

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_disk_offload_bin(self):
        pass

    @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
    def test_disk_offload_safetensors(self):
        pass

    @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
    def test_generate_continue_from_inputs_embeds(self):
        pass

    @is_flaky(max_attempts=5, description="flaky on some models.")
    def test_save_load(self):
        super().test_save_load()

    @pytest.mark.generate
    @unittest.skip(reason="Moshi requires setting `model.generated_audio_codes` in generate() before preparing inputs")
    def test_prepare_inputs_for_generation_kwargs_forwards(self):
        # If in the future `model.generated_audio_codes` is not required, this test can be re-enabled
        super().test_prepare_inputs_for_generation_kwargs_forwards(
            last_hidden_state=torch.randn(2, 3, 32), kwargs_depth_decoder={}
        )

    @unittest.skip(reason="Moshi has no separate base model without a head.")
    def test_model_base_model_prefix(self):
        pass


def place_dict_on_device(dict_to_place, device):
    for key in dict_to_place:
        if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor):
            dict_to_place[key] = dict_to_place[key].to(device)
    return dict_to_place


@require_torch
class MoshiIntegrationTests(unittest.TestCase):
    @cached_property
    def feature_extractor(self):
        return AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")

    @cached_property
    def tokenizer(self):
        return AutoTokenizer.from_pretrained("kmhf/hf-moshiko")

    def _load_datasample(self):
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        dataset = ds.cast_column("audio", Audio(sampling_rate=self.feature_extractor.sampling_rate))
        # automatic decoding with librispeech
        speech_sample = dataset.sort("id")[0]["audio"]["array"]
        return speech_sample

    @slow
    def test_moshika_conditional_greedy(self):
        model = MoshiForConditionalGeneration.from_pretrained(
            "kmhf/hf-moshika", dtype=torch.float16, device_map="auto"
        )
        inputs = self.feature_extractor(self._load_datasample(), return_tensors="pt").to(
            device=torch_device, dtype=torch.float16
        )

        user_audio_codes = model.audio_encoder.encode(**inputs, num_quantizers=8).audio_codes

        input_ids = self.tokenizer.encode("<pad><pad><pad><pad><unk> Hello,<pad><unk>", return_tensors="pt").to(
            torch_device
        )

        # fmt: off
        moshi_audio_codes = [[[1049, 127, 1880, 972, 972, 1156, 1913, 415, 1933],
                              [1700, 243, 91, 91, 91, 745, 1478, 638, 57],
                              [1626, 457, 457, 457, 457, 1839, 200, 2011, 1142],
                              [546, 290, 390, 390, 290, 1408, 1812, 1187, 1911],
                              [306, 306, 1314, 1314, 1314, 759, 796, 854, 1466],
                              [1443, 1443, 1030, 317, 347, 1178, 613, 1576, 2023],
                              [1871, 428, 1433, 1433, 1978, 1405, 1755, 820, 610],
                              [2008, 1744, 1511, 568, 1533, 550, 237, 1412, 1401]]]
        # fmt: on

        moshi_audio_codes = torch.tensor(moshi_audio_codes, device=torch_device)
        user_audio_codes = user_audio_codes[:, :, : moshi_audio_codes.shape[-1]]

        model_outputs = model.generate(
            user_audio_codes=user_audio_codes,
            moshi_audio_codes=moshi_audio_codes,
            input_ids=input_ids,
            do_sample=False,
            depth_decoder_do_sample=False,
            return_audio_codes=True,
            max_new_tokens=2,
        )

        expected_text_token = 452
        expected_audio_tokens = [916, 1396, 1238, 579, 1105, 914, 1257, 810]  # fmt: skip

        self.assertTrue(expected_text_token == model_outputs.sequences[0, -2].item())
        self.assertTrue(expected_audio_tokens == model_outputs.audio_codes[0, :, -1].tolist())

    @slow
    def test_moshiko_greedy_unconditional_fp16_eager(self):
        model = MoshiForConditionalGeneration.from_pretrained(
            "kmhf/hf-moshiko", dtype=torch.float16, device_map="auto"
        )
        some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]]  # fmt: skip

        model_outputs = model.generate(
            do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
        )

        # eager equivalence is not as strict as sdpa.
        self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())

    @slow
    def test_moshiko_greedy_unconditional_fp32(self):
        model = MoshiForConditionalGeneration.from_pretrained(
            "kmhf/hf-moshiko", dtype=torch.float32, device_map="auto"
        )

        expected_audio_codesum = 72065
        expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3]  # fmt: skip
        some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]]  # fmt: skip

        model_outputs = model.generate(
            do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
        )

        # make sure audio encoded codes are correct
        audio_code_sums = model_outputs.audio_codes.sum().item()
        self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))

        self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
        self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())

    @slow
    @require_torch_fp16
    def test_moshiko_greedy_unconditional_fp16(self):
        model = MoshiForConditionalGeneration.from_pretrained(
            "kmhf/hf-moshiko", dtype=torch.float16, device_map="auto"
        )

        expected_audio_codesum = 72065
        expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3]  # fmt: skip
        some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]]  # fmt: skip

        model_outputs = model.generate(
            do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
        )

        # make sure audio encoded codes are correct
        audio_code_sums = model_outputs.audio_codes.sum().item()
        self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))

        self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
        self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())

    @slow
    @require_torch_fp16
    def test_moshika_greedy_unconditional_fp16(self):
        model = MoshiForConditionalGeneration.from_pretrained(
            "kmhf/hf-moshika", dtype=torch.float16, device_map="auto"
        )

        expected_audio_codesum = 72932
        expected_text_tokens = [3, 3, 3, 0, 667, 263, 3, 3, 0, 705]  # fmt: skip
        some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 347], [1871, 428], [2008, 2008]]  # fmt: skip

        model_outputs = model.generate(
            do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
        )

        # make sure audio encoded codes are correct
        audio_code_sums = model_outputs.audio_codes.sum().item()
        self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= 2048)

        self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
        self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
