package eu.dnetlib.enabling.manager.msro.hope;

import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.util.Iterator;

import javax.xml.stream.XMLEventFactory;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;

import org.xml.sax.SAXException;

public class IterableXmlInputStreamParser implements Iterable<String> {

	private final ThreadLocal<XMLInputFactory> inputFactory = new ThreadLocal<XMLInputFactory>() {
		@Override
		protected XMLInputFactory initialValue() {
			return XMLInputFactory.newInstance();
		}
	};

	private final ThreadLocal<XMLOutputFactory> outputFactory = new ThreadLocal<XMLOutputFactory>() {
		@Override
		protected XMLOutputFactory initialValue() {
			return XMLOutputFactory.newInstance();
		}
	};

	private final ThreadLocal<XMLEventFactory> eventFactory = new ThreadLocal<XMLEventFactory>() {
		@Override
		protected XMLEventFactory initialValue() {
			return XMLEventFactory.newInstance();
		}
	};

	private final String element;

	private final InputStream inputStream;

	public IterableXmlInputStreamParser(String element, InputStream inputStream) {
		this.element = element;
		this.inputStream = inputStream;
	}

	@Override
	public Iterator<String> iterator() {

		return new Iterator<String>() {

			final XMLEventReader parser = getParser();

			private XMLEvent current = null;

			@Override
			public boolean hasNext() {
				try {
					current = findElement(parser);
					return true;
				} catch (IllegalArgumentException e) {
					return false;
				}
			}

			@Override
			public String next() {
				try {
					return copy(parser);
				} catch (Exception e) {
					throw new RuntimeException(e);
				}
			}

			@Override
			public void remove() {
				throw new UnsupportedOperationException();
			}

			private String copy(final XMLEventReader parser) throws XMLStreamException, SAXException, IOException {
				final StringWriter result = new StringWriter();
				final XMLEventWriter writer = outputFactory.get().createXMLEventWriter(result);

				StartElement newRecord = eventFactory.get().createStartElement("", null, element, current.asStartElement().getAttributes(),
						current.asStartElement().getNamespaces());

				// new root record
				writer.add(newRecord);

				// copy the rest as it is
				while (parser.hasNext()) {
					final XMLEvent event = parser.nextEvent();

					// TODO: replace with depth tracking instead of close tag tracking. 
					if (event.isEndElement() && event.asEndElement().getName().getLocalPart().equals(element)) {
						writer.add(event);
						break;
					}

					writer.add(event);
				}
				writer.close();

				String string = result.toString();
				return string;
				//Document doc = getDocumentBuilder().parse(new InputSource(new StringReader(string)));
				//doc.normalizeDocument();
				//return doc;
			}

			private XMLEvent findElement(XMLEventReader parser) {

				XMLEvent peek = peekEvent(parser);
				if (peek != null && peek.isStartElement()) {
					String name = peek.asStartElement().getName().getLocalPart();
					if (element.equals(name))
						return peek;
				}

				while (parser.hasNext()) {
					final XMLEvent event = nextEvent(parser);
					if (event != null && event.isStartElement()) {
						String name = event.asStartElement().getName().getLocalPart();
						if (element.equals(name)) {
							return event;
						}
					}
				}
				throw new IllegalArgumentException("cannot find element <" + element + ">");
			}
		};
	}

	private XMLEvent nextEvent(XMLEventReader parser) {
		try {
			return parser.nextEvent();
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

	private XMLEvent peekEvent(XMLEventReader parser) {
		try {
			return parser.peek();
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

	private XMLEventReader getParser() {
		try {
			return inputFactory.get().createXMLEventReader(inputStream);
		} catch (XMLStreamException e) {
			throw new RuntimeException(e);
		}
	}

}
