# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest.mock import patch

from transformers import (
    MODEL_FOR_CAUSAL_LM_MAPPING,
    TextGenerationPipeline,
    logging,
    pipeline,
)
from transformers.testing_utils import (
    CaptureLogger,
    is_pipeline_test,
    require_accelerate,
    require_torch,
    require_torch_accelerator,
    torch_device,
)

from .test_pipelines_common import ANY


@is_pipeline_test
@require_torch
class TextGenerationPipelineTests(unittest.TestCase):
    model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING

    @require_torch
    def test_small_model_pt(self):
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-random-LlamaForCausalLM",
            max_new_tokens=10,
        )
        # Using `do_sample=False` to force deterministic output
        outputs = text_generator("This is a test", do_sample=False)
        self.assertEqual(outputs, [{"generated_text": "This is a testкт MéxicoWSAnimImportдели pip letscosatur"}])

        outputs = text_generator(["This is a test", "This is a second test"], do_sample=False)
        self.assertEqual(
            outputs,
            [
                [{"generated_text": "This is a testкт MéxicoWSAnimImportдели pip letscosatur"}],
                [{"generated_text": "This is a second testкт MéxicoWSAnimImportдели Düsseld bootstrap learn user"}],
            ],
        )

        outputs = text_generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
        self.assertEqual(
            outputs,
            [
                {"generated_token_ids": ANY(list)},
                {"generated_token_ids": ANY(list)},
            ],
        )

    @require_torch
    def test_small_chat_model_pt(self):
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )
        # Using `do_sample=False` to force deterministic output
        chat1 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
        ]
        chat2 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a second test"},
        ]
        outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
        expected_chat1 = chat1 + [
            {
                "role": "assistant",
                "content": " factors factors factors factors factors factors factors factors factors factors",
            }
        ]
        self.assertEqual(
            outputs,
            [
                {"generated_text": expected_chat1},
            ],
        )

        outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10)
        expected_chat2 = chat2 + [
            {
                "role": "assistant",
                "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
            }
        ]

        self.assertEqual(
            outputs,
            [
                [{"generated_text": expected_chat1}],
                [{"generated_text": expected_chat2}],
            ],
        )

    @require_torch
    def test_small_chat_model_continue_final_message(self):
        # Here we check that passing a chat that ends in an assistant message is handled correctly
        # by continuing the final message rather than starting a new one
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )
        # Using `do_sample=False` to force deterministic output
        chat1 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
            {"role": "assistant", "content": "This is"},
        ]
        outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)

        # Assert that we continued the last message and there isn't a sneaky <|im_end|>
        self.assertEqual(
            outputs,
            [
                {
                    "generated_text": [
                        {"role": "system", "content": "This is a system message."},
                        {"role": "user", "content": "This is a test"},
                        {
                            "role": "assistant",
                            "content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
                        },
                    ]
                }
            ],
        )

    @require_torch
    def test_small_chat_model_continue_final_message_override(self):
        # Here we check that passing a chat that ends in an assistant message is handled correctly
        # by continuing the final message rather than starting a new one
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )
        # Using `do_sample=False` to force deterministic output
        chat1 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
        ]
        outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)

        # Assert that we continued the last message and there isn't a sneaky <|im_end|>
        self.assertEqual(
            outputs,
            [
                {
                    "generated_text": [
                        {"role": "system", "content": "This is a system message."},
                        {
                            "role": "user",
                            "content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
                        },
                    ]
                }
            ],
        )

    @require_torch
    def test_small_chat_model_with_dataset_pt(self):
        from torch.utils.data import Dataset

        from transformers.pipelines.pt_utils import KeyDataset

        class MyDataset(Dataset):
            data = [
                [
                    {"role": "system", "content": "This is a system message."},
                    {"role": "user", "content": "This is a test"},
                ],
            ]

            def __len__(self):
                return 1

            def __getitem__(self, i):
                return {"text": self.data[i]}

        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )

        dataset = MyDataset()
        key_dataset = KeyDataset(dataset, "text")

        for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10):
            expected_chat = dataset.data[0] + [
                {
                    "role": "assistant",
                    "content": " factors factors factors factors factors factors factors factors factors factors",
                }
            ]
            self.assertEqual(
                outputs,
                [
                    {"generated_text": expected_chat},
                ],
            )

    @require_torch
    def test_small_chat_model_with_iterator_pt(self):
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )

        # Using `do_sample=False` to force deterministic output
        chat1 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
        ]
        chat2 = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a second test"},
        ]
        expected_chat1 = chat1 + [
            {
                "role": "assistant",
                "content": " factors factors factors factors factors factors factors factors factors factors",
            }
        ]
        expected_chat2 = chat2 + [
            {
                "role": "assistant",
                "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
            }
        ]

        def data():
            yield from [chat1, chat2]

        outputs = text_generator(data(), do_sample=False, max_new_tokens=10)
        outputs = list(outputs)
        self.assertEqual(
            outputs,
            [
                [{"generated_text": expected_chat1}],
                [{"generated_text": expected_chat2}],
            ],
        )

    @require_torch
    def test_small_chat_model_with_response_parsing(self):
        text_generator = pipeline(
            task="text-generation",
            model="hf-internal-testing/tiny-gpt2-with-chatml-template",
        )
        # Using `do_sample=False` to force deterministic output
        chat = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
        ]
        text_generator.tokenizer.response_schema = {
            # A real response schema should probably have things like "role" and "content"
            # and "reasoning_content" but it's unlikely we'd get a tiny model to reliably
            # output anything like that, so let's keep it simple.
            "type": "object",
            "properties": {
                "first_word": {"type": "string", "x-regex": r"^\s*([a-zA-Z]+)"},
                "last_word": {"type": "string", "x-regex": r"([a-zA-Z]+)\s*$"},
            },
        }
        outputs = text_generator(chat, do_sample=False, max_new_tokens=10)
        parsed_message = outputs[0]["generated_text"][-1]
        self.assertEqual(parsed_message, {"first_word": "factors", "last_word": "factors"})

    def get_test_pipeline(
        self,
        model,
        tokenizer=None,
        image_processor=None,
        feature_extractor=None,
        processor=None,
        dtype="float32",
    ):
        text_generator = TextGenerationPipeline(
            model=model,
            tokenizer=tokenizer,
            feature_extractor=feature_extractor,
            image_processor=image_processor,
            processor=processor,
            dtype=dtype,
            max_new_tokens=5,
        )
        return text_generator, ["This is a test", "Another test"]

    def test_stop_sequence_stopping_criteria(self):
        prompt = """Hello I believe in"""
        text_generator = pipeline(
            "text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=5, do_sample=False
        )
        output = text_generator(prompt)
        self.assertEqual(
            output,
            [{"generated_text": "Hello I believe in fe fe fe fe fe"}],
        )

        output = text_generator(prompt, stop_sequence=" fe")
        self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}])

    def run_pipeline_test(self, text_generator, _):
        model = text_generator.model
        tokenizer = text_generator.tokenizer

        outputs = text_generator("This is a test")
        self.assertEqual(outputs, [{"generated_text": ANY(str)}])
        self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))

        outputs = text_generator("This is a test", return_full_text=False)
        self.assertEqual(outputs, [{"generated_text": ANY(str)}])
        self.assertNotIn("This is a test", outputs[0]["generated_text"])

        text_generator = pipeline(
            task="text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=5
        )
        outputs = text_generator("This is a test")
        self.assertEqual(outputs, [{"generated_text": ANY(str)}])
        self.assertNotIn("This is a test", outputs[0]["generated_text"])

        outputs = text_generator("This is a test", return_full_text=True)
        self.assertEqual(outputs, [{"generated_text": ANY(str)}])
        self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))

        outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
        self.assertEqual(
            outputs,
            [
                [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
                [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
            ],
        )

        if text_generator.tokenizer.pad_token is not None:
            outputs = text_generator(
                ["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
            )
            self.assertEqual(
                outputs,
                [
                    [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
                    [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
                ],
            )

        with self.assertRaises(ValueError):
            outputs = text_generator("test", return_full_text=True, return_text=True)
        with self.assertRaises(ValueError):
            outputs = text_generator("test", return_full_text=True, return_tensors=True)
        with self.assertRaises(ValueError):
            outputs = text_generator("test", return_text=True, return_tensors=True)

        # Empty prompt is slightly special
        # it requires BOS token to exist.
        # Special case for Pegasus which will always append EOS so will
        # work even without BOS.
        if (
            text_generator.tokenizer.bos_token_id is not None
            or "Pegasus" in tokenizer.__class__.__name__
            or "Git" in model.__class__.__name__
        ):
            outputs = text_generator("")
            self.assertEqual(outputs, [{"generated_text": ANY(str)}])
        else:
            with self.assertRaises((ValueError, AssertionError)):
                outputs = text_generator("", add_special_tokens=False)

        # We don't care about infinite range models.
        # They already work.
        # Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly.
        EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = [
            "RwkvForCausalLM",
            "XGLMForCausalLM",
            "GPTNeoXForCausalLM",
            "GPTNeoXJapaneseForCausalLM",
            "FuyuForCausalLM",
            "LlamaForCausalLM",
        ]
        if (
            tokenizer.model_max_length < 10000
            and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
        ):
            # Handling of large generations
            if str(text_generator.device) == "cpu":
                with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
                    text_generator("This is a test" * 500, max_new_tokens=5)

            outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=5)
            # Hole strategy cannot work
            if str(text_generator.device) == "cpu":
                with self.assertRaises(ValueError):
                    text_generator(
                        "This is a test" * 500,
                        handle_long_generation="hole",
                        max_new_tokens=tokenizer.model_max_length + 10,
                    )

    @require_torch
    @require_accelerate
    @require_torch_accelerator
    def test_small_model_pt_bloom_accelerate(self):
        import torch

        # Classic `model_kwargs`
        pipe = pipeline(
            model="hf-internal-testing/tiny-random-bloom",
            model_kwargs={"device_map": "auto", "dtype": torch.bfloat16},
            max_new_tokens=5,
            do_sample=False,
        )
        self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
        out = pipe("This is a test")
        self.assertEqual(
            out,
            [{"generated_text": ("This is a test test test test test test")}],
        )

        # Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
        pipe = pipeline(
            model="hf-internal-testing/tiny-random-bloom",
            device_map="auto",
            dtype=torch.bfloat16,
            max_new_tokens=5,
            do_sample=False,
        )
        self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
        out = pipe("This is a test")
        self.assertEqual(
            out,
            [{"generated_text": ("This is a test test test test test test")}],
        )

        # dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882
        pipe = pipeline(
            model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
        )
        self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
        out = pipe("This is a test")
        self.assertEqual(
            out,
            [{"generated_text": ("This is a test test test test test test")}],
        )

    @require_torch
    @require_torch_accelerator
    def test_small_model_fp16(self):
        import torch

        pipe = pipeline(
            model="hf-internal-testing/tiny-random-bloom",
            device=torch_device,
            dtype=torch.float16,
            max_new_tokens=3,
        )
        pipe("This is a test")

    @require_torch
    @require_accelerate
    @require_torch_accelerator
    def test_pipeline_accelerate_top_p(self):
        import torch

        pipe = pipeline(
            model="hf-internal-testing/tiny-random-bloom",
            device_map=torch_device,
            dtype=torch.float16,
            max_new_tokens=3,
        )
        pipe("This is a test", do_sample=True, top_p=0.5)

    def test_pipeline_length_setting_warning(self):
        prompt = """Hello world"""
        text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=5)
        logger = logging.get_logger("transformers.generation.utils")
        logger_msg = "Both `max_new_tokens`"  # The beginning of the message to be checked in this test

        # Both are set by the user -> log warning
        with CaptureLogger(logger) as cl:
            _ = text_generator(prompt, max_length=10, max_new_tokens=1)
        self.assertIn(logger_msg, cl.out)

        # The user only sets one -> no warning
        with CaptureLogger(logger) as cl:
            _ = text_generator(prompt, max_new_tokens=1)
        self.assertNotIn(logger_msg, cl.out)

        with CaptureLogger(logger) as cl:
            _ = text_generator(prompt, max_length=10, max_new_tokens=None)
        self.assertNotIn(logger_msg, cl.out)

    def test_return_dict_in_generate(self):
        text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=2)
        out = text_generator(
            ["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True
        )
        self.assertEqual(
            out,
            [
                [
                    {
                        "generated_text": ANY(str),
                        "logits": ANY(list),
                        "scores": ANY(list),
                    },
                ],
                [
                    {
                        "generated_text": ANY(str),
                        "logits": ANY(list),
                        "scores": ANY(list),
                    },
                ],
            ],
        )

    @require_torch
    def test_pipeline_assisted_generation(self):
        """Tests that we can run assisted generation in the pipeline"""
        model = "hf-internal-testing/tiny-random-MistralForCausalLM"
        pipe = pipeline("text-generation", model=model, assistant_model=model, max_new_tokens=2)

        # We can run the pipeline
        prompt = "Hello world"
        _ = pipe(prompt)

        # It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
        with self.assertRaises(ValueError):
            _ = pipe(prompt, generate_kwargs={"num_beams": 2})

    @require_torch
    def test_pipeline_skip_special_tokens(self):
        """Tests that we can use `skip_special_tokens=False` to get the special tokens in the output"""
        model_id = "google/gemma-3-270m-it"
        chat = [{"role": "user", "content": "What's your name?"}]
        generator = pipeline("text-generation", model=model_id)

        # normal pipeline use
        output = generator(chat, max_new_tokens=20, do_sample=False)
        self.assertNotIn("<end_of_turn>", str(output[0]["generated_text"]))

        # forcing special tokens to be included in the output
        output = generator(chat, max_new_tokens=1000, do_sample=False, skip_special_tokens=False)
        self.assertIn("<end_of_turn>", str(output[0]["generated_text"]))

    @require_torch
    def test_forward_tokenizer_kwargs(self):
        chat = [
            {"role": "system", "content": "This is a system message."},
            {"role": "user", "content": "This is a test"},
        ]
        model = "hf-internal-testing/tiny-gpt2-with-chatml-template"
        text_generator = pipeline("text-generation", model, max_new_tokens=5)
        tokenizer = text_generator.tokenizer

        with patch.object(tokenizer, "apply_chat_template", wraps=tokenizer.apply_chat_template) as mock:
            text_generator(chat, tokenizer_encode_kwargs={"enable_thinking": True})
            self.assertGreater(mock.call_count, 0)
            kw_call_args = mock.call_args[1]
            self.assertIn("enable_thinking", kw_call_args)
            self.assertEqual(kw_call_args["enable_thinking"], True)
