# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Common utilities for the SDK."""

import base64
import collections.abc
import datetime
import enum
import functools
import logging
import re
import typing
from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
import uuid
import warnings
import pydantic
from pydantic import alias_generators
from typing_extensions import TypeAlias

logger = logging.getLogger('google_genai._common')

StringDict: TypeAlias = dict[str, Any]


class ExperimentalWarning(Warning):
  """Warning for experimental features."""


def set_value_by_path(
    data: Optional[dict[Any, Any]], keys: list[str], value: Any
) -> None:
  """Examples:

  set_value_by_path({}, ['a', 'b'], v)
    -> {'a': {'b': v}}
  set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
    -> {'a': {'b': [{'c': v1}, {'c': v2}]}}
  set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
    -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
  """
  if value is None:
    return
  for i, key in enumerate(keys[:-1]):
    if key.endswith('[]'):
      key_name = key[:-2]
      if data is not None and key_name not in data:
        if isinstance(value, list):
          data[key_name] = [{} for _ in range(len(value))]
        else:
          raise ValueError(
              f'value {value} must be a list given an array path {key}'
          )
      if isinstance(value, list) and data is not None:
        for j, d in enumerate(data[key_name]):
          set_value_by_path(d, keys[i + 1 :], value[j])
      else:
        if data is not None:
          for d in data[key_name]:
            set_value_by_path(d, keys[i + 1 :], value)
      return
    elif key.endswith('[0]'):
      key_name = key[:-3]
      if data is not None and key_name not in data:
        data[key_name] = [{}]
      if data is not None:
        set_value_by_path(data[key_name][0], keys[i + 1 :], value)
      return
    if data is not None:
      data = data.setdefault(key, {})

  if data is not None:
    existing_data = data.get(keys[-1])
    # If there is an existing value, merge, not overwrite.
    if existing_data is not None:
      # Don't overwrite existing non-empty value with new empty value.
      # This is triggered when handling tuning datasets.
      if not value:
        pass
      # Don't fail when overwriting value with same value
      elif value == existing_data:
        pass
      # Instead of overwriting dictionary with another dictionary, merge them.
      # This is important for handling training and validation datasets in tuning.
      elif isinstance(existing_data, dict) and isinstance(value, dict):
        # Merging dictionaries. Consider deep merging in the future.
        existing_data.update(value)
      else:
        raise ValueError(
            f'Cannot set value for an existing key. Key: {keys[-1]};'
            f' Existing value: {existing_data}; New value: {value}.'
        )
    else:
      if (
          keys[-1] == '_self'
          and isinstance(data, dict)
          and isinstance(value, dict)
      ):
        data.update(value)
      else:
        data[keys[-1]] = value


def get_value_by_path(
    data: Any, keys: list[str], *, default_value: Any = None
) -> Any:
  """Examples:

  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
    -> v
  get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
    -> [v1, v2]
  """
  if keys == ['_self']:
    return data
  for i, key in enumerate(keys):
    if not data:
      return default_value
    if key.endswith('[]'):
      key_name = key[:-2]
      if key_name in data:
        return [
            get_value_by_path(d, keys[i + 1 :], default_value=default_value)
            for d in data[key_name]
        ]
      else:
        return default_value
    elif key.endswith('[0]'):
      key_name = key[:-3]
      if key_name in data and data[key_name]:
        return get_value_by_path(
            data[key_name][0], keys[i + 1 :], default_value=default_value
        )
      else:
        return default_value
    else:
      if key in data:
        data = data[key]
      elif isinstance(data, BaseModel) and hasattr(data, key):
        data = getattr(data, key)
      else:
        return default_value
  return data


def move_value_by_path(data: Any, paths: dict[str, str]) -> None:
  """Moves values from source paths to destination paths.

  Examples:
    move_value_by_path(
      {'requests': [{'content': v1}, {'content': v2}]},
      {'requests[].*': 'requests[].request.*'}
    )
      -> {'requests': [{'request': {'content': v1}}, {'request': {'content':
      v2}}]}
  """
  for source_path, dest_path in paths.items():
    source_keys = source_path.split('.')
    dest_keys = dest_path.split('.')

    # Determine keys to exclude from wildcard to avoid cyclic references
    exclude_keys = set()
    wildcard_idx = -1
    for i, key in enumerate(source_keys):
      if key == '*':
        wildcard_idx = i
        break

    if wildcard_idx != -1 and len(dest_keys) > wildcard_idx:
      # Extract the intermediate key between source and dest paths
      # Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*']
      # We want to exclude 'request'
      for i in range(wildcard_idx, len(dest_keys)):
        key = dest_keys[i]
        if key != '*' and not key.endswith('[]') and not key.endswith('[0]'):
          exclude_keys.add(key)

    # Move values recursively
    _move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys)


def _move_value_recursive(
    data: Any,
    source_keys: list[str],
    dest_keys: list[str],
    key_idx: int,
    exclude_keys: set[str],
) -> None:
  """Recursively moves values from source path to destination path."""
  if key_idx >= len(source_keys):
    return

  key = source_keys[key_idx]

  if key.endswith('[]'):
    # Handle array iteration
    key_name = key[:-2]
    if key_name in data and isinstance(data[key_name], list):
      for item in data[key_name]:
        _move_value_recursive(
            item, source_keys, dest_keys, key_idx + 1, exclude_keys
        )
  elif key == '*':
    # Handle wildcard - move all fields
    if isinstance(data, dict):
      # Get all keys to move (excluding specified keys)
      keys_to_move = [
          k
          for k in list(data.keys())
          if not k.startswith('_') and k not in exclude_keys
      ]

      # Collect values to move
      values_to_move = {k: data[k] for k in keys_to_move}

      # Set values at destination
      for k, v in values_to_move.items():
        # Build destination keys with the field name
        new_dest_keys = []
        for dk in dest_keys[key_idx:]:
          if dk == '*':
            new_dest_keys.append(k)
          else:
            new_dest_keys.append(dk)
        set_value_by_path(data, new_dest_keys, v)

      # Delete from source
      for k in keys_to_move:
        del data[k]
  else:
    # Navigate to next level
    if key in data:
      _move_value_recursive(
          data[key], source_keys, dest_keys, key_idx + 1, exclude_keys
      )


def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str:
  """Converts a snake_case string to CamelCase, if convert is True."""
  if not convert:
    return snake_str
  return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str)


def convert_to_dict(obj: object, convert_keys: bool = False) -> Any:
  """Recursively converts a given object to a dictionary.

  If the object is a Pydantic model, it uses the model's `model_dump()` method.

  Args:
    obj: The object to convert.
    convert_keys: Whether to convert the keys from snake case to camel case.

  Returns:
    A dictionary representation of the object, a list of objects if a list is
    passed, or the object itself if it is not a dictionary, list, or Pydantic
    model.
  """
  if isinstance(obj, pydantic.BaseModel):
    return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys)
  elif isinstance(obj, dict):
    return {
        maybe_snake_to_camel(key, convert_keys): convert_to_dict(value)
        for key, value in obj.items()
    }
  elif isinstance(obj, list):
    return [convert_to_dict(item, convert_keys) for item in obj]
  else:
    return obj


def _is_struct_type(annotation: type) -> bool:
  """Checks if the given annotation is list[dict[str, typing.Any]]

  or typing.List[typing.Dict[str, typing.Any]].

  This maps to Struct type in the API.
  """
  outer_origin = get_origin(annotation)
  outer_args = get_args(annotation)

  if outer_origin is not list:  # Python 3.9+ normalizes list
    return False

  if not outer_args or len(outer_args) != 1:
    return False

  inner_annotation = outer_args[0]

  inner_origin = get_origin(inner_annotation)
  inner_args = get_args(inner_annotation)

  if inner_origin is not dict:  # Python 3.9+ normalizes to dict
    return False

  if not inner_args or len(inner_args) != 2:
    # dict should have exactly two type arguments
    return False

  # Check if the dict arguments are str and typing.Any
  key_type, value_type = inner_args
  return key_type is str and value_type is typing.Any


def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
  """Removes extra fields from the response that are not in the model.

  Mutates the response in place.
  """

  key_values = list(response.items())

  for key, value in key_values:
    # Need to convert to snake case to match model fields names
    # ex: UsageMetadata
    alias_map = {
        field_info.alias: key for key, field_info in model.model_fields.items()
    }

    if key not in model.model_fields and key not in alias_map:
      response.pop(key)
      continue

    key = alias_map.get(key, key)

    annotation = model.model_fields[key].annotation

    # Get the BaseModel if Optional
    if typing.get_origin(annotation) is Union:
      annotation = typing.get_args(annotation)[0]

    # if dict, assume BaseModel but also check that field type is not dict
    # example: FunctionCall.args
    if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
      _remove_extra_fields(annotation, value)
    elif isinstance(value, list):
      if _is_struct_type(annotation):
        continue

      for item in value:
        # assume a list of dict is list of BaseModel
        if isinstance(item, dict):
          _remove_extra_fields(typing.get_args(annotation)[0], item)


T = typing.TypeVar('T', bound='BaseModel')


def _pretty_repr(
    obj: Any,
    *,
    indent_level: int = 0,
    indent_delta: int = 2,
    max_len: int = 100,
    max_items: int = 5,
    depth: int = 6,
    visited: Optional[FrozenSet[int]] = None,
) -> str:
  """Returns a representation of the given object."""
  if visited is None:
    visited = frozenset()

  obj_id = id(obj)
  if obj_id in visited:
    return '<... Circular reference ...>'

  if depth < 0:
    return '<... Max depth ...>'

  visited = frozenset(list(visited) + [obj_id])

  indent = ' ' * indent_level
  next_indent_str = ' ' * (indent_level + indent_delta)

  if isinstance(obj, pydantic.BaseModel):
    cls_name = obj.__class__.__name__
    items = []
    # Sort fields for consistent output
    fields = sorted(type(obj).model_fields)

    for field_name in fields:
      field_info = type(obj).model_fields[field_name]
      if not field_info.repr:  # Respect Field(repr=False)
        continue

      try:
        value = getattr(obj, field_name)
      except AttributeError:
        continue

      if value is None:
        continue

      value_repr = _pretty_repr(
          value,
          indent_level=indent_level + indent_delta,
          indent_delta=indent_delta,
          max_len=max_len,
          max_items=max_items,
          depth=depth - 1,
          visited=visited,
      )
      items.append(f'{next_indent_str}{field_name}={value_repr}')

    if not items:
      return f'{cls_name}()'
    return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
  elif isinstance(obj, str):
    if '\n' in obj:
      escaped = obj.replace('"""', '\\"\\"\\"')
      # Indent the multi-line string block contents
      return f'"""{escaped}"""'
    return repr(obj)
  elif isinstance(obj, bytes):
    if len(obj) > max_len:
      return f"{repr(obj[:max_len-3])[:-1]}...'"
    return repr(obj)
  elif isinstance(obj, collections.abc.Mapping):
    if not obj:
      return '{}'

    # Check if the next level of recursion for keys/values will exceed the depth limit.
    if depth <= 0:
      item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}"
      return f'{{<... {item_count_str} at Max depth ...>}}'

    if len(obj) > max_items:
      return f'<dict len={len(obj)}>'

    items = []
    try:
      sorted_keys = sorted(obj.keys(), key=str)
    except TypeError:
      sorted_keys = list(obj.keys())

    for k in sorted_keys:
      v = obj[k]
      k_repr = _pretty_repr(
          k,
          indent_level=indent_level + indent_delta,
          indent_delta=indent_delta,
          max_len=max_len,
          max_items=max_items,
          depth=depth - 1,
          visited=visited,
      )
      v_repr = _pretty_repr(
          v,
          indent_level=indent_level + indent_delta,
          indent_delta=indent_delta,
          max_len=max_len,
          max_items=max_items,
          depth=depth - 1,
          visited=visited,
      )
      items.append(f'{next_indent_str}{k_repr}: {v_repr}')
    return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
  elif isinstance(obj, (list, tuple, set)):
    return _format_collection(
        obj,
        indent_level=indent_level,
        indent_delta=indent_delta,
        max_len=max_len,
        max_items=max_items,
        depth=depth,
        visited=visited,
    )
  else:
    # Fallback to standard repr, indenting subsequent lines only
    raw_repr = repr(obj)
    # Replace newlines with newline + indent
    return raw_repr.replace('\n', f'\n{next_indent_str}')


def _format_collection(
    obj: Any,
    *,
    indent_level: int,
    indent_delta: int,
    max_len: int,
    max_items: int,
    depth: int,
    visited: FrozenSet[int],
) -> str:
  """Formats a collection (list, tuple, set)."""
  if isinstance(obj, list):
    brackets = ('[', ']')
    internal_obj = obj
  elif isinstance(obj, tuple):
    brackets = ('(', ')')
    internal_obj = list(obj)
  elif isinstance(obj, set):
    internal_obj = list(obj)
    if obj:
      brackets = ('{', '}')
    else:
      brackets = ('set(', ')')
  else:
    raise ValueError(f'Unsupported collection type: {type(obj)}')

  if not internal_obj:
    return brackets[0] + brackets[1]

  # If the call to _pretty_repr for elements will have depth < 0
  if depth <= 0:
    item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}"
    return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}'

  indent = ' ' * indent_level
  next_indent_str = ' ' * (indent_level + indent_delta)
  elements = []
  num_to_show = min(len(internal_obj), max_items)

  for i in range(num_to_show):
    elem = internal_obj[i]
    elements.append(
        next_indent_str
        + _pretty_repr(
            elem,
            indent_level=indent_level + indent_delta,
            indent_delta=indent_delta,
            max_len=max_len,
            max_items=max_items,
            depth=depth - 1,
            visited=visited,
        )
    )

  if len(internal_obj) > max_items:
    elements.append(
        f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>'
    )

  return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}'


class BaseModel(pydantic.BaseModel):

  model_config = pydantic.ConfigDict(
      alias_generator=alias_generators.to_camel,
      populate_by_name=True,
      from_attributes=True,
      protected_namespaces=(),
      extra='forbid',
      # This allows us to use arbitrary types in the model. E.g. PIL.Image.
      arbitrary_types_allowed=True,
      ser_json_bytes='base64',
      val_json_bytes='base64',
      ignored_types=(typing.TypeVar,),
  )

  def __repr__(self) -> str:
    try:
      return _pretty_repr(self)
    except Exception:
      return super().__repr__()

  @classmethod
  def _from_response(
      cls: typing.Type[T],
      *,
      response: dict[str, object],
      kwargs: dict[str, object],
  ) -> T:
    # To maintain forward compatibility, we need to remove extra fields from
    # the response.
    # We will provide another mechanism to allow users to access these fields.

    # For Agent Engine we don't want to call _remove_all_fields because the
    # user may pass a dict that is not a subclass of BaseModel.
    # If more modules require we skip this, we may want a different approach
    should_skip_removing_fields = (
        kwargs is not None
        and 'config' in kwargs
        and kwargs['config'] is not None
        and isinstance(kwargs['config'], dict)
        and 'include_all_fields' in kwargs['config']
        and kwargs['config']['include_all_fields']
    )

    if not should_skip_removing_fields:
      _remove_extra_fields(cls, response)
    validated_response = cls.model_validate(response)
    return validated_response

  def to_json_dict(self) -> dict[str, object]:
    return self.model_dump(exclude_none=True, mode='json')


class CaseInSensitiveEnum(str, enum.Enum):
  """Case insensitive enum."""

  @classmethod
  def _missing_(cls, value: Any) -> Any:
    try:
      return cls[value.upper()]  # Try to access directly with uppercase
    except KeyError:
      try:
        return cls[value.lower()]  # Try to access directly with lowercase
      except KeyError:
        warnings.warn(f'{value} is not a valid {cls.__name__}')
        try:
          # Creating a enum instance based on the value
          # We need to use super() to avoid infinite recursion.
          unknown_enum_val = super().__new__(cls, value)
          unknown_enum_val._name_ = str(value)  # pylint: disable=protected-access
          unknown_enum_val._value_ = value  # pylint: disable=protected-access
          return unknown_enum_val
        except:
          return None


def timestamped_unique_name() -> str:
  """Composes a timestamped unique name.

  Returns:
      A string representing a unique name.
  """
  timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
  unique_id = uuid.uuid4().hex[0:5]
  return f'{timestamp}_{unique_id}'


def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
  """Converts unserializable types in dict to json.dumps() compatible types.

  This function is called in models.py after calling convert_to_dict(). The
  convert_to_dict() can convert pydantic object to dict. However, the input to
  convert_to_dict() is dict mixed of pydantic object and nested dict(the output
  of converters). So they may be bytes in the dict and they are out of
  `ser_json_bytes` control in model_dump(mode='json') called in
  `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.

  Returns:
    A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
    to compatible type (e.g. base64 encoded string, isoformat date string).
  """
  processed_data: dict[str, object] = {}
  if not isinstance(data, dict):
    return data
  for key, value in data.items():
    if isinstance(value, bytes):
      processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
    elif isinstance(value, datetime.datetime):
      processed_data[key] = value.isoformat()
    elif isinstance(value, dict):
      processed_data[key] = encode_unserializable_types(value)
    elif isinstance(value, list):
      if all(isinstance(v, bytes) for v in value):
        processed_data[key] = [
            base64.urlsafe_b64encode(v).decode('ascii') for v in value
        ]
      if all(isinstance(v, datetime.datetime) for v in value):
        processed_data[key] = [v.isoformat() for v in value]
      else:
        processed_data[key] = [encode_unserializable_types(v) for v in value]
    else:
      processed_data[key] = value
  return processed_data


def experimental_warning(
    message: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
  """Experimental warning, only warns once."""

  def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
    warning_done = False

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
      nonlocal warning_done
      if not warning_done:
        warning_done = True
        warnings.warn(
            message=message,
            category=ExperimentalWarning,
            stacklevel=2,
        )
      return func(*args, **kwargs)

    return wrapper

  return decorator


def _normalize_key_for_matching(key_str: str) -> str:
  """Normalizes a key for case-insensitive and snake/camel matching."""
  return key_str.replace('_', '').lower()


def align_key_case(
    target_dict: StringDict, update_dict: StringDict
) -> StringDict:
  """Aligns the keys of update_dict to the case of target_dict keys.

  Args:
      target_dict: The dictionary with the target key casing.
      update_dict: The dictionary whose keys need to be aligned.

  Returns:
      A new dictionary with keys aligned to target_dict's key casing.
  """
  aligned_update_dict: StringDict = {}
  target_keys_map = {
      _normalize_key_for_matching(key): key for key in target_dict.keys()
  }

  for key, value in update_dict.items():
    normalized_update_key = _normalize_key_for_matching(key)

    if normalized_update_key in target_keys_map:
      aligned_key = target_keys_map[normalized_update_key]
    else:
      aligned_key = key

    if isinstance(value, dict) and isinstance(
        target_dict.get(aligned_key), dict
    ):
      aligned_update_dict[aligned_key] = align_key_case(
          target_dict[aligned_key], value
      )
    elif isinstance(value, list) and isinstance(
        target_dict.get(aligned_key), list
    ):
      # Direct assign as we treat update_dict list values as golden source.
      aligned_update_dict[aligned_key] = value
    else:
      aligned_update_dict[aligned_key] = value
  return aligned_update_dict


def recursive_dict_update(
    target_dict: StringDict, update_dict: StringDict
) -> None:
  """Recursively updates a target dictionary with values from an update dictionary.

  We don't enforce the updated dict values to have the same type with the
  target_dict values except log warnings.
  Users providing the update_dict should be responsible for constructing correct
  data.

  Args:
      target_dict (dict): The dictionary to be updated.
      update_dict (dict): The dictionary containing updates.
  """
  # Python SDK http request may change in camel case or snake case:
  # If the field is directly set via setv() function, then it is camel case;
  # otherwise it is snake case.
  # Align the update_dict key case to target_dict to ensure correct dict update.
  aligned_update_dict = align_key_case(target_dict, update_dict)
  for key, value in aligned_update_dict.items():
    if (
        key in target_dict
        and isinstance(target_dict[key], dict)
        and isinstance(value, dict)
    ):
      recursive_dict_update(target_dict[key], value)
    elif key in target_dict and not isinstance(target_dict[key], type(value)):
      logger.warning(
          f"Type mismatch for key '{key}'. Existing type:"
          f' {type(target_dict[key])}, new type: {type(value)}. Overwriting.'
      )
      target_dict[key] = value
    else:
      target_dict[key] = value
