package eu.dnetlib.dhp.solr;

import com.clearspring.analytics.util.Lists;

import eu.dnetlib.dhp.solr.mapping.RowToSolrInputDocumentMapper;
import org.apache.solr.client.solrj.SolrClient;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.client.solrj.response.UpdateResponse;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.SolrInputDocument;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.apache.zookeeper.KeeperException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.List;
import java.util.concurrent.ExecutionException;

import static eu.dnetlib.dhp.utils.SparkSessionSupport.runWithSparkSession;

public class RecordImporter implements Serializable {

    private static final Logger log = LoggerFactory.getLogger(RecordImporter.class);

    public static final int BATCH_SIZE = 1000;

    private static final int MAX_RETRIES = 3;

    public static final int RETRY_DELAY = 3000;

    private static final StructType PAYLOAD_SCHEMA = StructType.fromDDL("xml STRING, json STRING");

    public static void importRecords(SparkConf conf, String zkHost, String collection, String path, int batchSize, boolean shouldCommit) {
        runWithSparkSession(conf, true, spark -> {
            CloudClientParams params = new CloudClientParams(zkHost, collection);
            indexDocs(params, batchSize, spark.read().schema(PAYLOAD_SCHEMA).json(path));
            log.info("record import completed");
            if (shouldCommit) {
                CloudSolrClient client = CacheCloudSolrClient.getCachedCloudClient(params);
                UpdateResponse commitRsp = client.commit(collection);
                if (commitRsp.getStatus() != 0) {
                    log.error("got exception during commit operation", commitRsp.getException());
                    throw commitRsp.getException();
                } else {
                    log.info("commit done");
                }
            }
        });
    }

    private static void indexDocs(CloudClientParams params, int batchSize, Dataset<Row> docs) {
        docs.foreachPartition(solrDocs -> {
            try {
                final CloudSolrClient client = CacheCloudSolrClient.getCachedCloudClient(params);
                final List<SolrInputDocument> batch = Lists.newArrayList();
                while (solrDocs.hasNext()) {
                    SolrInputDocument doc = RowToSolrInputDocumentMapper.map(solrDocs.next());

                    if (wouldBatchBeFull(batch.size(), batchSize)) {
                        sendBatchToSolrWithRetry(params, client, batch);
                        batch.clear();
                    }
                    batch.add(doc);
                }
                if (!batch.isEmpty()) {
                    sendBatchToSolrWithRetry(params, client, batch);
                    batch.clear();
                }
            } catch (ExecutionException e) {
                throw new RuntimeException(e);
            }
        });
    }

    private static void sendBatchToSolrWithRetry(CloudClientParams params, SolrClient solrClient, List<SolrInputDocument> batch) throws ExecutionException {
        try {
            sendBatchToSolr(params, solrClient, batch, 1, RETRY_DELAY);
        } catch (Exception e) {
            Throwable e1 = SolrException.getRootCause(e);
            if (e1 instanceof KeeperException.SessionExpiredException || e1 instanceof KeeperException.OperationTimeoutException) {

                log.error("Error indexing batch to collection {} ; will retry ... \n\nERROR: {}", params.getCollection(), e.toString());

                CacheCloudSolrClient.invalidateCachedClient(params);
                CloudSolrClient newClient = CacheCloudSolrClient.getCachedCloudClient(params);
                sendBatchToSolr(params, newClient, batch, 1, RETRY_DELAY);
            } else {
                throw new ExecutionException(e.getMessage(), e);
            }
        }
    }

    private static void sendBatchToSolr(CloudClientParams params, SolrClient solrClient, List<SolrInputDocument> batch, int attempt, int retryDelay) throws ExecutionException {
        if (attempt > MAX_RETRIES) {
            String msg = String.format("Reached max number of allowed retries %d, failing...", MAX_RETRIES);
            log.error(msg);
            throw new ExecutionException(new RuntimeException(msg));
        }

        UpdateRequest req = new UpdateRequest();
        req.setParam("collection", params.getCollection());
        long initialTime = System.currentTimeMillis();

        log.info("Sending batch of {} to collection {} attempt {}", batch.size(), params.getCollection(), attempt);

        req.add(batch);

        try {
            solrClient.request(req);
            double timeTaken = (System.currentTimeMillis() - initialTime) / 1000.0;

            log.info("Took '{}' secs to index {} documents", timeTaken, batch.size());

        } catch (Exception e) {
            log.error("Error indexing batch to collection {} ; attempt {} ; will retry ... \n\nERROR: {}", params.getCollection(), attempt, e.toString());
            try {
                Thread.sleep(retryDelay);
            } catch(InterruptedException ie) {
                Thread.interrupted();
            }

            sendBatchToSolr(params, solrClient, batch, attempt + 1, retryDelay * 2);
        }
    }

    private static boolean wouldBatchBeFull(int numDocsInBatch, int batchSize) {
        return numDocsInBatch > 0 && numDocsInBatch >= batchSize;
    }

}
