"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""

from __future__ import annotations
from .codeinterpretertool import CodeInterpreterTool, CodeInterpreterToolTypedDict
from .completionargs import CompletionArgs, CompletionArgsTypedDict
from .conversationinputs import ConversationInputs, ConversationInputsTypedDict
from .documentlibrarytool import DocumentLibraryTool, DocumentLibraryToolTypedDict
from .functiontool import FunctionTool, FunctionToolTypedDict
from .imagegenerationtool import ImageGenerationTool, ImageGenerationToolTypedDict
from .websearchpremiumtool import WebSearchPremiumTool, WebSearchPremiumToolTypedDict
from .websearchtool import WebSearchTool, WebSearchToolTypedDict
from mistralai.types import BaseModel, Nullable, OptionalNullable, UNSET, UNSET_SENTINEL
from mistralai.utils import get_discriminator
from pydantic import Discriminator, Tag, model_serializer
from typing import List, Literal, Optional, Union
from typing_extensions import Annotated, NotRequired, TypeAliasType, TypedDict


ConversationStreamRequestHandoffExecution = Literal["client", "server"]

ConversationStreamRequestToolsTypedDict = TypeAliasType(
    "ConversationStreamRequestToolsTypedDict",
    Union[
        WebSearchToolTypedDict,
        WebSearchPremiumToolTypedDict,
        CodeInterpreterToolTypedDict,
        ImageGenerationToolTypedDict,
        FunctionToolTypedDict,
        DocumentLibraryToolTypedDict,
    ],
)


ConversationStreamRequestTools = Annotated[
    Union[
        Annotated[CodeInterpreterTool, Tag("code_interpreter")],
        Annotated[DocumentLibraryTool, Tag("document_library")],
        Annotated[FunctionTool, Tag("function")],
        Annotated[ImageGenerationTool, Tag("image_generation")],
        Annotated[WebSearchTool, Tag("web_search")],
        Annotated[WebSearchPremiumTool, Tag("web_search_premium")],
    ],
    Discriminator(lambda m: get_discriminator(m, "type", "type")),
]


class ConversationStreamRequestTypedDict(TypedDict):
    inputs: ConversationInputsTypedDict
    stream: NotRequired[bool]
    store: NotRequired[Nullable[bool]]
    handoff_execution: NotRequired[Nullable[ConversationStreamRequestHandoffExecution]]
    instructions: NotRequired[Nullable[str]]
    tools: NotRequired[Nullable[List[ConversationStreamRequestToolsTypedDict]]]
    completion_args: NotRequired[Nullable[CompletionArgsTypedDict]]
    name: NotRequired[Nullable[str]]
    description: NotRequired[Nullable[str]]
    agent_id: NotRequired[Nullable[str]]
    model: NotRequired[Nullable[str]]


class ConversationStreamRequest(BaseModel):
    inputs: ConversationInputs

    stream: Optional[bool] = True

    store: OptionalNullable[bool] = UNSET

    handoff_execution: OptionalNullable[ConversationStreamRequestHandoffExecution] = (
        UNSET
    )

    instructions: OptionalNullable[str] = UNSET

    tools: OptionalNullable[List[ConversationStreamRequestTools]] = UNSET

    completion_args: OptionalNullable[CompletionArgs] = UNSET

    name: OptionalNullable[str] = UNSET

    description: OptionalNullable[str] = UNSET

    agent_id: OptionalNullable[str] = UNSET

    model: OptionalNullable[str] = UNSET

    @model_serializer(mode="wrap")
    def serialize_model(self, handler):
        optional_fields = [
            "stream",
            "store",
            "handoff_execution",
            "instructions",
            "tools",
            "completion_args",
            "name",
            "description",
            "agent_id",
            "model",
        ]
        nullable_fields = [
            "store",
            "handoff_execution",
            "instructions",
            "tools",
            "completion_args",
            "name",
            "description",
            "agent_id",
            "model",
        ]
        null_default_fields = []

        serialized = handler(self)

        m = {}

        for n, f in type(self).model_fields.items():
            k = f.alias or n
            val = serialized.get(k)
            serialized.pop(k, None)

            optional_nullable = k in optional_fields and k in nullable_fields
            is_set = (
                self.__pydantic_fields_set__.intersection({n})
                or k in null_default_fields
            )  # pylint: disable=no-member

            if val is not None and val != UNSET_SENTINEL:
                m[k] = val
            elif val != UNSET_SENTINEL and (
                not k in optional_fields or (optional_nullable and is_set)
            ):
                m[k] = val

        return m
