package kinesisFHM.eventConsumer;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.Vector;

import kinesisFHM.commonTypes.Event;
import kinesisFHM.commonTypes.Relation;
import kinesisFHM.commonTypes.RelationCount;
import kinesisFHM.commonTypes.WfRType;
import kinesisFHM.utilities.CloudWatchReporter;

import org.apache.commons.codec.binary.Base64;

import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsRequest;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.GetShardIteratorRequest;
import com.amazonaws.services.kinesis.model.GetShardIteratorResult;
import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException;
import com.amazonaws.services.kinesis.model.PutRecordsRequest;
import com.amazonaws.services.kinesis.model.PutRecordsRequestEntry;
import com.amazonaws.services.kinesis.model.PutRecordsResult;
import com.amazonaws.services.kinesis.model.Record;
import com.amazonaws.services.kinesis.model.Shard;

public class EventConsumerThread extends Thread implements Runnable {

	private String inputStreamName;
	private String relationOutputStreamName;
	private String traceOutputStreamName;
	private AmazonKinesisClient amazonKinesisClient;
	private CloudWatchReporter traceOutputReporter;
	private CloudWatchReporter relationOutputReporter;
	private CloudWatchReporter inputReporter;
	private Shard inputShard;
	private int numOutputShards;

	private String fname;

	private long lastObservedEvent;

	private Properties props;

	private HashMap<String, SortedMap<Long, String>> activeTraces;
	private HashMap<String, Long> lastArrivals;
	private HashMap<String, Long> interArrivalTimes;
	private HashSet<String> retiredTraces;
	private ArrayList<SortedMap<Long, String>> emitTraces;
	private HashMap<Relation, RelationCount> relations;
	private boolean realTime;
	private long threadsleep;
	private int inputBatchSize;

	private boolean throttled;

	private String outputString;
	
	// maximum supported by AWS is 500
	private static final int maxOutputBatchSize = 499;
	// maximum supported by AWS is 1000
	private static final int maxInputBatchSize = 10000;

	public EventConsumerThread(String inputStreamName, 
								Shard inputShard, 
								AmazonKinesisClient amazonKinesisClient, 
								CloudWatchReporter traceOutputReporter, 
								CloudWatchReporter relationOutputReporter, 
								CloudWatchReporter inputReporter, 
								String relationOutputStreamName, 
								String traceOutputStreamName, 
								int numOutputShards, 
								boolean realTime, 
								long threadsleep, 
								int inputBatchSize) {
		this.inputStreamName = inputStreamName;
		this.relationOutputStreamName = relationOutputStreamName;
		this.traceOutputStreamName = traceOutputStreamName;
		this.amazonKinesisClient = amazonKinesisClient;
		this.inputReporter = inputReporter;
		this.traceOutputReporter = traceOutputReporter;
		this.relationOutputReporter = relationOutputReporter;
		this.inputShard = inputShard;
		this.numOutputShards = numOutputShards;
		this.realTime = realTime;
		this.fname = inputStreamName + inputShard.getShardId();
		this.threadsleep = threadsleep;
		this.inputBatchSize = Math.max(100,  Math.min(maxInputBatchSize, inputBatchSize));

		throttled=false;
		lastObservedEvent = 0;
		activeTraces = new HashMap<String, SortedMap<Long, String>>();
		lastArrivals = new HashMap<String, Long>();
		interArrivalTimes = new HashMap<String, Long>();
		retiredTraces = new HashSet<String>();
		emitTraces = new ArrayList<SortedMap<Long, String>>();
		relations = new HashMap<Relation, RelationCount>();
		
		System.out.println("Created EventConsumerThread " + this.getName() + " for shard " + inputShard.getShardId());
	}

	private Event decodeRecord(Record record) {
		byte[] data = Base64.decodeBase64(record.getData().array());
		ByteArrayInputStream bais = new ByteArrayInputStream(data);
		ObjectInputStream ois;
		Event event = null;
		try {
			ois = new ObjectInputStream( bais );
			event = (Event) ois.readObject();
		} catch (Exception e) {
			saveProps();
			System.out.println(this.getName() + " " + e.getMessage());
		}
		return event;
	}

	private boolean retireTraces() {
		Vector<String> retireList = new Vector<String>();
		//		System.out.println(this.getName() + " Retiring traces");
		for (String caseID : activeTraces.keySet()) {
			long lastArrival = lastArrivals.get(caseID);
			Long interarrivalTime = interArrivalTimes.get(caseID);
			if (interarrivalTime != null) {
				// it is not a new case, at least two events
				// we find the 99% quantile of the poisson distribution with that interarrival time
				//				double q99 = jdistlib.Poisson.quantile(0.99, interarrivalTime, true, false);
				//				double q99 = interarrivalTime * jdistlib.Poisson.quantile(0.99, 1, true, false);
				double q99 = interarrivalTime * 4;
				
				if (lastArrival + q99 < lastObservedEvent) {
					// move into list for emission to stream
					retireList.add(caseID);
					//					System.out.println(this.getName() + "    " + caseID);
				} else {
					// not yet time to retire
				}
			} else {
				// it is a new case with only one event, cannot retire
			}
		}
		// retire the cases, includes deleting all events
		for (String caseID : retireList) {
			emitTraces.add(activeTraces.get(caseID));
			activeTraces.remove(caseID);
			retiredTraces.add(caseID);
			interArrivalTimes.remove(caseID);
			lastArrivals.remove(caseID);
		}
		//		System.out.println(this.getName() + " Retiring traces finished");
		// considered complete, emit to stream
		if (!emitTraces.isEmpty()) 
			return emitCasesToStream();
		else 
			return false;
	}

	private boolean emitCasesToStream() {
		//		System.out.println(this.getName() + " Emitting completed traces to stream");
		boolean throttled = false;
		int traceWriteCount = emitTraces.size();
		while (!emitTraces.isEmpty()) {
			//			System.out.println(this.getName() + " Trace Stream Output Queue size = " + emitTraces.size());

			PutRecordsRequest putRecordsRequest  = new PutRecordsRequest();
			putRecordsRequest.setStreamName( traceOutputStreamName );
			List<PutRecordsRequestEntry> putRecordsRequestEntryList  = new ArrayList<PutRecordsRequestEntry>(); 

			for (int i=0; i<Math.min(emitTraces.size(), maxOutputBatchSize); i++) {
				SortedMap<Long, String> trace = emitTraces.get(i);

				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				try {
					new ObjectOutputStream( baos ).writeObject( trace );
					PutRecordsRequestEntry putRecordsRequestEntry = new PutRecordsRequestEntry();
					putRecordsRequestEntry.setData(ByteBuffer.wrap(Base64.encodeBase64(baos.toByteArray())));
					putRecordsRequestEntry.setPartitionKey( Integer.toString( (int) Math.ceil(Math.random() * numOutputShards) ));
					putRecordsRequestEntryList.add(putRecordsRequestEntry);
					//					System.out.println(this.getName() + " Added Trace " + trace + " to " + traceOutputStreamName);
				} catch (IOException e) {
					// TODO Auto-generated catch block
					System.out.println(this.getName() + " " + e.getMessage());
				}

			}				
			putRecordsRequest.setRecords(putRecordsRequestEntryList);
			PutRecordsResult putRecordsResult  = amazonKinesisClient.putRecords(putRecordsRequest);
			for(int j=putRecordsResult.getRecords().size()-1; j >= 0 ; j--) {
				if (putRecordsResult.getRecords().get(j).getErrorCode() == null)
					emitTraces.remove(j);
				else {
					if (putRecordsResult.getRecords().get(j).getErrorCode().equals("ProvisionedThroughputExceededException")) {
						// slowing down by lengthening our sleep period
//						threadsleep++;
						throttled = true;
						System.out.println("Backing off");
					} else {
						System.out.println(this.getName() + " ERROR : " + putRecordsResult.getRecords().get(j).getErrorMessage());
					}
				}
			}
			if (throttled) {
				System.out.println("Trace Output Throttled - Sleeping");
				try { Thread.sleep(threadsleep); } catch (InterruptedException e) {}
			}
		}
		// log the puts to cloudwatch
		traceOutputReporter.add(traceWriteCount);

		if (throttled) 
			threadsleep*=1.05;
		return throttled;
		//		System.out.println(this.getName() + " Done emitting completed traces");
	}

	private void computeNewRelations(String caseID) {

		Object[] trace = activeTraces.get(caseID).values().toArray();

		String ev1 = (String)trace[trace.length-1];
		addToRelations(new Relation(ev1, ev1, WfRType.COUNT) );

		if (trace.length > 1) {
			String ev2 = (String)trace[trace.length-2];		
			addToRelations(new Relation(ev2, ev1, WfRType.DIRECT) );
			addToRelations(new Relation(ev2, ev1, WfRType.SUCCESSOR) );

			if (trace.length > 2) {
				String ev3 = (String)trace[trace.length-3];
				if (ev3.equals(ev1)) {
					addToRelations(new Relation(ev2, ev1, WfRType.LOOPTWO) );
				}
				addToRelations(new Relation(ev3, ev1, WfRType.SUCCESSOR));
				if (trace.length > 3) {
					for (int i=0; i<trace.length-3; i++) {
						ev2 = (String)trace[i];
						addToRelations(new Relation(ev2, ev1, WfRType.SUCCESSOR));
					}
				}
			}
		}
	}

	private void addToRelations(Relation r) {
		RelationCount rc = relations.get(r);
		if (rc == null) {
			rc = new RelationCount(r, 0l);
			relations.put(r,  rc);
		}
		rc.add();
	}

	private boolean emitRelationsToStream() {
		/*
		 * Ideally we'd like to implement this with a separate worker thread that
		 * maintains a queue/list of records to be emitted, takes a chunk off the 
		 * front of the queue, tries to emit that chunk, and re-inserts those it cant
		 * emit into the queue again for a later retry.
		 * 
		 * However, this is something that can't be reconciled with the checkpoint semantics 
		 * here, because when we checkpoint a processed incoming event, we'd like to make
		 * sure that the corresponding relations have actually been emitted. The asynchronous
		 * queue implementation described above cannot do that, unless I maintain a complicated
		 * table referencing incoming events with emission requests in the queue
		 * But then the checkpointing will become very difficult to implement as well
		 * as one would need to periodically check this reference table to see whether
		 * there are any remaining emission requests for a given input event.
		 * 
		 * Thus, it is easier that we simply issue put requests until all relations
		 * for a particular set of incoming event have been processed. In batches.
		 */
		//		System.out.println(this.getName() + " Emitting relations to stream");
		ArrayList<RelationCount> relationList = new ArrayList<RelationCount>(relations.values());
		boolean throttled = false;
		int relationWriteCount = relations.size();
		outputString = new String(this.getName() + " ");
		while (!relationList.isEmpty()) {
			//			System.out.println(this.getName() + " Relation Stream Output Queue size = " + relationList.size());
			List <PutRecordsRequestEntry> putRecordsRequestEntryList  = new ArrayList<>(); 				

			for (int i=0; i<Math.min(relationList.size(), maxOutputBatchSize); i++) {
				RelationCount relation = relationList.get(i);
				PutRecordsRequestEntry putRecordsRequestEntry  = new PutRecordsRequestEntry();
				putRecordsRequestEntry.setData(ByteBuffer.wrap(relation.getBase64()));
				String key;
				if (relation.getRelation().getEvent1().compareTo(relation.getRelation().getEvent2()) <= 0)
					key = relation.getRelation().getEvent1() + ":" + relation.getRelation().getEvent2();
				else
					key = relation.getRelation().getEvent2() + ":" + relation.getRelation().getEvent1();					
				putRecordsRequestEntry.setPartitionKey(key);
//				putRecordsRequestEntry.setPartitionKey(relation.getRelation().getEvent1() + ":" + relation.getRelation().getEvent2());
				putRecordsRequestEntryList.add(putRecordsRequestEntry); 
				//				System.out.println(this.getName() + " Added RelationCount " + relation + " to " + relationOutputStreamName);
			}
			PutRecordsRequest putRecordsRequest  = new PutRecordsRequest();
			putRecordsRequest.setStreamName( relationOutputStreamName );
			putRecordsRequest.setRecords(putRecordsRequestEntryList);
			PutRecordsResult putRecordsResult  = amazonKinesisClient.putRecords(putRecordsRequest);
			for(int j=putRecordsResult.getRecords().size()-1; j >= 0 ; j--) {
				if (putRecordsResult.getRecords().get(j).getErrorCode() == null)
					relationList.remove(j);
				else {
					if (putRecordsResult.getRecords().get(j).getErrorCode().equals("ProvisionedThroughputExceededException")) {
						// slowing down by lengthening our sleep period
//						threadsleep++;
						throttled = true;
					} else {
//						System.out.println(this.getName() + " ERROR : " + putRecordsResult.getRecords().get(j).getErrorMessage());
					}
				}
			}
			if (throttled) {
				outputString += "T";
				try { Thread.sleep(threadsleep); } catch (InterruptedException e) {}
			} else {
				outputString += ".";
			}
		}
		// log the puts to cloudwatch
		relationOutputReporter.add(relationWriteCount);

		relations.clear();
		if (throttled) 
			threadsleep*=1.05;
		
		outputString += new String(" emitted " + relationWriteCount + " records, ");
		return throttled;
		//		System.out.println(this.getName() + " Done emitting relations");
	}

	private void processRecord(Record record) {

		Event event = decodeRecord(record);
		if (event != null) {

			String caseID = event.getCaseID();
			String eventID = event.getEventID();
			Long eventTS = event.getTS();

			if (!retiredTraces.contains(caseID)) {
				// It is not a retired trace trace, so either new or existing
				SortedMap<Long, String> trace = activeTraces.get(caseID);
				if (trace != null) {
					// it is an existing trace
					Long lastArrival = lastArrivals.get(caseID);
					Long currentInterarrivalTime = interArrivalTimes.get(caseID);
					Long newInterarrivalTime;
					if (currentInterarrivalTime == null) {
						// this is the second event for the case
						newInterarrivalTime = eventTS - lastArrival;
					} else {
						// this is neither the first nor the second event for the case
						// update the mean interarrival time
						newInterarrivalTime = (currentInterarrivalTime*trace.size() + (eventTS - lastArrival))/(trace.size()+1);
					}
					interArrivalTimes.put(caseID, newInterarrivalTime);
					lastArrivals.put(caseID,  eventTS);
					trace.put(eventTS, eventID);
				} else {
					// it is a new trace
					trace = new TreeMap<Long, String>();
					activeTraces.put(caseID, trace);
					lastArrivals.put(caseID,  eventTS);
					trace.put(eventTS, eventID);
					// we don't have an interarrival time yet, this was the first event
				}
				computeNewRelations(caseID);
			} else {
				// It is a retired trace
//				System.out.println(this.getName() + " Received event " + event.getEventID() + " for retired case " + caseID + " with TS " + event.getTS());
			}
			//update the last observed event timestamp
			if (eventTS > lastObservedEvent) 
				lastObservedEvent = eventTS;
		} else {
//			System.out.println(this.getName() + " Received empty event");
		}
	}

	private void loadProps() {
		props = new Properties();
		try {
			props.load(new FileReader(fname + ".properties"));
		} catch (FileNotFoundException e) {
			System.out.println(this.getName() + " " + e.getMessage());
		} catch (IOException e) {
			System.out.println(this.getName() + " " + e.getMessage());
		}	
	}

	private void saveProps() {
		try {
			props.store(new FileWriter(fname + ".properties"), null);
		} catch (IOException e) {
			System.out.println(this.getName() + " " + e.getMessage());
		}
	}

	public void run() {

		loadProps();
		String checkPoint = props.getProperty(inputShard.getShardId());

		GetShardIteratorRequest getShardIteratorRequest = new GetShardIteratorRequest();
		getShardIteratorRequest.setStreamName(inputStreamName);
		getShardIteratorRequest.setShardId(inputShard.getShardId());

		// FIXME: Debug
		//			checkPoint = null;

		if (checkPoint != null) {
			getShardIteratorRequest.setShardIteratorType("AFTER_SEQUENCE_NUMBER");
			getShardIteratorRequest.setStartingSequenceNumber(checkPoint);
		} else {
			getShardIteratorRequest.setShardIteratorType("TRIM_HORIZON");
		}
		GetShardIteratorResult getShardIteratorResult = amazonKinesisClient.getShardIterator(getShardIteratorRequest);
		String shardIterator = getShardIteratorResult.getShardIterator();

		long lastRetirement = 0;

		System.out.println(this.getName() + " Worker started");

		long previous_behind = Long.MAX_VALUE;
		long current_behind = 0l;
		int count = 0;

		while (true) try {
			// Create a new getRecordsRequest with an existing shardIterator 
			GetRecordsRequest getRecordsRequest = new GetRecordsRequest();
			getRecordsRequest.setShardIterator(shardIterator);
			getRecordsRequest.setLimit(inputBatchSize); 

			GetRecordsResult result = null;
			while (result == null) {
				try {
					result = amazonKinesisClient.getRecords(getRecordsRequest);
				} catch (ExpiredIteratorException e) {
					//						System.out.println(this.getName() + " Iterator expired, getting new one");
					getShardIteratorResult = amazonKinesisClient.getShardIterator(getShardIteratorRequest);
					shardIterator = getShardIteratorResult.getShardIterator();
					getRecordsRequest = new GetRecordsRequest();
					getRecordsRequest.setShardIterator(shardIterator);
					getRecordsRequest.setLimit(inputBatchSize); 
					result = amazonKinesisClient.getRecords(getRecordsRequest);
				} catch (ProvisionedThroughputExceededException e) {
					System.out.println(this.getName() + " Read throughput exceeded, waiting a turn");
					throttled = true;
					if (threadsleep > 50)
						threadsleep*=1.05;
					else 
						if (inputBatchSize < 20) 
							inputBatchSize++;
						else 
							inputBatchSize = Math.min( maxInputBatchSize, Math.round(inputBatchSize*1.05f) );
					try { Thread.sleep(threadsleep/10); } catch (InterruptedException ee) { }
				}
			}
			inputReporter.add(result.getRecords().size());

			long starttime = System.currentTimeMillis();
			for (Record record : result.getRecords()) {
				processRecord(record);
				checkPoint = record.getSequenceNumber();
			}
			throttled = throttled || emitRelationsToStream();
			long endtime = System.currentTimeMillis();
			

			if (checkPoint != null) {
				//					System.out.println(this.getName() + " Checkpointing");
				props.setProperty(inputShard.getShardId(), checkPoint);
				saveProps();
			}

			if (endtime - lastRetirement > 10*1000) {
				if (realTime) 
					lastObservedEvent = endtime;
				throttled = throttled || retireTraces();
				lastRetirement = endtime;
			}

			previous_behind = current_behind;
			current_behind = result.getMillisBehindLatest();
			if (!throttled) {
				if (count>0) count--;
				System.out.println(outputString + "processed " + result.getRecords().size() + "/" + inputBatchSize + " records in " + (endtime-starttime) + " at " + current_behind + " behind, sleep " + threadsleep + " traces (" + this.activeTraces.size() + "/" + this.retiredTraces.size() + ")");
			}
			else {
				System.out.println(outputString + "processed " + result.getRecords().size() + "/" + inputBatchSize + " records in " + (endtime-starttime) + " at " + current_behind + " behind, sleep " + threadsleep  + " traces (" + this.activeTraces.size() + "/" + this.retiredTraces.size() + ")" + " === Throttled ===");
				count=5;
			}

			if ( (result.getRecords().size() == inputBatchSize) && (current_behind > 10000) && (current_behind > previous_behind) && !throttled & count==0) {
				// speeding up by shortening our sleep period and increasing the batch size
				if (threadsleep > 20)
					threadsleep*=0.95;
				else
					// minimum threadsleep is 0
					threadsleep = Math.max(threadsleep-1, 0);
				//maximum batchsize is maxInputBatchSize
				inputBatchSize = Math.min( maxInputBatchSize, Math.round(inputBatchSize*1.05f) );
			}
			if ( result.getRecords().size() < inputBatchSize/2 ) {
				// slowing down by increasing sleep period and reducing the batch size
				if (threadsleep < 20)
					threadsleep++;
				else
					// maximum threadsleep is 500
					threadsleep=Math.min( 500,  Math.round(threadsleep*1.05));
				if (inputBatchSize > 100)
					// minimum batchsize is 100
					inputBatchSize = Math.max( 100, Math.round(inputBatchSize*0.95f));
			}
			throttled=false;

			try { Thread.sleep(threadsleep); } catch (InterruptedException exception) {	saveProps(); }

			shardIterator = result.getNextShardIterator();
		} catch (Exception e) { System.out.println(this.getName() + " " + e.getLocalizedMessage()); }	
	}
}
