# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import sys
from pathlib import Path

from transformers import is_torch_available
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    require_accelerate,
    require_torch_multi_accelerator,
    run_first,
    slow,
)


if is_torch_available():
    import torch

    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        DataCollatorForLanguageModeling,
        HfArgumentParser,
        Trainer,
        TrainingArguments,
    )


class TestContextParallel(TestCasePlus):
    """Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig."""

    @require_torch_multi_accelerator
    @require_accelerate
    @slow
    @run_first
    def test_cp_equivalence(self):
        """Test that CP produces the same losses as without CP."""

        # Shared setup
        world_size = 2
        script_path = __file__

        # Step 1: Run with CP enabled (cp_size=world_size)
        cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
        cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml"
        cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json"

        # Write config file inline (self-contained test)
        with open(cp_yes_config_path, "w") as f:
            f.write(
                f"""distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_version: 2
mixed_precision: bf16
num_processes: {world_size}
parallelism_config:
  parallelism_config_dp_replicate_size: 1
  parallelism_config_dp_shard_size: 1
  parallelism_config_tp_size: 1
  parallelism_config_cp_size: {world_size}
  parallelism_config_cp_comm_strategy: alltoall
"""
            )

        cmd_cp_yes = f"""
            accelerate launch
            --config_file {cp_yes_config_path}
            {script_path}
            --output_dir {cp_yes_output_dir}
            --report_to none
            --max_steps 10
            --per_device_train_batch_size 1
            --gradient_accumulation_steps 1
            --logging_steps 1
            --remove_unused_columns False
            --seed 42
            --loss_output_file {cp_yes_losses_path}
        """.split()

        execute_subprocess_async(cmd_cp_yes, env=self.get_env())

        # Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config)
        cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
        cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml"
        cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json"

        # Write config file inline (self-contained test)
        with open(cp_no_config_path, "w") as f:
            f.write(
                """distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_version: 2
mixed_precision: bf16
num_processes: 1
"""
            )

        cmd_cp_no = f"""
            accelerate launch
            --config_file {cp_no_config_path}
            {script_path}
            --output_dir {cp_no_output_dir}
            --report_to none
            --max_steps 10
            --per_device_train_batch_size 1
            --gradient_accumulation_steps 1
            --logging_steps 1
            --remove_unused_columns False
            --seed 42
            --loss_output_file {cp_no_losses_path}
        """.split()

        execute_subprocess_async(cmd_cp_no, env=self.get_env())

        # Compare losses - should be very close since CP just splits sequence computation
        with open(cp_yes_losses_path) as f:
            cp_yes_losses = json.load(f)
        with open(cp_no_losses_path) as f:
            cp_no_losses = json.load(f)

        assert len(cp_yes_losses) == len(cp_no_losses), (
            f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}"
        )

        # CP should produce very similar results (small numerical differences expected)
        # The differences come from:
        # - Different gradient reduction patterns in distributed training
        # - BF16 mixed precision accumulated differences
        # - Sequence splitting and gathering in CP mode
        cp_yes_losses_tensor = torch.tensor(cp_yes_losses)
        cp_no_losses_tensor = torch.tensor(cp_no_losses)

        # Use torch.testing.assert_close with rtol=2% and atol=0.02
        # Testing shows actual differences are typically <1.5%
        torch.testing.assert_close(
            cp_yes_losses_tensor,
            cp_no_losses_tensor,
            rtol=2e-2,  # 2% relative tolerance
            atol=2e-2,  # 0.02 absolute tolerance
            msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}",
        )


if __name__ == "__main__":
    # Parse custom arguments (not TrainingArguments parameters)
    loss_output_file = None

    if "--loss_output_file" in sys.argv:
        idx = sys.argv.index("--loss_output_file")
        loss_output_file = sys.argv[idx + 1]
        sys.argv.pop(idx)
        sys.argv.pop(idx)

    parser = HfArgumentParser((TrainingArguments,))
    training_args = parser.parse_args_into_dataclasses()[0]

    # Use SmolLM (small Llama-based model that works with CP)
    model_name = "HuggingFaceTB/SmolLM-135M"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="sdpa",  # CP requires SDPA
    )

    # Create simple dataset: just tokenize some text
    texts = [
        "The quick brown fox jumps over the lazy dog. " * 10,
        "Hello world, this is a test sentence for training. " * 10,
    ] * 4  # 8 samples total

    def tokenize_function(examples):
        return tokenizer(examples, max_length=128, truncation=True, padding="max_length")

    train_dataset = [tokenize_function(text) for text in texts]

    # Use standard DataCollatorForLanguageModeling for causal LM
    # pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2)
    # Trainer will automatically generate position_ids and shift_labels as needed
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Causal language modeling
        pad_to_multiple_of=4,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    # Train for a few steps
    trainer.train()

    # Verify training completed
    assert trainer.state.global_step > 0, "Training should have completed at least one step"

    # Save losses to file if requested (for equivalence testing)
    if loss_output_file and training_args.process_index == 0:
        losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
        with open(loss_output_file, "w") as f:
            json.dump(losses, f)
