# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import random
import numpy


def _read_words(filename):
  with open(filename, "r") as f:
    return f.read().replace("\n", " ").split()


def _build_vocab(filename):
  data = _read_words(filename)

  counter = collections.Counter(data)
  count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))

  words, _ = list(zip(*count_pairs))
  word_to_id = dict(zip(words, range(len(words))))

  return word_to_id

def numBatches(filename, batchSize, minLen):
  with open(filename, "r") as f:
    sentences = f.readlines()
    sentences = [s.split() for s in sentences]
    t = [ 1 for s in sentences if len(s) > minLen + 1]
    return int( sum(t) / batchSize )

def readBatchOfSentences(filename, batchSize, minLen, startpos):
  with open(filename, "r") as f:
    sentences = f.readlines()
    sentences = [ s.split() for s in sentences ]
    t = [ (s[0:minLen], s[minLen:]) for s in sentences if len(s) > minLen + 1]
    outp = numpy.empty( (batchSize, minLen), 'object' )
    outs = numpy.empty( (batchSize), 'object' )
    for c in range(startpos,startpos+batchSize):
      outp[c - startpos] = t[c][0]
      outs[c - startpos] = t[c][1]
    return outp, outs

def write_dict(filename, dict):
  with open(filename, "w") as f:
    for word in dict.keys():
      f.write("{0} {1}\n".format(word, dict[word]))
    f.close()

def vocabSize(data_path=None, stem=None):
  inputFName = os.path.join(data_path, stem)
  word2Id = _build_vocab(inputFName)
  return len(word2Id)

def _file_to_word_ids(filename, word_to_id):
  data = _read_words(filename)
  return [word_to_id[word] for word in data]


def dataFolds(data_path=None, input_stem=None, target_stem=None, numFolds=10, shuffle=True):
  inputFName = os.path.join(data_path, input_stem)
  targetFName = os.path.join(data_path, target_stem)

  inputWord2Id = _build_vocab(inputFName)
  targetWord2Id = _build_vocab(targetFName)

  inputFile = open(inputFName, "r")
  inputTraces = inputFile.read().split('\n')
  inputFile.close()

  targetFile = open(targetFName, "r")
  targetTraces = targetFile.read().split('\n')
  targetFile.close()

  if (len(inputTraces) != len(targetTraces)):
      return None, None, None

  if (shuffle):
    indices = numpy.random.permutation(len(inputTraces))
    shuffledInputTraces = numpy.empty(len(inputTraces), dtype=numpy.object)
    numpy.put(shuffledInputTraces, indices, inputTraces)
    shuffledTargetTraces = numpy.empty(len(targetTraces), dtype=numpy.object)
    numpy.put(shuffledTargetTraces, indices, targetTraces)
  else:
    shuffledInputTraces = inputTraces
    shuffledTargetTraces = targetTraces

  nrow = len(shuffledInputTraces)
  foldSize = nrow // numFolds

  inputFolds = [None] * numFolds
  targetFolds = [None] * numFolds
  for k in range(0, numFolds):
    inputFolds[k] = shuffledInputTraces[k*foldSize:(k+1)*foldSize]
    targetFolds[k] = shuffledTargetTraces[k * foldSize:(k + 1) * foldSize]

  inputVocabulary = len(inputWord2Id)
  targetVocabulary = len(targetWord2Id)

  return inputFolds, targetFolds, inputWord2Id, targetWord2Id, inputVocabulary, targetVocabulary


def validationAndTrainingData(inputFolds, targetFolds, inputWord2Id, targetWord2Id, numFold):
  inputValidationData = inputFolds[numFold]
  inputValidationWords = " ".join(inputValidationData)
  targetValidationData = targetFolds[numFold]
  targetValidationWords = " ".join(targetValidationData)

  if (len(inputFolds) == 1):
    inputTrainingWords = inputValidationWords
    targetTrainingWords = targetValidationWords
  else:
    inputTrainingData = numpy.delete(inputFolds, numFold, 0)
    inputTrainingWords = " ".join(numpy.reshape(inputTrainingData,(-1,)))
    targetTrainingData = numpy.delete(targetFolds, numFold, 0)
    targetTrainingWords = " ".join(numpy.reshape(targetTrainingData,(-1,)))

  return [inputWord2Id[word] for word in inputValidationWords.split()], \
         [inputWord2Id[word] for word in inputTrainingWords.split()], \
         [targetWord2Id[word] for word in targetValidationWords.split()], \
         [targetWord2Id[word] for word in targetTrainingWords.split()]

def words_iterator(input_raw_data, target_raw_data, batch_size, num_steps, overlapData):
  """Iterate on the raw PTB data.

  This generates batch_size pointers into the raw PTB data, and allows
  minibatch iteration along these pointers.

  Args:
    raw_data: one of the raw data outputs from ptb_raw_data.
    batch_size: int, the batch size.
    num_steps: int, the number of unrolls.

  Yields:
    Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
    The second element of the tuple is the same data time-shifted to the
    right by one.

  Raises:
    ValueError: if batch_size or num_steps are too high.
  """
  input_raw_data = numpy.array(input_raw_data, dtype=numpy.int32)
  target_raw_data = numpy.array(target_raw_data, dtype=numpy.int32)

  data_len = len(input_raw_data)
  batch_len = data_len // batch_size
  inputdata = numpy.zeros([batch_size, batch_len], dtype=numpy.int32)
  targetdata = numpy.zeros([batch_size, batch_len], dtype=numpy.int32)
  for i in range(batch_size):
    inputdata[i] = input_raw_data[batch_len * i:batch_len * (i + 1)]
    targetdata[i] = target_raw_data[batch_len * i:batch_len * (i + 1)]

  if (overlapData == False):
    epoch_size = (batch_len - 1) // num_steps
  else:
    epoch_size = (batch_len - 1 - num_steps)

  if epoch_size == 0:
    raise ValueError("epoch_size == 0, decrease batch_size or num_steps")

  if (overlapData == False):
    for i in range(epoch_size):
      # WHY: Do we have to multiply by num_steps here? Why not let them overlap?
      x = inputdata[:, i*num_steps:(i+1)*num_steps]
      y = targetdata[:, i*num_steps+1:(i+1)*num_steps+1]
      yield (x, y)
  else:
    for i in range(epoch_size):
      # WHY: Do we have to multiply by num_steps here? Why not let them overlap?
      x = inputdata[ :, i:(i+num_steps)]
      y = targetdata[:, (i+1):(i+1+num_steps)]
      yield (x, y)

