# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Testing suite for the PyTorch Phi model."""

import unittest

from transformers import is_torch_available
from transformers.testing_utils import (
    require_torch,
    slow,
    torch_device,
)

from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester


if is_torch_available():
    import torch

    from transformers import (
        AutoTokenizer,
        PhiForCausalLM,
        PhiModel,
    )


class PhiModelTester(CausalLMModelTester):
    if is_torch_available():
        base_model_class = PhiModel


@require_torch
class PhiModelTest(CausalLMModelTest, unittest.TestCase):
    model_tester_class = PhiModelTester

    # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
    def is_pipeline_test_to_skip(
        self,
        pipeline_test_case_name,
        config_class,
        model_architecture,
        tokenizer_name,
        image_processor_name,
        feature_extractor_name,
        processor_name,
    ):
        return True


@slow
@require_torch
class PhiIntegrationTest(unittest.TestCase):
    def test_model_phi_1_logits(self):
        input_ids = {
            "input_ids": torch.tensor(
                [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
            )
        }

        model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device)
        model.eval()

        output = model(**input_ids).logits

        EXPECTED_OUTPUT = torch.tensor([[2.2671,  6.7684, -2.0107, -1.2440, -1.5335, -2.3828,  6.9186,  6.4245, 3.1548,  0.9998,  0.0760,  4.4653,  4.9857,  4.2956,  1.2308, -1.4178, 0.1361,  0.5191, -0.5699, -2.2201, -3.0750, -3.9600, -4.5936, -3.7394, -2.7777,  6.1874, -0.4148, -1.5684, -0.5967,  0.2395], [1.7004,  4.0383,  0.0546,  0.4530, -0.3619, -0.9021,  1.8355,  1.3587, 1.2406,  2.5775, -0.8834,  5.1910,  4.2565,  4.1406,  3.0752, -0.9099, 1.1595,  0.0264,  0.3243, -1.1803, -1.3945, -2.1406, -3.9939, -1.4438, -2.9546,  3.9204,  1.0851, -1.0598, -1.7819, -0.4827]]).to(torch_device)  # fmt: skip

        torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4)

    def test_model_phi_1_5_logits(self):
        input_ids = {
            "input_ids": torch.tensor(
                [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
            )
        }

        model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device)
        model.eval()

        output = model(**input_ids).logits

        EXPECTED_OUTPUT = torch.tensor([[12.2922, 13.3507,  8.6963,  9.1355,  9.3502,  9.2667, 14.2027, 13.1363, 13.5446, 11.1337,  9.9279, 16.7195, 13.0768, 14.9141, 11.9965,  8.0233, 10.3129, 10.6118, 10.0204,  9.3827,  8.8344,  8.2806,  8.0153,  8.0540, 7.0964, 16.5743, 11.1256,  9.6987, 11.4770, 10.5440], [12.3323, 14.6050,  8.9986,  8.1580,  9.5654,  6.6728, 12.5966, 12.6662, 12.2784, 11.7522,  8.2039, 16.3102, 11.2203, 13.6088, 12.0125,  9.1021, 9.8216, 10.0987,  9.0926,  8.4260,  8.8009,  7.6547,  6.8075,  7.7881, 7.4501, 15.7451, 10.5053,  8.3129, 10.0027,  9.2612]]).to(torch_device)  # fmt: skip

        torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4)

    def test_model_phi_2_logits(self):
        input_ids = {
            "input_ids": torch.tensor(
                [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
            )
        }

        model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device)
        model.eval()

        output = model(**input_ids).logits

        EXPECTED_OUTPUT = torch.tensor([[6.4830,  6.1644,  3.4055,  2.2848,  5.4654,  2.8360,  5.5975,  5.5391, 7.3101,  4.2498,  2.5913, 10.3885,  6.4359,  8.7982,  5.6534,  0.5150, 2.7498,  3.1930,  2.4334,  1.7781,  1.5613,  1.3067,  0.8291,  0.5633, 0.6522,  9.8191,  5.5771,  2.7987,  4.2845,  3.7030], [6.0642,  7.8242,  3.4634,  1.9259,  4.3169,  2.0913,  6.0446,  3.6804, 6.6736,  4.0727,  2.1791, 11.4139,  5.6795,  7.5652,  6.2039,  2.7174, 4.3266,  3.6930,  2.8058,  2.6721,  2.3047,  2.0848,  2.0972,  2.0441, 1.3160,  9.2085,  4.5557,  3.0296,  2.6045,  2.4059]]).to(torch_device)  # fmt: skip

        torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3)

    def test_phi_2_generation(self):
        model = PhiForCausalLM.from_pretrained("microsoft/phi-2")
        tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")

        inputs = tokenizer(
            "Can you help me write a formal email to a potential business partner proposing a joint venture?",
            return_tensors="pt",
            return_attention_mask=False,
        )

        outputs = model.generate(**inputs, max_new_tokens=30)
        output_text = tokenizer.batch_decode(outputs)

        EXPECTED_OUTPUT = [
            "Can you help me write a formal email to a potential business partner proposing a joint venture?\nInput: Company A: ABC Inc.\nCompany B: XYZ Ltd.\nJoint Venture: A new online platform for e-commerce"
        ]

        self.assertListEqual(output_text, EXPECTED_OUTPUT)
