package kinesisFHM.relationConsumer;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;

import kinesisFHM.commonTypes.DependencyRelation;
import kinesisFHM.commonTypes.EventPair;
import kinesisFHM.commonTypes.RelationCount;
import kinesisFHM.commonTypes.WfRType;
import kinesisFHM.utilities.CloudWatchReporter;

import org.apache.commons.codec.binary.Base64;

import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient;
import com.amazonaws.services.cloudwatch.model.Dimension;
import com.amazonaws.services.cloudwatch.model.MetricDatum;
import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest;
import com.amazonaws.services.cloudwatch.model.StandardUnit;
import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsRequest;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.GetShardIteratorRequest;
import com.amazonaws.services.kinesis.model.GetShardIteratorResult;
import com.amazonaws.services.kinesis.model.ProvisionedThroughputExceededException;
import com.amazonaws.services.kinesis.model.PutRecordsRequest;
import com.amazonaws.services.kinesis.model.PutRecordsRequestEntry;
import com.amazonaws.services.kinesis.model.PutRecordsResult;
import com.amazonaws.services.kinesis.model.Record;
import com.amazonaws.services.kinesis.model.Shard;

public class RelationConsumerThread extends Thread implements Runnable {

	private String inputStreamName;
	private String outputStreamName;
	private AmazonKinesisClient amazonKinesisClient;
	private CloudWatchReporter inputReporter;
	private CloudWatchReporter outputReporter;
	private Shard inputShard;

	private String fname;	
	private Properties props;

	// maximum supported by AWS is 500
	private static final int maxOutputBatchSize = 499;
	// maximum supported by AWS is 1000
	private static final int maxInputBatchSize = 10000;

	private long threadsleep;
	private int inputBatchSize;
	private boolean throttled;
	private String outputString;
	
	private HashMap<EventPair, Long> counts = new HashMap<EventPair, Long>();
	private HashMap<EventPair, Long> followsCount = new HashMap<EventPair, Long>();
	private HashMap<EventPair, Long> L2LCount = new HashMap<EventPair, Long>();
	private HashMap<EventPair, Long> LDCount = new HashMap<EventPair, Long>();
	private HashMap<EventPair, DependencyRelation> followsDep = new HashMap<EventPair, DependencyRelation>();
	private HashMap<EventPair, DependencyRelation> L2LDep = new HashMap<EventPair, DependencyRelation>();
	private HashMap<EventPair, DependencyRelation> LDDep = new HashMap<EventPair, DependencyRelation>();
	
	protected static float mindep = 0.5f;
	
	public RelationConsumerThread(	String inputStreamName, 
									Shard inputShard, 
									AmazonKinesisClient amazonKinesisClient, 
									CloudWatchReporter inputReporter,
									CloudWatchReporter outputReporter,
									String outputStreamName,
									long threadsleep,
									int inputBatchSize) {
		this.inputStreamName = inputStreamName;
		this.outputStreamName = outputStreamName;
		this.amazonKinesisClient = amazonKinesisClient;
		this.inputReporter = inputReporter;
		this.outputReporter = outputReporter;
		this.inputShard = inputShard;
		this.fname = inputStreamName + inputShard.getShardId();
		this.threadsleep = threadsleep;
		this.inputBatchSize = Math.min(inputBatchSize, maxInputBatchSize);
		this.throttled = false;
		
		System.err.println("Created RelationConsumerThread " + this.getName() + " for shard " + inputShard.getShardId());
		
		File f = new File(fname+".persistedCounts");
		ObjectInputStream ois;
		try {
			ois = new ObjectInputStream(new FileInputStream(f));
			counts = (HashMap<EventPair, Long>) ois.readObject();
			followsCount = (HashMap<EventPair, Long>) ois.readObject();
			L2LCount = (HashMap<EventPair, Long>) ois.readObject();
			LDCount = (HashMap<EventPair, Long>) ois.readObject();
		} catch (FileNotFoundException e) {
			System.err.println(this.getName() + " " + e.getMessage());			
		} catch (IOException e) {
			System.err.println(this.getName() + " " + e.getMessage());
		} catch (ClassNotFoundException e) {
			System.err.println(this.getName() + " " + e.getMessage());
		}
	}
	
	private RelationCount decodeRecord(Record record) {
		byte[] data = Base64.decodeBase64(record.getData().array());
		ByteArrayInputStream bais = new ByteArrayInputStream(data);
		ObjectInputStream ois;
		RelationCount rel = null;
		try {
			ois = new ObjectInputStream( bais );
			rel = (RelationCount) ois.readObject();
		} catch (Exception e) {
			saveProps();
			System.err.println(this.getName() + " " + e.getMessage());
		}
		return rel;
	}
	
	private void emitDependenciesToStream() {
		List<DependencyRelation> relSet = new ArrayList<DependencyRelation>();
		
		for (DependencyRelation rel : followsDep.values()) {
			if (rel.getDep() >= mindep && rel.getDep() <= 1)
				relSet.add(rel);
		}
		for (DependencyRelation rel : L2LDep.values()) {
			if (rel.getDep() >= mindep && rel.getDep() <= 1)
				relSet.add(rel);
		}
		for (DependencyRelation rel : LDDep.values()) {
			if (rel.getDep() >= mindep && rel.getDep() <= 1)
				relSet.add(rel);
		}

		int writeCount = relSet.size();
		outputString = new String(this.getName() + " emitted " + writeCount + " ");
		while (!relSet.isEmpty()) {
//			System.err.println(this.getName() + " Output Queue size = " + relSet.size());
			PutRecordsRequest putRecordsRequest  = new PutRecordsRequest();
			putRecordsRequest.setStreamName( outputStreamName );
			List <PutRecordsRequestEntry> putRecordsRequestEntryList  = new ArrayList<>(); 
			
			for (int i=0; i<Math.min(relSet.size(), maxOutputBatchSize); i++) {
				DependencyRelation rel = relSet.get(i);
			    PutRecordsRequestEntry putRecordsRequestEntry  = new PutRecordsRequestEntry();		    
				putRecordsRequestEntry.setData(ByteBuffer.wrap(rel.getBase64()));
				putRecordsRequestEntry.setPartitionKey(rel.getEvent1() + ":" + rel.getEvent2());
				putRecordsRequestEntryList.add(putRecordsRequestEntry);
//				System.err.println(this.getName() + " Added DependencyRelation " + rel + " to " + outputStreamName);
			}
			putRecordsRequest.setRecords(putRecordsRequestEntryList);
			PutRecordsResult putRecordsResult  = amazonKinesisClient.putRecords(putRecordsRequest);
			outputString += ".";
			for(int j=putRecordsResult.getRecords().size()-1; j >= 0 ; j--) {
				if (putRecordsResult.getRecords().get(j).getErrorCode() == null)
					relSet.remove(j);
				else {
					if (putRecordsResult.getRecords().get(j).getErrorCode().equals("ProvisionedThroughputExceededException")) {
						// slowing down by lengthening our sleep period
//						threadsleep++;
						throttled = true;
						System.out.println("Backing off");
					} else {
						System.out.println(this.getName() + " ERROR : " + putRecordsResult.getRecords().get(j).getErrorMessage());
					}
				}
			}
			if (throttled) {
				System.out.println("Update Output Throttled - Sleeping");
				try { Thread.sleep(threadsleep); } catch (InterruptedException e) {}
			}
		}
		// log the puts to cloudwatch
		outputReporter.add(writeCount);

		if (throttled) 
			threadsleep*=1.05;
		followsDep.clear();
		L2LDep.clear();
		LDDep.clear();
	}

	private void processRecord(Record record) {
		RelationCount rel = decodeRecord(record);
		if (rel != null) {
			
			WfRType type = rel.getRelation().getType();

			EventPair ep;
			EventPair invpair;
			
			ep = new EventPair(rel.getRelation().getEvent1(), rel.getRelation().getEvent2());
			invpair = new EventPair(rel.getRelation().getEvent2(), rel.getRelation().getEvent1());
			
			switch (type) {
			case DIRECT : {
				Long prevCount = followsCount.get(ep);
				if (prevCount == null) prevCount = 0l;
				followsCount.put(ep, prevCount + rel.getCount());
				update_follows_dep(ep);
				update_follows_dep(invpair);
				break;
			}
			case LOOPTWO : {
				Long prevCount = L2LCount.get(ep);
				if (prevCount == null) prevCount = 0l;
				L2LCount.put(ep, prevCount + rel.getCount());
				update_L2L_dep(ep);
				update_L2L_dep(invpair);
				break;
			}
			case SUCCESSOR : {
				Long prevCount = LDCount.get(ep);
				if (prevCount == null) prevCount = 0l;
				LDCount.put(ep, prevCount + rel.getCount());
				update_LD_dep(ep);
				break;
			}
			case COUNT : {
				Long prevCount = counts.get(ep);
				if (prevCount == null) prevCount = 0l;
				counts.put(ep,  prevCount + rel.getCount());
				break;
			}
			}
		} else {
//			System.err.println(this.getName() + " Received empty relation");
		}
	}
	
	private void update_LD_dep(EventPair ep) {
		Long ev1count = counts.get(new EventPair(ep.getEvent1(), ep.getEvent1()));
		Long ev2count = counts.get(new EventPair(ep.getEvent2(), ep.getEvent2()));
		if ( (ev1count != null) && (ev2count != null) ) {
			float d = 2*( (float)LDCount.get(ep) - (float)Math.abs(ev1count-ev2count) )/(ev1count + ev2count + 1);
			DependencyRelation rel = new DependencyRelation(ep.getEvent1(), ep.getEvent2(), WfRType.SUCCESSOR, d);
			LDDep.put(ep, rel);
		}
	}

	private void update_L2L_dep(EventPair ep) {		
		EventPair pe = new EventPair(ep.getEvent2(), ep.getEvent1());
		Long epcount = L2LCount.get(ep);
		if (epcount == null) {
			epcount = 0l;
			L2LCount.put(ep,  epcount);
		}
		Long pecount = L2LCount.get(pe);
		if (pecount == null) {
			pecount = 0l;
			L2LCount.put(pe, pecount);
		}
		// In the original FHM papers, this is (epcount + pecount)/(epcount + precount + 1)
		// which implies that the relationship is symmetrical. This is clearly not the case
		float d = (float)(epcount - pecount)/((float)(epcount + pecount + 1));
		DependencyRelation rel = new DependencyRelation(ep.getEvent1(), ep.getEvent2(), WfRType.LOOPTWO, d);
		L2LDep.put(ep, rel);
	}

	private void update_follows_dep(EventPair ep) {
		Long epcount = followsCount.get(ep);
		
		if (epcount == null) {
			epcount = 0l;
			followsCount.put(ep,  epcount);
		}
		if (!ep.getEvent1().equals(ep.getEvent2())) {
			EventPair pe = new EventPair(ep.getEvent2(), ep.getEvent1());
			Long pecount = followsCount.get(pe);
			if (pecount == null) {
				pecount = 0l;
				followsCount.put(pe,  pecount);
			}

			float d = (float)(epcount - pecount)/((float)(epcount + pecount + 1));
			DependencyRelation rel = new DependencyRelation(ep.getEvent1(), ep.getEvent2(), WfRType.DIRECT, d);
			followsDep.put(ep, rel);
		} else {
			float d = (float)(epcount)/((float)(epcount + 1));
			DependencyRelation rel = new DependencyRelation(ep.getEvent1(), ep.getEvent2(), WfRType.DIRECT, d);
			followsDep.put(ep, rel);
		}
	}

	private void loadProps() {
		props = new Properties();
		try {
			props.load(new FileReader(fname + ".properties"));
		} catch (FileNotFoundException e) {
			System.err.println(this.getName() + " " + e.getMessage());
		} catch (IOException e) {
			System.err.println(this.getName() + " " + e.getMessage());
		}	
	}
	
	private void saveProps() {
		try {
			props.store(new FileWriter(fname + ".properties"), null);
		} catch (IOException e) {
			System.err.println(this.getName() + " " + e.getMessage());
		}
	}
	
	public void run() {
		try {
			
			loadProps();
			String checkPoint = props.getProperty(inputShard.getShardId());
			
			GetShardIteratorRequest getShardIteratorRequest = new GetShardIteratorRequest();
			getShardIteratorRequest.setStreamName(inputStreamName);
			getShardIteratorRequest.setShardId(inputShard.getShardId());
			
			// FIXME: DEBUG
//			checkPoint = null;
			
			if (checkPoint != null) {
				getShardIteratorRequest.setShardIteratorType("AFTER_SEQUENCE_NUMBER");
				getShardIteratorRequest.setStartingSequenceNumber(checkPoint);
			} else {
				getShardIteratorRequest.setShardIteratorType("TRIM_HORIZON");
			}
			GetShardIteratorResult getShardIteratorResult = amazonKinesisClient.getShardIterator(getShardIteratorRequest);
			String shardIterator = getShardIteratorResult.getShardIterator();

			System.err.println(this.getName() + " Worker started");
			
			long previous_behind = Long.MAX_VALUE;
			long current_behind = 0l;
			int count = 0;
			
			while (true) try {
				// Create a new getRecordsRequest with an existing shardIterator 
				GetRecordsRequest getRecordsRequest = new GetRecordsRequest();
				getRecordsRequest.setShardIterator(shardIterator);
				getRecordsRequest.setLimit(inputBatchSize); 

				GetRecordsResult result = null;
				while (result == null) {
					try {
						result = amazonKinesisClient.getRecords(getRecordsRequest);
					} catch (ExpiredIteratorException e) {
//						System.err.println(this.getName() + " Iterator expired, getting new one");
						getShardIteratorResult = amazonKinesisClient.getShardIterator(getShardIteratorRequest);
						shardIterator = getShardIteratorResult.getShardIterator();
						getRecordsRequest = new GetRecordsRequest();
						getRecordsRequest.setShardIterator(shardIterator);
						getRecordsRequest.setLimit(inputBatchSize); 
						result = amazonKinesisClient.getRecords(getRecordsRequest);
					} catch (ProvisionedThroughputExceededException e) {
						System.err.println(this.getName() + " Read throughput exceeded, waiting a turn");
						throttled = true;
						if (threadsleep > 50)
							threadsleep*=1.05;
						else 
							if (inputBatchSize < 20) 
								inputBatchSize++;
							else 
								inputBatchSize = Math.min( maxInputBatchSize, Math.round(inputBatchSize*1.05f) );
						try { Thread.sleep(threadsleep/10); } catch (InterruptedException ee) { }
					}
				}
				inputReporter.add(result.getRecords().size());

				long starttime = System.currentTimeMillis();
				for (Record record : result.getRecords()) {
					processRecord(record);
					checkPoint = record.getSequenceNumber();
				}
				emitDependenciesToStream();
				long endtime = System.currentTimeMillis();
				
				//FIXME: DEBUG
//				checkPoint = null;
				
				if (checkPoint != null) {
					props.setProperty(inputShard.getShardId(), checkPoint);
					saveProps();
					File f = new File(fname+".persistedCounts");
					ObjectOutputStream oos;
					try {
						oos = new ObjectOutputStream(new FileOutputStream(f, false));
						oos.writeObject(counts);
						oos.writeObject(followsCount);
						oos.writeObject(L2LCount);
						oos.writeObject(LDCount);
					} catch (FileNotFoundException e) {
						System.err.println(this.getName() + " " + e.getMessage());
					} catch (IOException e) {
						System.err.println(this.getName() + " " + e.getMessage());
					}
				}

				previous_behind = current_behind;
				current_behind = result.getMillisBehindLatest();
				if (!throttled) {
					if (count>0) count--;
					System.err.println(outputString + ", processed " + result.getRecords().size() + "/" + inputBatchSize + " records in " + (endtime-starttime) + " millis at " + current_behind + " millis behind tip, threadsleep " + threadsleep + " count " + count);
				}
				else {
					System.err.println(outputString + ", processed " + result.getRecords().size() + "/" + inputBatchSize + " records in " + (endtime-starttime) + " millis at " + current_behind + " millis behind tip, threadsleep " + threadsleep + " count " + count + " === Throttled ===");
					count=5;
				}

				if ( (result.getRecords().size() == inputBatchSize) && (current_behind > 10000) && (current_behind > previous_behind) && !throttled & count==0)
					// speeding up by shortening our sleep period
					if (threadsleep > 50)
						threadsleep*=0.95;
					else
						if (inputBatchSize < 20) 
							inputBatchSize++;
						else 
							inputBatchSize = Math.min( maxInputBatchSize, Math.round(inputBatchSize*1.05f) );
				if ( result.getRecords().size() < inputBatchSize/2 ) {
					if (threadsleep < 400)
						threadsleep*=1.05;
					else
						inputBatchSize*=0.95;
				}
				throttled=false;

				try { Thread.sleep(threadsleep); } catch (InterruptedException exception) {	saveProps(); }
	
				shardIterator = result.getNextShardIterator();
			} catch (Exception e) { System.err.println(this.getName() + " " + e.getLocalizedMessage()); }	
		} catch (Exception e) {
			saveProps();
			throw e;
		}
	}

}
