
    haC                     d   d dl Z d dlmZmZ d dlZd dlmZ ddlmZ ddlm	Z	 ddl
mZ ddlmZmZ d	d
lmZmZ d	dlmZ d	dlmZmZ d	dlmZmZ d	dlmZmZ d	dlmZmZ ddl m!Z!m"Z" dejF                  de$fdZ% G d de      Z& G d de      Z' G d de      Z( G d de      Z) G d de      Z* G d d e      Z+ G d! d"e      Z, G d# d$e      Z- G d% d&ej\                        Z/ G d' d(ej`                        Z1 G d) d*e      Z2 G d+ d,e2      Z3 G d- d.e      Z4e G d/ d0ee             Z5g d1Z6y)2    N)OptionalUnion)nn   )GenerationMixin)BaseModelOutput)PreTrainedModel)auto_docstringcan_return_tuple   )Aimv2AttentionAimv2EncoderLayer)	AutoModel)LlamaMLPLlamaRMSNorm)LlavaForConditionalGeneration
LlavaModel)LlavaNextCausalLMOutputWithPastLlavaNextModelOutputWithPast)SiglipEncoderSiglipVisionEmbeddings   )Ovis2ConfigOvis2VisionConfiglogitsdimc                     | j                  |      }|j                  |d      d   }t        j                  | t        j                        j                  ||d      }||j                         z
  |z   }|S )NT)keepdimr   )memory_formatg      ?)softmaxmaxtorch
zeros_likelegacy_contiguous_formatscatter_detach)r   r   y_softindexy_hardrets         e/var/www/html/eduruby.in/venv/lib/python3.12/site-packages/transformers/models/ovis2/modular_ovis2.pyhard_softmaxr,   #   sk    ^^C FJJsDJ)!,EfE4R4RS\\]`bgilmF
6==?
"V
+CJ    c                       e Zd Zy)Ovis2ModelOutputWithPastN__name__
__module____qualname__ r-   r+   r/   r/   -       r-   r/   c                       e Zd Zy)Ovis2CausalLMOutputWithPastNr0   r4   r-   r+   r7   r7   1   r5   r-   r7   c                       e Zd Zy)Ovis2RMSNormNr0   r4   r-   r+   r9   r9   5   r5   r-   r9   c                       e Zd Zy)Ovis2VisionMLPNr0   r4   r-   r+   r;   r;   9   r5   r-   r;   c                   b     e Zd Zdef fdZd Zdej                  dej                  fdZ	 xZ
S )Ovis2VisionEmbeddingsconfigc                 n    t         |   |       t        |j                  |j                        | _        y N)super__init__r9   hidden_sizerms_norm_epsrms_normselfr>   	__class__s     r+   rB   zOvis2VisionEmbeddings.__init__>   s*     $V%7%79L9LMr-   c                     t        d      NzNot needed for Ovis2)NotImplementedErrorrG   s    r+   interpolate_pos_encodingz.Ovis2VisionEmbeddings.interpolate_pos_encodingB   s    !"899r-   pixel_valuesreturnc                 (   | j                   j                  j                  }| j                  |j                  |            }|j	                  d      j                  dd      }| j                  |      }|| j                  | j                        z   }|S )Ndtyper   r   )	patch_embeddingweightrR   toflatten	transposerE   position_embeddingposition_ids)rG   rN   target_dtypepatch_embeds
embeddingss        r+   forwardzOvis2VisionEmbeddings.forwardE   s    ++2288++LOO,O,OP!))!,66q!<
]]:.
$"9"9$:K:K"LL
r-   )r1   r2   r3   r   rB   rM   r"   FloatTensorTensorr]   __classcell__rH   s   @r+   r=   r=   =   s4    N0 N:E$5$5 %,, r-   r=   c                       e Zd Zy)Ovis2VisionAttentionNr0   r4   r-   r+   rc   rc   P   r5   r-   rc   c                       e Zd Zy)Ovis2VisionEncoderLayerNr0   r4   r-   r+   re   re   T   r5   r-   re   c                   $     e Zd Zdef fdZ xZS )Ovis2VisionEncoderr>   c                     t         |   |       t        j                  t	        |j
                        D cg c]  }t        |       c}      | _        y c c}w r@   )rA   rB   r   
ModuleListrangenum_hidden_layersre   layers)rG   r>   _rH   s      r+   rB   zOvis2VisionEncoder.__init__Y   s@     mmeTZTlTlNm$n%<V%D$no$ns   A)r1   r2   r3   r   rB   r`   ra   s   @r+   rg   rg   X   s    p0 p pr-   rg   c            	       p     e Zd Zdef fdZe	 	 	 ddeej                     dee	   dee	   fd       Z
 xZS )Ovis2VisionTransformerr>   c                     t         |           || _        t        |      | _        t        |      | _        t        |j                  |j                        | _
        d| _        y )NF)rA   rB   r>   r=   r\   rg   encoderr9   rC   rD   rE   gradient_checkpointingrF   s     r+   rB   zOvis2VisionTransformer.__init___   sO    /7)&1$V%7%79L9LM&+#r-   attention_maskoutput_attentionsoutput_hidden_statesc                 (   ||n| j                   j                  }||n| j                   j                  }| j                  |      }| j	                  ||||d      }|d   }| j                  |      }t        ||j                  |j                        S )NT)inputs_embedsrs   rt   ru   return_dictr   )last_hidden_statehidden_states
attentions)	r>   rt   ru   r\   rq   rE   r   rz   r{   )rG   rN   rs   rt   ru   rz   encoder_outputsry   s           r+   r]   zOvis2VisionTransformer.forwardg   s     2C1N-TXT_T_TqTq$8$D $++JjJj 	 5,,')/!5 ' 
 ,A. MM*;</)77&11
 	
r-   )NNN)r1   r2   r3   r   rB   r   r   r"   r_   boolr]   r`   ra   s   @r+   ro   ro   ^   s^    ,0 ,  26,0/3
 !.
 $D>	

 'tn
 
r-   ro   c                   P     e Zd Zdej                  dej                  f fdZ xZS )Ovis2VisualEmbeddingTablevisual_tokensrO   c                    |j                   t        j                  t        j                  t        j                  t        j
                  t        j                  fv rt        | !  |      S t        j                  || j                        S r@   )rR   r"   int8int16int32int64longrA   r]   matmulrT   )rG   r   rH   s     r+   r]   z!Ovis2VisualEmbeddingTable.forward   sW    5::u{{EKKV[V`V`"aa7?=11||M4;;77r-   )r1   r2   r3   r"   r_   r]   r`   ra   s   @r+   r   r      s#    8U\\ 8ell 8 8r-   r   c                   B    e Zd ZU eed<   dZdZdgZdZdZ	dZ
dZdZdZdZy)Ovis2PreTrainedModelr>   modelTrc   past_key_valuesN)r1   r2   r3   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_cache_class_supports_flash_attn_supports_flex_attn_supports_sdpa_can_compile_fullgraph_supports_attention_backendr4   r-   r+   r   r      sF    &*#/0"3 N!"&r-   r   c                        e Zd ZU eed<   def fdZdej                  deej                  ej                  f   fdZ
 xZS )Ovis2VisionModelr>   c                    t         |   |       || _        t        |      | _        |j
                  | _        |j                  | _        t        j                  |j                  |j                  z  |j                  z  | j                  | j
                  z
  d      | _        t        j                  | j                  | j
                  z
        | _        y NF)bias)rA   rB   r>   ro   transformernum_visual_indicator_tokens
vocab_sizer   LinearrC   hidden_stridehead_linear	LayerNorm	head_normrF   s     r+   rB   zOvis2VisionModel.__init__   s     1&9+1+M+M( ++99!5!558L8LLOOd>>>

 doo8X8X&XYr-   rN   rO   c           	         | j                  |      }|j                  }| j                  j                  dkD  r|j                  \  }}}| j                  j                  }t        t        j                  |            }||z  |k7  rt        d      |||z  z
  |z  }	t        j                  j                  |ddd|	d|	fdd      }||	z  }|j                  |||z  |||z  ||      }|j                  dddddd      }|j                  |d	||z  |z        }| j                  |      }
| j                  |
      }
| j                  j                   d
k(  r$t        j                  j#                  |
d	d      }|S | j                  j                   dk(  rt%        |
d	      }|S | j                  j                   dk(  r!t        j                  j'                  |
d	      }S )Nr   z.Token sequence length must be a perfect squarer   constantr   r         gumbel_argmaxT)r   hard	st_argmaxr   r    )r   ry   r>   r   shapeintmathsqrt
ValueErrorr   
functionalpadreshapepermuter   r   tokenize_functiongumbel_softmaxr,   r    )rG   rN   outputsry   
num_imagesseq_len
hidden_dimr   sqrt_lpad_sizer   
prob_tokens               r+   r]   zOvis2VisionModel.forward   s   ""<0#55;;$$q(.?.E.E+J KK55M7+,F') !QRR%-)?@MQH " 1 12CaAxYZ\dEegqst uhF 1 9 9Fm3]FmD[]jlv! !2 9 9!Q1a K 1 9 9B =
 J! !!"34';;((O;55f"45PJ  [[**k9%f"5J  [[**i7..v2.>Jr-   )r1   r2   r3   r   r   rB   r"   r^   tupler_   r]   r`   ra   s   @r+   r   r      sC    Z0 Z"E$5$5 "%ell@Z:[ "r-   r   c            !           e Zd Zi Zdef fdZdej                  dej                  fdZe	e
	 	 	 	 	 	 	 	 	 	 	 	 	 ddej                  dej                  deej                     deej                     d	eeej                        d
eej                     deej                     dee   dee   dee   dee   deej                     deeej                  f   deeef   fd              Z xZS )
Ovis2Modelr>   c                 |   t         |   |       t        |j                        | _        t        |j                  j                  |j                        | _        |j                  j                  | _	        |j                  | _        |j                  | _
        t        j                  |j                        | _        | `y r@   )rA   rB   r   vision_configvision_towerr   r   rC   visual_embeddings_tablevisual_vocab_sizevisual_indicator_token_idsr   from_configtext_configlanguage_modelmulti_modal_projectorrF   s     r+   rB   zOvis2Model.__init__   s     ,V-A-AB'@AUAUA`A`bhbtbt'u$!'!5!5!@!@ ++*0*K*K''33F4F4FG&r-   rN   rO   c                 4   | j                  |      }|j                  \  }}}t        j                  ||| j                   j                  f|j
                  |j                  d|j                        }t        j                  ||gd      }| j                  |      }t        j                  | j                  | j                   j                  z
  | j                  t        j                        j                  |j                        }| j                  |      }||fS )NF)rR   devicerequires_gradlayoutr   r   rQ   )r   r   r"   zerosr   rR   r   r   catr   aranger   r   rU   )	rG   rN   image_features
batch_sizeimg_seq_lenrm   padding_tensorvisual_indicatorvisual_indicator_featuress	            r+   get_image_featureszOvis2Model.get_image_features   s     **<8%3%9%9"
Kd&7&7&S&ST &&!((!((
 NN#CK55nE <<""T%6%6%R%RR""**
 "^""
#	 	
 %)$@$@AQ$R!888r-   	input_idsrs   rY   r   rw   labels	use_cachert   ru   rx   cache_positionlogits_to_keepc                    |	|	n| j                   j                  }	|
|
n| j                   j                  }
|d u |d uz  rt        d      | | j	                         |      }| | j                  |      \  }}| j                  |||      }|j                  ||      }t        | j                        D ]  \  }}|Y| | j	                         t        j                  |t        j                  |j                              k(  }|j                  d      }n||k(  j                  |j                        }|j!                         s||   j#                  ||         j                  |j                  |j$                        ||<     | j&                  d	||||||	|
d||d
|}t)        |j*                  |j,                  |j.                  |j0                  |      S d       S )
Nz:You must specify exactly one of input_ids or inputs_embedsrN   )rw   r   )rR   r   r   T)
rs   rY   r   rw   r   rt   ru   rx   r   r   )ry   r   rz   r{   image_hidden_statesr4   )r>   rt   ru   r   get_input_embeddingsr   get_placeholder_maskmasked_scatter	enumerater   r"   tensorr   r   allrU   any	expand_asrR   r   r/   ry   r   rz   r{   )rG   r   rN   rs   rY   r   rw   r   r   rt   ru   rx   r   r   kwargsr   r   special_image_maskivisual_indicator_idmaskr   s                         r+   r]   zOvis2Model.forward   s"   & 2C1N-TXT_T_TqTq$8$D $++JjJj 	 -t";<YZZ 7D557	BM#8<8O8O]i8O8j5N5!%!:!:+- "; "
 *889K^\M*3D4S4S*T &&$(,GD,E,E,G%8

S`SgSgh- D  88B<D%)<<@@AUAUVD88:1!4"=#67M00-2E2EF "$'  &$%% 
)%+'/!5))
 
 (%77#33!//))2>2J
 	

 QU
 	
r-   NNNNNNNNNNNNr   )r1   r2   r3   _checkpoint_conversion_mappingr   rB   r"   r^   r   r   r
   
LongTensorr   r_   listr}   r   r   r   r/   r]   r`   ra   s   @r+   r   r      s   %'"	'{ 	'9''9 
		92  '+*.1537=A59-1$(,0/3&*5934J
##J
 ''J
 !.	J

 u//0J
 "$u'8'8"9:J
   1 12J
 ))*J
 D>J
 $D>J
 'tnJ
 d^J
 !!1!12J
 c5<</0J
  
u..	/!J
  J
r-   r   c            !           e Zd Zi Zdef fdZed        Zdej                  fdZ
ee	 	 	 	 	 	 	 	 	 	 	 	 	 ddej                  dej                  deej                     deej                     d	eeej                        d
eej                     deej                     dee   dee   dee   dee   deej                     deeej                  f   deeef   fd              Z xZS )Ovis2ForConditionalGenerationr>   c                     t         |   |       t        j                  |j                  |j
                  d      | _        y r   )rA   rB   r   r   rC   r   lm_headrF   s     r+   rB   z&Ovis2ForConditionalGeneration.__init__M  s0     yy!3!3V5F5FUSr-   c                     t        d      rJ   )AttributeErrorrL   s    r+   r   z3Ovis2ForConditionalGeneration.multi_modal_projectorQ  s    344r-   rN   c                 :    | j                   j                  |      S )Nr   )r   r   )rG   rN   s     r+   r   z0Ovis2ForConditionalGeneration.get_image_featuresU  s    zz,,,,GGr-   r   rs   rY   r   rw   r   r   rt   ru   rx   r   r   rO   c                    |	|	n| j                   j                  }	|
|
n| j                   j                  }
 | j                  d||||||||	|
d|d|}|d   }t	        |t
              rt        | d      n|}| j                  |dd|ddf         }d}|4 | j                  d||| j                   j                  j                  d|}t        |||j                  |j                  |j                  |j                        S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration

        >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
        >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")

        >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
        >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
        "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
        ```NT)r   rN   rs   rY   r   rw   r   rt   ru   rx   r   r   )r   r   r   )lossr   r   rz   r{   r   r4   )r>   rt   ru   r   
isinstancer   slicer   loss_functionr   r   r7   r   rz   r{   r   )rG   r   rN   rs   rY   r   rw   r   r   rt   ru   rx   r   r   r   r   rz   slice_indicesr   r   s                       r+   r]   z%Ovis2ForConditionalGeneration.forwardX  s7   \ 2C1N-TXT_T_TqTq$8$D $++JjJj 	 $** 
%)%+'/!5)
 
  
8B>SV8W~ot4]kmA}a,?@A%4%% f9P9P9[9[_eD +#33!//)) ' ; ;
 	
r-   r   )r1   r2   r3   r   r   rB   propertyr   r"   r^   r   r   r
   r   r   r_   r   r}   r   r   r   r7   r]   r`   ra   s   @r+   r   r   I  s   %'"T{ T 5 5Hu/@/@ H  '+*.1537=A59-1$(,0/3&*5934R
##R
 ''R
 !.	R

 u//0R
 "$u'8'8"9:R
   1 12R
 ))*R
 D>R
 $D>R
 'tnR
 d^R
 !!1!12R
 c5<</0R
  
u11	2!R
  R
r-   r   )r   r   r   )7r   typingr   r   r"   r   
generationr   modeling_outputsr   modeling_utilsr	   utilsr
   r   aimv2.modeling_aimv2r   r   autor   llama.modeling_llamar   r   llava.modeling_llavar   r   llava_next.modeling_llava_nextr   r   siglip.modeling_siglipr   r   configuration_ovis2r   r   r_   r   r,   r/   r7   r9   r;   r=   rc   re   rg   Modulero   	Embeddingr   r   r   r   r   __all__r4   r-   r+   <module>r     s%     "   ) / - 5 D  9 L j J ? C 	; 		"A 		< 		X 	2 &	> 		/ 	p p'
RYY '
T8 8'? '2+ 2js
 s
l b
$A? b
 b
J Rr-   