from langchain.text_splitter import TokenTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import PyPDFLoader
from PyPDF2 import PdfReader
import pandas as pd
import openai
import os
import json
from datetime import datetime
import argparse
from dotenv import load_dotenv
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Step 1: Define the Prompt Template
mcq_prompt_template = """
You are an expert at creating multiple-choice questions (MCQs) with correct answers and detailed explanations.
Your goal is to prepare students for their exams by creating MCQs from the given material.
------------
{text}
------------
For each question, follow this format:
  Question: Write a clear and concise question.
  \n\n
  Options: Provide four options (labeled A, B, C, and D) with one correct answer.
  \n\n
  Correct Answer: Clearly specify which option is correct.
  \n\n
  Explanation: Provide a detailed explanation of the correct answer to help the student understand the concept.
  \n\n
MCQs:
"""

prompt = PromptTemplate(
    input_variables=["text"],
    template=mcq_prompt_template
)

# Step 2: Initialize the LLMChain
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
mcq_chain = LLMChain(llm=llm, prompt=prompt)

# Step 3: Function to Split Text Using `TokenTextSplitter`
def chunk_text_with_token_splitter(file_path, chunk_size=1500, chunk_overlap=100):
    """
    Splits the text into chunks based on token count using TokenTextSplitter.

    Args:
        text (str): The input text to be split.
        chunk_size (int): Maximum number of tokens per chunk.
        chunk_overlap (int): Number of tokens to overlap between chunks.

    Returns:
        list: List of text chunks.
    """
    file_type=file_path.split('.')[-1]
    if file_type=='pdf':
        loader = PyPDFLoader(file_path)
        data = loader.load()
        question_gen=""
        for page in data:
          question_gen += page.page_content
    elif file_type =="json":
      with open(file_path, 'r') as file:
        data = json.load(file)
      question_gen = " ".join(data)
    # Concatenate page content from all documents
    splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return splitter.split_text(question_gen)

# Step 4: Process Each Chunk and Generate MCQs


def mcq_main(input_text, output_file):
    """
    Splits large text into chunks, generates MCQs for each chunk,
    and saves the results into a JSON file.
    """
    # Chunk the text using TokenTextSplitter
    text_chunks = chunk_text_with_token_splitter(input_text, chunk_size=1500, chunk_overlap=200)
    print(f"Total Chunks Created: {len(text_chunks)}")

    mcqs = []  # To store generated MCQs
    for idx, chunk in enumerate(text_chunks):
        print(f"Processing Chunk {idx + 1}/{len(text_chunks)}...")
        try:
            # Generate MCQs for the current chunk
            response = mcq_chain.run({"text": chunk})
            # Split the response into individual MCQs using a delimiter (e.g., "\n\n")
            mcq_list = response.split("\n\n")
            count=mcq_list.count('---')
            for _ in range(count):
              mcq_list.remove("---")
            n = len(mcq_list)
            print(mcq_list)
            for i in range(0, n, 4):
                temp = [
                    mcq_list[i],
                    mcq_list[i + 1],
                    mcq_list[i + 2],
                    mcq_list[i + 3]
                ]
                mcqs.extend(temp)
        except Exception as e:
            print(f"Error processing chunk {idx + 1}: {e}")

    # Save the MCQs to a JSON file
    #base_folder = f'Data/{user_id}/'  # Folder named after the user ID

    # Create the user-specific folder if it doesn't exist
    
    # Create a file name with current datetime
    #current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    #output_file = f"{base_folder}MCQ_{current_datetime}.json"

    # Write MCQs to a JSON file
    with open(output_file, "w", encoding="utf-8") as json_file:
        json.dump(mcqs, json_file, indent=4, ensure_ascii=False)

    print(f"MCQs saved to {output_file}")
    return output_file


if __name__ == "__main__":
    parser= argparse.ArgumentParser(description="StudyBuddy app for generating questions and answers")
    parser.add_argument(
      "--file_path",
      required=True,
      help="Paths of the files"
    )
    parser.add_argument(
      "--output_path",
      type=str,
      required=True,
      help="Output file path for storing the response"
    )

    args = parser.parse_args()
    file_path= args.file_path
    output_path= args.output_path
    mcq_main(file_path, output_path)