
    h3                     J   d Z ddlZddlmZmZ ddlZddlmZ ddlmZ ddl	m
Z
 ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ ddl m!Z!  ejD                  e#      Z$ G d dejJ                        Z& G d dejJ                        Z' G d dejP                        Z) G d dejP                        Z* G d de      Z+e G d de             Z, G d de,      Z- ed !       G d" d#e,             Z. ed$!       G d% d&e,e             Z/d&dgZ0y)'z/PyTorch TrOCR decoder model (based on RoBERTa).    N)OptionalUnion)nn)CrossEntropyLoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging)deprecate_kwarg   )TrOCRConfigc                   n     e Zd ZdZdedef fdZd	dej                  dedej                  f fdZ xZ	S )
TrOCRLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                 N    d| _         t        | 	  || j                   z   |       y )N   )offsetsuper__init__)selfr   r   	__class__s      f/var/www/html/eduruby.in/venv/lib/python3.12/site-packages/transformers/models/trocr/modeling_trocr.pyr    z(TrOCRLearnedPositionalEmbedding.__init__0   s$     $++5}E    	input_idspast_key_values_lengthposition_idsc                 $   |a|j                   dd \  }}t        j                  |||z   t        j                  | j                  j
                        j                  |d      }n|j                  d      }t        | %  || j                  z         S )z3`input_ids' shape is expected to be [bsz x seqlen].Nr   )dtypedevicer   )shapetorcharangelongweightr*   expand	unsqueezer   forwardr   )r!   r%   r&   r'   bszseq_lenr"   s         r#   r3   z'TrOCRLearnedPositionalEmbedding.forward6   s     $??2A.LC <<&(>(HPUPZPZcgcncncucufS"o  (11!4Lw|dkk9::r$   )r   N)
__name__
__module____qualname____doc__intr    r-   Tensorr3   __classcell__r"   s   @r#   r   r   +   sH    Fs F3 F; ;s ;^c^j^j ; ;r$   r   c            
       `     e Zd ZdZd	dedededee   f fdZdej                  f fdZ
 xZS )
TrOCRScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    r   r   padding_idxembed_scalec                 6    t         |   |||       || _        y N)r   r    rA   )r!   r   r   r@   rA   r"   s        r#   r    z!TrOCRScaledWordEmbedding.__init__J   s    D&r$   r%   c                 <    t         |   |      | j                  z  S rC   )r   r3   rA   )r!   r%   r"   s     r#   r3   z TrOCRScaledWordEmbedding.forwardN   s    wy)D,<,<<<r$   )      ?)r6   r7   r8   r9   r:   r   floatr    r-   r;   r3   r<   r=   s   @r#   r?   r?   E   sE    's '3 'S '_ghm_n '= = =r$   r?   c            	            e Zd ZdZddededee   f fdZeddededee   fd       Z e	j                         dde	j                  d	efd
       Z	 dde	j                  ded	ee   fdZ xZS )"TrOCRSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.num_positionsr   r@   c                     t         |           d| _        || _        || _        | j                  |||      | _        | j                  dt        j                  d             y )Nr   _float_tensorr   )
r   r    r   r   r@   get_embeddingweightsregister_bufferr-   FloatTensor)r!   rI   r   r@   r"   s       r#   r    z+TrOCRSinusoidalPositionalEmbedding.__init__U   sV    *&))-T_e.?.?.BCr$   r   c                    |dz  }t        j                  d      |dz
  z  }t        j                  t        j                  |t        j
                        j                         | z        }t        j                  | t        j
                        j                         j                  d      |j                  d      z  }t        j                  t        j                  |      t        j                  |      gd      j                  | d      }|dz  dk(  r-t        j                  |t        j                  | d      gd      }|	d||ddf<   |j                  t        j                               S )	z
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        r   i'  r   )r)   r   dimr+   N)mathlogr-   expr.   int64rF   r2   catsincosviewzerostoget_default_dtype)r   r   r@   half_dimembs        r#   rL   z0TrOCRSinusoidalPositionalEmbedding.get_embedding]   s    !A%hhuoA.iiXU[[AGGISDPQll>=CCEOOPQRUXUbUbcdUeeii338a@EEnVXY1!))S%++na"@AqIC""#CQvve--/00r$   r%   r&   c                 P   |j                         \  }}| j                  || j                  |      j                  |j                        }| j                  dz   |z   }| j
                  || j
                  j                  d      kD  r,| j                  || j                  | j                        | _        | j
                  j                  | j                        | _        | j
                  j                  d|j                  d            j                  ||d      j                         }|S )Nr   r   r+   )size"create_position_ids_from_input_idsr@   r\   r*   rM   rL   r   rK   index_selectrZ   detach)r!   r%   r&   r4   r5   r'   max_posxs           r#   r3   z*TrOCRSinusoidalPositionalEmbedding.forwardp   s     ~~'W>>y$JZJZ\rsvv

 ""Q&0<<7T\\->->q-A#A--gt7I7I4K[K[\DL||t'9'9:LL%%a):):2)>?DDS'SUV]]_r$   c                     |j                  |      j                         }t        j                  |d      j	                  |      |z   |z  }|j                         |z   S )z
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.
        r   rQ   )ner:   r-   cumsumtype_asr/   )r!   r%   r@   r&   maskincremental_indicess         r#   rb   zETrOCRSinusoidalPositionalEmbedding.create_position_ids_from_input_ids   sW     ||K(,,.$||Da8@@FI__cgg"'')K77r$   rC   )r   )r6   r7   r8   r9   r:   r   r    staticmethodrL   r-   no_gradr;   r3   rb   r<   r=   s   @r#   rH   rH   R   s    NDc D# DHUXM D 1c 1# 1HUXM 1 1$ U]]_ s  & bc
8
847
8QYZ]Q^
8r$   rH   c                       e Zd ZdZ	 	 	 	 	 	 	 ddededee   dee   dee   dee   dee   d	ee   d
ee   f fdZ e	ddd      	 	 	 	 	 	 dde
j                  dee
j                     dee   dee
j                     dee
j                     dee   dee
j                     dee
j                  ee
j                     eee
j                        f   fd       Z xZS )TrOCRAttentionz>Multi-headed attention from 'Attention Is All You Need' paper.	embed_dim	num_headskdimvdimdropout
is_decoderbiasis_cross_attention	layer_idxc                 P   t         |           || _        ||n|| _        ||n|| _        || _        || _        ||z  | _        | j                  |z  | j                  k(  st        d| j                   d| d      | j                  dz  | _	        || _
        |
| _        t        j                  | j                  ||      | _        t        j                  | j                  ||      | _        t        j                  |||      | _        t        j                  |||      | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rw   )r   r    rq   rs   rt   rr   ru   head_dim
ValueErrorscalingrv   ry   r   Lineark_projv_projq_projout_proj)r!   configrq   rr   rs   rt   ru   rv   rw   rx   ry   r"   s              r#   r    zTrOCRAttention.__init__   s    	" ,D)	 ,D)	"!Y.	)T^^;MdnnM] ^;b"  }}d*$"ii		94@ii		94@ii	94@		)YTBr$   past_key_valuepast_key_values4.58new_nameversionhidden_stateskey_value_statesattention_masklayer_head_maskoutput_attentionscache_positionreturnc                 
   |du}|j                         \  }	}
}| j                  |      | j                  z  }|St        |t              rA|j
                  j                  | j                        }|r|j                  }n|j                  }n|}|r|n|}|rK|IrGj                  | j                     j                  }|j                  | j                     j                  }n| j                  |      }| j                  |      }|j                  |	d| j                   | j"                        j%                  dd      }|j                  |	d| j                   | j"                        j%                  dd      }|D|s|nd}j'                  ||| j                  d|i      \  }}|rd|j
                  | j                  <   |	| j                   z  d| j"                  f}|j                  |	|
| j                   | j"                        j%                  dd      } |j(                  | } |j(                  | } |j(                  | }|j                  d      }t+        j,                  ||j%                  dd            }|j                         |	| j                   z  |
|fk7  r/t/        d|	| j                   z  |
|f d|j                                |{|j                         |	d|
|fk7  r#t/        d	|	d|
|f d|j                                |j                  |	| j                   |
|      |z   }|j                  |	| j                   z  |
|      }t0        j2                  j5                  |d
      }||j                         | j                   fk7  r*t/        d| j                   f d|j                                |j                  dddd      |j                  |	| j                   |
|      z  }|j                  |	| j                   z  |
|      }|r?|j                  |	| j                   |
|      }|j                  |	| j                   z  |
|      }nd}t0        j2                  j7                  || j6                  | j8                        }t+        j,                  ||      }|j                         |	| j                   z  |
| j"                  fk7  r7t/        d|	| j                   |
| j"                  f d|j                                |j                  |	| j                   |
| j"                        }|j%                  dd      }|j)                  |	|
|      }| j;                  |      }||fS )z#Input shape: Batch x Time x ChannelNr+   r   r   r   Tz$Attention weights should be of size z	, but is z!Attention mask should be of size rQ   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )ra   r   r~   
isinstancer   
is_updatedgetry   cross_attention_cacheself_attention_cachelayerskeysvaluesr   r   rZ   rr   r|   	transposeupdatereshaper-   bmmr}   r   
functionalsoftmaxru   r   r   )r!   r   r   r   r   r   r   r   rx   r4   tgt_lenrq   query_statesr   curr_past_key_valuecurrent_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputs                           r#   r3   zTrOCRAttention.forward   s7    .T9"/"4"4"6Wi {{=1DLL@&/+>?,77;;DNNK
%*9*O*O'*9*N*N'&5#-?)]/"=*,33DNNCHHJ.55dnnELLL^4J;;~6L#b$..$--PZZ[\^_`J',,S"dnndmmT^^_`bcdL*7It+>+E+Ednn?OQ_>`,(
L &AEO..t~~>DNN*B>
#((gt~~t}}U__`acde+|++Z8'Z''4
+|++Z8//!$yyz/C/CAq/IJ3#7'"JJ6dnn8LgW^7_6` a %%'(* 
 %""$a'(BB 7a'8R7SS\]k]p]p]r\st  (,,S$..'7SVddL',,S4>>-A7GTL}},,\r,B&##%$..):: Et~~FWEX Y',,./1  +//2q!<|?P?PQTVZVdVdfmov?wwL',,S4>>-A7GTL
 %1$5$5c4>>7T[$\!055cDNN6JGU\]L$(!]]**<4<<RVR_R_*`
ii
L9#"6!OO2CRVR_R_3`2a b$$&') 
 "&&sDNNGT]]S!++Aq1!))#w	BmmK0111r$   )NN        FTFN)NNNNFN)r6   r7   r8   r9   r:   r   rF   boolr    r   r-   r;   r	   tupler3   r<   r=   s   @r#   rp   rp      s   H #"#&%*#-2$(!C !C 	!C
 sm!C sm!C %!C TN!C tn!C %TN!C D>!CF %0A6R 48+/1526,115p2||p2 #5<<0p2 "%	p2
 !.p2 "%,,/p2 $D>p2 !.p2 
u||Xell3XeELL>Q5RR	Sp2 Sp2r$   rp   c                   H    e Zd Zddef fdZ eddd      	 	 	 	 	 	 	 	 	 ddej                  deej                     d	eej                     d
eej                     deej                     deej                     dee	   dee
   dee
   deej                     fd       Z xZS )TrOCRDecoderLayerr   c                 b   t         |           |j                  | _        t	        || j                  |j
                  |j                  d|      | _        |j                  | _        t        |j                     | _        |j                  | _        t        j                  | j                        | _        |j                   rot	        || j                  |j
                  |j"                  |j"                  |j                  dd|	      | _        t        j                  | j                        | _        t        j(                  | j                  |j*                        | _        t        j(                  |j*                  | j                        | _        t        j                  | j                        | _        y )NT)rq   rr   ru   rv   ry   )rq   rr   rs   rt   ru   rv   rx   ry   )r   r    hidden_sizerq   rp   decoder_attention_headsattention_dropout	self_attnru   r   activation_functionactivation_fnactivation_dropoutr   	LayerNormself_attn_layer_normrv   cross_attention_hidden_sizeencoder_attnencoder_attn_layer_normr   decoder_ffn_dimfc1fc2final_layer_norm)r!   r   ry   r"   s      r#   r    zTrOCRDecoderLayer.__init__+  s=   ++'nn44,,
 ~~#F$>$>?"(";";$&LL$@! ... 88777700#'#
!D ,.<<+GD(99T^^V-C-CD99V33T^^D "T^^ <r$   r   r   r   r   r   r   encoder_hidden_statesencoder_attention_maskr   cross_attn_layer_head_maskr   	use_cacher   c           	      2   |}| j                  ||||||
      \  }}t        j                  j                  || j                  | j                        }||z   }| j                  |      }d}|i|}| j                  |||||||
      \  }}t        j                  j                  || j                  | j                        }||z   }| j                  |      }|}| j                  | j                  |            }t        j                  j                  || j                  | j                        }| j                  |      }t        j                  j                  || j                  | j                        }||z   }| j                  |      }|f}|r|||fz  }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size *(decoder_attention_heads,)*.
            past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   r   r   r   N)r   r   r   r   r   r   r   )r   r   r   ru   r   r   r   r   r   r   r   r   r   )r!   r   r   r   r   r   r   r   r   r   r   residualself_attn_weightscross_attn_weightsoutputss                  r#   r3   zTrOCRDecoderLayer.forwardO  s   @ ! ,0>>'+)+/) ,: ,
(( --mt||VZVcVc-d =011-@ " ,$H040A0A+!65 : /"3- 1B 1-M- MM11-4<<Z^ZgZg1hM$}4M 88GM !**488M+BC--mt?V?Vaeanan-o/--mt||VZVcVc-d =0--m< ")+=>>Gr$   rC   )	NNNNNNFTN)r6   r7   r8   r   r    r   r-   r;   r   r	   r   r3   r<   r=   s   @r#   r   r   *  s   "={ "=H %0A6R 268<9=26=A+/,1$(15Q||Q !.Q  (5	Q
 !) 6Q "%,,/Q %-U\\$:Q "%Q $D>Q D>Q !.Q SQr$   r   c                   ,    e Zd ZU eed<   dZdZdgZd Zy)TrOCRPreTrainedModelr   modelTr   c                 6   | j                   j                  }t        |t        j                  t        j
                  f      rY|j                  j                  j                  d|       |j                  %|j                  j                  j                          y y t        |t        j                        rf|j                  j                  j                  d|       |j                  2|j                  j                  |j                     j                          y y y )Nr   )meanstd)r   init_stdr   r   r   Conv1dr0   datanormal_rw   zero_	Embeddingr@   )r!   moduler   s      r#   _init_weightsz"TrOCRPreTrainedModel._init_weights  s    kk""fryy"))45MM&&CS&9{{&  &&( '-MM&&CS&9!!-""6#5#56<<> . .r$   N)	r6   r7   r8   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modulesr    r$   r#   r   r     s"    &*#,-	?r$   r   c                   J     e Zd ZdZdef fdZ	 	 	 	 	 	 	 	 	 	 	 	 	 ddZ xZS )TrOCRDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]

    Args:
        config: TrOCRConfig
    r   c           	      `   t         |   |       |j                  | _        |j                  | _        |j
                  | _        |j                  rt        j                  |j                        nd}t        |j                  |j                  | j                  |      | _        |j                  r&t        |j                   |j                        | _        n@t%        |j                   | j                  z   dz   |j                  | j                        | _        |j&                  r%t)        j*                  |j                        | _        nd | _        t)        j,                  t/        |j0                        D cg c]  }t3        ||       c}      | _        d| _        | j9                          y c c}w )NrE   )rA   r   )ry   F)r   r    ru   decoder_layerdrop	layerdroppad_token_idr@   scale_embeddingrS   sqrtr   r?   
vocab_sizeembed_tokensuse_learned_position_embeddingsr   max_position_embeddingsembed_positionsrH   layernorm_embeddingr   r   
ModuleListrangedecoder_layersr   r   gradient_checkpointing	post_init)r!   r   rA   ir"   s       r#   r    zTrOCRDecoder.__init__  sP    ~~11!..7=7M7Mdii 2 23SV4v1143C3CQ\
 11#B6CaCacicucu#vD #E..1A1AAAE""  $D  %%')||F4F4F'GD$'+D$mmUZ[a[p[pUq$rPQ%6v%K$rs&+#	 %ss   3F+c                    |
|
n| j                   j                  }
||n| j                   j                  }|	|	n| j                   j                  }	||n| j                   j                  }||t        d      |"|}|j                  d|j                  d         }n-| |j                         dd }|dddddf   }nt        d      | j                  r%| j                  r|	rt        j                  d       d}	|	rN|L|4t        t        | j                         t        | j                               nt        | j                         }|	r:t        |t               r*t        j                  d       t        j"                  |      }||j%                         nd	}|| j'                  |      }| j                   j(                  r| j+                  ||
      }n| j+                  ||
      }||z   }| j,                  | j-                  |      }t.        j0                  j3                  || j2                  | j                        }|j                  }t5        ||||      }||t7        ||j8                  |d         }|rdnd}|
rdnd}|
r|dnd}t;        ||gddg      D ]j  \  }}|	|j                         d	   t=        | j>                        k7  s3t        d| dt=        | j>                         d|j                         d	    d       tA        | j>                        D ]{  \  }}|r||fz  }| j                  r%tC        jD                  g       }|| jF                  k  r? ||||||||   nd|||   nd||
|	|
      }|d	   }|
sg||d   fz  }|s||d   fz  }} |r||fz  }|st!        d |||||fD              S tI        |||||      S )a  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer+   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz^`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...F)r   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )r&   r   )r   r   	head_maskcross_attn_head_maskzThe `z` should be specified for z layers, but it is for .)r   r   r   r   r   r   r   r   r   c              3   $   K   | ]  }|| 
 y wrC   r   ).0vs     r#   	<genexpr>z'TrOCRDecoder.forward.<locals>.<genexpr>  s      = s   )last_hidden_stater   r   
attentionscross_attentions)%r   r   output_hidden_statesr   use_return_dictr}   rZ   r,   ra   r   r   loggerwarning_oncer   r
   r   r   from_legacy_cacheget_seq_lengthr   r   r   r   r   r   ru   r   r   r)   ziplenr   	enumerater-   randr   r   )r!   r%   r   r   r   r   r   r   inputs_embedsr   r   r  return_dictr   inputinput_shaper&   	embed_posr   all_hidden_statesall_self_attnsall_cross_attentions	attn_mask	mask_nameidxdecoder_layerdropout_probabilitylayer_outputss                               r#   r3   zTrOCRDecoder.forward  sX   b 2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	%0%<k$++B]B]  ]%>stt"E!r5;;r?;I&',,.s3K!!Q(+Edee&&4==##t "	0 )4 $L$DlZ^ZeZeFfg!5 
 OU;\
 2CCOTOETE`!?!?!Afg  --i8M;;66,,UKa,bI,,YOe,fI%	1##/ 44]CM--mt||VZVcVc-dkk:K8N

 !,1G1S%?&(;(;[QS_&"
 #7BD0d&7<Q<]rdh %(4H(IKYoKp$q 	 Iy$>>#A&3t{{+;<$	{*DSEUDV W%NN,Q/03 	 #,DKK"8 	@C#!m%55!}}&+jjn#&7)%'=3<3H3dI]Ii,@,Eos /"3#-M *!,M =#3"55(4(]1-=,??(7	@<  -!11 ':K^]qr  
 9+++%1
 	
r$   )NNNNNNNNNNNNN)r6   r7   r8   r9   r   r    r3   r<   r=   s   @r#   r   r     sD    { B "#!!Q
r$   r   a  
    The TrOCR Model with a language modeling head. Can be used for summarization.
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    )custom_introc                   $     e Zd Z fdZd Z xZS )TrOCRDecoderWrapperc                 D    t         |   |       t        |      | _        y rC   )r   r    r   decoderr!   r   r"   s     r#   r    zTrOCRDecoderWrapper.__init__  s     #F+r$   c                 &     | j                   |i |S rC   )r  )r!   argskwargss      r#   r3   zTrOCRDecoderWrapper.forward  s    t||T,V,,r$   )r6   r7   r8   r    r3   r<   r=   s   @r#   r  r    s    ,-r$   r  zy
    The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
    c            "           e Zd ZdgZ fdZd Zd Zd Zd Zd Z	d Z
e	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd	eej                     d
eej                     deej                      deej                     deej                     deej                     deeeej                            deej                      deej                     dee   dee   dee   dee   deej                     deeef   fd       Z xZS )TrOCRForCausalLMzoutput_projection.weightc                     d|_         d|_        t        |   |       t	        |      | _        t        j                  |j                  |j                  d      | _
        | j                          y )NTFr{   )rv   is_encoder_decoderr   r    r  r   r   r   r   r   output_projectionr   r  s     r#   r    zTrOCRForCausalLM.__init__  sZ     $)! (0
!#6+=+=v?P?PW\!] 	r$   c                 B    | j                   j                  j                  S rC   r   r  r   r!   s    r#   get_input_embeddingsz%TrOCRForCausalLM.get_input_embeddings  s    zz!!...r$   c                 :    || j                   j                  _        y rC   r(  )r!   values     r#   set_input_embeddingsz%TrOCRForCausalLM.set_input_embeddings  s    */

'r$   c                     | j                   S rC   r&  r)  s    r#   get_output_embeddingsz&TrOCRForCausalLM.get_output_embeddings  s    %%%r$   c                     || _         y rC   r/  )r!   new_embeddingss     r#   set_output_embeddingsz&TrOCRForCausalLM.set_output_embeddings  s
    !/r$   c                 &    || j                   _        y rC   r   r  )r!   r  s     r#   set_decoderzTrOCRForCausalLM.set_decoder  s    $

r$   c                 .    | j                   j                  S rC   r5  r)  s    r#   get_decoderzTrOCRForCausalLM.get_decoder  s    zz!!!r$   r%   r   r   r   r   r   r   r  labelsr   r   r  r  r   r   c                 F   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }| j                  j                  |||||||||
||||      }| j                  |d         }d}|	Ft               } ||j                  d| j                   j                        |	j                  d            }|s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                  |j                        S )a
  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import (
        ...     TrOCRConfig,
        ...     TrOCRProcessor,
        ...     TrOCRForCausalLM,
        ...     ViTConfig,
        ...     ViTModel,
        ...     VisionEncoderDecoderModel,
        ... )
        >>> import requests
        >>> from PIL import Image

        >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
        >>> # init vision2text model with random weights
        >>> encoder = ViTModel(ViTConfig())
        >>> decoder = TrOCRForCausalLM(TrOCRConfig())
        >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

        >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
        >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
        >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

        >>> # load image from the IAM dataset
        >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
        >>> pixel_values = processor(image, return_tensors="pt").pixel_values
        >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"

        >>> # training
        >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
        >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
        >>> model.config.vocab_size = model.config.decoder.vocab_size

        >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
        >>> outputs = model(pixel_values, labels=labels)
        >>> loss = outputs.loss
        >>> round(loss.item(), 2)
        5.30

        >>> # inference
        >>> generated_ids = model.generate(pixel_values)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> generated_text
        'industry, " Mr. Brown commented icily. " Let us have a'
        ```N)r%   r   r   r   r   r   r   r  r   r   r  r  r   r   r+   r   )losslogitsr   r   r   r   )r   r   r  r  r   r  r&  r   rZ   r   r   r   r   r   r   )r!   r%   r   r   r   r   r   r   r  r9  r   r   r  r  r   r   r<  r;  loss_fctoutputs                       r#   r3   zTrOCRForCausalLM.forward  sK   Z 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B] **$$)"7#9!5+'/!5#) % 
  ''
3')HFKKDKK,B,BCV[[QS_UDY,F'+'7D7V#CVC0#33!//))$55
 	
r$   )NNNNNNNNNNNNNN)r6   r7   r8   _tied_weights_keysr    r*  r-  r0  r3  r6  r8  r   r   r-   
LongTensorr;   rO   r   r   r   r   r3   r<   r=   s   @r#   r#  r#    s    55	/0&0%"  1515=A=A,07;EI59-1$(,0/3&*15u
E,,-u
 !.u
  ((9(9:	u

 !))9)9 :u
 ELL)u
 'u||4u
 "%e.?.?(@"ABu
   1 12u
 ))*u
 D>u
 $D>u
 'tnu
 d^u
 !.u
  
u77	8!u
 u
r$   r#  )1r9   rS   typingr   r   r-   r   torch.nnr   activationsr   cache_utilsr	   r
   r   
generationr   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   utils.deprecationr   configuration_trocrr   
get_loggerr6   r  r   r   r?   ModulerH   rp   r   r   r   r  r#  __all__r   r$   r#   <module>rP     s0   6  "   % ! C C ) : l - , 0 , 
		H	%;bll ;4
=r|| 
=;8 ;8|W2RYY W2tw2 wt ?? ? ?$x
' x
v -. -- 
V
+_ V

V
r 5
6r$   