Newer
Older
Import / applications / MakePDF / Security / BigInteger / BigInteger.cpp
/*
  Big Integer Class
  Written by John Ryland (C) Copyright 2015
*/
#include "BigInteger.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include "../../Util.h"

enum ComparisonResult
{
  CR_Equal,
  CR_LessThan,
  CR_GreaterThan
};

struct BigInteger::Pimpl
{
  void Init(const char* str);
  void Copy(const BigInteger& other);
  void Normalize();

  static ComparisonResult Compare(const BigInteger& a, const BigInteger& b);
  static BigInteger Add(const BigInteger& a, const BigInteger& b);
  static BigInteger Subtract(const BigInteger& a, const BigInteger& b);
  static BigInteger Multiply(const BigInteger& a, const BigInteger& b);
  static BigInteger Divide(const BigInteger& a, const BigInteger& b, BigInteger& modulusResult);
  static BigInteger Modulus(const BigInteger& a, const BigInteger& b);
  static BigInteger ShiftLeft(const BigInteger& a, int shift);
  static BigInteger ShiftRight(const BigInteger& a, int shift);
  static BigInteger Or(const BigInteger& a, const BigInteger& b);
  static BigInteger And(const BigInteger& a, const BigInteger& b);
  static BigInteger Xor(const BigInteger& a, const BigInteger& b);
  static BigInteger Not(const BigInteger& a);

  std::vector<uint8_t> m_data;
};

BigInteger::BigInteger() // Default constructor
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Init("0");
}

BigInteger::BigInteger(const BigInteger& other) // Copy constructor
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Copy(other);
}

BigInteger::BigInteger(uint64_t a_number) // Create from a number
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Init(int2hex(a_number).c_str());
}

BigInteger::BigInteger(const char* str) // Create from a string
 : m_pimpl(std::make_unique<Pimpl>())
{
  m_pimpl->Init(str);
}

BigInteger::~BigInteger() // Destructor
{
}

void BigInteger::operator=(const BigInteger& other) // Assignment operator
{
  m_pimpl->Copy(other);
}

bool BigInteger::operator!() const // Is zero
{
  return (Pimpl::Compare(*this, BigInteger()) == CR_Equal);
}

bool BigInteger::operator!=(const BigInteger& other) const
{
  return (Pimpl::Compare(*this, other) != CR_Equal);
}

bool BigInteger::operator==(const BigInteger& other) const
{
  return (Pimpl::Compare(*this, other) == CR_Equal);
}

bool BigInteger::operator<(const BigInteger& other) const
{
  return (Pimpl::Compare(*this, other) == CR_LessThan);
}

bool BigInteger::operator>(const BigInteger& other) const
{
  return (Pimpl::Compare(*this, other) == CR_GreaterThan);
}

bool BigInteger::operator<=(const BigInteger& other) const
{
  ComparisonResult result = Pimpl::Compare(*this, other);
  return (result == CR_Equal || result == CR_LessThan);
}

bool BigInteger::operator>=(const BigInteger& other) const
{
  ComparisonResult result = Pimpl::Compare(*this, other);
  return (result == CR_Equal || result == CR_GreaterThan);
}

bool BigInteger::operator&&(const BigInteger& other) const
{
  return (!!*this) && (!!other);
}

bool BigInteger::operator||(const BigInteger& other) const
{
  return (!!*this) || (!!other);
}

BigInteger& BigInteger::operator--()
{
  *this = Pimpl::Subtract(*this, BigInteger(1));
  return *this;
}

BigInteger& BigInteger::operator++()
{
  *this = Pimpl::Add(*this, BigInteger(1));
  return *this;
}

BigInteger BigInteger::operator++(int) 
{
  BigInteger tmp(*this);
  operator++();
  return tmp;
}

BigInteger BigInteger::operator--(int) 
{
  BigInteger tmp(*this);
  operator--();
  return tmp;
}

BigInteger& BigInteger::operator-=(const BigInteger& other)
{
  *this = Pimpl::Subtract(*this, other);
  return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other)
{
  *this = Pimpl::Add(*this, other);
  return *this;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) 
{
  *this = Pimpl::Multiply(*this, other);
  return *this;
}

BigInteger BigInteger::operator-(const BigInteger& other) const
{
  return Pimpl::Subtract(*this, other);
}

BigInteger BigInteger::operator+(const BigInteger& other) const
{
  return Pimpl::Add(*this, other);
}

BigInteger BigInteger::operator*(const BigInteger& other) const
{
  return Pimpl::Multiply(*this, other);
}

BigInteger BigInteger::operator/(const BigInteger& other) const
{
  BigInteger modulus;
  return Pimpl::Divide(*this, other, modulus);
}

BigInteger BigInteger::operator%(const BigInteger& other) const
{
  return Pimpl::Modulus(*this, other);
}

BigInteger BigInteger::operator&(const BigInteger& other) const
{
  return Pimpl::And(*this, other);
}

BigInteger BigInteger::operator|(const BigInteger& other) const
{
  return Pimpl::Or(*this, other);
}

BigInteger BigInteger::operator^(const BigInteger& other) const
{
  return Pimpl::Xor(*this, other);
}

BigInteger BigInteger::operator~() const
{
  return Pimpl::Not(*this);
}

BigInteger BigInteger::operator>>(int shift) const
{
  return Pimpl::ShiftRight(*this, shift);
}

BigInteger BigInteger::operator<<(int shift) const
{
  return Pimpl::ShiftLeft(*this, shift);
}

BigInteger& BigInteger::operator>>=(int shift)
{
  *this = Pimpl::ShiftRight(*this, shift);
  return *this;
}

BigInteger& BigInteger::operator<<=(int shift)
{
  *this = Pimpl::ShiftLeft(*this, shift);
  return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other)
{
  BigInteger modulus;
  *this = Pimpl::Divide(*this, other, modulus);
  return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other)
{
  *this = Pimpl::Modulus(*this, other);
  return *this;
}

BigInteger& BigInteger::operator&=(const BigInteger& other)
{
  *this = Pimpl::And(*this, other);
  return *this;
}

BigInteger& BigInteger::operator|=(const BigInteger& other)
{
  *this = Pimpl::Or(*this, other);
  return *this;
}

BigInteger& BigInteger::operator^=(const BigInteger& other)
{
  *this = Pimpl::Xor(*this, other);
  return *this;
}

BigInteger::operator uint64_t() const
{
  return (uint64_t)strtoull(std::string(*this).c_str(), 0, 16);
}

void BigInteger::Print() const
{
  std::cout << std::string(*this);
}

BigInteger BigInteger::DivMod(const BigInteger& num, const BigInteger& denom, BigInteger& mod)
{
  return Pimpl::Divide(num, denom, mod);
}

// base ^ exponent % modulus
BigInteger BigInteger::ExpMod(const BigInteger& base, const BigInteger& exponent, const BigInteger& modulus)
{
  BigInteger dst(1);
  BigInteger a = base % modulus;
  BigInteger b = exponent % modulus;
  printf("\n a: %s\n", std::string(a).c_str());
  printf("\n b: %s\n", std::string(b).c_str());
  while (!!b)
  {
    if (b[0] & 1)  // if (!!(b & BigInteger(1)))
    {
      //printf("\n XXXXX \n");
      dst = (dst * a) % modulus;
      //printf("\n dst: %s\n", std::string(dst).c_str());
      //printf("\n dst: %s\n", std::string(dst).c_str());
    }
    a = a * a;
    printf("\n a1: %s\n", std::string(a).c_str());
    a = a % modulus;
    //printf("\n a: %s\n", std::string(a).c_str());
    /*
    a = (a * a) % modulus;
    printf("\n a: %s\n", std::string(a).c_str());
    */
    //printf("\n a2: %s\n", std::string(a).c_str());
    b >>= 1;
  }
  return dst;
}

BigInteger::operator std::string() const
{
  std::ostringstream oss;
  oss << "0x";
  for (int i = 0; i < Size(); i++)
    oss << std::setfill('0') << std::setw(2) << std::hex << (int)m_pimpl->m_data[Size() - 1 - i];
  return oss.str();
}

size_t BigInteger::Size() const
{
  return m_pimpl->m_data.size();
}

uint8_t BigInteger::operator[](size_t index) const
{
  return m_pimpl->m_data[index];
}

#if defined(BIG_INTEGER_ALLOW_DIRECT_ACCESS)
uint8_t& BigInteger::operator[](size_t index)
{
  return m_pimpl->m_data[index];
}
#endif


////////////////// PRIVATE //////////////////////////////////////////////////

static unsigned int HexCharToInt(unsigned char ch)
{
  unsigned int c = tolower(ch);
  return (c >= 'a') ? c - 'a' + 10 : c - '0';
}

void BigInteger::Pimpl::Init(const char* str)
{
  int len = strlen(str);
  int byteLen = ((len+1) / 2);
  m_data.resize(byteLen); 
  int idx = len - 2;
  for (int i = 0; i < byteLen; i++)
  {
    int val = 0;
    // TODO: how to tell if the endian order is right?
    val  = (idx+1 < len) ? HexCharToInt(str[idx + 1]) : 0; // TODO: optimize away the ?
    val |= (idx   < len) ? (HexCharToInt(str[idx]) << 4) : 0;
    m_data[i] = val; // This is assuming the string is in reverse order, low order to high order
    idx -= 2;
  }
}

void BigInteger::Pimpl::Copy(const BigInteger& other)
{
  m_data = other.m_pimpl->m_data;
}

void BigInteger::Pimpl::Normalize()
{
  // Assumes at index 0 is the low order data (eg the smaller placed values)
  int idx = m_data.size() - 1;
  while (idx && !m_data[idx])
    idx--;
  m_data.resize(idx + 1);
}

ComparisonResult BigInteger::Pimpl::Compare(const BigInteger& a, const BigInteger& b)
{
  // the two numbers don't have to be normalized
  size_t an = a.Size() - 1;
  size_t bn = b.Size() - 1;
  while (an > bn)
    if (a[an--]) // a > b
      return CR_GreaterThan; // Checks if we have any high order values above
                             // the highest possible high order value in other
  while (an < bn)
    if (b[bn--]) // b > a
      return CR_LessThan; // Reversed check, does other have any higher order values than us

  // now an == bn, we need to check more carefully
  do {
    if (a[an] > b[an])
      return CR_GreaterThan; // a > b
    if (a[an] < b[an])
      return CR_LessThan; // b > a
  } while (an--);
  return CR_Equal; // the two values are equal
}

BigInteger BigInteger::Pimpl::Or(const BigInteger& a, const BigInteger& b)
{
  if (a < b) // We use the larger of the two (X | 0 -> X)
    return Or(b, a);
  BigInteger result = a;
  size_t n = b.Size();
  for (size_t idx = 0; idx < n; idx++)
    result.m_pimpl->m_data[idx] |= b[idx];
  return result;
}

BigInteger BigInteger::Pimpl::And(const BigInteger& a, const BigInteger& b)
{
  if (a > b) // We use the smaller of the two (X & 0 -> 0)
    return And(b, a);
  BigInteger result = a;
  size_t n = a.Size();
  for (size_t idx = 0; idx < n; idx++)
    result.m_pimpl->m_data[idx] &= b[idx];
  return result;
}

BigInteger BigInteger::Pimpl::Xor(const BigInteger& a, const BigInteger& b)
{
  if (a < b) // We use the larger of the two (X ^ 0 -> X)
    return Xor(b, a);
  BigInteger result = a;
  size_t n = b.Size();
  for (size_t idx = 0; idx < n; idx++)
    result.m_pimpl->m_data[idx] ^= b[idx];
  return result;
}

BigInteger BigInteger::Pimpl::Not(const BigInteger& a)
{
  BigInteger result = a;
  size_t n = a.Size();
  for (size_t idx = 0; idx < n; idx++)
    result.m_pimpl->m_data[idx] = ~a[idx];
  return result;
}

BigInteger BigInteger::Pimpl::Add(const BigInteger& a, const BigInteger& b)
{
  // Assumes a is as large as b, swap to make a the larger one we add b in to
  // TODO: perhaps need to expand for overflow eg:  0x80 + 0x80 = 0x0100, we need to be one byte larger
  if (a < b)
    return Add(b, a);
  BigInteger result = a;
  const BigInteger& other = b;
  uint8_t carry = 0; // carry / borrow flag
  uint8_t old;
  size_t idx = 0;
  size_t on = other.Size() - 1;
  
  do {
    /*
       uint32_t tmp = dst[i];
       tmp += c;
       tmp += buf[i];
       dst[i] = tmp & 0xFF;
       c = tmp >> 8;
     */
    // apply the carry from previous iteration
    old = result[idx];
    result.m_pimpl->m_data[idx] += carry;
    carry = 0;

    // if it wrapped over from applying the carry, we need to carry again for next iteration
    if (old > result[idx])
      carry = 1;

    if (idx <= on) // TODO: optimize this check out of the loop
    {
      // now apply the subtraction and carry if that wraps too
      old = result[idx];
      result.m_pimpl->m_data[idx] += other[idx];
      if (old > result[idx])
        carry = 1;
    }

  } while (idx++ != (on + 1));

  return result;
}

BigInteger BigInteger::Pimpl::Subtract(const BigInteger& a, const BigInteger& b)
{
  // TODO: check pre-condition (we don't deal with negatives at the moment, so assume that a is > b
  if (b > a)
  {
    // TODO: handle negative numbers
    printf("Subtraction will become negative!!!!!!\n");
    return BigInteger(); // Returning 0 for now
  }

  BigInteger result = a;
  const BigInteger& other = b;
  uint8_t carry = 0; // carry / borrow flag
  uint8_t old;
  size_t idx = 0;
  size_t n = a.Size() - 1;
  size_t on = b.Size() - 1;
  do {
    // apply the carry from previous iteration
    old = result[idx];
    result.m_pimpl->m_data[idx] -= carry;
    carry = 0;

    // if it wrapped over from applying the carry, we need to carry again for next iteration
    if (result[idx] > old)
      carry = 1;

    if (idx <= on) // TODO: optimize out of loop
    {
        // now apply the subtraction and carry if that wraps too
        old = result[idx];
        result.m_pimpl->m_data[idx] -= other[idx]; // TODO: this goes off the end of other
        if (result[idx] > old)
          carry = 1;
    }
  } while (idx++ != (n + 1));

  return result;
}

// big number multiply algorithms
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
BigInteger BigInteger::Pimpl::Multiply(const BigInteger& a, const BigInteger& b)
{
  if (a < b) // Swap if a is not the larger value. We can now assume a is the larger value after
    return Multiply(b, a);

  // Work out expected output size
  size_t an = a.Size();
  size_t bn = b.Size();

  // Setup a destination value
  BigInteger d;
  d.m_pimpl->m_data.resize(an + bn);
  for (size_t j = 0; j < an; j++)
  {
    // basically component wise matrix of a * b
    uint8_t c1 = 0, c2 = 0;
    for (size_t i = 0; i < bn; i++)
    {
      uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
      uint32_t t2 = d[i+j] + c2 + (t1 & 0xFF);
      d.m_pimpl->m_data[i+j] = t2;
      c1 = t1 >> 8;
      c2 = t2 >> 8;
    }
    d.m_pimpl->m_data[bn+j] = d[bn+j] + c1 + c2;
  }
  d.m_pimpl->Normalize();
  return d;
}

BigInteger BigInteger::Pimpl::Divide(const BigInteger& a, const BigInteger& b, BigInteger& mod)
{
  BigInteger tmp = b;
  BigInteger res = a;
  tmp.m_pimpl->Normalize();
  res.m_pimpl->Normalize();
  size_t an = res.Size();
  size_t tn = tmp.Size();
  if (tn > an)
  {
    mod = a;
    return BigInteger();
  }
  size_t bitsLarger = (an - tn + 1) << 3;
  BigInteger bit(1);
  bit <<= bitsLarger;
  tmp <<= bitsLarger;
  mod = res;
  res = BigInteger();
  res <<= (an << 3); // resizes the array
  do {
    if (mod >= tmp) {
      mod -= tmp;
      res += bit;
    }
    tmp >>= 1;
    bit >>= 1;
  } while (bitsLarger--);
  res.m_pimpl->Normalize(); // normalize the result
  mod.m_pimpl->Normalize(); // normalize the result
  return res;
}

BigInteger BigInteger::Pimpl::Modulus(const BigInteger& a, const BigInteger& b)
{
  BigInteger tmp = b;
  BigInteger mod = a;
  tmp.m_pimpl->Normalize();
  mod.m_pimpl->Normalize();
  size_t an = mod.Size();
  size_t tn = tmp.Size();

  if (tn > an)
  {
    printf("\n bn > an   an: %i bn: %i \n", an, tn);
    return mod;
  }
  if (tn == an)
    if (b > a)
    {
      printf("\n b > a   an: %i bn: %i \n", an, tn);
      printf(" a: %s \n", std::string(a).c_str());
      printf(" b: %s \n\n", std::string(b).c_str());
      return a;
    }
  size_t bitsLarger = (an - tn + 1) << 3;
  printf(" bl: %i \n", bitsLarger);


/*
  if (tn > an)
    return mod;
  size_t bitsLarger = (an - tn + 1) << 3;
  */

  //printf(" bl: %i \n", bitsLarger);
  tmp <<= bitsLarger;
  printf(" tmp: %s \n", std::string(tmp).c_str());
  do {
    if (mod >= tmp) {
      mod -= tmp;
      mod.m_pimpl->Normalize();
      printf(" mod: %s \n", std::string(mod).c_str());
      break;
    }
    tmp >>= 1;
    tmp.m_pimpl->Normalize();
    printf(" tmp: %s \n", std::string(tmp).c_str());
  } while (bitsLarger--);
  printf(" mod: %s \n", std::string(mod).c_str());
  mod.m_pimpl->Normalize(); // normalize the result
  return mod;
}

BigInteger BigInteger::Pimpl::ShiftLeft(const BigInteger& a, int shift)
{
  // Because the data is stored in reverse endian order, the left shift is really shuffling the bytes right
  // Within the bytes we still shift left
  BigInteger res(a);
  size_t byteShift = shift >> 3;
  size_t n = res.Size() - 1;
  // TODO: there is a potential bug here that the shift could be less than 8, and yet the data still needs expanding
  res.m_pimpl->m_data.resize(n + 1 + byteShift);
  // move the bytes along by byteShift bytes
  for (int i = 0; i <= n; i++)
      res.m_pimpl->m_data[n - i + byteShift] = res[n - i];
  // Fill in with zero at the beginning (because of reverse order, beginning bytes are like the zeros in 10000)
  for (int i = 0; i < byteShift; i++)
      res.m_pimpl->m_data[i] = 0;
  n += byteShift;
  size_t idx = 0;
  uint8_t overflow = 0;
  shift &= 7;
  do {
    uint16_t var = res[idx] << shift;
    res.m_pimpl->m_data[idx] = (var & 0xFF) | overflow;
    overflow = var >> 8;
  } while (idx++ <= n + 1);
  return res;
}

BigInteger BigInteger::Pimpl::ShiftRight(const BigInteger& a, int shift)
{
  // Because the data is stored in reverse endian order, the right shift is really shuffling the bytes left
  // Within bytes we still shift right
  // Shift right is easier because we don't need to change any buffer sizes
  // TODO: check we are zeroing out properly, or normalizing which might be easier
  BigInteger res(a);
  size_t byteShift = shift >> 3;
  size_t idx = (res.Size() - 1) - byteShift;
  uint8_t overflow = 0;
  shift &= 7;
  do {
    uint16_t var = (res[idx] << 8) >> shift;
    res.m_pimpl->m_data[idx + byteShift] = (var >> 8) | overflow;
    overflow = var & 0xFF;
  } while (idx--);
  return res;
}