/*
  uint2048_t Class
  Written by John Ryland
  (C) Copyright 2015
*/
#include "Integer.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include "../../Util.h"

enum ComparisonResult
{
  CR_Equal,
  CR_LessThan,
  CR_GreaterThan
};

struct uint2048_t::Pimpl
{
  void Init(const char* str);
  void Copy(const uint2048_t& other);
  void Normalize();
  size_t GetSize();

  static ComparisonResult Compare(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Add(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Subtract(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Multiply(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Divide(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Modulus(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t ShiftLeft(const uint2048_t& a, int shift);
  static uint2048_t ShiftRight(const uint2048_t& a, int shift);
  static uint2048_t Or(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t And(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Xor(const uint2048_t& a, const uint2048_t& b);
  static uint2048_t Not(const uint2048_t& a);

  // Helper function to optimize the ExpMod function
  static void MulMod(uint8_t* a, uint8_t* b, uint8_t* mod, size_t an, size_t bn, size_t mn, uint8_t* dest);

#define s_size  257 // Fixed at 2048 bits (256*8)
  uint8_t m_data[s_size];
};

uint2048_t::uint2048_t() // Default constructor
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Init("0");
}

uint2048_t::uint2048_t(const uint2048_t& other) // Copy constructor
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Copy(other);
}

uint2048_t::uint2048_t(uint64_t a_number) // Create from a number
 : m_pimpl(std::make_unique<Pimpl>())
{
  std::string numStr = int2hex(a_number);
  //printf("\n init from int: 0x%x -> %s\n", int(a_number), numStr.c_str());
  m_pimpl->Init(numStr.c_str());
  //m_pimpl->Init(int2hex(a_number).c_str());
}

uint2048_t::uint2048_t(const char* str) // Create from a string
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Init(str);
}

uint2048_t::~uint2048_t() // Destructor
{
}

void uint2048_t::operator=(const uint2048_t& other) // Assignment operator
{
  m_pimpl->Copy(other);
}

bool uint2048_t::operator!() const // Is zero
{
  return (Pimpl::Compare(*this, uint2048_t()) == CR_Equal);
}

bool uint2048_t::operator!=(const uint2048_t& other) const
{
  return (Pimpl::Compare(*this, other) != CR_Equal);
}

bool uint2048_t::operator==(const uint2048_t& other) const
{
  return (Pimpl::Compare(*this, other) == CR_Equal);
}

bool uint2048_t::operator<(const uint2048_t& other) const
{
  return (Pimpl::Compare(*this, other) == CR_LessThan);
}

bool uint2048_t::operator>(const uint2048_t& other) const
{
  return (Pimpl::Compare(*this, other) == CR_GreaterThan);
}

bool uint2048_t::operator<=(const uint2048_t& other) const
{
  ComparisonResult result = Pimpl::Compare(*this, other);
  return (result == CR_Equal || result == CR_LessThan);
}

bool uint2048_t::operator>=(const uint2048_t& other) const
{
  ComparisonResult result = Pimpl::Compare(*this, other);
  return (result == CR_Equal || result == CR_GreaterThan);
}

bool uint2048_t::operator&&(const uint2048_t& other) const
{
  return (!!*this) && (!!other);
}

bool uint2048_t::operator||(const uint2048_t& other) const
{
  return (!!*this) || (!!other);
}

uint2048_t& uint2048_t::operator--()
{
  *this = Pimpl::Subtract(*this, uint2048_t(1));
  return *this;
}

uint2048_t& uint2048_t::operator++()
{
  *this = Pimpl::Add(*this, uint2048_t(1));
  return *this;
}

uint2048_t uint2048_t::operator++(int) 
{
  uint2048_t tmp(*this);
  operator++();
  return tmp;
}

uint2048_t uint2048_t::operator--(int) 
{
  uint2048_t tmp(*this);
  operator--();
  return tmp;
}

uint2048_t& uint2048_t::operator-=(const uint2048_t& other)
{
  *this = Pimpl::Subtract(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator+=(const uint2048_t& other)
{
  *this = Pimpl::Add(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator*=(const uint2048_t& other) 
{
  *this = Pimpl::Multiply(*this, other);
  return *this;
}

uint2048_t uint2048_t::operator-(const uint2048_t& other) const
{
  return Pimpl::Subtract(*this, other);
}

uint2048_t uint2048_t::operator+(const uint2048_t& other) const
{
  return Pimpl::Add(*this, other);
}

uint2048_t uint2048_t::operator*(const uint2048_t& other) const
{
  return Pimpl::Multiply(*this, other);
}

uint2048_t uint2048_t::operator/(const uint2048_t& other) const
{
  return Pimpl::Divide(*this, other);
}

uint2048_t uint2048_t::operator%(const uint2048_t& other) const
{
  return Pimpl::Modulus(*this, other);
}

uint2048_t uint2048_t::operator&(const uint2048_t& other) const
{
  return Pimpl::And(*this, other);
}

uint2048_t uint2048_t::operator|(const uint2048_t& other) const
{
  return Pimpl::Or(*this, other);
}

uint2048_t uint2048_t::operator^(const uint2048_t& other) const
{
  return Pimpl::Xor(*this, other);
}

uint2048_t uint2048_t::operator~() const
{
  return Pimpl::Not(*this);
}

uint2048_t uint2048_t::operator>>(int shift) const
{
  return Pimpl::ShiftRight(*this, shift);
}

uint2048_t uint2048_t::operator<<(int shift) const
{
  return Pimpl::ShiftLeft(*this, shift);
}

uint2048_t& uint2048_t::operator>>=(int shift)
{
  *this = Pimpl::ShiftRight(*this, shift);
  return *this;
}

uint2048_t& uint2048_t::operator<<=(int shift)
{
  *this = Pimpl::ShiftLeft(*this, shift);
  return *this;
}

uint2048_t& uint2048_t::operator/=(const uint2048_t& other)
{
  *this = Pimpl::Divide(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator%=(const uint2048_t& other)
{
  *this = Pimpl::Modulus(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator&=(const uint2048_t& other)
{
  *this = Pimpl::And(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator|=(const uint2048_t& other)
{
  *this = Pimpl::Or(*this, other);
  return *this;
}

uint2048_t& uint2048_t::operator^=(const uint2048_t& other)
{
  *this = Pimpl::Xor(*this, other);
  return *this;
}

uint2048_t::operator uint64_t() const
{
  return (uint64_t)strtoull(std::string(*this).c_str(), 0, 16);
}

void uint2048_t::Print() const
{
  std::cout << std::string(*this);
}

// base ^ exponent % modulus
uint2048_t uint2048_t::ExpMod(const uint2048_t& base, const uint2048_t& exponent, const uint2048_t& modulus)
{
  uint2048_t dst(1);
  uint2048_t a = base % modulus;
  uint2048_t b = exponent % modulus;

  size_t mn = modulus.m_pimpl->GetSize();
  const uint8_t *modBits = modulus.m_pimpl->m_data;
  uint8_t *aBits = a.m_pimpl->m_data;
  size_t an = a.m_pimpl->GetSize();

#if 1 
  // pre-computed modulus shifted by 0 to 7
  uint8_t tmpModBits[s_size*16];
  for (int shift = 0; shift < 8; shift++)
  {
    uint32_t overflow = 0;
    uint8_t* mb = &tmpModBits[s_size*2*shift];
    memset(mb, 0, s_size*2);

    for (int i = 0; i < s_size; i++)
    {
      uint32_t var = modBits[i] << shift;
      mb[i] = var | overflow;
      overflow = var >> 8;
    }
  }
#endif

  uint8_t tmp1Bits[s_size*2];
  uint8_t *dstBits = dst.m_pimpl->m_data;

  //printf("\n a: %s\n", std::string(a).c_str());
  //printf("\n b: %s\n", std::string(b).c_str());
  while (!!b) // while (bnbits)
  {
    if (b[0] & 1)  // if (!!(b & BigInteger(1)))
    {
      //dst = (dst * a) % modulus;
      memset(tmp1Bits, 0, s_size*2);
      size_t dn = dst.m_pimpl->GetSize();
      Pimpl::MulMod(dstBits, aBits, tmpModBits, dn, an, mn, tmp1Bits);
      memcpy(dstBits, tmp1Bits, s_size);
    }

    //a = (a * a) % modulus;
    memset(tmp1Bits, 0, s_size*2);
    Pimpl::MulMod(aBits, aBits, tmpModBits, an, an, mn, tmp1Bits);
    memcpy(aBits, tmp1Bits, s_size);
    an = a.m_pimpl->GetSize();
    
    b >>= 1; // perhaps could instead of shifting, test a particular bit instead of lowest bit
    // bnbits--;
  }

  return dst;
}

uint2048_t::operator std::string() const
{
  std::ostringstream oss;
  oss << "0x";
  size_t n = m_pimpl->GetSize();
  for (size_t i = 0; i < n; i++)
    oss << std::setfill('0') << std::setw(2) << std::hex << (int)m_pimpl->m_data[n - i - 1];
  return oss.str();
}

size_t uint2048_t::Size() const
{
  return (sizeof(m_pimpl->m_data) / sizeof(m_pimpl->m_data[0]));
}

uint8_t uint2048_t::operator[](size_t index) const
{
  return m_pimpl->m_data[index];
}


////////////////// PRIVATE //////////////////////////////////////////////////

static unsigned int HexCharToInt(unsigned char ch)
{
  unsigned int c = tolower(ch);
  return (c >= 'a') ? c - 'a' + 10 : c - '0';
}

void uint2048_t::Pimpl::Init(const char* str)
{
  int len = strlen(str);
  memset(m_data, 0, s_size);
  size_t idx = 0;
  for (int i = 0; i < len; i++, idx++)
  {
    int val = HexCharToInt(str[len - 1 - i]);
    i++;
    if (i < len)
      val |= HexCharToInt(str[len - 1 - i]) << 4;
    m_data[idx] = val;
  }
}

void uint2048_t::Pimpl::Copy(const uint2048_t& other)
{
  memcpy(m_data, other.m_pimpl->m_data, sizeof(m_data));
}

void uint2048_t::Pimpl::Normalize()
{
}

size_t uint2048_t::Pimpl::GetSize()
{
  size_t an;
  for (an = s_size - 1; an; an--)
    if (m_data[an])
      break;
  an++;
  return an;
}

ComparisonResult uint2048_t::Pimpl::Compare(const uint2048_t& a, const uint2048_t& b)
{
  size_t an = s_size - 1;
  ComparisonResult result = CR_Equal;
  do {
    if (a[an] > b[an]) {
      result = CR_GreaterThan; // a > b
      break;
    }
    if (a[an] < b[an]) {
      result = CR_LessThan; // b > a
      break;
    }
  } while (an--);
  /*
  // This extra loop has no meaning except to try to show how we might reduce
  // timing based side channel attacks by making the running time more consistent
  int dummy = 0;
  do {
    if (a[an] > b[an]) {
      dummy++;
    }
    if (a[an] < b[an]) {
      dummy--;
    }
  } while (an--);
  */
  return result; // the two values are equal
}

uint2048_t uint2048_t::Pimpl::Or(const uint2048_t& a, const uint2048_t& b)
{
  uint2048_t result = a;
  for (size_t idx = 0; idx < s_size; idx++)
    result.m_pimpl->m_data[idx] |= b[idx];
  return result;
}

uint2048_t uint2048_t::Pimpl::And(const uint2048_t& a, const uint2048_t& b)
{
  uint2048_t result = a;
  for (size_t idx = 0; idx < s_size; idx++)
    result.m_pimpl->m_data[idx] &= b[idx];
  return result;
}

uint2048_t uint2048_t::Pimpl::Xor(const uint2048_t& a, const uint2048_t& b)
{
  uint2048_t result = a;
  for (size_t idx = 0; idx < s_size; idx++)
    result.m_pimpl->m_data[idx] ^= b[idx];
  return result;
}

uint2048_t uint2048_t::Pimpl::Not(const uint2048_t& a)
{
  uint2048_t result;
  for (size_t idx = 0; idx < s_size; idx++)
    result.m_pimpl->m_data[idx] = ~a[idx];
  return result;
}

uint2048_t uint2048_t::Pimpl::Add(const uint2048_t& a, const uint2048_t& b)
{
  uint2048_t result;
  uint8_t carry = 0; // carry / borrow flag
  for (size_t i = 0; i < s_size; i++)
  {
    // apply the carry from previous iteration
    uint32_t tmp = a[i] + carry + b[i];
    //if (i != s_size - 1)
    result.m_pimpl->m_data[i] = tmp;
    carry = tmp >> 8;
  }
  return result;
}

uint2048_t uint2048_t::Pimpl::Subtract(const uint2048_t& a, const uint2048_t& b)
{
  uint2048_t result;
  uint32_t borrow = 0; // carry / borrow flag
  for (size_t i = 0; i < s_size; i++)
  {
    // apply the borrow from previous iteration
    uint32_t tmp = a[i] - borrow - b[i];
    result.m_pimpl->m_data[i] = tmp;
    borrow = (tmp >> 8) & 1;
  }
  return result;
}

// big number multiply algorithms
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
uint2048_t uint2048_t::Pimpl::Multiply(const uint2048_t& a, const uint2048_t& b)
{
  //if (b > a) // Swap if a is not the larger value. We can now assume a is the larger value after
  //  return Multiply(b, a);

  // Work out expected output size
  size_t an = a.m_pimpl->GetSize();
  size_t bn = b.m_pimpl->GetSize();
  //if (bn > an)
  //  return Multiply(b, a);

/*
a=   1234   a[0] = 4, a[1] = 3 ...
b=  x5678
    -----
      4x8            r[0+0] +=                      (a[0] x b[0]) % 10
     +4x7x10         r[0+1] += (a[0] x b[0]) / 10 + (a[0] x b[1]) % 10
     +4x6x100        r[0+2] += (a[0] x b[1]) / 10 + (a[0] x b[2]) % 10
     +4x5x1000       r[0+3] += (a[0] x b[2]) / 10 + (a[0] x b[3]) % 10
     +3x8x10         r[1+0] +=                      (a[1] x b[0]) % 10
     +3x7x100        r[1+1] += (a[1] x b[0]) / 10 + (a[1] x b[1]) % 10
     +3x6x1000
     +3x5x10000
  -----
*/

  // Setup a destination value
  uint2048_t d;
  uint8_t dest[s_size*2];
  memset(dest, 0, s_size*2);
  for (size_t j = 0; j < an; j++)
  {
    // basically component wise matrix of a * b
    uint32_t c1 = 0, c2 = 0;
    for (size_t i = 0; i < bn; i++)
    {
      uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
      uint32_t t2 = (uint32_t)dest[i+j] + c2 + (t1 & 0xFF);
      c1 = t1 >> 8;
      c2 = t2 >> 8;
      dest[i+j] = t2;
    }
    //dest[bn+j] = dest[bn+j] + c1 + c2;
    dest[bn+j] += c1 + c2;
  }
  memcpy(d.m_pimpl->m_data, dest, s_size);
  return d;
}

uint2048_t uint2048_t::Pimpl::Divide(const uint2048_t& a, const uint2048_t& b)
{
  size_t an = a.m_pimpl->GetSize() - 1;
  size_t bn = b.m_pimpl->GetSize() - 1;
  if (bn > an)
    return uint2048_t();
  size_t bitsLarger = (an - bn + 1) << 3;
  //printf(" bl: %i \n", bitsLarger);
  uint2048_t bit = uint2048_t(1) << bitsLarger;
  //printf(" bit: %s \n", std::string(bit).c_str());
  uint2048_t tmp = b << bitsLarger;
  //printf(" tmp: %s \n", std::string(tmp).c_str());
  uint2048_t mod = a;
  uint2048_t res = uint2048_t();
  do {
    if (mod >= tmp) {
      mod -= tmp;
      res += bit;
    }
    tmp >>= 1;
    bit >>= 1;
  } while (bitsLarger--);
  return res;
}

uint2048_t uint2048_t::Pimpl::Modulus(const uint2048_t& a, const uint2048_t& b)
{
  size_t an = a.m_pimpl->GetSize() - 1;
  size_t bn = b.m_pimpl->GetSize() - 1;
  //printf(" an: %i bn: %i \n", an, bn);
  //printf("\n");
  if (bn > an)
  {
    //printf("\n bn > an   an: %i bn: %i \n", an, bn);
    return a;
  }
  if (bn == an)
    if (b > a)
    {
      //printf("\n b > a   an: %i bn: %i \n", an, bn);
      //printf(" a: %s \n", std::string(a).c_str());
      //printf(" b: %s \n\n", std::string(b).c_str());
      return a;
    }
  size_t bitsLarger = (an - bn + 1) << 3;
  //printf(" bl: %i \n", bitsLarger);
  uint2048_t tmp = b;
  tmp <<= bitsLarger;
  //printf(" tmp: %s \n", std::string(tmp).c_str());
  uint2048_t mod = a;
  do {
    if (mod >= tmp) {
      mod -= tmp;
      //printf(" mod: %s \n", std::string(mod).c_str());
    }
    tmp >>= 1;
    //printf(" tmp: %s \n", std::string(tmp).c_str());
  } while (bitsLarger--);
  //printf(" mod: %s \n", std::string(mod).c_str());
  return mod;
}

void uint2048_t::Pimpl::MulMod(uint8_t* a, uint8_t* b, uint8_t* mod, size_t an, size_t bn, size_t mn, uint8_t* dest)
//uint2048_t uint2048_t::Pimpl::MulMod(const uint2048_t& a, const uint2048_t& b, const uint2048_t& mod)
{
  // This is where we are doing the multiplication in to dest (which is double width)
  for (size_t j = 0; j < an; j++)
  {
    // basically component wise matrix of a * b
    uint32_t c1 = 0, c2 = 0;
    for (size_t i = 0; i < bn; i++)
    {
      uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
      uint32_t t2 = (uint32_t)dest[i+j] + c2 + (t1 & 0xFF);
      c1 = t1 >> 8;
      c2 = t2 >> 8;
      dest[i+j] = t2;
    }
    dest[bn+j] += c1 + c2;
  }

  // 'an' is now the size of dest
  an = an + bn;

  uint8_t* dstPtr = dest;

  // check the magnitude of modulus is about or less than dest (otherwise we can just return dest)
  if (mn <= an)
  {
    // This is where we are doing the modulus of  dest % mod
    size_t byteShift = an - mn + 1;
    size_t bitsLarger = byteShift << 3;

    // siz is the byte width of the numbers we need to deal with
    const size_t siz = s_size + byteShift;

    // shift 'mod' in to place directly in to 'tmp'  (eg: tmp = mod << byteShift)
    //memcpy(&tmp[byteShift], mod, s_size);

    // here we are doing the shift 'virtually' in that we just keep track of the shift amount
    // and then fix up below as applicable
    const uint8_t* tmpPtr = mod;
    size_t byt = byteShift;

    do {
      size_t idx = siz - 1;
      // we are checking if dest >= tmp. We exit the loop when they stop being equal as we scan from high
      // order bytes to low order ones. When they stop being equal (or we get to the end of the array)
      // the comparison of the last value we were at will be the comparison result for the entire value
      for (idx = siz - 1; idx >= byt; idx--)
        if (dest[idx] != tmpPtr[idx-byt]) // here 'tmp' is indexed in a way as if it was shifted by byt bytes
          break;
      if (idx < byt)
        for (idx = byt - 1; idx; idx--)
          if (dest[idx]) // here we are in the area of the virtual 'tmp' that was shifted out to be filled with virtual 0s so this collapsed to just be checking if dest[idx] != 0
            break;
      /*
      while (idx && dest[idx] == ((idx >= byt) ? tmpPtr[idx-byt] : 0) ) // Hot code path
      {
        idx--;
      }
      */
      if (dest[idx] >= ((idx >= byt) ? tmpPtr[idx-byt] : 0) )
      {
        // here we subtract 'tmp' from 'dest'  (eg: dest -= tmp)
        uint32_t borrow = 0; // carry / borrow flag
        //for (size_t i = 0; i < siz; i++)
        for (size_t i = byt; i < siz; i++)
        {
          // apply the borrow from previous iteration
          // uint32_t x = dest[i] - borrow - ((i >= byt) ? tmpPtr[i-byt] : 0); // tmp[i];  // Hot
          /*
          uint32_t x = dest[i] - borrow - tmpPtr[i-byt]; // Hot code path
          dest[i] = x;
          borrow = (x >> 8) & 1;
          */
          uint32_t t = tmpPtr[i-byt]; // OOO
          uint32_t sub = borrow + t;
          uint32_t x = dest[i]; // OOO
          x = x - sub;
          dest[i] = x; // OO
          x >>= 8;
          x &= 1;
          borrow = x;// (x >> 8) & 1;
        }
      }

      if (bitsLarger--)
      {
        // 'tmp' = 'mod' << bitsLarger
        // but we just update the virtual tmp values (mod has precomputed shifts from 0-7)
        tmpPtr = mod + s_size*2*(bitsLarger & 7);
        byt = bitsLarger >> 3;
      }
      bitsLarger++;
    } while (bitsLarger--);
  }
}

uint2048_t uint2048_t::Pimpl::ShiftLeft(const uint2048_t& a, int shift)
{
  // Because the data is stored in reverse endian order, the left shift is really shuffling the bytes right
  // Within the bytes we still shift left
  uint2048_t res;
  size_t byteShift = shift >> 3;
  uint8_t overflow = 0;
  shift &= 7;
  for (int i = byteShift; i < s_size; i++)
  {
    uint32_t var = a[i - byteShift] << shift;
    res.m_pimpl->m_data[i] = var | overflow;
    overflow = var >> 8;
  }
  return res;
}

uint2048_t uint2048_t::Pimpl::ShiftRight(const uint2048_t& a, int shift)
{
  // Because the data is stored in reverse endian order, the right shift is really shuffling the bytes left
  // Within bytes we still shift right
  // Shift right is easier because we don't need to change any buffer sizes
  // TODO: check we are zeroing out properly, or normalizing which might be easier
  uint2048_t res;
  size_t byteShift = shift >> 3;
  uint8_t overflow = 0;
  shift &= 7;
  //for (size_t i = s_size - 1; i; i--)
  for (size_t i = byteShift; i < s_size - 1; i++)
  {
    uint8_t var = a[i] >> shift;
    overflow    = a[i+1] << (8 - shift);
    res.m_pimpl->m_data[i-byteShift] = var | overflow;
  }
  res.m_pimpl->m_data[s_size-1-byteShift] = a[s_size-1] >> shift;
  return res;
}

