# Copyright 2025 Eduard Durech and SGLang team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py

import unittest

import pytest
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow


class FlashAttentionParityTest(unittest.TestCase):
    # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
    def _lcs(self, X, Y):
        m = len(X)
        n = len(Y)
        L = [[0] * (n + 1) for _ in range(m + 1)]

        for i in range(m + 1):
            for j in range(n + 1):
                if i == 0 or j == 0:
                    L[i][j] = 0
                elif X[i - 1] == Y[j - 1]:
                    L[i][j] = L[i - 1][j - 1] + 1
                else:
                    L[i][j] = max(L[i - 1][j], L[i][j - 1])

        return L[m][n]

    # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
    def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
        rouge_l_scores = []

        for s1, s2 in zip(output_strs_list1, output_strs_list2):
            lcs_len = self._lcs(s1, s2)
            precision = lcs_len / len(s1) if len(s1) > 0 else 0
            recall = lcs_len / len(s2) if len(s2) > 0 else 0
            if precision + recall > 0:
                fmeasure = (2 * precision * recall) / (precision + recall)
            else:
                fmeasure = 0.0
            rouge_l_scores.append(fmeasure)

        return rouge_l_scores

    def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
        for _ in range(n_warmup):
            model.generate(**inputs, max_new_tokens=20, do_sample=False)
        torch.cuda.synchronize()

        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)

        start_time.record()
        for _ in range(n_runs):
            model.generate(**inputs, max_new_tokens=20, do_sample=False)
        end_time.record()
        torch.cuda.synchronize()

        return start_time.elapsed_time(end_time) / n_runs

    @pytest.mark.flash_attn_3_test
    @require_torch_gpu
    @require_flash_attn
    @require_flash_attn_3
    @slow
    def test_flash_attention_2_3_parity(self):
        model_id = "meta-llama/Llama-3.2-1B-Instruct"
        prompt = ["The ETH AI Center is", "What is life?"]

        # 1. Load model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
        ).to("cuda")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token_id = tokenizer.eos_token_id

        # 2. Generate with both models
        inputs = tokenizer(prompt, padding=True, padding_side="left", return_tensors="pt").to("cuda")

        with torch.no_grad():
            output_2 = model.generate(
                **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
            )
            model.set_attn_implementation("flash_attention_3")
            output_3 = model.generate(
                **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
            )

        # 3. Correctness check
        # 3a. Logits
        logits_2 = torch.stack(output_2.scores)
        logits_3 = torch.stack(output_3.scores)
        torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
        logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
        logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
        max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()

        # 3b. Generated text
        text_2s, text_3s = [], []
        for i in range(len(prompt)):
            text_2s.append(tokenizer.decode(output_2.sequences[i], skip_special_tokens=True))
            text_3s.append(tokenizer.decode(output_3.sequences[i], skip_special_tokens=True))

        rouge_scores = self._calculate_rouge_l(text_2s, text_3s)
        for i in range(len(rouge_scores)):
            assert rouge_scores[i] > 0.99, f"Generated texts at prompt {i} do not match (ROUGE-L: {rouge_scores[i]})"

        # 4. Performance check
        with torch.no_grad():
            time_3 = self._benchmark_generation(model, inputs)
            model.set_attn_implementation("flash_attention_2")
            time_2 = self._benchmark_generation(model, inputs)

        print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
        print(f"Prompt: '{prompt}'")
        print(f"Generated text with Flash Attention 2: {text_2s}")
        print(f"Generated text with Flash Attention 3: {text_3s}")
        print(f"ROUGE-L: {rouge_scores}")
        print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
        print(f"Flash Attention 2 latency: {time_2:.2f} ms")
        print(f"Flash Attention 3 latency: {time_3:.2f} ms")
        print(f"Speed-up: {time_2 / time_3:.2f}x")
        print("---")
