# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch DPT model."""

import unittest

from transformers import Dinov2Config, DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils.import_utils import get_torch_major_and_minor_version

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
    import torch

    from transformers import DPTForDepthEstimation
    from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES


if is_vision_available():
    from PIL import Image

    from transformers import DPTImageProcessor


class DPTModelTester:
    def __init__(
        self,
        parent,
        batch_size=2,
        num_channels=3,
        image_size=32,
        patch_size=16,
        use_labels=True,
        num_labels=3,
        is_training=True,
        hidden_size=4,
        num_hidden_layers=2,
        num_attention_heads=2,
        intermediate_size=8,
        out_features=["stage1", "stage2"],
        apply_layernorm=False,
        reshape_hidden_states=False,
        neck_hidden_sizes=[2, 2],
        fusion_hidden_size=6,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.num_channels = num_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.out_features = out_features
        self.apply_layernorm = apply_layernorm
        self.reshape_hidden_states = reshape_hidden_states
        self.use_labels = use_labels
        self.num_labels = num_labels
        self.is_training = is_training
        self.neck_hidden_sizes = neck_hidden_sizes
        self.fusion_hidden_size = fusion_hidden_size
        # DPT's sequence length
        self.seq_length = (self.image_size // self.patch_size) ** 2 + 1

    def prepare_config_and_inputs(self):
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

        labels = None
        if self.use_labels:
            labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)

        config = self.get_config()

        return config, pixel_values, labels

    def get_config(self):
        return DPTConfig(
            backbone_config=self.get_backbone_config(),
            backbone=None,
            neck_hidden_sizes=self.neck_hidden_sizes,
            fusion_hidden_size=self.fusion_hidden_size,
        )

    def get_backbone_config(self):
        return Dinov2Config(
            image_size=self.image_size,
            patch_size=self.patch_size,
            num_channels=self.num_channels,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            is_training=self.is_training,
            out_features=self.out_features,
            reshape_hidden_states=self.reshape_hidden_states,
        )

    def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
        config.num_labels = self.num_labels
        model = DPTForDepthEstimation(config)
        model.to(torch_device)
        model.eval()
        result = model(pixel_values)
        self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        config, pixel_values, labels = config_and_inputs
        inputs_dict = {"pixel_values": pixel_values}
        return config, inputs_dict


@require_torch
class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
    """
    Here we also overwrite some of the tests of test_modeling_common.py, as DPT does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    """

    all_model_classes = (DPTForDepthEstimation,) if is_torch_available() else ()
    pipeline_model_mapping = {"depth-estimation": DPTForDepthEstimation} if is_torch_available() else {}

    test_resize_embeddings = False
    test_torch_exportable = True
    test_torch_exportable_strictly = get_torch_major_and_minor_version() != "2.7"

    def setUp(self):
        self.model_tester = DPTModelTester(self)
        self.config_tester = ConfigTester(self, config_class=DPTConfig, has_text_modality=False, hidden_size=37)

    def test_config(self):
        self.config_tester.run_common_tests()

    @unittest.skip(reason="DPT with AutoBackbone does not have a base model and hence no input_embeddings")
    def test_inputs_embeds(self):
        pass

    def test_for_depth_estimation(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)

    def test_training(self):
        for model_class in self.all_model_classes:
            if model_class.__name__ == "DPTForDepthEstimation":
                continue

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            config.return_dict = True

            if model_class.__name__ in MODEL_MAPPING_NAMES.values():
                continue

            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()

    def test_training_gradient_checkpointing(self):
        for model_class in self.all_model_classes:
            if model_class.__name__ == "DPTForDepthEstimation":
                continue

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            config.use_cache = False
            config.return_dict = True

            if model_class.__name__ in MODEL_MAPPING_NAMES.values() or not model_class.supports_gradient_checkpointing:
                continue
            model = model_class(config)
            model.to(torch_device)
            model.gradient_checkpointing_enable()
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()

    @unittest.skip(reason="DPT with AutoBackbone does not have a base model and hence no input_embeddings")
    def test_model_get_set_embeddings(self):
        pass

    @unittest.skip(
        reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
    )
    def test_training_gradient_checkpointing_use_reentrant(self):
        pass

    @unittest.skip(
        reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
    )
    def test_training_gradient_checkpointing_use_reentrant_false(self):
        pass

    @slow
    def test_model_from_pretrained(self):
        model_name = "Intel/dpt-large"
        model = DPTForDepthEstimation.from_pretrained(model_name)
        self.assertIsNotNone(model)


# We will verify our results on an image of cute cats
def prepare_img():
    image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
    return image


@require_torch
@require_vision
@slow
class DPTModelIntegrationTest(unittest.TestCase):
    def test_inference_depth_estimation_dinov2(self):
        image_processor = DPTImageProcessor.from_pretrained("facebook/dpt-dinov2-small-kitti")
        model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-small-kitti").to(torch_device)

        image = prepare_img()
        inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)
            predicted_depth = outputs.predicted_depth

        # verify the predicted depth
        expected_shape = torch.Size((1, 576, 736))
        self.assertEqual(predicted_depth.shape, expected_shape)

        expectations = Expectations(
            {
                (None, None): [[6.0336, 7.1502, 7.4130], [6.8977, 7.2383, 7.2268], [7.9180, 8.0525, 8.0134]],
                ("cuda", 8): [[6.0350, 7.1518, 7.4144], [6.8992, 7.2396, 7.2280], [7.9194, 8.0538, 8.0145]],
            }
        )
        expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)

        torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)

    def test_inference_depth_estimation_beit(self):
        image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-base-384")
        model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-base-384").to(torch_device)

        image = prepare_img()
        inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)
            predicted_depth = outputs.predicted_depth

        # verify the predicted depth
        expected_shape = torch.Size((1, 384, 384))
        self.assertEqual(predicted_depth.shape, expected_shape)

        expectations = Expectations(
            {
                (None, None): [
                    [2669.7061, 2663.7144, 2674.9399],
                    [2633.9326, 2650.9092, 2665.4270],
                    [2621.8271, 2632.0129, 2637.2290],
                ],
                ("cuda", 8): [
                    [2669.4292, 2663.4121, 2674.6233],
                    [2633.7400, 2650.7026, 2665.2085],
                    [2621.6572, 2631.8452, 2637.0525],
                ],
            }
        )
        expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)

        torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)

    def test_inference_depth_estimation_swinv2(self):
        image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
        model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to(torch_device)

        image = prepare_img()
        inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(**inputs)
            predicted_depth = outputs.predicted_depth

        # verify the predicted depth
        expected_shape = torch.Size((1, 256, 256))
        self.assertEqual(predicted_depth.shape, expected_shape)

        expectations = Expectations(
            {
                (None, None): [
                    [1032.7719, 1025.1886, 1030.2661],
                    [1023.7619, 1021.0075, 1024.9121],
                    [1022.5667, 1018.8522, 1021.4145],
                ],
                ("cuda", 8): [
                    [1032.7170, 1025.0629, 1030.1941],
                    [1023.7309, 1020.9786, 1024.8594],
                    [1022.5233, 1018.8235, 1021.3312],
                ],
            }
        )
        expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)

        torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
