import cv2
import pytesseract
pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration
import matplotlib.pyplot as plt
from ultralytics import YOLO
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import io
import os
import json
from datetime import datetime
import argparse
from dotenv import load_dotenv
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
def correct_and_filter_extracted_text(extracted_text):
    """
    Corrects and reconstructs text extracted from images using OpenAI GPT API,
    filtering out any irrelevant or off-topic content.

    Args:
        extracted_text (str): The raw text extracted from an image.
        api_key (str): Your OpenAI API key.

    Returns:
        str: The corrected, refined, and filtered version of the input text.
    """
    #openai.api_key = "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"
    #client = openai.OpenAI(api_key=os.environ.get("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"))
    # Prompt for GPT API
    prompt_template ="""
      You are an assistant skilled in correcting and refining text.
      Here is the text extracted from an image:
      {extracted_text}
      Please do the following:
      1. Correct grammar, spelling, and structural errors.
      2. Format the text for clarity and readability, retaining relevant technical or scientific details.
      3. If any off-topic lines seem related to the context, analyze them and rephrase to make them coherent.
      4. Just return the refined text, no extra lines like (eg: Here is the refined text, this is your cleaned text,etc.)"
    """
    
    try:
        #client = OpenAI(
          #api_key=os.environ.get("OPENAI_API_KEY")
        #)
        prompt = ChatPromptTemplate.from_template(
          template=prompt_template
        )
        llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
        chain = prompt | llm
        corrected_text= chain.invoke({"extracted_text": extracted_text})
        return corrected_text

    except Exception as e:
        return f"An error occurred: {e}"

class Text_Detector():
  def __init__(self,model_path):
    self.model=YOLO(model_path)

  def get_img(self,img_path):
    img=cv2.imread(img_path)
    #img=cv2.resize(img,(1200,1200))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

  def predict(self,img):
    #img= self.get_img(img_path)
    results=self.model.predict(img)
    result = results[0]
    return result


  def generate(self,img):
    result=self.predict(img)
    output = []
    for box in result.boxes:
      x1, y1, x2, y2 = [
      round(x) for x in box.xyxy[0].tolist()
      ]
      class_id = box.cls[0].item()
      prob = round(box.conf[0].item(), 2)
      output.append([
      x1, y1, x2, y2, result.names[class_id], prob
      ])
    return output

class Text_Extractor(Text_Detector):
    def __init__(self, model_path):
        super().__init__(model_path)

        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")


    def enhance_img(self, img):
      gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
      invert_img = cv2.bitwise_not(gray_img)
      blur_img = cv2.GaussianBlur(invert_img, (111, 111), 0)
      invert_blur_img = cv2.bitwise_not(blur_img)
      trace_img = cv2.divide(gray_img, invert_blur_img, scale=256.0)
      return trace_img

    def ocr(self, img):
        #gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
        #custom_config = r'--psm 6'
        img= self.enhance_img(img)
        text = pytesseract.image_to_string(img)
        #print(text.strip())
        return text.strip()
    
    def analyze_figure_with_blip(self, cropped_img):
        pil_img = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))
        inputs = self.processor(pil_img, return_tensors="pt")
        outputs = self.blip_model.generate(**inputs)
        caption = self.processor.decode(outputs[0], skip_special_tokens=True)
        return caption
    
    def draw_bounding_boxes(self, img_path, save_path="/var/www/html/ai-buddy/cgi-bin/ai-buddy-model/content/"):
        """
        Draw bounding boxes around detected text regions.
        """
        img = self.get_img(img_path)
        detections = self.generate(img)

        for box in detections:
            x1, y1, x2, y2, label, prob = box
            # Draw rectangle
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
            # Add label
            cv2.putText(
                img,
                f"{label} ({prob})",
                (x1, y1 - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (255, 0, 0),
                2,
            )

        # Save or show the image with bounding boxes
        if save_path:
            cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        else:
            plt.imshow(img)
            plt.axis("off")
            plt.show()
    

    def extract_text(self, img):
      #img = self.get_img(img_path)
      boxes = self.generate(img)
      #print("Boxes found are: ",boxes)
      #self.draw_bounding_boxes("/content/gojou.jpg")
      text=self.ocr(img)
      if text=="":
        return None
      figures=[]
      figure_flag=False
      for box in boxes:
        x1, y1, x2, y2, label, prob = box
        if prob>=0.70:
          cropped = img[y1:y2, x1:x2]
          if label=="figure":
            figures.append(self.analyze_figure_with_blip(cropped))
            figure_flag= True
      if figure_flag:
        line= "\nThe figures found in the page are about:\n"+"\n".join(figures)
        text+=line
      return text


def process_image(image, extractor):
  if isinstance(image, str):  # If it's a file path
    file_type = image.split(".")[-1].lower()
    if file_type in ["png", "jpg", "jpeg"]:
      # Process with extract_text
      image= cv2.imread(image)
      image= cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      raw_text = extractor.extract_text(image)
      return correct_and_filter_extracted_text(raw_text)
  elif isinstance(image, np.ndarray):  # If it's a PIL Image object
    # Directly pass the image to extract_text
    image= cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    raw_text = extractor.extract_text(image)
    return correct_and_filter_extracted_text(raw_text)
  else:
    raise ValueError("Unsupported input type. Must be a file path or a PIL Image object.")        
        


def main(model_path, images, output_path):
    model = Text_Extractor(model_path)
    #base_folder = f'Data/{user_id}/'  # Folder named after the user ID

    

    # Create a file name with current datetime
    #current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    #output_file = f"{base_folder}text_extracted_{current_datetime}.json"
    results = []  # Initialize results as an empty array
    if isinstance(images, list):  # If multiple images are provided
        for image in images:
            extracted_text = process_image(image, model)
            results.append(extracted_text.content)
    elif isinstance(images, np.ndarray) or isinstance(images, str):  # If a single image is provided
        extracted_text = process_image(images, model)
        results.append(extracted_text.content)
    print(extracted_text.content)
    if output_path:
        # Write results array directly as JSON array
        with open(output_path, "w") as f:
            json.dump(results, f, indent=4)
        print(f"Results saved to {output_path}")
    else:
        print("Extracted Text Results:")
        for idx, text in enumerate(results):
            print(f"Image {idx + 1} extracted text:\n{text}")
    
 # Return the results array

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Text extraction from images using a YOLO-based model.")
    parser.add_argument("--model_path", required=True, type=str, help="Path to the model file (e.g., .pt file).")
    parser.add_argument(
        "--images",
        required=True,
        nargs="+",
        help="Paths to the image files (space-separated if multiple).",
    )

    parser.add_argument(
        "--output_file",
        type=str,
        required=True,
        help="Optional path to save the extracted text to a file.",
    )

    args = parser.parse_args()

    # Pass the arguments to the main function
    main(args.model_path, args.images,args.output_file)
