
    hHb                        d dl mZmZmZ d dlZd dlmZ d dlZddlm	Z	 ddl
mZmZ ddlmZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZ ddlmZm Z  ddl!m"Z" ddl#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,  e jZ                  e.      Z/ G d de      Z0 G d de*      Z1 G d de(      Z2	 	 	 d.dejf                  dejh                  dejh                  dejh                  deejh                     de5dee5   dee5   de6ejh                  ejh                  f   fd Z7 G d! d"e$      Z8 G d# d$e      Z9 G d% d&e)      Z: G d' d(e%      Z; G d) d*e&      Z< G d+ d,e'      Z=g d-Z>y)/    )CallableOptionalUnionN   )ACT2FN)CacheDynamicCache)PretrainedConfiglayer_type_validation)create_causal_mask!create_sliding_window_causal_mask)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputWithPastCausalLMOutputWithPast)ALL_ATTENTION_FUNCTIONS)Unpack)TransformersKwargslogging)deprecate_kwarg   )	GemmaAttentionGemmaForCausalLMGemmaForSequenceClassificationGemmaForTokenClassificationGemmaMLP
GemmaModelGemmaRMSNormapply_rotary_pos_emb	repeat_kvc                        e Zd ZdZdZdgZddddddddZdgdgfd	d
gd	gfd	gd	gfdZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d fd	Z xZ	S )Gemma2Configa  
    This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the Gemma2-7B.
    e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        vocab_size (`int`, *optional*, defaults to 256000):
            Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Gemma2Model`]
        hidden_size (`int`, *optional*, defaults to 2304):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 9216):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 26):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*, defaults to 4):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details, check out [this
            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
            `num_attention_heads`.
        head_dim (`int`, *optional*, defaults to 256):
            The attention head dimension.
        hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
            if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
        max_position_embeddings (`int`, *optional*, defaults to 8192):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*, defaults to 0):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 1):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 2):
            Beginning of stream token id.
        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        query_pre_attn_scalar (`float`, *optional*, defaults to 256):
            scaling factor used on the attention scores
        sliding_window (`int`, *optional*, defaults to 4096):
            in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
        layer_types (`list`, *optional*):
            Attention pattern for each layer.
        final_logit_softcapping (`float`, *optional*, defaults to 30.0):
            scaling factor when applying tanh softcapping on the logits.
        attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
            scaling factor when applying tanh softcapping on the attention scores.

    ```python
    >>> from transformers import Gemma2Model, Gemma2Config
    >>> # Initializing a Gemma2 gemma2-7b style configuration
    >>> configuration = Gemma2Config()
    >>> # Initializing a model from the gemma2-7b style configuration
    >>> model = Gemma2Model(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```gemma2past_key_valuescolwiserowwise)zlayers.*.self_attn.q_projzlayers.*.self_attn.k_projzlayers.*.self_attn.v_projzlayers.*.self_attn.o_projzlayers.*.mlp.gate_projzlayers.*.mlp.up_projzlayers.*.mlp.down_proj	input_idsinputs_embedshidden_statesattention_mask)embed_tokenslayersnormc                    t        |   d||||d| || _        |	| _        || _        || _        || _        || _        || _        || _	        |
| _
        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        | j*                  ;t-        | j                        D cg c]  }t/        |dz   dz        rdnd c}| _        t1        | j*                         y c c}w )N)pad_token_idbos_token_ideos_token_idtie_word_embeddings   r   sliding_attentionfull_attention )super__init__
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headshead_dimnum_key_value_headsinitializer_rangerms_norm_eps	use_cache
rope_thetaattention_biasattention_dropouthidden_activationquery_pre_attn_scalarsliding_windowfinal_logit_softcappingattn_logit_softcappinglayer_typesrangeboolr   )selfr9   r;   r<   r=   r>   r@   r?   rG   r:   rA   rB   rC   r/   r1   r0   r2   rD   rE   rF   rH   rI   rL   rJ   rK   kwargsi	__class__s                              g/var/www/html/eduruby.in/venv/lib/python3.12/site-packages/transformers/models/gemma2/modular_gemma2.pyr8   zGemma2Config.__init__   s&   8 	 	
%%% 3		

 	
 %'>$&!2!2#6  #6 !2("$,!2!2%:",'>$&<#&#X]^b^t^tXu STtQUaK'8#>NN D 	d../ s   C?)i  i 	  i $              gelu_pytorch_tanhi    g{Gz?gư>Tr   r3   r   Tg     @F        rW   i   Ng      >@g      I@)
__name__
__module____qualname____doc__
model_typekeys_to_ignore_at_inferencebase_model_tp_planbase_model_pp_planr8   __classcell__rR   s   @rS   r"   r"   1   s    JX J#4"5%.%.%.%."+ )"+ &(9:#%568IJ!"_$56 - $ ! $#3<0 <0    r"   c                       e Zd Zy)Gemma2RMSNormNrZ   r[   r\   r6   rd   rS   rf   rf          rd   rf   c                        e Zd Z fdZ xZS )	Gemma2MLPc                 T    t         |   |       t        |j                     | _        y N)r7   r8   r   rG   act_fnrO   configrR   s     rS   r8   zGemma2MLP.__init__   s"     V556rd   )rZ   r[   r\   r8   rb   rc   s   @rS   rj   rj      s    7 7rd   rj   modulequerykeyvaluer*   dropoutscalingsoftcapreturnc                    || j                   dz  }t        || j                        }	t        || j                        }
t        j                  ||	j                  dd            |z  }|||z  }t        j                  |      }||z  }|#|d d d d d d d |	j                  d   f   }||z   }t        j                  j                  |dt        j                        j                  |j                        }t        j                  j                  ||| j                        }t        j                  ||
      }|j                  dd      j!                         }||fS )	N      r   r   )dimdtype)ptrainingr3   )r?   r    num_key_value_groupstorchmatmul	transposetanhshapenn
functionalsoftmaxfloat32tor}   rt   r   
contiguous)rp   rq   rr   rs   r*   rt   ru   rv   rP   
key_statesvalue_statesattn_weightscausal_maskattn_outputs                 rS   eager_attention_forwardr      sA    //4'3 ; ;<JUF$?$?@L<<z';';Aq'ABWLL#g-zz,/#g-!$Q1.D
0@0@0D.D%DE#k1 ==((2U]](SVVW\WbWbcL==((6??([L,,|\:K''1-88:K$$rd   c                   N    e Zd Zdedef fdZ eddd      	 	 ddej                  d	e	ej                  ej                  f   d
e
ej                     de
e   de
ej                     dee   de	ej                  e
ej                     e
e	ej                        f   fd       Z xZS )Gemma2Attentionro   	layer_idxc                    t         |   ||       | j                  j                  | _        | j                  j                  | _        d| _        |j                  dz  | _        |j                  |   dk(  r|j                  | _	        y d | _	        y )NTry   r4   )
r7   r8   ro   rK   rF   	is_causalrH   ru   rL   rI   rO   ro   r   rR   s      rS   r8   zGemma2Attention.__init__   sx    +&*kk&H&H#!%!>!>33T97=7I7I)7TXk7kf33qurd   past_key_valuer$   4.58new_nameversionr)   position_embeddingsr*   cache_positionrP   rw   c                 `   |j                   d d }g |d| j                  }| j                  |      j                  |      j	                  dd      }	| j                  |      j                  |      j	                  dd      }
| j                  |      j                  |      j	                  dd      }|\  }}t        |	|
||      \  }	}
|'|||d}|j                  |
|| j                  |      \  }
}t        }| j                  j                  dk7  rt        | j                  j                     } || |	|
||f| j                  r| j                  nd| j                   | j"                  | j$                  d|\  }} |j&                  g |d j)                         }| j+                  |      }||fS )Nr{   r3   r   )sincosr   eagerrY   )rt   ru   rI   rv   )r   r?   q_projviewr   k_projv_projr   updater   r   ro   _attn_implementationr   r   rF   ru   rI   rK   reshaper   o_proj)rO   r)   r   r*   r$   r   rP   input_shapehidden_shapequery_statesr   r   r   r   cache_kwargsattention_interfacer   r   s                     rS   forwardzGemma2Attention.forward  s    $))#2.88b8$--8{{=166|DNNqRST[[/44\BLLQPQR
{{=166|DNNqRST&S#7jRUWZ#[ j&#&snUL'6'='=j,X\XfXfht'u$J(?;;++w6"9$++:Z:Z"[$7%
 /3mmD**LL..//%
 %
!\ *k));;;;FFHkk+.L((rd   )NN)rZ   r[   r\   r"   intr8   r   r   Tensortupler   r   
LongTensorr   r   r   rb   rc   s   @rS   r   r      s    v| v v %0A6R ,059+)||+) #5<<#=>+) !.	+)
 "%+) !!1!12+) -.+) 
u||Xell3XeELL>Q5RR	S+) S+)rd   r   c                   z    e Zd Zdedef fdZ eddd      	 	 	 	 	 	 ddej                  d	e	ej                  ej                  f   d
e
ej                     de
ej                     de
e   de
e   de
e   de
ej                     de	ej                  e
e	ej                  ej                  f      f   fd       Z xZS )Gemma2DecoderLayerro   r   c                    t         |           |j                  | _        || _        |j                  |   | _        t        ||      | _        t        |      | _	        t        |j                  |j                        | _        t        |j                  |j                        | _        t        |j                  |j                        | _        t        |j                  |j                        | _        y )N)ro   r   )eps)r7   r8   r;   ro   rL   attention_typer   	self_attnrj   mlprf   rB   input_layernormpost_attention_layernormpre_feedforward_layernormpost_feedforward_layernormr   s      rS   r8   zGemma2DecoderLayer.__init__4  s    !--$00;()LV$,V-?-?VEXEXY(5f6H6HfNaNa(b%)6v7I7IvObOb)c&*78J8JPVPcPc*d'rd   r   r$   r   r   r)   r   r*   position_idsoutput_attentionsrC   r   rw   c	                    |}
| j                  |      } | j                  d||||||||d|	\  }}| j                  |      }|
|z   }|}
| j                  |      }| j	                  |      }| j                  |      }|
|z   }|f}|r||fz  }|S )N)r)   r   r*   r   r$   r   rC   r   r6   )r   r   r   r   r   r   )rO   r)   r   r*   r   r$   r   rC   r   rP   residualself_attn_weightsoutputss                rS   r   zGemma2DecoderLayer.forwardA  s     !,,]; ,:4>> 
,
' 3)%+/)
,
 
,
(( 55mD =0 66}E/77F =0 ")++Grd   )NNNFFN)rZ   r[   r\   r"   r   r8   r   r   r   r   r   r   r   rN   FloatTensorr   rb   rc   s   @rS   r   r   3  s   e| e e %0A6R
 2637+/,1$)59*||* #5<<#=>* !.	*
 u//0* "%* $D>* D>* !!1!12* 
u  (51B1BEDUDU1U+V"WW	X* S*rd   r   c                       e Zd Zdef fdZ	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     dee	   deej                     dee   d	ee   d
ee   deej                     dee   defdZ xZS )Gemma2Modelro   c           	          t         |   |       t        j                  t	        |j
                        D cg c]  }t        ||       c}      | _        y c c}w rl   )r7   r8   r   
ModuleListrM   r=   r   r,   r   s      rS   r8   zGemma2Model.__init__p  sD     mmDI&JbJbDcdy	2d
ds   Ar'   r*   r   r$   r(   rC   r   output_hidden_statesr   rP   rw   c
                    ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|d u |d uz  rt	        d      | j
                  r%| j                  r|rt        j                  d       d}|| j                  |      }|r$|"| j                  st        | j                         }|	F||j                         nd}t        j                  |||j                  d   z   |j                        }	||	j!                  d      }t#        |x}t$              s*| j                   |||	||d}t'        di |t)        di |d	}|}| j+                  ||      }t        j,                  | j                   j.                  d
z  |j0                        }||z  }|rdnd }|rdnd }| j2                  d | j                   j4                   D ]9  }|r||fz  } ||f|||j6                     |||||	d|
}|d   }|s1||d   fz  }; | j9                  |      }|r||fz  }t;        ||||      S )Nz:You must specify exactly one of input_ids or inputs_embedszX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.F)ro   r   r3   )device)ro   input_embedsr*   r   r$   r   )r5   r4   g      ?)r}   r6   )r   r*   r   r$   r   rC   r   )last_hidden_stater$   r)   
attentions)ro   r   r   rC   
ValueErrorgradient_checkpointingr   loggerwarning_oncer+   r	   get_seq_lengthr   aranger   r   	unsqueeze
isinstancedictr   r   
rotary_embtensorr;   r}   r,   r=   r   r-   r   )rO   r'   r*   r   r$   r(   rC   r   r   r   rP   past_seen_tokenscausal_mask_mappingmask_kwargsr)   r   
normalizerall_hidden_statesall_self_attnsdecoder_layerlayer_outputss                        rS   r   zGemma2Model.forwardv  s    2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	-t";<YZZ&&4==Yj I  --i8M0*$++>O!CRC^==?de"\\ "2]5H5H5K"KTaThThN )33A6L ?-F ++ -"0"0#2 ,K #5"C{"C%F%U%U# & #oom\J
 \\$++"9"93">mFYFYZ
%
2 #7BD0d![[)H4;;+H+HI 	6M#!m%55!)
$72=3O3OP) /"3#-
 
M *!,M =#3"55'	6* 		-0-!11&+++%	
 	
rd   )	NNNNNNNNN)rZ   r[   r\   r"   r8   r   r   r   r   r   r   rN   r   r   r   r   rb   rc   s   @rS   r   r   o  s    
| 
 151537+/59$(,0/359k
E,,-k
 !.k
 u//0	k

 "%k
   1 12k
 D>k
 $D>k
 'tnk
 !!1!12k
 +,k
 
!k
rd   r   c                   @    e Zd Z fdZ	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     dee   deej                     deej                     dee
   d	ee
   d
ee
   deej                     deeej                  f   defdZ xZS )Gemma2ForCausalLMc                 d    t         |   |       t        |      | _        | j	                          y rl   )r7   r8   r   model	post_initrn   s     rS   r8   zGemma2ForCausalLM.__init__  s&      (
rd   r'   r*   r   r$   r(   labelsrC   r   r   r   logits_to_keeprw   c                 .   | j                   rF| j                  j                  dk7  r-t        j	                  d| j                  j                   d       ||n| j                  j
                  }|	|	n| j                  j                  }	 | j                  d||||||||	|
d	|}|j                  }t        |t              rt        | d      n|}| j                  |dd|ddf         }| j                  j                  G|| j                  j                  z  }t        j                  |      }|| j                  j                  z  }d}| | j                   ||| j"                  fi |}t%        |||j&                  |j(                  |j*                        S )a  
        Example:

        ```python
        >>> from transformers import AutoTokenizer, Gemma2ForCausalLM

        >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

        >>> prompt = "What is your favorite condiment?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "What is your favorite condiment?"
        ```r   zhIt is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `zp`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.N)	r'   r*   r   r$   r(   rC   r   r   r   )losslogitsr$   r)   r   r6   )r   ro   r   r   r   r   r   r   r   r   r   slicelm_headrJ   r   r   loss_functionr9   r   r$   r)   r   )rO   r'   r*   r   r$   r(   r   rC   r   r   r   r   rP   r   r)   slice_indicesr   r   s                     rS   r   zGemma2ForCausalLM.forward  s   B ==T[[==H#{{??@  Aqr 2C1N-TXT_T_TqTq$8$D $++JjJj 	 ,64:: ,
)%+'/!5),
 ,
  118B>SV8W~ot4]kmA}a,?@A;;..:dkkAAAFZZ'FdkkAAAF%4%%ffdooPPD%#33!//))
 	
rd   )NNNNNNNNNNr   )rZ   r[   r\   r8   r   r   r   r   r   r   rN   r   r   r   r   rb   rc   s   @rS   r   r     s    151537+/59-1$(,0/35934K
E,,-K
 !.K
 u//0	K

 "%K
   1 12K
 ))*K
 D>K
 $D>K
 'tnK
 !!1!12K
 c5<</0K
 
 K
rd   r   c                       e Zd Zy)Gemma2ForSequenceClassificationNrg   r6   rd   rS   r   r   8  rh   rd   r   c                       e Zd Zy)Gemma2ForTokenClassificationNrg   r6   rd   rS   r   r   <  rh   rd   r   )r"   r   r   Gemma2PreTrainedModelr   r   )rY   NN)?typingr   r   r   r   torch.nnr   torch.utils.checkpointactivationsr   cache_utilsr   r	   configuration_utilsr
   r   masking_utilsr   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   modeling_utilsr   processing_utilsr   utilsr   r   utils.deprecationr   gemma.modeling_gemmar   r   r   r   r   r   r   r   r    
get_loggerrZ   r   r"   rf   rj   Moduler   floatr   r   r   r   r   r   r   r   __all__r6   rd   rS   <module>r
     sw    - ,    ! . J R B 9 O 5 & 0 0
 
 
 
		H	%Z0# Z0z	L 	7 7 ## %II %<< % 
 % <<	 %
 U\\* %  % e_ % e_ % 5<<%& %F5)n 5)p93 9xr
* r
jQ
( Q
h	&D 		#> 	rd   