# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Encodec model."""

import copy
import inspect
import os
import unittest

import numpy as np
from datasets import Audio, load_dataset
from parameterized import parameterized

from transformers import AutoProcessor, EncodecConfig
from transformers.testing_utils import (
    is_torch_available,
    require_torch,
    slow,
    torch_device,
)

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
    import torch

    from transformers import EncodecFeatureExtractor, EncodecModel


def prepare_inputs_dict(
    config,
    input_ids=None,
    input_values=None,
    decoder_input_ids=None,
    attention_mask=None,
    decoder_attention_mask=None,
):
    if input_ids is not None:
        encoder_dict = {"input_ids": input_ids}
    else:
        encoder_dict = {"input_values": input_values}

    decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}

    return {**encoder_dict, **decoder_dict}


@require_torch
class EncodecModelTester:
    def __init__(
        self,
        parent,
        # `batch_size` needs to be an even number if the model has some outputs with batch dim != 0.
        batch_size=12,
        num_channels=2,
        is_training=False,
        intermediate_size=40,
        hidden_size=32,
        num_filters=8,
        num_residual_layers=1,
        upsampling_ratios=[8, 4],
        num_lstm_layers=1,
        codebook_size=64,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.num_channels = num_channels
        self.is_training = is_training
        self.intermediate_size = intermediate_size
        self.hidden_size = hidden_size
        self.num_filters = num_filters
        self.num_residual_layers = num_residual_layers
        self.upsampling_ratios = upsampling_ratios
        self.num_lstm_layers = num_lstm_layers
        self.codebook_size = codebook_size

    def prepare_config_and_inputs(self):
        input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
        config = self.get_config()
        inputs_dict = {"input_values": input_values}
        return config, inputs_dict

    def prepare_config_and_inputs_for_common(self):
        config, inputs_dict = self.prepare_config_and_inputs()
        return config, inputs_dict

    def prepare_config_and_inputs_for_model_class(self, model_class):
        config, inputs_dict = self.prepare_config_and_inputs()
        inputs_dict["audio_codes"] = ids_tensor([1, self.batch_size, 1, self.num_channels], self.codebook_size).type(
            torch.int32
        )
        inputs_dict["audio_scales"] = [None]

        return config, inputs_dict

    def prepare_config_and_inputs_for_normalization(self):
        input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
        config = self.get_config()
        config.normalize = True

        processor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sampling_rate)
        input_values = input_values.tolist()
        inputs_dict = processor(
            input_values, sampling_rate=config.sampling_rate, padding=True, return_tensors="pt"
        ).to(torch_device)

        return config, inputs_dict

    def get_config(self):
        return EncodecConfig(
            audio_channels=self.num_channels,
            chunk_in_sec=None,
            hidden_size=self.hidden_size,
            num_filters=self.num_filters,
            num_residual_layers=self.num_residual_layers,
            upsampling_ratios=self.upsampling_ratios,
            num_lstm_layers=self.num_lstm_layers,
            codebook_size=self.codebook_size,
        )

    def create_and_check_model_forward(self, config, inputs_dict):
        model = EncodecModel(config=config).to(torch_device).eval()
        result = model(**inputs_dict)
        self.parent.assertEqual(
            result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
        )


@require_torch
class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
    all_model_classes = (EncodecModel,) if is_torch_available() else ()
    is_encoder_decoder = True

    test_resize_embeddings = False
    pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}

    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        # model does not have attention and does not support returning hidden states
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
        if "output_attentions" in inputs_dict:
            inputs_dict.pop("output_attentions")
        if "output_hidden_states" in inputs_dict:
            inputs_dict.pop("output_hidden_states")
        return inputs_dict

    def setUp(self):
        self.model_tester = EncodecModelTester(self)
        self.config_tester = ConfigTester(
            self, config_class=EncodecConfig, hidden_size=37, common_properties=[], has_text_modality=False
        )

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_model_forward(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_model_forward(*config_and_inputs)

    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.forward)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            expected_arg_names = ["input_values", "padding_mask", "bandwidth"]
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)

    @unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
    def test_inputs_embeds(self):
        pass

    @unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
    def test_model_get_set_embeddings(self):
        pass

    @unittest.skip(
        reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
    )
    def test_retain_grad_hidden_states_attentions(self):
        pass

    @unittest.skip(
        reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
    )
    def test_attention_outputs(self):
        pass

    def test_feed_forward_chunking(self):
        (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
        # original_config.norm_type = "time_group_norm"
        for model_class in self.all_model_classes:
            torch.manual_seed(0)
            config = copy.deepcopy(original_config)
            config.chunk_length_s = None
            config.overlap = None
            config.sampling_rate = 20

            model = model_class(config)
            model.to(torch_device)
            model.eval()
            inputs = self._prepare_for_class(inputs_dict, model_class)
            inputs["input_values"] = inputs["input_values"].repeat(1, 1, 10)

            hidden_states_no_chunk = model(**inputs)[1]

            torch.manual_seed(0)
            config.chunk_length_s = 2
            config.overlap = 0
            config.sampling_rate = 20

            model = model_class(config)
            model.to(torch_device)
            model.eval()

            hidden_states_with_chunk = model(**inputs)[1]
            torch.testing.assert_close(hidden_states_no_chunk, hidden_states_with_chunk, rtol=1e-1, atol=1e-2)

    @unittest.skip(
        reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
    )
    def test_hidden_states_output(self):
        pass

    def test_determinism(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_determinism(first, second):
            # outputs are not tensors but list (since each sequence don't have the same frame_length)
            out_1 = first.cpu().numpy()
            out_2 = second.cpu().numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]

            if isinstance(first, tuple) and isinstance(second, tuple):
                for tensor1, tensor2 in zip(first, second):
                    check_determinism(tensor1, tensor2)
            else:
                check_determinism(first, second)

    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def assert_nested_tensors_close(a, b):
            if isinstance(a, (tuple, list)) and isinstance(b, (tuple, list)):
                assert len(a) == len(b), f"Length mismatch: {len(a)} vs {len(b)}"
                for i, (x, y) in enumerate(zip(a, b)):
                    assert_nested_tensors_close(x, y)
            elif torch.is_tensor(a) and torch.is_tensor(b):
                a_clean = set_nan_tensor_to_zero(a)
                b_clean = set_nan_tensor_to_zero(b)
                assert torch.allclose(a_clean, b_clean, atol=1e-5), (
                    "Tuple and dict output are not equal. Difference:"
                    f" Max diff: {torch.max(torch.abs(a_clean - b_clean))}. "
                    f"Tuple has nan: {torch.isnan(a).any()} and inf: {torch.isinf(a)}. "
                    f"Dict has nan: {torch.isnan(b).any()} and inf: {torch.isinf(b)}."
                )
            else:
                raise ValueError(f"Mismatch between {a} vs {b}")

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            with torch.no_grad():
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
                dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)

            self.assertTrue(isinstance(tuple_output, tuple))
            self.assertTrue(isinstance(dict_output, dict))
            # cast dict_output.values() to list as it is a odict_values object
            assert_nested_tensors_close(tuple_output, list(dict_output.values()))

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

    def test_identity_shortcut(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
        config.use_conv_shortcut = False
        self.model_tester.create_and_check_model_forward(config, inputs_dict)

    def test_model_forward_with_normalization(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_normalization()
        self.model_tester.create_and_check_model_forward(config, inputs_dict)


def normalize(arr):
    norm = np.linalg.norm(arr)
    normalized_arr = arr / norm
    return normalized_arr


def compute_rmse(arr1, arr2):
    arr1_np = arr1.cpu().numpy().squeeze()
    arr2_np = arr2.cpu().numpy().squeeze()
    max_length = min(arr1.shape[-1], arr2.shape[-1])
    arr1_np = arr1_np[..., :max_length]
    arr2_np = arr2_np[..., :max_length]
    arr1_normalized = normalize(arr1_np)
    arr2_normalized = normalize(arr2_np)
    return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())


"""
Integration tests for the Encodec model.

Code for expected output can be found below:
- test_integration: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec-py
- test_batch: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec_batch-py

"""

# fmt: off
# first key is model_id from hub, second key is bandwidth
# -- test_integration
EXPECTED_ENCODER_CODES = {
    "facebook/encodec_24khz": {
        "1.5": torch.tensor([[[  62,  835,  835,  835,  835,  835,  835,  835,  408,  408],
         [1007, 1007, 1007,  544,  424,  424, 1007,  424,  302,  424]]]),
        "3.0": torch.tensor(
            [
                [
                    [62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
                    [1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
                    [786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
                    [741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
                ],
            ]
        ),
        "6.0": torch.tensor(
            [
                [
                    [62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
                    [1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
                    [786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
                    [741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
                    [528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
                    [1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
                    [896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
                    [899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],

                ],
            ]
        ),
        "12.0": torch.tensor(
            [
                [
                    [62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
                    [1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
                    [786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
                    [741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
                    [528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
                    [1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
                    [896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
                    [899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
                    [827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
                    [963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
                    [1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
                    [413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
                    [971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
                    [848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
                    [420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
                    [587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
                ],
            ]
        ),
        "24.0": torch.tensor(
            [
                [
                    [62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
                    [1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
                    [786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
                    [741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
                    [528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
                    [1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
                    [896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
                    [899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
                    [827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
                    [963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
                    [1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
                    [413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
                    [971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
                    [848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
                    [420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
                    [587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
                    [270, 160, 26, 131, 597, 506, 670, 637, 248, 160],
                    [ 15, 215, 134, 69, 215, 155, 1012, 1009, 260, 417],
                    [580, 561, 686, 896, 497, 637, 580, 245, 896, 264],
                    [511, 239, 560, 691, 571, 627, 571, 571, 258, 619],
                    [591, 942, 591, 251, 250, 250, 857, 486, 295, 295],
                    [565, 546, 654, 301, 301, 623, 639, 568, 565, 282],
                    [539, 317, 639, 539, 651, 539, 538, 640, 615, 615],
                    [637, 556, 637, 582, 640, 515, 515, 632, 254, 613],
                    [305, 643, 500, 550, 522, 500, 550, 561, 522, 305],
                    [954, 456, 584, 755, 505, 782, 661, 671, 497, 505],
                    [577, 464, 637, 647, 552, 552, 624, 647, 624, 647],
                    [728, 748, 931, 608, 538, 1015, 294, 294, 666, 538],
                    [602, 535, 666, 665, 655, 979, 574, 535, 571, 781],
                    [321, 620, 557, 566, 511, 910, 672, 623, 853, 674],
                    [621, 556, 947, 474, 610, 752, 1002, 597, 474, 474],
                    [605, 948, 657, 588, 485, 633, 459, 968, 939, 325],
                ],
            ]
        ),
        },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor([[[214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
            [989, 989, 611,  77,  77, 989, 976, 976, 976,  77]]]),
        "6.0": torch.tensor([[[ 214,  214,  214,  214,  214,  118,  214,  214,  214,  214],
            [ 989,  989,  611,   77,   77,  989,  976,  976,  976,   77],
            [ 977, 1009,  538,  925,  925,  977, 1022, 1022, 1022,  925],
            [ 376, 1012, 1023,  725,  725, 1023,  376,  962,  376,  847]]]),
        "12.0": torch.tensor([[[ 214,  214,  214,  214,  214,  118,  214,  214,  214,  214],
            [ 989,  989,  611,   77,   77,  989,  976,  976,  976,   77],
            [ 977, 1009,  538,  925,  925,  977, 1022, 1022, 1022,  925],
            [ 376, 1012, 1023,  725,  725, 1023,  376,  962,  376,  847],
            [ 979, 1012,  323,  695, 1018, 1023,  979, 1023,  979,  650],
            [ 945,  762,  528,  865,  824,  945,  945,  945,  957,  957],
            [ 904,  973, 1014,  681,  582, 1014, 1014, 1014, 1014,  681],
            [ 229,  392,  796,  392,  977, 1017,  250, 1017,  250, 1017]]]),
        "24.0": torch.tensor([[[ 214,  214,  214,  214,  214,  118,  214,  214,  214,  214],
            [ 989,  989,  611,   77,   77,  989,  976,  976,  976,   77],
            [ 977, 1009,  538,  925,  925,  977, 1022, 1022, 1022,  925],
            [ 376, 1012, 1023,  725,  725, 1023,  376,  962,  376,  847],
            [ 979, 1012,  323,  695, 1018, 1023,  979, 1023,  979,  650],
            [ 945,  762,  528,  865,  824,  945,  945,  945,  957,  957],
            [ 904,  973, 1014,  681,  582, 1014, 1014, 1014, 1014,  681],
            [ 229,  392,  796,  392,  977, 1017,  250, 1017,  250, 1017],
            [ 902,  436,  935, 1011, 1023, 1023, 1023,  154, 1023,  392],
            [ 982,  878,  961,  832,  629,  431,  919,  629,  919,  792],
            [ 727,  727,  401,  727,  979,  587,  727,  487,  413,  201],
            [ 928,  924,  965,  934,  840,  480,  924,  920,  924,  486],
            [  10,  625,  712,  552,  712,  259,  394,  131,  726,  516],
            [ 882, 1022,   32,  524,  267,  861,  974,  882,  108,  521],
            [ 304,  841,  306,  415,   69,  376,  928,  510,  381,  104],
            [   0,    0,    0,  484,   83,    0,  307,  262,    0,    0]]])
    }
}
EXPECTED_ENCODER_SCALES = {
    "facebook/encodec_24khz": {
        "1.5": None,
        "3.0": None,
        "6.0": None,
        "12.0": None,
        "24.0": None
    },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
        4.483359e-02, 1.000000e-08]),
        "6.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
        4.483359e-02, 1.000000e-08]),
        "12.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
        4.483359e-02, 1.000000e-08]),
        "24.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
        4.483359e-02, 1.000000e-08])
    }
}
EXPECTED_DECODER_OUTPUTS = {
    "facebook/encodec_24khz": {
        "1.5": torch.tensor(
            [[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0005, 0.0001, -0.0015, -0.0007, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0006, -0.0009, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, 0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0012, -0.0011, -0.0012, -0.0013, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
        ),
        "3.0": torch.tensor(
            [[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0006, 0.0002, -0.0015, -0.0008, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0005, -0.0008, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, -0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0013, -0.0011, -0.0013, -0.0014, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
        ),
        "6.0": torch.tensor(
            [[ 0.0004, -0.0001, 0.0001, -0.0003, 0.0004, 0.0003, 0.0000, 0.0001, 0.0007, 0.0002, -0.0013, -0.0007, -0.0002, -0.0015, -0.0001, 0.0014, 0.0014, 0.0011, 0.0010, 0.0010, 0.0009, 0.0004, 0.0000, 0.0000, 0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0001, -0.0002, -0.0002, 0.0002, 0.0005, 0.0009, 0.0010, 0.0008, 0.0007, 0.0002, -0.0003, -0.0004, -0.0008, -0.0008, 0.0000, 0.0006, 0.0010, 0.0012, 0.0012, 0.0013, 0.0014]]
        ),
        "12.0": torch.tensor(
            [[ 0.0004, -0.0001, 0.0001, -0.0004, 0.0003, 0.0002, -0.0000, 0.0001, 0.0006, 0.0002, -0.0013, -0.0006, -0.0001, -0.0014, 0.0001, 0.0018, 0.0018, 0.0014, 0.0012, 0.0013, 0.0011, 0.0006, 0.0000, 0.0000, -0.0000, -0.0001, -0.0001, -0.0004, -0.0004, -0.0000, -0.0000, -0.0000, 0.0005, 0.0007, 0.0011, 0.0011, 0.0009, 0.0007, 0.0002, -0.0003, -0.0004, -0.0007, -0.0007, 0.0002, 0.0009, 0.0013, 0.0015, 0.0014, 0.0015, 0.0016]]
        ),
        "24.0": torch.tensor(
            [[ 0.0005, 0.0001, 0.0004, -0.0001, 0.0003, 0.0002, 0.0000, 0.0001, 0.0007, 0.0005, -0.0011, -0.0005, -0.0001, -0.0018, -0.0000, 0.0021, 0.0019, 0.0013, 0.0011, 0.0012, 0.0012, 0.0006, -0.0000, -0.0001, -0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0000, -0.0001, -0.0002, 0.0003, 0.0004, 0.0008, 0.0007, 0.0006, 0.0007, 0.0001, -0.0004, -0.0003, -0.0006, -0.0008, 0.0004, 0.0011, 0.0015, 0.0016, 0.0015, 0.0016, 0.0018]]
        )
    },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor(
            [
                [0.0034, 0.0028, 0.0037, 0.0041, 0.0029, 0.0022, 0.0021, 0.0020, 0.0021, 0.0023, 0.0021, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0021, 0.0023, 0.0025, 0.0022, 0.0017, 0.0015, 0.0017, 0.0020, 0.0024, 0.0031, 0.0039, 0.0045, 0.0046, 0.0042, 0.0034, 0.0027, 0.0023, 0.0022, 0.0023, 0.0024, 0.0022, 0.0023, 0.0024, 0.0027, 0.0027, 0.0027, 0.0025, 0.0024, 0.0024, 0.0026, 0.0028, 0.0027, 0.0024, 0.0022],
                [ -0.0031, -0.0027, -0.0018, -0.0017, -0.0024, -0.0029, -0.0030, -0.0026, -0.0021, -0.0018, -0.0018, -0.0019, -0.0017, -0.0014, -0.0012, -0.0010, -0.0008, -0.0004, -0.0001, -0.0004, -0.0012, -0.0015, -0.0014, -0.0013, -0.0011, -0.0005, 0.0002, 0.0007, 0.0008, 0.0004, -0.0003, -0.0010, -0.0012, -0.0011, -0.0009, -0.0009, -0.0009, -0.0008, -0.0006, -0.0005, -0.0005, -0.0005, -0.0006, -0.0008, -0.0008, -0.0006, -0.0005, -0.0007, -0.0010, -0.0012],
            ]
        ),
        "6.0": torch.tensor(
            [
                [0.0052, 0.0049, 0.0057, 0.0058, 0.0048, 0.0043, 0.0042, 0.0041, 0.0041, 0.0042, 0.0040, 0.0038, 0.0038, 0.0038, 0.0037, 0.0037, 0.0037, 0.0037, 0.0038, 0.0037, 0.0035, 0.0034, 0.0036, 0.0039, 0.0043, 0.0047, 0.0053, 0.0057, 0.0057, 0.0055, 0.0050, 0.0046, 0.0043, 0.0041, 0.0042, 0.0042, 0.0041, 0.0041, 0.0042, 0.0043, 0.0043, 0.0043, 0.0041, 0.0040, 0.0040, 0.0041, 0.0042, 0.0042, 0.0040, 0.0039],
                [ 0.0001, 0.0006, 0.0013, 0.0011, 0.0005, 0.0001, -0.0001, 0.0001, 0.0003, 0.0005, 0.0005, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.0010, 0.0013, 0.0015, 0.0014, 0.0010, 0.0008, 0.0010, 0.0012, 0.0015, 0.0019, 0.0023, 0.0026, 0.0026, 0.0024, 0.0020, 0.0016, 0.0013, 0.0013, 0.0014, 0.0015, 0.0015, 0.0016, 0.0017, 0.0017, 0.0017, 0.0016, 0.0015, 0.0013, 0.0013, 0.0013, 0.0013, 0.0012, 0.0010, 0.0009],
            ]
        ),
        "12.0": torch.tensor(
            [
                [0.0014, 0.0012, 0.0021, 0.0024, 0.0017, 0.0013, 0.0012, 0.0011, 0.0011, 0.0012, 0.0011, 0.0010, 0.0009, 0.0009, 0.0008, 0.0008, 0.0009, 0.0010, 0.0012, 0.0012, 0.0009, 0.0008, 0.0010, 0.0013, 0.0017, 0.0024, 0.0031, 0.0036, 0.0036, 0.0033, 0.0028, 0.0023, 0.0020, 0.0020, 0.0022, 0.0022, 0.0022, 0.0022, 0.0023, 0.0024, 0.0024, 0.0023, 0.0021, 0.0021, 0.0021, 0.0023, 0.0024, 0.0024, 0.0022, 0.0021],
                [ -0.0034, -0.0029, -0.0020, -0.0020, -0.0024, -0.0027, -0.0030, -0.0030, -0.0028, -0.0025, -0.0025, -0.0025, -0.0025, -0.0025, -0.0023, -0.0022, -0.0020, -0.0017, -0.0013, -0.0014, -0.0017, -0.0019, -0.0018, -0.0015, -0.0011, -0.0006, 0.0000, 0.0005, 0.0005, 0.0002, -0.0003, -0.0008, -0.0010, -0.0009, -0.0007, -0.0006, -0.0006, -0.0005, -0.0005, -0.0005, -0.0005, -0.0007, -0.0008, -0.0009, -0.0009, -0.0008, -0.0007, -0.0008, -0.0010, -0.0011],
            ]
        ),
        "24.0": torch.tensor(
            [
                [ 0.0010, 0.0008, 0.0018, 0.0021, 0.0014, 0.0011, 0.0009, 0.0007, 0.0006, 0.0006, 0.0005, 0.0003, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0002, 0.0002, 0.0001, -0.0002, -0.0004, -0.0003, 0.0000, 0.0005, 0.0011, 0.0018, 0.0022, 0.0022, 0.0018, 0.0012, 0.0007, 0.0004, 0.0003, 0.0004, 0.0006, 0.0006, 0.0007, 0.0007, 0.0009, 0.0008, 0.0007, 0.0005, 0.0004, 0.0004, 0.0006, 0.0007, 0.0007, 0.0005, 0.0004],
                [-0.0039, -0.0035, -0.0027, -0.0026, -0.0028, -0.0031, -0.0035, -0.0035, -0.0034, -0.0033, -0.0032, -0.0032, -0.0031, -0.0031, -0.0029, -0.0028, -0.0026, -0.0024, -0.0021, -0.0021, -0.0024, -0.0025, -0.0024, -0.0021, -0.0017, -0.0011, -0.0006, -0.0002, -0.0002, -0.0004, -0.0009, -0.0013, -0.0015, -0.0015, -0.0014, -0.0013, -0.0012, -0.0011, -0.0010, -0.0010, -0.0011, -0.0012, -0.0014, -0.0015, -0.0015, -0.0014, -0.0013, -0.0014, -0.0016, -0.0017],
            ]
        )
    }
}
EXPECTED_CODEC_ERROR = {
    "facebook/encodec_24khz": {
        "1.5": 0.0022229827009141445,
        "3.0": 0.001862662611529231,
        "6.0": 0.0015231302240863442,
        "12.0": 0.0013,
        "24.0": 0.0012,
    },
    "facebook/encodec_48khz": {
        "3.0": 0.000840399123262614,
        "6.0": 0.0006692984024994075,
        "12.0": 0.0005328940460458398,
        "24.0": 0.0004473362350836396,
    }
}
# -- test_batch
EXPECTED_ENCODER_CODES_BATCH = {
    "facebook/encodec_24khz": {
        "1.5": torch.tensor(
            [
                [
                    [62, 106, 475, 475, 404, 404, 475, 404, 404, 475, 475, 404, 475, 475, 475, 835, 475, 475, 835, 835,
                     106, 106, 738, 106, 738, 106, 408, 408, 738, 408, 408, 408, 738, 408, 408, 408, 408, 738, 408,
                     1017, 604, 64, 303, 394, 5, 570, 991, 570, 969, 814],
                    [424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007, 729, 1007, 961, 1007, 1007, 961, 969, 1007,
                     1007, 424, 518, 1007, 544, 1007, 518, 913, 424, 424, 544, 424, 518, 518, 518, 302, 424, 424, 424,
                     544, 424, 114, 200, 787, 931, 343, 434, 315, 487, 872, 769, 463],

                ],
                [
                    [835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 408, 835, 738, 408, 408, 408, 408, 408,
                     408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 408,
                     339, 834, 819, 875, 957, 670, 811, 670, 237, 53],
                    [857, 857, 544, 518, 937, 518, 913, 913, 518, 913, 518, 913, 518, 518, 544, 424, 424, 518, 424, 424,
                     424, 544, 424, 424, 424, 518, 424, 518, 518, 937, 544, 424, 518, 302, 518, 424, 424, 518, 424, 424,
                     913, 857, 841, 363, 463, 78, 176, 645, 255, 571],

                ],

            ]

        ),
        "3.0": torch.tensor(
            [
                [
                    [62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
                    [424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
                    [212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
                    [956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],

                ],
                [
                    [835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
                    [857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
                    [705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
                    [366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],

                ],
            ]
        ),
        "6.0": torch.tensor(
            [
                [
                    [62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
                    [424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
                    [212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
                    [956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
                    [712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
                    [939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
                    [853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
                    [899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],

                ],
                [
                    [835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
                    [857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
                    [705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
                    [366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
                    [373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
                    [293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
                    [800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
                    [916, 237, 828, 701, 518, 835, 948, 315, 948, 315],

                ],
            ]
        ),
        "12.0": torch.tensor(
            [
                [
                    [62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
                    [424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
                    [212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
                    [956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
                    [712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
                    [939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
                    [853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
                    [899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
                    [817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
                    [953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
                    [623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
                    [413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
                    [817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
                    [792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
                    [622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
                    [928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
                ],
                [
                    [835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
                    [857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
                    [705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
                    [366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
                    [373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
                    [293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
                    [800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
                    [916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
                    [420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
                    [736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
                    [1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
                    [735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
                    [898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
                    [ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
                    [622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
                    [1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],

                ],
            ]
        ),
        "24.0": torch.tensor(
            [
                [
                    [62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
                    [424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
                    [212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
                    [956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
                    [712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
                    [939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
                    [853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
                    [899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
                    [817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
                    [953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
                    [623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
                    [413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
                    [817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
                    [792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
                    [622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
                    [928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
                    [939, 637, 861, 506, 861, 61, 475, 264, 1019, 260],
                    [166, 215, 69, 69, 890, 69, 284, 828, 396, 180],
                    [561, 896, 841, 144, 580, 659, 886, 514, 686, 451],
                    [691, 691, 239, 735, 62, 287, 383, 972, 550, 505],
                    [451, 811, 238, 251, 250, 841, 734, 329, 551, 846],
                    [313, 601, 494, 763, 811, 565, 748, 441, 601, 480],
                    [653, 242, 630, 572, 701, 973, 632, 374, 561, 521],
                    [984, 987, 419, 454, 386, 507, 532, 636, 515, 671],
                    [647, 550, 515, 292, 876, 1011, 719, 549, 691, 911],
                    [683, 536, 656, 603, 698, 867, 987, 857, 886, 491],
                    [444, 937, 826, 555, 585, 710, 466, 852, 655, 591],
                    [658, 952, 903, 508, 739, 596, 420, 721, 464, 306],
                    [665, 334, 765, 532, 618, 278, 836, 838, 517, 597],
                    [613, 674, 596, 904, 987, 977, 938, 615, 672, 776],
                    [689, 386, 749, 658, 250, 869, 957, 806, 750, 659],
                    [652, 509, 910, 826, 566, 622, 951, 696, 900, 895],
                ],
                [
                    [835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
                    [857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
                    [705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
                    [366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
                    [373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
                    [293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
                    [800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
                    [916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
                    [420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
                    [736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
                    [1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
                    [735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
                    [898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
                    [ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
                    [622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
                    [1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],
                    [248, 248, 637, 248, 977, 506, 546, 270, 670, 506],
                    [547, 447, 15, 134, 1009, 215, 134, 396, 260, 160],
                    [635, 497, 686, 765, 264, 497, 244, 675, 624, 656],
                    [864, 571, 616, 511, 588, 781, 525, 258, 674, 503],
                    [449, 757, 857, 451, 658, 486, 299, 299, 251, 596],
                    [809, 628, 255, 568, 623, 301, 639, 546, 617, 623],
                    [551, 497, 908, 539, 661, 710, 640, 539, 646, 315],
                    [689, 507, 875, 515, 613, 637, 527, 515, 662, 637],
                    [983, 686, 456, 768, 601, 561, 768, 653, 500, 688],
                    [493, 566, 664, 782, 683, 683, 721, 603, 323, 497],
                    [1015, 552, 411, 423, 607, 646, 687, 1018, 689, 607],
                    [516, 293, 471, 294, 293, 294, 608, 538, 803, 717],
                    [974, 994, 952, 637, 637, 927, 535, 571, 602, 535],
                    [776, 789, 476, 944, 652, 959, 589, 679, 321, 623],
                    [776, 931, 720, 1009, 676, 731, 386, 676, 701, 676],
                    [684, 543, 716, 392, 661, 517, 792, 588, 922, 676],
                ],
            ]
        )
    },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor([[[790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
            [989, 989,  77, 546, 989, 546, 989, 160, 546, 989]],
            [[214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
            [289, 289, 989, 764, 289, 289, 882, 882, 882, 882]]]),
        "6.0": torch.tensor([[[ 790,  790,  790,  214,  214,  214,  799,  214,  214,  214],
            [ 989,  989,   77,  546,  989,  546,  989,  160,  546,  989],
            [ 977,  977,  977,  977,  538,  977,  977,  960,  977,  977],
            [ 376,  376,  962,  962,  607,  962,  963,  896,  962,  376]],
            [[ 214,  214,  214,  214,  214,  214,  214,  214,  214,  214],
            [ 289,  289,  989,  764,  289,  289,  882,  882,  882,  882],
            [1022, 1022,  471,  925,  821,  821,  267,  925,  925,  267],
            [ 979,  992,  914,  921,    0,    0, 1023,  963,  963, 1023]]]),
        "12.0": torch.tensor([[[ 790,  790,  790,  214,  214,  214,  799,  214,  214,  214],
            [ 989,  989,   77,  546,  989,  546,  989,  160,  546,  989],
            [ 977,  977,  977,  977,  538,  977,  977,  960,  977,  977],
            [ 376,  376,  962,  962,  607,  962,  963,  896,  962,  376],
            [ 979,  979,  979, 1012,  979, 1012,  921,    0, 1002,  695],
            [ 824, 1018,  762,  957,  824,  762,  762, 1007,  957,  336],
            [ 681,  973,  973,  452,  211,  681,  802,  679,  547,  884],
            [ 950, 1017, 1016, 1017,  986, 1017,  229,  607, 1017,  689]],

            [[ 214,  214,  214,  214,  214,  214,  214,  214,  214,  214],
            [ 289,  289,  989,  764,  289,  289,  882,  882,  882,  882],
            [1022, 1022,  471,  925,  821,  821,  267,  925,  925,  267],
            [ 979,  992,  914,  921,    0,    0, 1023,  963,  963, 1023],
            [ 403,  940,  976, 1018,  677, 1002,  979,  677,  677,  677],
            [1018,  794,  762,  444,  485,  485,  974,  548,  548, 1018],
            [ 679,  243,  679, 1005, 1005,  973, 1014, 1005, 1005, 1014],
            [ 810,   13, 1017,  537,  522,  702,  202, 1017, 1017,   15]]]),
        "24.0": torch.tensor(
            [
                [
                    [790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
                    [989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
                    [977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
                    [376, 376, 962, 962, 607, 962, 963, 896, 962, 376],
                    [979, 979, 979, 1012, 979, 1012, 921, 0, 1002, 695],
                    [824, 1018, 762, 957, 824, 762, 762, 1007, 957, 336],
                    [681, 973, 973, 452, 211, 681, 802, 679, 547, 884],
                    [950, 1017, 1016, 1017, 986, 1017, 229, 607, 1017, 689],
                    [1004, 1011, 669, 1023, 1023, 1023, 905, 297, 810, 970],
                    [982, 681, 982, 629, 662, 919, 878, 476, 629, 982],
                    [727, 727, 959, 959, 979, 959, 530, 959, 337, 961],
                    [924, 456, 924, 486, 924, 959, 102, 924, 805, 924],
                    [649, 542, 993, 993, 949, 787, 56, 886, 949, 405],
                    [864, 1022, 1022, 1022, 460, 753, 805, 309, 1022, 32],
                    [953, 0, 0, 180, 352, 10, 581, 516, 322, 452],
                    [300, 0, 1020, 307, 0, 543, 924, 627, 258, 262],
                ],
                [
                    [214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
                    [289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
                    [1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
                    [979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023],
                    [403, 940, 976, 1018, 677, 1002, 979, 677, 677, 677],
                    [1018, 794, 762, 444, 485, 485, 974, 548, 548, 1018],
                    [679, 243, 679, 1005, 1005, 973, 1014, 1005, 1005, 1014],
                    [810, 13, 1017, 537, 522, 702, 202, 1017, 1017, 15],
                    [728, 252, 970, 984, 971, 950, 673, 902, 1011, 810],
                    [332, 1014, 476, 854, 1014, 861, 332, 411, 411, 408],
                    [959, 727, 611, 979, 611, 727, 999, 497, 821, 0],
                    [995, 698, 924, 688, 102, 510, 924, 970, 344, 961],
                    [ 81, 516, 847, 924, 10, 240, 1005, 726, 993, 378],
                    [467, 496, 484, 496, 456, 1022, 337, 600, 456, 1022],
                    [789, 65, 937, 976, 159, 953, 343, 764, 179, 159],
                    [ 10, 790, 483, 10, 1020, 352, 848, 333, 83, 848],
                ],
            ]
        )
    }
}
EXPECTED_ENCODER_SCALES_BATCH = {
    "facebook/encodec_24khz": {
        "1.5": None,
        "3.0": None,
        "6.0": None,
        "12.0": None,
        "24.0": None
    },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor([[[1.027247e-01],
            [7.877284e-02]],
            [[1.014922e-01],
            [8.696266e-02]],
            [[6.308002e-02],
            [7.748771e-02]],
            [[6.899278e-02],
            [1.045912e-01]],
            [[6.440169e-02],
            [8.843135e-02]],
            [[4.139878e-02],
            [1.000000e-08]],
            [[5.848629e-02],
            [1.000000e-08]],
            [[2.329416e-04],
            [1.000000e-08]],
            [[1.000000e-08],
            [1.000000e-08]]]),
        "6.0": torch.tensor([[[1.027247e-01],
            [7.877284e-02]],
            [[1.014922e-01],
            [8.696266e-02]],
            [[6.308002e-02],
            [7.748771e-02]],
            [[6.899278e-02],
            [1.045912e-01]],
            [[6.440169e-02],
            [8.843135e-02]],
            [[4.139878e-02],
            [1.000000e-08]],
            [[5.848629e-02],
            [1.000000e-08]],
            [[2.329416e-04],
            [1.000000e-08]],
            [[1.000000e-08],
            [1.000000e-08]]]),
        "12.0": torch.tensor([[[1.027247e-01],
            [7.877284e-02]],
            [[1.014922e-01],
            [8.696266e-02]],
            [[6.308002e-02],
            [7.748771e-02]],
            [[6.899278e-02],
            [1.045912e-01]],
            [[6.440169e-02],
            [8.843135e-02]],
            [[4.139878e-02],
            [1.000000e-08]],
            [[5.848629e-02],
            [1.000000e-08]],
            [[2.329416e-04],
            [1.000000e-08]],
            [[1.000000e-08],
            [1.000000e-08]]]),
        "24.0": torch.tensor([[[1.027247e-01],
            [7.877284e-02]],
            [[1.014922e-01],
            [8.696266e-02]],
            [[6.308002e-02],
            [7.748771e-02]],
            [[6.899278e-02],
            [1.045912e-01]],
            [[6.440169e-02],
            [8.843135e-02]],
            [[4.139878e-02],
            [1.000000e-08]],
            [[5.848629e-02],
            [1.000000e-08]],
            [[2.329416e-04],
            [1.000000e-08]],
            [[1.000000e-08],
            [1.000000e-08]]])
    }
}
EXPECTED_DECODER_OUTPUTS_BATCH = {
    "facebook/encodec_24khz": {
        "1.5": torch.tensor(
            [
                [[ 0.0010, 0.0004, 0.0005, 0.0002, 0.0005, -0.0001, -0.0003, -0.0001, 0.0003, 0.0001, -0.0014, -0.0009, -0.0007, -0.0023, -0.0009, 0.0008, 0.0007, 0.0003, 0.0001, 0.0001, 0.0003, -0.0001, -0.0003, -0.0004, -0.0005, -0.0007, -0.0009, -0.0011, -0.0010, -0.0006, -0.0007, -0.0007, -0.0005, -0.0005, -0.0003, -0.0002, -0.0002, -0.0001, -0.0005, -0.0008, -0.0005, -0.0007, -0.0009, -0.0002, 0.0003, 0.0005, 0.0004, 0.0001, 0.0003, 0.0004]],
                [[ -0.0001, -0.0000, 0.0003, 0.0001, 0.0005, 0.0001, -0.0006, -0.0002, 0.0002, 0.0002, -0.0031, -0.0004, 0.0006, -0.0066, -0.0032, 0.0044, 0.0025, -0.0019, -0.0017, 0.0001, 0.0019, -0.0010, -0.0014, -0.0009, -0.0007, -0.0009, -0.0019, -0.0024, -0.0019, -0.0001, -0.0017, -0.0022, -0.0004, 0.0005, -0.0014, -0.0023, 0.0002, 0.0015, -0.0022, -0.0033, 0.0024, 0.0009, -0.0041, 0.0000, 0.0030, 0.0020, -0.0015, -0.0018, 0.0014, 0.0007]],
            ]
        ),
        "3.0": torch.tensor(
            [
                [[ 0.0013, 0.0007, 0.0009, 0.0005, 0.0006, 0.0002, -0.0001, 0.0000, 0.0005, 0.0003, -0.0012, -0.0006, -0.0003, -0.0019, -0.0003, 0.0015, 0.0013, 0.0009, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, -0.0000, -0.0001, -0.0002, -0.0003, -0.0004, -0.0004, 0.0001, -0.0000, -0.0000, 0.0003, 0.0003, 0.0005, 0.0005, 0.0004, 0.0005, 0.0001, -0.0003, -0.0002, -0.0004, -0.0006, 0.0003, 0.0009, 0.0012, 0.0013, 0.0012, 0.0014, 0.0015]],
                [[ 0.0000, -0.0003, 0.0005, 0.0004, 0.0011, 0.0013, 0.0002, 0.0005, 0.0002, 0.0006, -0.0025, -0.0005, 0.0004, -0.0069, -0.0027, 0.0038, 0.0013, -0.0015, -0.0005, 0.0003, 0.0014, -0.0006, -0.0002, -0.0010, -0.0008, -0.0001, -0.0006, -0.0012, -0.0016, 0.0010, 0.0001, -0.0010, -0.0002, 0.0013, -0.0002, -0.0017, 0.0005, 0.0019, -0.0019, -0.0035, 0.0022, -0.0001, -0.0040, 0.0012, 0.0015, 0.0012, 0.0001, -0.0010, 0.0005, 0.0004]],
            ]
        ),
        "6.0": torch.tensor(
            [
                [[ 0.0010, 0.0005, 0.0007, 0.0001, 0.0003, -0.0000, -0.0002, -0.0001, 0.0003, 0.0001, -0.0014, -0.0007, -0.0004, -0.0019, -0.0004, 0.0013, 0.0012, 0.0008, 0.0007, 0.0007, 0.0008, 0.0003, 0.0001, 0.0001, -0.0000, -0.0001, -0.0001, -0.0002, -0.0001, 0.0002, 0.0002, 0.0001, 0.0005, 0.0005, 0.0008, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, 0.0002, -0.0001, -0.0002, 0.0006, 0.0012, 0.0015, 0.0016, 0.0014, 0.0016, 0.0017]],
                [[ -0.0005, -0.0001, 0.0003, 0.0001, 0.0010, 0.0012, 0.0002, 0.0004, 0.0012, 0.0003, -0.0023, -0.0003, -0.0005, -0.0063, -0.0026, 0.0040, 0.0024, -0.0018, -0.0005, 0.0016, 0.0004, -0.0008, 0.0009, 0.0002, -0.0015, -0.0003, 0.0004, -0.0011, -0.0013, 0.0012, 0.0001, -0.0019, 0.0007, 0.0021, -0.0009, -0.0016, 0.0015, 0.0013, -0.0022, -0.0015, 0.0016, -0.0014, -0.0033, 0.0017, 0.0025, -0.0004, -0.0005, 0.0010, 0.0005, 0.0001]],
            ]
        ),
        "12.0": torch.tensor(
            [
                [[ 0.0003, 0.0002, 0.0004, -0.0004, -0.0003, -0.0007, -0.0008, -0.0006, -0.0001, -0.0002, -0.0016, -0.0009, -0.0004, -0.0021, -0.0003, 0.0015, 0.0016, 0.0012, 0.0011, 0.0010, 0.0010, 0.0005, 0.0002, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004, -0.0004, 0.0000, -0.0000, -0.0002, 0.0001, 0.0001, 0.0004, 0.0003, 0.0002, 0.0004, -0.0001, -0.0005, -0.0004, -0.0006, -0.0007, 0.0003, 0.0009, 0.0013, 0.0015, 0.0015, 0.0017, 0.0018]],
                [[ -0.0008, -0.0003, 0.0003, -0.0001, 0.0008, 0.0013, 0.0004, 0.0008, 0.0015, 0.0006, -0.0021, -0.0001, -0.0003, -0.0062, -0.0022, 0.0043, 0.0028, -0.0013, -0.0002, 0.0017, 0.0010, -0.0001, 0.0008, 0.0001, -0.0010, 0.0003, 0.0008, -0.0006, -0.0007, 0.0012, 0.0003, -0.0013, 0.0007, 0.0019, -0.0002, -0.0013, 0.0011, 0.0016, -0.0016, -0.0017, 0.0014, -0.0006, -0.0029, 0.0011, 0.0028, 0.0006, -0.0004, 0.0005, 0.0008, 0.0003]],
            ]
        ),
        "24.0": torch.tensor(
            [
                [[ 0.0009, 0.0004, 0.0007, 0.0002, 0.0004, -0.0001, -0.0003, -0.0002, 0.0002, 0.0001, -0.0015, -0.0009, -0.0006, -0.0024, -0.0005, 0.0016, 0.0014, 0.0010, 0.0009, 0.0008, 0.0008, 0.0004, 0.0001, 0.0000, -0.0001, -0.0002, -0.0003, -0.0006, -0.0006, -0.0003, -0.0005, -0.0006, -0.0003, -0.0004, -0.0001, -0.0002, -0.0003, -0.0001, -0.0006, -0.0011, -0.0008, -0.0010, -0.0012, -0.0000, 0.0007, 0.0011, 0.0012, 0.0011, 0.0013, 0.0014]],
                [[ -0.0009, -0.0004, 0.0001, -0.0003, 0.0007, 0.0012, 0.0003, 0.0006, 0.0017, 0.0008, -0.0020, 0.0001, -0.0002, -0.0064, -0.0023, 0.0047, 0.0029, -0.0016, -0.0004, 0.0019, 0.0010, -0.0002, 0.0007, -0.0001, -0.0013, 0.0005, 0.0012, -0.0007, -0.0008, 0.0013, -0.0001, -0.0022, 0.0004, 0.0020, -0.0004, -0.0014, 0.0017, 0.0020, -0.0018, -0.0016, 0.0015, -0.0015, -0.0036, 0.0014, 0.0030, 0.0004, 0.0002, 0.0015, 0.0011, 0.0007]],
            ]
        )
    },
    "facebook/encodec_48khz": {
        "3.0": torch.tensor([[[ 0.005083,  0.004669,  0.005723,  0.005600,  0.004231,  0.003830,
            0.003684,  0.003349,  0.003032,  0.003055,  0.002768,  0.002370,
            0.002384,  0.002450,  0.002391,  0.002363,  0.002357,  0.002435,
            0.002568,  0.002463,  0.002137,  0.002092,  0.002440,  0.002772,
            0.003035,  0.003473,  0.003963,  0.004288,  0.004315,  0.004087,
            0.003618,  0.003166,  0.002874,  0.002775,  0.002820,  0.002758,
            0.002565,  0.002498,  0.002583,  0.002671,  0.002656,  0.002613,
            0.002433,  0.002236,  0.002215,  0.002302,  0.002287,  0.002113,
            0.001909,  0.001767],
            [-0.003928, -0.002733, -0.001330, -0.001914, -0.002927, -0.003272,
            -0.003677, -0.003615, -0.003341, -0.002907, -0.002764, -0.002742,
            -0.002593, -0.002308, -0.002024, -0.001856, -0.001672, -0.001256,
            -0.000929, -0.001217, -0.001864, -0.002118, -0.002025, -0.001932,
            -0.001816, -0.001572, -0.001214, -0.000885, -0.000829, -0.000976,
            -0.001417, -0.001874, -0.002030, -0.001952, -0.001858, -0.001863,
            -0.001895, -0.001843, -0.001801, -0.001792, -0.001812, -0.001865,
            -0.002008, -0.002120, -0.002132, -0.002093, -0.002170, -0.002370,
            -0.002587, -0.002749]],
            [[ 0.004229,  0.003422,  0.005044,  0.006059,  0.005242,  0.004623,
            0.004231,  0.004050,  0.004314,  0.004701,  0.004559,  0.004105,
            0.003874,  0.003713,  0.003355,  0.003055,  0.003235,  0.003927,
            0.004500,  0.004195,  0.003328,  0.002804,  0.002628,  0.002456,
            0.002693,  0.003883,  0.005604,  0.006791,  0.006702,  0.005427,
            0.003622,  0.002328,  0.002173,  0.002871,  0.003505,  0.003410,
            0.002851,  0.002511,  0.002534,  0.002685,  0.002714,  0.002538,
            0.002110,  0.001697,  0.001786,  0.002415,  0.002940,  0.002856,
            0.002348,  0.001883],
            [-0.003444, -0.002916, -0.000590,  0.000157, -0.000702, -0.001472,
            -0.002032, -0.001891, -0.001283, -0.000670, -0.000590, -0.000875,
            -0.001090, -0.001095, -0.001172, -0.001287, -0.000907,  0.000111,
            0.000858,  0.000471, -0.000532, -0.001127, -0.001463, -0.001853,
            -0.001762, -0.000666,  0.000964,  0.002054,  0.001914,  0.000743,
            -0.000876, -0.001990, -0.001951, -0.001042, -0.000229, -0.000171,
            -0.000558, -0.000752, -0.000704, -0.000609, -0.000594, -0.000723,
            -0.001085, -0.001455, -0.001374, -0.000795, -0.000350, -0.000480,
            -0.000993, -0.001432]]]),
        "6.0": torch.tensor([[[ 5.892794e-03,  5.767163e-03,  7.065284e-03,  7.068626e-03,
            5.825328e-03,  5.601424e-03,  5.582351e-03,  5.209565e-03,
            4.829186e-03,  4.809568e-03,  4.663883e-03,  4.402087e-03,
            4.337528e-03,  4.311915e-03,  4.236566e-03,  4.209972e-03,
            4.179818e-03,  4.196202e-03,  4.309553e-03,  4.267083e-03,
            4.052189e-03,  4.068719e-03,  4.381632e-03,  4.692366e-03,
            4.998885e-03,  5.466312e-03,  5.895300e-03,  6.115717e-03,
            6.055626e-03,  5.773376e-03,  5.316667e-03,  4.826934e-03,
            4.450697e-03,  4.315911e-03,  4.310716e-03,  4.202125e-03,
            4.008702e-03,  3.957694e-03,  4.017603e-03,  4.060654e-03,
            4.036821e-03,  3.923071e-03,  3.659022e-03,  3.427053e-03,
            3.387271e-03,  3.462438e-03,  3.434755e-03,  3.247944e-03,
            3.009581e-03,  2.800536e-03],
            [-1.867314e-03, -6.082351e-04,  9.374358e-04,  5.555808e-04,
            -3.020080e-04, -5.281629e-04, -9.364292e-04, -1.057594e-03,
            -9.703087e-04, -6.292185e-04, -4.193477e-04, -3.605868e-04,
            -2.948678e-04, -1.198237e-04,  4.924605e-05,  1.602105e-04,
            3.162385e-04,  6.700790e-04,  9.868707e-04,  8.484383e-04,
            4.327767e-04,  3.108105e-04,  4.244343e-04,  5.422112e-04,
            7.239584e-04,  1.008546e-03,  1.265120e-03,  1.447669e-03,
            1.436084e-03,  1.271058e-03,  8.684017e-04,  4.149990e-04,
            2.143449e-04,  2.508474e-04,  3.018488e-04,  2.782424e-04,
            2.369677e-04,  3.040710e-04,  3.242530e-04,  2.599912e-04,
            2.211208e-04,  1.311762e-04, -9.807519e-05, -2.752687e-04,
            -3.114068e-04, -2.832832e-04, -3.900219e-04, -6.142824e-04,
            -8.507833e-04, -1.055882e-03]],
            [[ 3.971702e-04, -2.164055e-04,  1.562327e-03,  2.695718e-03,
            2.374928e-03,  2.145125e-03,  1.870762e-03,  1.852614e-03,
            2.074345e-03,  2.312302e-03,  2.222824e-03,  1.876336e-03,
            1.609606e-03,  1.420574e-03,  1.193270e-03,  9.592943e-04,
            1.132237e-03,  1.776782e-03,  2.258269e-03,  1.945908e-03,
            9.930646e-04,  1.733529e-04, -2.533881e-04, -3.138177e-04,
            3.226010e-04,  1.859203e-03,  3.879325e-03,  5.267750e-03,
            5.101699e-03,  3.609465e-03,  1.653315e-03,  2.709297e-04,
            -3.190451e-05,  5.129501e-04,  1.224789e-03,  1.397457e-03,
            1.110794e-03,  8.736057e-04,  8.860155e-04,  1.055910e-03,
            1.100855e-03,  8.834896e-04,  3.825913e-04, -3.267327e-05,
            6.586456e-05,  7.147206e-04,  1.394876e-03,  1.535393e-03,
            1.192172e-03,  7.061819e-04],
            [-6.897163e-03, -6.407891e-03, -4.015491e-03, -3.082125e-03,
            -3.434983e-03, -3.885052e-03, -4.456392e-03, -4.296550e-03,
            -3.861045e-03, -3.553474e-03, -3.547473e-03, -3.800863e-03,
            -4.123025e-03, -4.237277e-03, -4.244958e-03, -4.263899e-03,
            -3.808572e-03, -2.811858e-03, -2.147519e-03, -2.516703e-03,
            -3.550721e-03, -4.353373e-03, -4.846224e-03, -4.960613e-03,
            -4.273535e-03, -2.714785e-03, -7.043980e-04,  6.689885e-04,
            5.069164e-04, -9.122533e-04, -2.816979e-03, -4.124952e-03,
            -4.235019e-03, -3.491365e-03, -2.676077e-03, -2.381226e-03,
            -2.492559e-03, -2.634424e-03, -2.632524e-03, -2.528266e-03,
            -2.536691e-03, -2.746170e-03, -3.187869e-03, -3.553530e-03,
            -3.462211e-03, -2.862707e-03, -2.273719e-03, -2.201617e-03,
            -2.565818e-03, -3.044683e-03]]]),
        "12.0": torch.tensor([[[ 2.237194e-03,  2.508208e-03,  3.986347e-03,  4.020395e-03,
            2.889890e-03,  2.733388e-03,  2.684146e-03,  2.251372e-03,
            1.787451e-03,  1.720550e-03,  1.689184e-03,  1.495478e-03,
            1.321027e-03,  1.185375e-03,  1.098422e-03,  1.055453e-03,
            9.591801e-04,  9.328910e-04,  1.026154e-03,  1.031992e-03,
            9.155220e-04,  9.732856e-04,  1.282264e-03,  1.624059e-03,
            1.920021e-03,  2.333685e-03,  2.730524e-03,  2.919153e-03,
            2.856711e-03,  2.632692e-03,  2.256703e-03,  1.901129e-03,
            1.684760e-03,  1.638201e-03,  1.644909e-03,  1.569378e-03,
            1.448412e-03,  1.478291e-03,  1.580583e-03,  1.633777e-03,
            1.597190e-03,  1.475462e-03,  1.242885e-03,  1.065243e-03,
            1.052842e-03,  1.103825e-03,  1.059115e-03,  9.251673e-04,
            7.235570e-04,  5.053390e-04],
            [-4.534880e-03, -3.111026e-03, -1.486247e-03, -1.739966e-03,
            -2.399862e-03, -2.583335e-03, -3.157276e-03, -3.517166e-03,
            -3.598212e-03, -3.303007e-03, -3.037215e-03, -2.982930e-03,
            -3.026671e-03, -2.958387e-03, -2.836909e-03, -2.775315e-03,
            -2.719575e-03, -2.431532e-03, -2.090512e-03, -2.095603e-03,
            -2.366266e-03, -2.404480e-03, -2.235661e-03, -2.063206e-03,
            -1.888533e-03, -1.640449e-03, -1.407782e-03, -1.250053e-03,
            -1.275359e-03, -1.373277e-03, -1.601508e-03, -1.838720e-03,
            -1.876643e-03, -1.736149e-03, -1.622051e-03, -1.578928e-03,
            -1.564748e-03, -1.455850e-03, -1.391748e-03, -1.418254e-03,
            -1.462577e-03, -1.554713e-03, -1.730076e-03, -1.829485e-03,
            -1.816249e-03, -1.772218e-03, -1.855736e-03, -2.013720e-03,
            -2.196174e-03, -2.378810e-03]],
            [[ 8.993230e-04,  6.808847e-04,  2.595528e-03,  3.586462e-03,
            3.023965e-03,  2.479527e-03,  1.868662e-03,  1.565682e-03,
            1.563900e-03,  1.666364e-03,  1.715061e-03,  1.609638e-03,
            1.294764e-03,  8.647116e-04,  5.122397e-04,  2.899101e-04,
            3.817413e-04,  8.303743e-04,  1.253686e-03,  1.179640e-03,
            6.591807e-04,  1.167982e-04, -3.405492e-04, -5.258832e-04,
            -4.165239e-05,  1.393227e-03,  3.473584e-03,  4.953051e-03,
            4.779391e-03,  3.182305e-03,  1.140233e-03, -2.133392e-04,
            -4.233644e-04,  2.426380e-04,  1.126914e-03,  1.557022e-03,
            1.490265e-03,  1.264647e-03,  1.170405e-03,  1.237709e-03,
            1.112253e-03,  6.990263e-04,  1.700171e-04, -1.761244e-04,
            1.852706e-05,  8.140961e-04,  1.621285e-03,  1.813497e-03,
            1.394625e-03,  7.860070e-04],
            [-4.677887e-03, -3.966209e-03, -1.634288e-03, -8.592710e-04,
            -1.395248e-03, -2.189968e-03, -3.198638e-03, -3.410639e-03,
            -3.241918e-03, -3.051681e-03, -2.845973e-03, -2.786646e-03,
            -3.078280e-03, -3.367662e-03, -3.450923e-03, -3.427895e-03,
            -3.058358e-03, -2.258006e-03, -1.607386e-03, -1.647450e-03,
            -2.164357e-03, -2.647080e-03, -3.110953e-03, -3.304542e-03,
            -2.798792e-03, -1.407999e-03,  5.630683e-04,  1.961336e-03,
            1.813856e-03,  3.529640e-04, -1.526076e-03, -2.695498e-03,
            -2.702039e-03, -1.889018e-03, -9.337939e-04, -3.885011e-04,
            -2.970786e-04, -4.415356e-04, -5.492531e-04, -5.430978e-04,
            -7.051138e-04, -1.102020e-03, -1.577104e-03, -1.846151e-03,
            -1.623901e-03, -8.853760e-04, -1.772702e-04, -4.866864e-05,
            -4.633263e-04, -1.017192e-03]]]),
        "24.0": torch.tensor(
            [
                [
                    [0.0004, 0.0008, 0.0024, 0.0024, 0.0013, 0.0013, 0.0013, 0.0009, 0.0005, 0.0005, 0.0006, 0.0005, 0.0005, 0.0003, 0.0003, 0.0003, 0.0002, 0.0002, 0.0003, 0.0004, 0.0003, 0.0004, 0.0008, 0.0012, 0.0015, 0.0018, 0.0021, 0.0022, 0.0021, 0.0019, 0.0016, 0.0014, 0.0012, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0014, 0.0014, 0.0013, 0.0011, 0.0009, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0007],
                    [ -0.0055, -0.0040, -0.0024, -0.0026, -0.0031, -0.0031, -0.0036, -0.0039, -0.0039, -0.0035, -0.0031, -0.0029, -0.0028, -0.0027, -0.0026, -0.0024, -0.0023, -0.0020, -0.0017, -0.0016, -0.0017, -0.0017, -0.0015, -0.0012, -0.0010, -0.0008, -0.0006, -0.0004, -0.0004, -0.0005, -0.0006, -0.0007, -0.0006, -0.0004, -0.0002, -0.0001, 0.0001, 0.0002, 0.0003, 0.0004, 0.0004, 0.0003, 0.0001, 0.0001, 0.0000, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004],
                ],
                [
                    [-0.0024, -0.0029, -0.0009, 0.0002, -0.0002, -0.0007, -0.0012, -0.0013, -0.0012, -0.0011, -0.0011, -0.0012, -0.0016, -0.0021, -0.0024, -0.0026, -0.0024, -0.0018, -0.0013, -0.0015, -0.0022, -0.0029, -0.0035, -0.0038, -0.0031, -0.0015, 0.0008, 0.0025, 0.0023, 0.0006, -0.0016, -0.0030, -0.0032, -0.0024, -0.0015, -0.0010, -0.0009, -0.0011, -0.0010, -0.0009, -0.0010, -0.0014, -0.0020, -0.0023, -0.0020, -0.0011, -0.0001, 0.0001, -0.0003, -0.0009],
                    [-0.0086, -0.0081, -0.0059, -0.0050, -0.0053, -0.0061, -0.0071, -0.0071, -0.0069, -0.0067, -0.0066, -0.0066, -0.0070, -0.0073, -0.0074, -0.0073, -0.0069, -0.0060, -0.0053, -0.0055, -0.0061, -0.0067, -0.0072, -0.0074, -0.0067, -0.0052, -0.0031, -0.0015, -0.0016, -0.0029, -0.0048, -0.0059, -0.0059, -0.0051, -0.0041, -0.0036, -0.0034, -0.0034, -0.0034, -0.0033, -0.0035, -0.0039, -0.0043, -0.0046, -0.0043, -0.0035, -0.0027, -0.0025, -0.0029, -0.0034],
                ],
            ]
        )
    }
}
# ---- error over whole batch
EXPECTED_CODEC_ERROR_BATCH = {
    "facebook/encodec_24khz": {
        "1.5": 0.0011174238752573729,
        "3.0": 0.0009308119188062847,
        "6.0": 0.0008,
        "12.0": 0.0006830253987573087,
        "24.0": 0.000642190920189023,
    },
    "facebook/encodec_48khz": {
        "3.0": 0.00039895583176985383,
        "6.0": 0.0003249854489695281,
        "12.0": 0.0002540576097089797,
        "24.0": 0.00021899679268244654,
    }
}
# fmt: on


@slow
@require_torch
class EncodecIntegrationTest(unittest.TestCase):
    @parameterized.expand(
        [
            (f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
            for model_id, v in EXPECTED_ENCODER_CODES.items()
            for bandwidth in v
        ]
    )
    def test_integration(self, name, model_id, bandwidth):
        # load model
        model = EncodecModel.from_pretrained(model_id).to(torch_device)
        processor = AutoProcessor.from_pretrained(model_id)

        # load audio
        librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
        audio_array = librispeech_dummy[0]["audio"]["array"]
        if model.config.audio_channels > 1:
            audio_array = np.array([audio_array] * model.config.audio_channels)
        inputs = processor(
            raw_audio=audio_array,
            sampling_rate=processor.sampling_rate,
            return_tensors="pt",
            padding=True,
        ).to(torch_device)

        model = model.eval()
        with torch.no_grad():
            # Compare encoder outputs with expected values
            encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
            codes = torch.cat([encoded[0] for encoded in encoded_frames["audio_codes"]], dim=-1).unsqueeze(0)
            torch.testing.assert_close(
                codes[..., : EXPECTED_ENCODER_CODES[model_id][bandwidth].shape[-1]],
                EXPECTED_ENCODER_CODES[model_id][bandwidth].to(torch_device),
                rtol=1e-4,
                atol=1e-4,
            )
            if EXPECTED_ENCODER_SCALES[model_id][bandwidth] is not None:
                scales = torch.tensor([encoded[0].squeeze() for encoded in encoded_frames["audio_scales"]])
                torch.testing.assert_close(scales, EXPECTED_ENCODER_SCALES[model_id][bandwidth], rtol=1e-4, atol=1e-4)

            # Compare decoder outputs with expected values
            decoded_frames = model.decode(
                encoded_frames["audio_codes"],
                encoded_frames["audio_scales"],
                inputs["padding_mask"],
                last_frame_pad_length=encoded_frames["last_frame_pad_length"],
            )
            torch.testing.assert_close(
                decoded_frames["audio_values"][0][..., : EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].shape[-1]],
                EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].to(torch_device),
                rtol=1e-4,
                atol=1e-4,
            )

            # Compare codec error with expected values
            codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
            torch.testing.assert_close(codec_error, EXPECTED_CODEC_ERROR[model_id][bandwidth], rtol=1e-4, atol=1e-4)

            # make sure forward and enc-dec give same result
            full_enc = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
            torch.testing.assert_close(
                full_enc["audio_values"],
                decoded_frames["audio_values"],
                rtol=1e-4,
                atol=1e-4,
            )

    @parameterized.expand(
        [
            (f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
            for model_id, v in EXPECTED_ENCODER_CODES_BATCH.items()
            for bandwidth in v
        ]
    )
    def test_batch(self, name, model_id, bandwidth):
        # load model
        model = EncodecModel.from_pretrained(model_id).to(torch_device)
        processor = AutoProcessor.from_pretrained(model_id)

        # load audio
        librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
        n_channels = model.config.audio_channels
        if n_channels == 1:
            audio_samples = [audio_sample["array"] for audio_sample in librispeech_dummy[-2:]["audio"]]
        else:
            audio_samples = []
            for _sample in librispeech_dummy[-2:]["audio"]:
                # concatenate mono channels to target number of channels
                audio_array = np.concatenate([_sample["array"][np.newaxis]] * n_channels, axis=0)
                audio_samples.append(audio_array)
        inputs = processor(
            raw_audio=audio_samples,
            sampling_rate=processor.sampling_rate,
            return_tensors="pt",
            padding=True,
        ).to(torch_device)

        # apply model
        model = model.eval()
        with torch.no_grad():
            # Compare encoder outputs with expected values
            encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
            codes = encoded_frames["audio_codes"].permute(1, 2, 0, 3)
            codes = codes.reshape(codes.size(0), codes.size(1), -1)
            torch.testing.assert_close(
                codes[..., : EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].shape[-1]],
                EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].to(torch_device),
                rtol=1e-4,
                atol=1e-4,
            )
            if EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth] is not None:
                scales = torch.stack(encoded_frames["audio_scales"])
                torch.testing.assert_close(
                    scales, EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth].to(torch_device), rtol=1e-4, atol=1e-4
                )

            # Compare decoder outputs with expected values
            decoded_frames = model.decode(
                encoded_frames["audio_codes"],
                encoded_frames["audio_scales"],
                inputs["padding_mask"],
                last_frame_pad_length=encoded_frames["last_frame_pad_length"],
            )
            torch.testing.assert_close(
                decoded_frames["audio_values"][..., : EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].shape[-1]],
                EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].to(torch_device),
                rtol=1e-4,
                atol=1e-4,
            )

            # Compare codec error with expected values
            codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
            torch.testing.assert_close(
                codec_error, EXPECTED_CODEC_ERROR_BATCH[model_id][bandwidth], rtol=1e-4, atol=1e-4
            )

            # make sure forward and enc-dec give same result
            input_values_dec = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
            torch.testing.assert_close(
                input_values_dec["audio_values"], decoded_frames["audio_values"], rtol=1e-4, atol=1e-4
            )
