# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
import os
import tempfile
import warnings

from safetensors import safe_open

from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_available
from transformers.integrations.tensor_parallel import get_packed_weights, get_tensor_shard, repack_weights
from transformers.testing_utils import (
    TestCasePlus,
    backend_device_count,
    get_torch_dist_unique_port,
    require_huggingface_hub_greater_or_equal,
    require_torch_multi_accelerator,
    torch_device,
)


if is_torch_available():
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp


def global_wrapper(rank, func, tp, port, func_args, func_kwargs):
    def setup_dist_env(rank, world_size, port):
        os.environ["WORLD_SIZE"] = str(world_size)
        os.environ["RANK"] = str(rank)
        os.environ["LOCAL_RANK"] = str(rank)
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(port)

    world_size = tp
    setup_dist_env(rank, world_size, port)

    if torch.cuda.is_available():
        torch.cuda.set_device(rank)
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    else:
        dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)

    func(rank, *func_args, **func_kwargs)

    dist.barrier()
    dist.destroy_process_group()


def init_distributed(tp: int):
    def _init_distributed(func):
        def wrapper(*args, **kwargs):
            world_size = tp
            port = get_torch_dist_unique_port()
            spawn_args = (func, tp, port, args, kwargs)
            mp.spawn(global_wrapper, args=spawn_args, nprocs=world_size)

        return wrapper

    return _init_distributed


class TestTensorParallelUtils(TestCasePlus):
    def test_packed_unpacked_conversion(self):
        WORLD_SIZE = 2
        PACKED_BLOCK_SIZE = 800
        SHARDING_DIM = 2
        NUM_BLOCKS = 2

        original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
        original_packed_weights.get_dtype = lambda: "F32"  # get_packed_weights expects PySlice object
        empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)

        class MockDeviceMesh:
            def size(self):
                return WORLD_SIZE

        mock_mesh = (
            MockDeviceMesh()
        )  # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run

        packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
        packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)

        # simulate all gather of sharded weights
        packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
        unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)

        assert torch.allclose(unpacked_weights, original_packed_weights)


class TestTensorParallelProperties(TestCasePlus):
    def test_tp_plan_property_setter_getter(self):
        """Test that tp_plan property can be set and retrieved correctly."""
        model_id = "JackFram/llama-68m"
        model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")

        # Test setting empty plan
        model.tp_plan = {}
        self.assertEqual(model.tp_plan, {})

        # Test setting a valid plan
        valid_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
        model.tp_plan = valid_plan
        self.assertEqual(model.tp_plan, valid_plan)

        # Test updating the plan
        model.tp_plan.update({"model.layers.*.self_attn.k_proj": "colwise"})
        expected_plan = {"model.layers.*.self_attn.q_proj": "colwise", "model.layers.*.self_attn.k_proj": "colwise"}
        self.assertEqual(model.tp_plan, expected_plan)

        # Test overriding existing entry
        model.tp_plan.update({"model.layers.*.self_attn.q_proj": "colwise_rep"})
        expected_plan = {
            "model.layers.*.self_attn.q_proj": "colwise_rep",
            "model.layers.*.self_attn.k_proj": "colwise",
        }
        self.assertEqual(model.tp_plan, expected_plan)

    def test_tp_plan_validation_invalid_style(self):
        """Test that invalid parallel styles are rejected."""
        model_id = "JackFram/llama-68m"
        model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")

        # Test invalid parallel style
        with self.assertRaises(ValueError) as context:
            model.tp_plan = {"layers.*.self_attn.q_proj": "invalid_style"}

        self.assertIn("Unsupported tensor parallel style 'invalid_style'", str(context.exception))
        self.assertIn("Supported styles are", str(context.exception))

    def test_tp_plan_validation_nonexistent_layer_warning(self):
        """Test that warnings are issued for non-existent layer patterns."""

        model_id = "JackFram/llama-68m"
        model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")

        # Test warning for non-existent layer pattern
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            model.tp_plan = {"nonexistent.*.layer": "colwise"}

            # Check that a warning was issued
            self.assertTrue(len(w) > 0)
            warning_message = str(w[0].message)
            self.assertIn("Layer pattern 'nonexistent.*.layer' does not match any parameters", warning_message)

    def test_tp_plan_valid_layer_patterns(self):
        """Test that valid layer patterns are accepted without warnings."""
        model_id = "JackFram/llama-68m"
        model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")

        # Test valid layer patterns that should match the model structure
        valid_plans = [
            {"model.layers.*.self_attn.q_proj": "colwise"},
            {"model.layers.*.self_attn.k_proj": "rowwise"},
            {"model.layers.*.mlp.gate_proj": "colwise_rep"},
        ]

        for plan in valid_plans:
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")
                model.tp_plan = plan

                # Filter out any warnings that are not about layer patterns
                layer_warnings = [
                    warning
                    for warning in w
                    if "Layer pattern" in str(warning.message)
                    and "does not match any parameters" in str(warning.message)
                ]

                # Should not have layer pattern warnings for valid patterns
                self.assertEqual(
                    len(layer_warnings),
                    0,
                    f"Unexpected warning for valid pattern {plan}: {[str(w.message) for w in layer_warnings]}",
                )

        # Verify the final plan was set correctly
        self.assertEqual(model.tp_plan, valid_plans[-1])

    def test_tp_plan_none_handling(self):
        """Test that None values are handled correctly."""
        model_id = "JackFram/llama-68m"
        model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")

        # Test setting None
        model.tp_plan = None
        self.assertEqual(model.tp_plan, {})

        # Test setting a plan after None
        model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
        self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"})


# ====== TEST FUNCTIONS ======
def _test_model_dense_forward_impl(rank, mode):
    """Implementation for comparing TP and non-TP model outputs."""
    model_id = "JackFram/llama-68m"

    # Ensure same random seed for reproducibility
    torch.manual_seed(0)

    # Load tokenizer and prepare inputs - same for both models
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    prompt = "Can I help"
    inputs = tokenizer(prompt, return_tensors="pt")

    # Load TP model first to determine device
    model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
    dist.barrier()
    if mode == "eval":
        model_tp.eval()
    else:
        model_tp.train()

    # Load non-TP model and move to same device as TP model
    device = model_tp.device
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
    model = model.to(device)

    if mode == "eval":
        model.eval()
    else:
        model.train()

    # Prepare inputs on the same device
    input_ids = inputs.input_ids.to(device)

    # Run forward pass on both models
    with torch.no_grad():
        # Non-TP model output
        outputs = model(input_ids)
        logits = outputs.logits

        # TP model output
        outputs_tp = model_tp(input_ids)
        logits_tp = outputs_tp.logits

    # Compare outputs - they should match
    assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
        f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
    )

    dist.barrier()


def _test_model_dense_backward_pass_impl(rank):
    """Implementation for comparing TP and non-TP model backward passes."""
    model_id = "JackFram/llama-68m"

    torch.manual_seed(0)

    model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
    dist.barrier()
    model_tp.train()

    device = model_tp.device
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
    model = model.to(device)
    model.train()

    batch_size, seq_length = 2, 10
    torch.manual_seed(42)  # Different seed for inputs to ensure they're deterministic
    input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
    labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)

    outputs = model(input_ids, labels=labels)
    loss = outputs.loss
    loss.backward()

    outputs_tp = model_tp(input_ids, labels=labels)
    loss_tp = outputs_tp.loss
    loss_tp.backward()

    assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
        f"TP and non-TP model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
    )

    # Compare gradients for matching parameters
    # Note: TP model may have sharded parameters (DTensors), so we slice the reference gradient to match
    for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
        if param.grad is not None and param_tp.grad is not None:
            grad = param.grad
            grad_tp = param_tp.grad

            if isinstance(param_tp.data, dist.tensor.DTensor):
                placement = param_tp.data.placements[0]
                if hasattr(placement, "dim") and placement.dim is not None:
                    grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
                else:
                    grad_shard = grad
            else:
                grad_shard = grad

            grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp

            assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
                f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
            )

    dist.barrier()


def _test_model_dense_forward_compile_impl(rank, mode):
    """Implementation for comparing TP and non-TP model outputs with torch.compile."""
    model_id = "JackFram/llama-68m"

    torch.manual_seed(0)

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    prompt = "Can I help"
    inputs = tokenizer(prompt, return_tensors="pt")

    model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
    dist.barrier()
    if mode == "eval":
        model_tp.eval()
    else:
        model_tp.train()

    device = model_tp.device
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
    model = model.to(device)

    if mode == "eval":
        model.eval()
    else:
        model.train()

    # Compile both models
    model.forward = torch.compile(model.forward)
    model_tp.forward = torch.compile(model_tp.forward)

    input_ids = inputs.input_ids.to(device)

    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

        outputs_tp = model_tp(input_ids)
        logits_tp = outputs_tp.logits

    assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
        f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
    )

    dist.barrier()


def _test_model_dense_save_impl(rank, tmp_dir):
    """Implementation of test_model_save for distributed execution."""
    model_id = "JackFram/llama-68m"

    if dist.is_initialized():
        kwargs = {"tp_plan": "auto"}
        result_dir = f"{tmp_dir}/tp"
    else:
        kwargs = {}
        result_dir = f"{tmp_dir}/nontp"

    model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
    model.save_pretrained(result_dir)


class TestTensorParallelBase(TestCasePlus):
    """Base class for tensor parallel tests. Subclasses must set nproc_per_node."""

    nproc_per_node = None

    @require_torch_multi_accelerator
    def test_model_dense_forward_eval(self):
        """Test that TP and non-TP models produce the same outputs in eval mode."""
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("eval")

    @require_torch_multi_accelerator
    def test_model_dense_forward_train(self):
        """Test that TP and non-TP models produce the same outputs in train mode."""
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("train")

    @require_torch_multi_accelerator
    def test_model_dense_backward_pass(self):
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)()

    @require_torch_multi_accelerator
    def test_model_dense_forward_compile_eval(self):
        """Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval")

    @require_torch_multi_accelerator
    def test_model_dense_forward_compile_train(self):
        """Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train")

    @require_huggingface_hub_greater_or_equal("0.31.4")
    @require_torch_multi_accelerator
    def test_model_dense_save(self):
        if self.nproc_per_node is None:
            self.skipTest("nproc_per_node not set")
        if backend_device_count(torch_device) < self.nproc_per_node:
            self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

        with tempfile.TemporaryDirectory() as tmp_dir:
            # First run with TP (distributed)
            init_distributed(tp=self.nproc_per_node)(_test_model_dense_save_impl)(tmp_dir)

            # Then run without TP (non-distributed)
            _test_model_dense_save_impl(0, tmp_dir)

            non_tp_model_path = os.path.join(tmp_dir, "nontp")
            tp_model_path = os.path.join(tmp_dir, "tp")

            for filename in os.listdir(non_tp_model_path):
                if not filename.endswith(".safetensors"):
                    continue

                non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
                tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
                for non_tp_key in non_tp_model.keys():
                    non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
                    tp_tensor = tp_model.get_tensor(non_tp_key)
                    assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
                    del non_tp_tensor, tp_tensor


class TestTensorParallel2Proc(TestTensorParallelBase):
    """Test tensor parallel with 2 processes."""

    nproc_per_node = 2


class TestTensorParallel4Proc(TestTensorParallelBase):
    """Test tensor parallel with 4 processes."""

    nproc_per_node = 4
