package eu.dnetlib.miscutils.iterators.xml;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringWriter;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.util.Iterator;

import javax.xml.stream.XMLEventFactory;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;

import org.xml.sax.SAXException;

public class XMLIterator implements Iterator<String> {

	private ThreadLocal<XMLInputFactory> inputFactory = new ThreadLocal<XMLInputFactory>() {
		@Override
		protected XMLInputFactory initialValue() {
			return XMLInputFactory.newInstance();
		}
	};

	private ThreadLocal<XMLOutputFactory> outputFactory = new ThreadLocal<XMLOutputFactory>() {
		@Override
		protected XMLOutputFactory initialValue() {
			return XMLOutputFactory.newInstance();
		}
	};

	private ThreadLocal<XMLEventFactory> eventFactory = new ThreadLocal<XMLEventFactory>() {
		@Override
		protected XMLEventFactory initialValue() {
			return XMLEventFactory.newInstance();
		}
	};

	public static final String UTF_8 = "UTF-8";

	final XMLEventReader parser;

	private XMLEvent current = null;

	private String element;

	private InputStream inputStream;

	public XMLIterator(String element, InputStream inputStream) {
		super();
		this.element = element;
		this.inputStream = inputStream;
		this.parser = getParser();
		
		this.current = findElement(parser);
	}

	@Override
	public boolean hasNext() {
		return current != null;
	}

	@Override
	public String next() {
		try {
			String result = copy(parser);
			current = findElement(parser);
			return result;
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void remove() {
		throw new UnsupportedOperationException();
	}

	private String copy(final XMLEventReader parser) throws XMLStreamException, SAXException, IOException {
		final StringWriter result = new StringWriter();
		final XMLEventWriter writer = outputFactory.get().createXMLEventWriter(result);

		StartElement newRecord = eventFactory.get().createStartElement("", null, element, current.asStartElement().getAttributes(), null);

		// new root record
		writer.add(newRecord);

		// copy the rest as it is
		while (parser.hasNext()) {
			final XMLEvent event = parser.nextEvent();

			// TODO: replace with depth tracking instead of close tag tracking. 
			if (event.isEndElement() && event.asEndElement().getName().getLocalPart().equals(element)) {
				writer.add(event);
				break;
			}

			writer.add(event);
		}
		writer.close();

		String string = result.toString();
		return string;
	}

	private XMLEvent findElement(XMLEventReader parser) {

		/*
		 * if (current != null && element.equals(current.asStartElement().getName().getLocalPart())) { return current; }
		 */

		XMLEvent peek = peekEvent(parser);
		if (peek != null && peek.isStartElement()) {
			String name = peek.asStartElement().getName().getLocalPart();
			if (element.equals(name))
				return peek;
		}

		while (parser.hasNext()) {
			final XMLEvent event = nextEvent(parser);
			if (event != null && event.isStartElement()) {
				String name = event.asStartElement().getName().getLocalPart();
				if (element.equals(name)) {
					return event;
				}
			}
		}
		return null;
	}

	private XMLEvent nextEvent(XMLEventReader parser) {
		try {
			return parser.nextEvent();
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

	private XMLEvent peekEvent(XMLEventReader parser) {
		try {
			return parser.peek();
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

	private XMLEventReader getParser() {
		try {
			return inputFactory.get().createXMLEventReader(sanitize(inputStream));
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

	private Reader sanitize(final InputStream in) {
		final CharsetDecoder charsetDecoder = Charset.forName(UTF_8).newDecoder();
		charsetDecoder.onMalformedInput(CodingErrorAction.REPLACE);
		charsetDecoder.onUnmappableCharacter(CodingErrorAction.REPLACE);
		return new InputStreamReader(in, charsetDecoder);
	}

}
