import os

import libcst as cst


# Files from external libraries that should not be tracked
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
EXCLUDED_EXTERNAL_FILES = {
    "habana": [{"name": "modeling_all_models", "type": "modeling"}],
}


def convert_relative_import_to_absolute(
    import_node: cst.ImportFrom,
    file_path: str,
    package_name: str | None = "transformers",
) -> cst.ImportFrom:
    """
    Convert a relative libcst.ImportFrom node into an absolute one,
    using the file path and package name.

    Args:
        import_node: A relative import node (e.g. `from ..utils import helper`)
        file_path: Path to the file containing the import (can be absolute or relative)
        package_name: The top-level package name (e.g. 'myproject')

    Returns:
        A new ImportFrom node with the absolute import path
    """
    if not (import_node.relative and len(import_node.relative) > 0):
        return import_node  # Already absolute

    file_path = os.path.abspath(file_path)
    rel_level = len(import_node.relative)

    # Strip file extension and split into parts
    file_path_no_ext = file_path.removesuffix(".py")
    file_parts = file_path_no_ext.split(os.path.sep)

    # Ensure the file path includes the package name
    if package_name not in file_parts:
        raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")

    # Slice file_parts starting from the package name
    pkg_index = file_parts.index(package_name)
    module_parts = file_parts[pkg_index + 1 :]  # e.g. ['module', 'submodule', 'foo']
    if len(module_parts) < rel_level:
        raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")

    base_parts = module_parts[:-rel_level]

    # Flatten the module being imported (if any)
    def flatten_module(module: cst.BaseExpression | None) -> list[str]:
        if not module:
            return []
        if isinstance(module, cst.Name):
            return [module.value]
        elif isinstance(module, cst.Attribute):
            parts = []
            while isinstance(module, cst.Attribute):
                parts.insert(0, module.attr.value)
                module = module.value
            if isinstance(module, cst.Name):
                parts.insert(0, module.value)
            return parts
        return []

    import_parts = flatten_module(import_node.module)

    # Combine to get the full absolute import path
    full_parts = [package_name] + base_parts + import_parts

    # Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
    if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
        full_parts = [file_parts[pkg_index - 1]] + full_parts

    # Build the dotted module path
    dotted_module: cst.BaseExpression | None = None
    for part in full_parts:
        name = cst.Name(part)
        dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)

    # Return a new ImportFrom node with absolute import
    return import_node.with_changes(module=dotted_module, relative=[])


def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
    """
    Convert an absolute import to a relative one if it belongs to `package_name`.

    Parameters:
    - node: The ImportFrom node to possibly transform.
    - file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
    - package_name: The top-level package name (e.g., 'mypackage').

    Returns:
    - A possibly modified ImportFrom node.
    """
    if import_node.relative:
        return import_node  # Already relative import

    # Extract module name string from ImportFrom
    def get_module_name(module):
        if isinstance(module, cst.Name):
            return module.value, [module.value]
        elif isinstance(module, cst.Attribute):
            parts = []
            while isinstance(module, cst.Attribute):
                parts.append(module.attr.value)
                module = module.value
            if isinstance(module, cst.Name):
                parts.append(module.value)
            parts.reverse()
            return ".".join(parts), parts
        return "", None

    module_name, submodule_list = get_module_name(import_node.module)

    # Check if it's from the target package
    if (
        not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
        and module_name != package_name
    ):
        return import_node  # Not from target package

    # Locate the package root inside the file path
    norm_file_path = os.path.normpath(file_path)
    parts = norm_file_path.split(os.sep)

    try:
        pkg_index = parts.index(package_name)
    except ValueError:
        # Package name not found in path — assume we can't resolve relative depth
        return import_node

    # Depth is how many directories after the package name before the current file
    depth = len(parts) - pkg_index - 1  # exclude the .py file itself
    for i, submodule in enumerate(parts[pkg_index + 1 :]):
        if submodule == submodule_list[2 + i]:
            depth -= 1
        else:
            break

    # Create the correct number of dots
    relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]

    # Strip package prefix from import module path
    if module_name.startswith("optimum." + package_name + "."):
        stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
    else:
        stripped_name = module_name[len(package_name) :].lstrip(".")

    # Build new module node
    if stripped_name == "":
        new_module = None
    else:
        name_parts = stripped_name.split(".")[i:]
        new_module = cst.Name(name_parts[0])
        for part in name_parts[1:]:
            new_module = cst.Attribute(value=new_module, attr=cst.Name(part))

    return import_node.with_changes(module=new_module, relative=relative)


class AbsoluteImportTransformer(cst.CSTTransformer):
    def __init__(self, relative_path: str, source_library: str):
        super().__init__()
        self.relative_path = relative_path
        self.source_library = source_library

    def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
        return convert_relative_import_to_absolute(
            import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
        )


class RelativeImportTransformer(cst.CSTTransformer):
    def __init__(self, relative_path: str, source_library: str):
        super().__init__()
        self.relative_path = relative_path
        self.source_library = source_library

    def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
        return convert_to_relative_import(updated_node, self.relative_path, self.source_library)
