/*
Big Integer Class
Written by John Ryland (C) Copyright 2015
*/
#include "BigInteger.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include "../../Util.h"
enum ComparisonResult
{
CR_Equal,
CR_LessThan,
CR_GreaterThan
};
struct BigInteger::Pimpl
{
void Init(const char* str);
void Copy(const BigInteger& other);
void Normalize();
static ComparisonResult Compare(const BigInteger& a, const BigInteger& b);
static BigInteger Add(const BigInteger& a, const BigInteger& b);
static BigInteger Subtract(const BigInteger& a, const BigInteger& b);
static BigInteger Multiply(const BigInteger& a, const BigInteger& b);
static BigInteger Divide(const BigInteger& a, const BigInteger& b, BigInteger& modulusResult);
static BigInteger Modulus(const BigInteger& a, const BigInteger& b);
static BigInteger ShiftLeft(const BigInteger& a, int shift);
static BigInteger ShiftRight(const BigInteger& a, int shift);
static BigInteger Or(const BigInteger& a, const BigInteger& b);
static BigInteger And(const BigInteger& a, const BigInteger& b);
static BigInteger Xor(const BigInteger& a, const BigInteger& b);
static BigInteger Not(const BigInteger& a);
std::vector<uint8_t> m_data;
};
BigInteger::BigInteger() // Default constructor
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Init("0");
}
BigInteger::BigInteger(const BigInteger& other) // Copy constructor
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Copy(other);
}
BigInteger::BigInteger(uint64_t a_number) // Create from a number
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Init(int2hex(a_number).c_str());
}
BigInteger::BigInteger(const char* str) // Create from a string
: m_pimpl(std::make_unique<Pimpl>())
{
m_pimpl->Init(str);
}
BigInteger::~BigInteger() // Destructor
{
}
void BigInteger::operator=(const BigInteger& other) // Assignment operator
{
m_pimpl->Copy(other);
}
bool BigInteger::operator!() const // Is zero
{
return (Pimpl::Compare(*this, BigInteger()) == CR_Equal);
}
bool BigInteger::operator!=(const BigInteger& other) const
{
return (Pimpl::Compare(*this, other) != CR_Equal);
}
bool BigInteger::operator==(const BigInteger& other) const
{
return (Pimpl::Compare(*this, other) == CR_Equal);
}
bool BigInteger::operator<(const BigInteger& other) const
{
return (Pimpl::Compare(*this, other) == CR_LessThan);
}
bool BigInteger::operator>(const BigInteger& other) const
{
return (Pimpl::Compare(*this, other) == CR_GreaterThan);
}
bool BigInteger::operator<=(const BigInteger& other) const
{
ComparisonResult result = Pimpl::Compare(*this, other);
return (result == CR_Equal || result == CR_LessThan);
}
bool BigInteger::operator>=(const BigInteger& other) const
{
ComparisonResult result = Pimpl::Compare(*this, other);
return (result == CR_Equal || result == CR_GreaterThan);
}
bool BigInteger::operator&&(const BigInteger& other) const
{
return (!!*this) && (!!other);
}
bool BigInteger::operator||(const BigInteger& other) const
{
return (!!*this) || (!!other);
}
BigInteger& BigInteger::operator--()
{
*this = Pimpl::Subtract(*this, BigInteger(1));
return *this;
}
BigInteger& BigInteger::operator++()
{
*this = Pimpl::Add(*this, BigInteger(1));
return *this;
}
BigInteger BigInteger::operator++(int)
{
BigInteger tmp(*this);
operator++();
return tmp;
}
BigInteger BigInteger::operator--(int)
{
BigInteger tmp(*this);
operator--();
return tmp;
}
BigInteger& BigInteger::operator-=(const BigInteger& other)
{
*this = Pimpl::Subtract(*this, other);
return *this;
}
BigInteger& BigInteger::operator+=(const BigInteger& other)
{
*this = Pimpl::Add(*this, other);
return *this;
}
BigInteger& BigInteger::operator*=(const BigInteger& other)
{
*this = Pimpl::Multiply(*this, other);
return *this;
}
BigInteger BigInteger::operator-(const BigInteger& other) const
{
return Pimpl::Subtract(*this, other);
}
BigInteger BigInteger::operator+(const BigInteger& other) const
{
return Pimpl::Add(*this, other);
}
BigInteger BigInteger::operator*(const BigInteger& other) const
{
return Pimpl::Multiply(*this, other);
}
BigInteger BigInteger::operator/(const BigInteger& other) const
{
BigInteger modulus;
return Pimpl::Divide(*this, other, modulus);
}
BigInteger BigInteger::operator%(const BigInteger& other) const
{
return Pimpl::Modulus(*this, other);
}
BigInteger BigInteger::operator&(const BigInteger& other) const
{
return Pimpl::And(*this, other);
}
BigInteger BigInteger::operator|(const BigInteger& other) const
{
return Pimpl::Or(*this, other);
}
BigInteger BigInteger::operator^(const BigInteger& other) const
{
return Pimpl::Xor(*this, other);
}
BigInteger BigInteger::operator~() const
{
return Pimpl::Not(*this);
}
BigInteger BigInteger::operator>>(int shift) const
{
return Pimpl::ShiftRight(*this, shift);
}
BigInteger BigInteger::operator<<(int shift) const
{
return Pimpl::ShiftLeft(*this, shift);
}
BigInteger& BigInteger::operator>>=(int shift)
{
*this = Pimpl::ShiftRight(*this, shift);
return *this;
}
BigInteger& BigInteger::operator<<=(int shift)
{
*this = Pimpl::ShiftLeft(*this, shift);
return *this;
}
BigInteger& BigInteger::operator/=(const BigInteger& other)
{
BigInteger modulus;
*this = Pimpl::Divide(*this, other, modulus);
return *this;
}
BigInteger& BigInteger::operator%=(const BigInteger& other)
{
*this = Pimpl::Modulus(*this, other);
return *this;
}
BigInteger& BigInteger::operator&=(const BigInteger& other)
{
*this = Pimpl::And(*this, other);
return *this;
}
BigInteger& BigInteger::operator|=(const BigInteger& other)
{
*this = Pimpl::Or(*this, other);
return *this;
}
BigInteger& BigInteger::operator^=(const BigInteger& other)
{
*this = Pimpl::Xor(*this, other);
return *this;
}
BigInteger::operator uint64_t() const
{
return (uint64_t)strtoull(std::string(*this).c_str(), 0, 16);
}
void BigInteger::Print() const
{
std::cout << std::string(*this);
}
BigInteger BigInteger::DivMod(const BigInteger& num, const BigInteger& denom, BigInteger& mod)
{
return Pimpl::Divide(num, denom, mod);
}
// base ^ exponent % modulus
BigInteger BigInteger::ExpMod(const BigInteger& base, const BigInteger& exponent, const BigInteger& modulus)
{
BigInteger dst(1);
BigInteger a = base % modulus;
BigInteger b = exponent % modulus;
printf("\n a: %s\n", std::string(a).c_str());
printf("\n b: %s\n", std::string(b).c_str());
while (!!b)
{
if (b[0] & 1) // if (!!(b & BigInteger(1)))
{
//printf("\n XXXXX \n");
dst = (dst * a) % modulus;
//printf("\n dst: %s\n", std::string(dst).c_str());
//printf("\n dst: %s\n", std::string(dst).c_str());
}
a = a * a;
printf("\n a1: %s\n", std::string(a).c_str());
a = a % modulus;
//printf("\n a: %s\n", std::string(a).c_str());
/*
a = (a * a) % modulus;
printf("\n a: %s\n", std::string(a).c_str());
*/
//printf("\n a2: %s\n", std::string(a).c_str());
b >>= 1;
}
return dst;
}
BigInteger::operator std::string() const
{
std::ostringstream oss;
oss << "0x";
for (int i = 0; i < Size(); i++)
oss << std::setfill('0') << std::setw(2) << std::hex << (int)m_pimpl->m_data[Size() - 1 - i];
return oss.str();
}
size_t BigInteger::Size() const
{
return m_pimpl->m_data.size();
}
uint8_t BigInteger::operator[](size_t index) const
{
return m_pimpl->m_data[index];
}
#if defined(BIG_INTEGER_ALLOW_DIRECT_ACCESS)
uint8_t& BigInteger::operator[](size_t index)
{
return m_pimpl->m_data[index];
}
#endif
////////////////// PRIVATE //////////////////////////////////////////////////
static unsigned int HexCharToInt(unsigned char ch)
{
unsigned int c = tolower(ch);
return (c >= 'a') ? c - 'a' + 10 : c - '0';
}
void BigInteger::Pimpl::Init(const char* str)
{
int len = strlen(str);
int byteLen = ((len+1) / 2);
m_data.resize(byteLen);
int idx = len - 2;
for (int i = 0; i < byteLen; i++)
{
int val = 0;
// TODO: how to tell if the endian order is right?
val = (idx+1 < len) ? HexCharToInt(str[idx + 1]) : 0; // TODO: optimize away the ?
val |= (idx < len) ? (HexCharToInt(str[idx]) << 4) : 0;
m_data[i] = val; // This is assuming the string is in reverse order, low order to high order
idx -= 2;
}
}
void BigInteger::Pimpl::Copy(const BigInteger& other)
{
m_data = other.m_pimpl->m_data;
}
void BigInteger::Pimpl::Normalize()
{
// Assumes at index 0 is the low order data (eg the smaller placed values)
int idx = m_data.size() - 1;
while (idx && !m_data[idx])
idx--;
m_data.resize(idx + 1);
}
ComparisonResult BigInteger::Pimpl::Compare(const BigInteger& a, const BigInteger& b)
{
// the two numbers don't have to be normalized
size_t an = a.Size() - 1;
size_t bn = b.Size() - 1;
while (an > bn)
if (a[an--]) // a > b
return CR_GreaterThan; // Checks if we have any high order values above
// the highest possible high order value in other
while (an < bn)
if (b[bn--]) // b > a
return CR_LessThan; // Reversed check, does other have any higher order values than us
// now an == bn, we need to check more carefully
do {
if (a[an] > b[an])
return CR_GreaterThan; // a > b
if (a[an] < b[an])
return CR_LessThan; // b > a
} while (an--);
return CR_Equal; // the two values are equal
}
BigInteger BigInteger::Pimpl::Or(const BigInteger& a, const BigInteger& b)
{
if (a < b) // We use the larger of the two (X | 0 -> X)
return Or(b, a);
BigInteger result = a;
size_t n = b.Size();
for (size_t idx = 0; idx < n; idx++)
result.m_pimpl->m_data[idx] |= b[idx];
return result;
}
BigInteger BigInteger::Pimpl::And(const BigInteger& a, const BigInteger& b)
{
if (a > b) // We use the smaller of the two (X & 0 -> 0)
return And(b, a);
BigInteger result = a;
size_t n = a.Size();
for (size_t idx = 0; idx < n; idx++)
result.m_pimpl->m_data[idx] &= b[idx];
return result;
}
BigInteger BigInteger::Pimpl::Xor(const BigInteger& a, const BigInteger& b)
{
if (a < b) // We use the larger of the two (X ^ 0 -> X)
return Xor(b, a);
BigInteger result = a;
size_t n = b.Size();
for (size_t idx = 0; idx < n; idx++)
result.m_pimpl->m_data[idx] ^= b[idx];
return result;
}
BigInteger BigInteger::Pimpl::Not(const BigInteger& a)
{
BigInteger result = a;
size_t n = a.Size();
for (size_t idx = 0; idx < n; idx++)
result.m_pimpl->m_data[idx] = ~a[idx];
return result;
}
BigInteger BigInteger::Pimpl::Add(const BigInteger& a, const BigInteger& b)
{
// Assumes a is as large as b, swap to make a the larger one we add b in to
// TODO: perhaps need to expand for overflow eg: 0x80 + 0x80 = 0x0100, we need to be one byte larger
if (a < b)
return Add(b, a);
BigInteger result = a;
const BigInteger& other = b;
uint8_t carry = 0; // carry / borrow flag
uint8_t old;
size_t idx = 0;
size_t on = other.Size() - 1;
do {
/*
uint32_t tmp = dst[i];
tmp += c;
tmp += buf[i];
dst[i] = tmp & 0xFF;
c = tmp >> 8;
*/
// apply the carry from previous iteration
old = result[idx];
result.m_pimpl->m_data[idx] += carry;
carry = 0;
// if it wrapped over from applying the carry, we need to carry again for next iteration
if (old > result[idx])
carry = 1;
if (idx <= on) // TODO: optimize this check out of the loop
{
// now apply the subtraction and carry if that wraps too
old = result[idx];
result.m_pimpl->m_data[idx] += other[idx];
if (old > result[idx])
carry = 1;
}
} while (idx++ != (on + 1));
return result;
}
BigInteger BigInteger::Pimpl::Subtract(const BigInteger& a, const BigInteger& b)
{
// TODO: check pre-condition (we don't deal with negatives at the moment, so assume that a is > b
if (b > a)
{
// TODO: handle negative numbers
printf("Subtraction will become negative!!!!!!\n");
return BigInteger(); // Returning 0 for now
}
BigInteger result = a;
const BigInteger& other = b;
uint8_t carry = 0; // carry / borrow flag
uint8_t old;
size_t idx = 0;
size_t n = a.Size() - 1;
size_t on = b.Size() - 1;
do {
// apply the carry from previous iteration
old = result[idx];
result.m_pimpl->m_data[idx] -= carry;
carry = 0;
// if it wrapped over from applying the carry, we need to carry again for next iteration
if (result[idx] > old)
carry = 1;
if (idx <= on) // TODO: optimize out of loop
{
// now apply the subtraction and carry if that wraps too
old = result[idx];
result.m_pimpl->m_data[idx] -= other[idx]; // TODO: this goes off the end of other
if (result[idx] > old)
carry = 1;
}
} while (idx++ != (n + 1));
return result;
}
// big number multiply algorithms
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
BigInteger BigInteger::Pimpl::Multiply(const BigInteger& a, const BigInteger& b)
{
if (a < b) // Swap if a is not the larger value. We can now assume a is the larger value after
return Multiply(b, a);
// Work out expected output size
size_t an = a.Size();
size_t bn = b.Size();
// Setup a destination value
BigInteger d;
d.m_pimpl->m_data.resize(an + bn);
for (size_t j = 0; j < an; j++)
{
// basically component wise matrix of a * b
uint8_t c1 = 0, c2 = 0;
for (size_t i = 0; i < bn; i++)
{
uint32_t t1 = ((uint32_t)a[j] * (uint32_t)b[i]) + (uint32_t)c1;
uint32_t t2 = d[i+j] + c2 + (t1 & 0xFF);
d.m_pimpl->m_data[i+j] = t2;
c1 = t1 >> 8;
c2 = t2 >> 8;
}
d.m_pimpl->m_data[bn+j] = d[bn+j] + c1 + c2;
}
d.m_pimpl->Normalize();
return d;
}
BigInteger BigInteger::Pimpl::Divide(const BigInteger& a, const BigInteger& b, BigInteger& mod)
{
BigInteger tmp = b;
BigInteger res = a;
tmp.m_pimpl->Normalize();
res.m_pimpl->Normalize();
size_t an = res.Size();
size_t tn = tmp.Size();
if (tn > an)
{
mod = a;
return BigInteger();
}
size_t bitsLarger = (an - tn + 1) << 3;
BigInteger bit(1);
bit <<= bitsLarger;
tmp <<= bitsLarger;
mod = res;
res = BigInteger();
res <<= (an << 3); // resizes the array
do {
if (mod >= tmp) {
mod -= tmp;
res += bit;
}
tmp >>= 1;
bit >>= 1;
} while (bitsLarger--);
res.m_pimpl->Normalize(); // normalize the result
mod.m_pimpl->Normalize(); // normalize the result
return res;
}
BigInteger BigInteger::Pimpl::Modulus(const BigInteger& a, const BigInteger& b)
{
BigInteger tmp = b;
BigInteger mod = a;
tmp.m_pimpl->Normalize();
mod.m_pimpl->Normalize();
size_t an = mod.Size();
size_t tn = tmp.Size();
if (tn > an)
{
printf("\n bn > an an: %i bn: %i \n", an, tn);
return mod;
}
if (tn == an)
if (b > a)
{
printf("\n b > a an: %i bn: %i \n", an, tn);
printf(" a: %s \n", std::string(a).c_str());
printf(" b: %s \n\n", std::string(b).c_str());
return a;
}
size_t bitsLarger = (an - tn + 1) << 3;
printf(" bl: %i \n", bitsLarger);
/*
if (tn > an)
return mod;
size_t bitsLarger = (an - tn + 1) << 3;
*/
//printf(" bl: %i \n", bitsLarger);
tmp <<= bitsLarger;
printf(" tmp: %s \n", std::string(tmp).c_str());
do {
if (mod >= tmp) {
mod -= tmp;
mod.m_pimpl->Normalize();
printf(" mod: %s \n", std::string(mod).c_str());
break;
}
tmp >>= 1;
tmp.m_pimpl->Normalize();
printf(" tmp: %s \n", std::string(tmp).c_str());
} while (bitsLarger--);
printf(" mod: %s \n", std::string(mod).c_str());
mod.m_pimpl->Normalize(); // normalize the result
return mod;
}
BigInteger BigInteger::Pimpl::ShiftLeft(const BigInteger& a, int shift)
{
// Because the data is stored in reverse endian order, the left shift is really shuffling the bytes right
// Within the bytes we still shift left
BigInteger res(a);
size_t byteShift = shift >> 3;
size_t n = res.Size() - 1;
// TODO: there is a potential bug here that the shift could be less than 8, and yet the data still needs expanding
res.m_pimpl->m_data.resize(n + 1 + byteShift);
// move the bytes along by byteShift bytes
for (int i = 0; i <= n; i++)
res.m_pimpl->m_data[n - i + byteShift] = res[n - i];
// Fill in with zero at the beginning (because of reverse order, beginning bytes are like the zeros in 10000)
for (int i = 0; i < byteShift; i++)
res.m_pimpl->m_data[i] = 0;
n += byteShift;
size_t idx = 0;
uint8_t overflow = 0;
shift &= 7;
do {
uint16_t var = res[idx] << shift;
res.m_pimpl->m_data[idx] = (var & 0xFF) | overflow;
overflow = var >> 8;
} while (idx++ <= n + 1);
return res;
}
BigInteger BigInteger::Pimpl::ShiftRight(const BigInteger& a, int shift)
{
// Because the data is stored in reverse endian order, the right shift is really shuffling the bytes left
// Within bytes we still shift right
// Shift right is easier because we don't need to change any buffer sizes
// TODO: check we are zeroing out properly, or normalizing which might be easier
BigInteger res(a);
size_t byteShift = shift >> 3;
size_t idx = (res.Size() - 1) - byteShift;
uint8_t overflow = 0;
shift &= 7;
do {
uint16_t var = (res[idx] << 8) >> shift;
res.m_pimpl->m_data[idx + byteShift] = (var >> 8) | overflow;
overflow = var & 0xFF;
} while (idx--);
return res;
}