import { OptionsState } from "../types";

export function generateNotebook(
  inputState: any,
  optionsState: OptionsState,
  chunkingState: {
    strategy: string | null;
    maxTokens: number | undefined;
  },
  databaseState: any,
  embedding: string,
  runtime: string
) {
  let installBlock = `!pip install sycamore-ai[${databaseState.targetDatabase.toLowerCase()},local-inference]`;

  if (embedding === "text-embedding-3-small") {
    installBlock = `!pip install sycamore-ai[${databaseState.targetDatabase.toLowerCase()}]`;
  }

  installBlock += `\n# DocPrep code uses the Sycamore document ETL library: https://github.com/aryn-ai/sycamore `;

  let installBlock2 = `!apt-get install poppler-utils`;

  let apiKeyBlock = `# API Keys
from google.colab import userdata
# Set your secrets in the colab notebook. Navigate to the left pane
# and choose the key option to set your keys. Make sure to enable Notebook access
try:
  # Visit https://www.aryn.ai/get-started to get a key.
  os.environ["ARYN_API_KEY"] = userdata.get('ARYN_API_KEY')`;

  let localApiKeyBlock = `# It's best to store API keys in a configuration file or set them as environment variables.  
# For quick testing, you can define them here:
#
# os.environ["ARYN_API_KEY"] = "YOUR_ARYN_API_KEY"`;

  let importsBlock = `import pyarrow.fs
import sycamore
import json
import os`;

  if (databaseState.targetDatabase === "DuckDB") {
    importsBlock += `\nimport duckdb`;
  }
  if (databaseState.targetDatabase === "OpenSearch") {
    importsBlock += `\nfrom opensearchpy import OpenSearch`;
    apiKeyBlock += `\n  os.environ["OS_USER_NAME"] = userdata.get('OS_USER_NAME')
  os.environ["OS_PASSWORD"] = userdata.get('OS_PASSWORD')`;
    localApiKeyBlock += `\n# os.environ["OS_USER_NAME"] = "YOUR_OPENSEARCH_USER_NAME"
# os.environ["OS_PASSWORD"] = "YOUR_OPENSEARCH_PASSWORD"`;
  }
  if (databaseState.targetDatabase === "Pinecone") {
    importsBlock += `\nfrom pinecone import Pinecone`;
    apiKeyBlock += `\n  os.environ["PINECONE_API_KEY"] = userdata.get('PINECONE_API_KEY')`;
    localApiKeyBlock += `\n# os.environ["PINECONE_API_KEY"] = "YOUR_PINECONE_API_KEY"`;
  }
  if (databaseState.targetDatabase === "Qdrant") {
    importsBlock += `\nfrom qdrant_client import QdrantClient`;
    apiKeyBlock += `\n  os.environ["QDRANT_API_KEY"] = userdata.get('QDRANT_API_KEY')`;
    localApiKeyBlock += `\n# os.environ["QDRANT_API_KEY"] = "YOUR_QDRANT_API_KEY"`;
  }
  if (databaseState.targetDatabase === "Weaviate") {
    importsBlock += `\nfrom weaviate.client import WeaviateClient`;
    apiKeyBlock += `\n  os.environ["WEAVIATE_API_KEY"] = userdata.get('WEAVIATE_API_KEY')`;
    localApiKeyBlock += `\n# os.environ["WEAVIATE_API_KEY"] = "YOUR_WEAVIEATE_API_KEY"`;
  }
  if (databaseState.targetDatabase === "Elasticsearch") {
    importsBlock += `\nfrom elasticsearch import Elasticsearch`;
    apiKeyBlock += `\n  os.environ["ELASTICSEARCH_API_KEY"] = userdata.get('ELASTICSEARCH_API_KEY')`;
    localApiKeyBlock += `\n# os.environ["ELASTICSEARCH_API_KEY"] = "YOUR_ELASTICSEARCH_API_KEY"`;
  }

  let modelName = "";
  let tokenizer = "";
  let maxTokens = 512;
  let dimensions = 384;
  let embedder = "";
  if (embedding === "MiniLM-L6-v2") {
    modelName = "sentence-transformers/all-MiniLM-L6-v2";
    tokenizer = "HuggingFaceTokenizer";
    maxTokens = 512;
    dimensions = 384;
    embedder = "SentenceTransformerEmbedder";
  }
  if (embedding === "text-embedding-3-small") {
    modelName = "text-embedding-3-small";
    tokenizer = "OpenAITokenizer";
    maxTokens = 8191;
    dimensions = 1536;
    embedder = "OpenAIEmbedder";
    apiKeyBlock += `\n  os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')`;
    localApiKeyBlock += `\n# os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"`;
  }

  importsBlock += `\nfrom sycamore.functions.tokenizer import ${tokenizer}
from sycamore.llms import OpenAIModels, OpenAI
from sycamore.transforms import COALESCE_WHITESPACE
from sycamore.transforms.merge_elements import ${chunkingState.strategy}
from sycamore.transforms.partition import ArynPartitioner
from sycamore.transforms.embed import ${embedder}
from sycamore.materialize_config import MaterializeSourceMode
from sycamore.utils.pdf_utils import show_pages
from sycamore.transforms.summarize_images import SummarizeImages
from sycamore.context import ExecMode`;

  if (databaseState.targetDatabase === "Weaviate") {
    importsBlock += `\nfrom weaviate.client import ConnectionParams
from weaviate.collections.classes.config import Configure
from weaviate.auth import AuthApiKey
from weaviate.auth import Auth`;
  }

  if (databaseState.targetDatabase === "Pinecone") {
    importsBlock += `\nfrom pinecone import ServerlessSpec`;
  }

  if (inputState.fileType !== "pdf") {
    installBlock2 += `\n!apt-get install -y libreoffice`;
    importsBlock += `\nfrom sycamore.utils.fileformat_tools import binary_representation_to_pdf`;
  }
  //  # Add OpenAI API key if you want to summarize the images in the document. Also, uncomment the .transform(SummarizeImages) line in a subsequent cell
  // # os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
  apiKeyBlock += `
except Exception as e:
  print("YOU ARE MISSING REQUIRED API KEYS FOR THIS PIPELINE. Add your API keys to the Secrets page (icon is a key) in the Colab left navigation panel. It is case sensitive.")`;

  if (runtime === "local") {
    apiKeyBlock = localApiKeyBlock;
  }
  // step 1
  const fileType = inputState.fileType || "pdf";
  let paths = `["${inputState.filePath}"]`;
  if (inputState.sourceLocation === "colab") {
    paths = `list(uploaded.keys())`;
  }
  const filesystem =
    inputState.sourceLocation === "s3"
      ? `pyarrow.fs.S3FileSystem(region="us-east-1", anonymous=True)`
      : "None";

  let uploadBlock = "";
  if (inputState.sourceLocation === "colab") {
    uploadBlock = `# Upload the document to the Colab environment. 
# Note: Please ensure the document is of type ${inputState.fileType.toUpperCase()}
from google.colab import files
uploaded = files.upload()`;
  }

  //  step 2
  const partitioningBlock = `.partition(partitioner=ArynPartitioner(
        threshold=${
          optionsState.auto_threshold === true
            ? '"auto"'
            : optionsState.threshold
        },
        use_ocr=${optionsState.use_ocr === true ? "True" : "False"},
        extract_table_structure=${
          optionsState.extract_table_structure === true ? "True" : "False"
        },
        extract_images=${
          optionsState.extract_images === true ? "True" : "False"
        },
        source="docprep"
    ))`;

  //  step3

  let chunkingBlock = "";
  switch (chunkingState.strategy) {
    case "GreedyTextElementMerger":
      chunkingBlock = `.merge(merger=GreedyTextElementMerger(
      tokenizer=tokenizer,  max_tokens=max_tokens, merge_across_pages=False
    ))`;
      break;
    case "GreedySectionMerger":
      chunkingBlock = `.merge(merger=GreedySectionMerger(
      tokenizer=tokenizer,  max_tokens=max_tokens, merge_across_pages=False
    ))`;
      break;
    case "MarkedMerger":
      chunkingBlock = `.mark_bbox_preset(tokenizer=tokenizer)
    .merge(merger=MarkedMerger())`;
      break;
  }

  const processingBlock1 = `# Sycamore uses lazy execution for efficiency, so the ETL pipeline will only execute when running cells with specific functions.
  
paths = ${paths}
${
  inputState.sourceLocation === "s3"
    ? `# Configure your AWS credentials here if the bucket is private
fsys = ${filesystem}`
    : ""
}
# Initialize the Sycamore context
ctx = sycamore.init(ExecMode.LOCAL)
# Set the embedding model and its parameters
model_name = "${modelName}"
max_tokens = ${maxTokens}
dimensions = ${dimensions}
# Initialize the tokenizer
tokenizer = ${tokenizer}(model_name)

ds = (
    ctx.read.binary(paths, binary_format="${fileType}"${
    inputState.sourceLocation === "s3" ? ", filesystem=fsys" : ""
  })${
    inputState.fileType !== "pdf"
      ? `\n    .map(binary_representation_to_pdf)`
      : ""
  }
    # Partition and extract tables and images
    ${partitioningBlock}
    # Use materialize to cache output. If changing upstream code or input files, change setting from USE_STORED to RECOMPUTE to create a new cache.
    .materialize(path="${
      runtime === "colab" ? "/content" : "."
    }/materialize/partitioned", source_mode=MaterializeSourceMode.USE_STORED)
    # Merge elements into larger chunks
    ${chunkingBlock}
    # Split elements that are too big to embed
    .split_elements(tokenizer=tokenizer, max_tokens=max_tokens)
)

ds.execute()

# Display the first 3 pages after chunking
show_pages(ds, limit=3)
`;

  const processingBlock2 = `embedded_ds = (
    # Copy document properties to each Document's sub-elements
    ds.spread_properties(["path", "entity"])
    # Convert all Elements to Documents
    .explode() 
    # Embed each Document. You can change the embedding model. Make your target vector index matches this number of dimensions.
    .embed(embedder=${embedder}(model_name=model_name))
)
# To know more about docset transforms, please visit https://sycamore.readthedocs.io/en/latest/sycamore/transforms.html
`;

  //  step4
  let databaseBlock = "";
  if (databaseState.targetDatabase === "DuckDB") {
    databaseBlock = `db_url = "${
      databaseState.fields.databasePath || "demo.db"
    }"
table_name = "${databaseState.fields.tableName || "demo_table"}"
# Execute the write operation to DuckDB
embedded_ds.write.duckdb(
    db_url=db_url,
    table_name=table_name,
    dimensions=dimensions
)`;
  }
  if (databaseState.targetDatabase === "Weaviate") {
    // const clientArgsBlock = Object.entries(databaseState.fields.weaviateClientArgs)
    //   .map(([key, value]) => `${key}="${value}"`)
    //   .join(",\n        ");
    databaseBlock = `collection_name = "${databaseState.fields.collectionName}"
# Client configuration for connecting to a Weaviate instance
client_args = {
    "connection_params": ConnectionParams.from_params(
        http_host="${databaseState.fields.http_host}",
        http_port=${databaseState.fields.http_port},
        http_secure=${databaseState.fields.http_secure},
        grpc_host="${databaseState.fields.grpc_host}",
        grpc_port=${databaseState.fields.grpc_port},
        grpc_secure=${databaseState.fields.grpc_secure}
    ),
    "auth_client_secret": Auth.api_key(api_key=os.getenv("WEAVIATE_API_KEY"))
}
# Configuration for the collection
collection_config = {
    "name": collection_name,
    "vectorizer_config": [Configure.NamedVectors.none(name="embedding")],
}
# Writing data to the Weaviate collection
embedded_ds.write.weaviate(
    wv_client_args=client_args,
    collection_name=collection_name,
    collection_config=collection_config,
    flatten_properties=${databaseState.fields.flattenProperties}
)
    `;
  }

  if (databaseState.targetDatabase === "Pinecone") {
    // let argsBlock = Object.entries(databaseState.fields.pineconeArgs)
    //   .map(([key, value]) => `${key}="${value}"`)
    //   .join(",\n    ");
    // if (argsBlock) {
    //   argsBlock = `,\n    ${argsBlock}`;
    // }
    databaseBlock = `# Create an instance of ServerlessSpec with the specified cloud provider and region
spec = ServerlessSpec(cloud="${databaseState.fields.cloudProvider}", region="${databaseState.fields.region}")
index_name = "${databaseState.fields.indexName}"
# Write data to a Pinecone index 
embedded_ds.write.pinecone(index_name=index_name, 
    dimensions=dimensions, 
    distance_metric="cosine",
    index_spec=spec
)`;
  }

  if (databaseState.targetDatabase === "Elasticsearch") {
    // let argsBlock = Object.entries(databaseState.fields.elasticsearchArgs)
    //   .map(([key, value]) => `${key}="${value}"`)
    //   .join(",\n    ");
    // if (argsBlock) {
    //   argsBlock = `,\n    ${argsBlock}`;
    // }
    databaseBlock = `# Write to a persistent Elasticsearch Index. Note: You must have a specified elasticsearch instance running for this to work.
# For more information on how to set one up, refer to https://www.elastic.co/guide/en/elasticsearch/reference/current/install-elasticsearch.html
url = "${databaseState.fields.hostUrl}"
index_name = "${databaseState.fields.indexName}"
embedded_ds.write.elasticsearch(
    url=url, 
    index_name=index_name,
    es_client_args={"api_key": os.getenv("ELASTICSEARCH_API_KEY")},
    mappings={
        "properties": {
            "embeddings": {
                "type": "dense_vector",
                "dims": dimensions,
                "index": True,
                "similarity": "cosine",
                },
            }
        }
    )`;
  }

  if (databaseState.targetDatabase === "OpenSearch") {
    databaseBlock = `index_name = "${databaseState.fields.indexName}"
# Configure the OpenSearch client arguments
os_client_args = {
    "hosts": [{"host": "${databaseState.fields.hostUrl}", "port": ${databaseState.fields.hostPort}}],
    "http_auth": (os.getenv("OS_USER_NAME"), os.getenv("OS_PASSWORD")),
    "verify_certs": False,
    "use_ssl": True,
}

# Configure the settings and mappings for the OpenSearch index
index_settings = {
    "body": {
        "settings": {
            "index.knn": True,
        },
        "mappings": {
            "properties": {
                "embedding": {
                    "type": "knn_vector",
                    "dimension": dimensions,
                    "method": {"name": "hnsw", "engine": "faiss"},
                },
            },
        },
    },
}

# Write the docset to the specified OpenSearch index
embedded_ds.write.opensearch(
    os_client_args=os_client_args,
    index_name=index_name,
    index_settings=index_settings,
)`;
  }

  if (databaseState.targetDatabase === "Qdrant") {
    databaseBlock = `location = "${databaseState.fields.hostUrl}"
collection_name = "${databaseState.fields.collectionName}"
embedded_ds.write.qdrant(
    client_params={"location": location, "api_key": os.getenv("QDRANT_API_KEY")},
    collection_params={"collection_name": collection_name, "vectors_config": { "size": dimensions, "distance": "Cosine"}}
)`;
  }

  let writerBlock = "";
  if (databaseState.targetDatabase === "DuckDB") {
    //     writerBlock = `# Connect to the DuckDB database, make sure database file exists at the specified path
    // con = duckdb.connect(database="/content/${databaseState.fields.databasePath}", read_only=True)
    // # Execute a SQL query to select the 'text_representation' column from the table
    // con.sql("SELECT text_representation FROM ${databaseState.fields.tableName}").show()`;
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
# If you previously used a DuckDB in Colab with a different number of vector dimensions, you may need to restart the runtime.
query = f"SELECT * from {table_name}"
query_docs = ctx.read.duckdb(db_url=db_url, table_name=table_name, query=query)
query_docs.show(show_embedding=False)`;
  }
  if (databaseState.targetDatabase === "OpenSearch") {
    //     writerBlock = `# Create an OpenSearch client instance using specified connection arguments and enable SSL with the 'use_ssl' parameter for https connection
    // client = OpenSearch(**os_client_args)

    // # Perform a search query to get all documents from the index created
    // hits = client.search(index=index_name, body={"query": {"match_all": {}}}, params={"_source_excludes": "embedding"})["hits"]["hits"]

    // # Iterate over each hit returned by the search and retrieve the 'text_representation' field
    // for hit in hits:
    //     text_representation = hit['_source'].get('text_representation', '')
    //     print(text_representation)`;
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
query_docs = ctx.read.opensearch(os_client_args=os_client_args, index_name=index_name, query={"query": {"match_all": {}}})
query_docs.show(show_embedding=False)`;
  }
  if (databaseState.targetDatabase === "Pinecone") {
    //     writerBlock = `# Initialize a Pinecone connection with your API key
    // pc = Pinecone(
    //     api_key=os.getenv('PINECONE_API_KEY'),
    // )
    // # Retrieve a list of all vector IDs from the specified Pinecone index
    // vector_ids = list(pc.Index(index_name).list())[0]
    // # Fetch the  vector data from the Pinecone index using the IDs from the list
    // vector_data = pc.Index(index_name).fetch(ids=vector_ids).vectors
    // for item in vector_data.items():
    //     text_representation = item[1]['metadata']['text_representation']
    //     print(text_representation)`;
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
query_docs = ctx.read.pinecone(index_name=index_name, api_key=os.getenv('PINECONE_API_KEY'))
query_docs.show(show_embedding=False)
`;
  }
  if (databaseState.targetDatabase === "Qdrant") {
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
query_docs = ctx.read.qdrant(
    {
        "url": location,
        "api_key": os.getenv("QDRANT_API_KEY"),
    },
    {"collection_name": collection_name, "limit": 100, "using": "{optional_vector_name}", "with_vectors": True}
)
query_docs.show(show_embedding=False)`;
  }
  if (databaseState.targetDatabase === "Weaviate") {
    //     writerBlock = `import json

    // # Initialize the Weaviate client
    // client = WeaviateClient(
    //     connection_params=client_args["connection_params"],
    //     auth_client_secret=client_args["auth_client_secret"]
    // )
    // # Establish a connection to the Weaviate instance
    // client.connect()
    // # Fetch and print object from the specified collection
    // for obj in client.collections.get(collection_name).query.fetch_objects(limit=20).objects:
    //   print(json.dumps(obj.properties, indent=4))`;
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
query_docs = ctx.read.weaviate(
    wv_client_args=client_args, collection_name=collection_name
)
query_docs.show(show_embedding=False)`;
  }

  if (databaseState.targetDatabase === "Elasticsearch") {
    writerBlock = `# Verify data has been loaded using DocSet Query to retrieve chunks
query_params = {"match_all": {}}
query_docs = ctx.read.elasticsearch(url=url, 
                                    index_name=index_name, 
                                    query=query_params,
                                    es_client_args={"api_key": os.getenv("ELASTICSEARCH_API_KEY")})
query_docs.show(show_embedding=False)`;
  }

  if (
    !installBlock ||
    !importsBlock ||
    !processingBlock1 ||
    !processingBlock2 ||
    !databaseBlock ||
    !writerBlock
  ) {
    return null;
  }

  const cells = [
    {
      cell_type: "code",
      metadata: {},
      source: [installBlock],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [installBlock2],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [importsBlock],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [apiKeyBlock],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [processingBlock1],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [processingBlock2],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [databaseBlock],
    },
    {
      cell_type: "code",
      metadata: {},
      source: [writerBlock],
    },
  ];

  if (uploadBlock && uploadBlock.trim().length > 0) {
    cells.splice(4, 0, {
      cell_type: "code",
      metadata: {},
      source: [uploadBlock],
    });
  }

  const notebook = {
    cells: cells,
    metadata: {},
    nbformat: 4,
    nbformat_minor: 2,
  };

  return notebook;
}
