//
// Big Number Class
// Written by John Ryland (C) Copyright 2015
//
#ifndef BIGNUM_CLASS_H
#define BIGNUM_CLASS_H
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <vector>
/*
Based in part on work from here which is more C style version or such code:
Source: http://dmitry.gr/index.php?r=05.Projects&proj=10.%20Shoving%20RSA%20into%20small%20places
License: you may use this code for any non-commercial purpose as long as you give me credit in your write-up/source. For commercial uses, contact me. One final note: trying this on something like an ATtiny, which lacks hardware multiplier will be very slow (you'll notice) but it will in fact work. Comments, suggestions, grievances, complaints, requests? [dmitrygr@gmail.com].
*/
class BigNumber
{
public:
BigNumber()
{
Init("00");
}
BigNumber(const BigNumber& other) // Copy constructor
{
m_nbytes = other.m_nbytes;
m_data = (uint8_t*)malloc(m_nbytes + 1);
memcpy(m_data, other.m_data, m_nbytes + 1);
}
BigNumber(const char* str) // Create from a string
{
Init(str);
}
~BigNumber()
{
free(m_data);
}
void operator=(const BigNumber& other) // Assignment operator
{
free(m_data);
m_nbytes = other.m_nbytes;
m_data = (uint8_t*)malloc(m_nbytes + 1);
memcpy(m_data, other.m_data, m_nbytes + 1);
}
bool operator!() const // Is zero
{
size_t i = m_nbytes;
do {
if (m_data[i])
return false;
} while (i--);
return true;
}
static int Compare(const BigNumber& a, const BigNumber& b)
{
// the two numbers don't have to be normalized
size_t an = a.m_nbytes;
size_t bn = b.m_nbytes;
while (an > bn)
if (a.m_data[an--]) // a > b
return 1; // Checks if we have any high order values above the highest possible high order value in other
while (an < bn)
if (b.m_data[bn--]) // b > a
return -1; // Reversed check, does other have any higher order values than us
// now an == bn, we need to check more carefully
do {
if (a.m_data[an] > b.m_data[an])
return 1; // a > b
if (a.m_data[an] < b.m_data[an])
return -1; // b > a
} while (an--);
return 0; // the two values are equal
}
bool operator==(const BigNumber& other) const
{
return (Compare(*this, other) == 0);
}
bool operator<(const BigNumber& other) const
{
return (Compare(*this, other) == -1);
}
bool operator>(const BigNumber& other) const
{
return (Compare(*this, other) == 1);
}
bool operator<=(const BigNumber& other) const
{
int result = Compare(*this, other);
return (result == 0 || result == -1);
}
bool operator>=(const BigNumber& other) const
{
int result = Compare(*this, other);
return (result == 0 || result == 1);
}
// a -= b. precondition: a >= b [this is assumed and not checked]
void operator-=(const BigNumber& other)
{
// TODO: check pre-condition
uint8_t carry = 0; // carry / borrow flag
uint8_t old;
size_t idx = 0;
do {
// printf("before -= c = %i m_data[idx] = %x other.m_data[idx] = %x\n", carry, m_data[idx], other.m_data[idx]);
// apply the carry from previous iteration
old = m_data[idx];
m_data[idx] -= carry;
carry = 0;
// if it wrapped over from applying the carry, we need to carry again for next iteration
if (m_data[idx] > old)
carry = 1;
if (idx <= other.m_nbytes + 1) // TODO: check it is right, and optimize with two loops
{
// now apply the subtraction and carry if that wraps too
old = m_data[idx];
m_data[idx] -= other.m_data[idx]; // TODO: this goes off the end of other
if (m_data[idx] > old)
carry = 1; // TODO: ? perhaps carry++ ?
}
// printf("after -= c = %i m_data[idx] = %x other.m_data[idx] = %x\n", carry, m_data[idx], other.m_data[idx]);
} while (idx++ != m_nbytes + 1);
}
// a += b
void operator+=(const BigNumber& other)
{
uint8_t carry = 0; // carry / borrow flag
uint8_t old;
size_t idx = 0;
do {
/*
uint32_t tmp = dst[i];
tmp += c;
tmp += buf[i];
dst[i] = tmp & 0xFF;
c = tmp >> 8;
*/
// apply the carry from previous iteration
old = m_data[idx];
m_data[idx] += carry;
carry = 0;
// if it wrapped over from applying the carry, we need to carry again for next iteration
if (old > m_data[idx])
carry = 1;
// now apply the subtraction and carry if that wraps too
old = m_data[idx];
m_data[idx] += other.m_data[idx];
if (old > m_data[idx])
carry = 1; // TODO: ? perhaps carry++ ?
} while (idx++ != other.m_nbytes + 1);
}
void operator>>=(int shift)
{
// Because the data is stored in reverse endian order, the logical left shift is really shifting the data right
int byteShift = 0;
while (shift >= 8)
{
byteShift++;
shift -= 8;
}
/*
m_data = (uint8_t*)realloc(m_data, m_nbytes + 1 + byteShift);
for (int i = 0; i < byteShift; i++)
m_data[m_nbytes + 1 + i] = 0; // Clear the expanded buffer
m_nbytes += byteShift;
*/
while (shift)
{
size_t idx = m_nbytes - byteShift;
uint8_t overflow = 0;
do {
uint8_t last_overflow = overflow;
overflow = (m_data[idx] & 1) ? 0x80 : 0;
m_data[idx + byteShift] = (m_data[idx] >> 1) | last_overflow;
} while (idx--);
shift--;
}
}
void operator<<=(int shift)
{
// Because the data is stored in reverse endian order, the logical left shift is really shifting the data right
int byteShift = 0;
while (shift >= 8)
{
byteShift++;
shift -= 8;
}
// TODO: there is a bug here that the shift could be less than 8, and yet the data still needs expanding
m_data = (uint8_t*)realloc(m_data, m_nbytes + 1 + byteShift);
// move the bytes along by byteShift bytes
for (int i = 0; i <= m_nbytes; i++)
m_data[m_nbytes - i + byteShift] = m_data[m_nbytes - i];
// Fill in with zero at the beginning (because of reverse order, beginning bytes are like the zeros in 10000)
for (int i = 0; i < byteShift; i++)
m_data[i] = 0;
m_nbytes += byteShift;
while (shift)
{
size_t idx = 0;
uint8_t overflow = 0;
do {
uint8_t last_overflow = overflow;
overflow = (m_data[idx] & 0x80) ? 1 : 0;
m_data[idx] = (m_data[idx] << 1) | last_overflow;
} while (idx++ <= m_nbytes + 1);
shift--;
}
}
//static char bnMod(BN* top, BN* bot) // was destructive to bot
void operator%=(const BigNumber& other) // however this is non-destructive to other, different semantics
{
BigNumber tmp(other);
tmp.Normalize();
if (tmp.m_nbytes <= m_nbytes)
{
size_t lsh = m_nbytes - tmp.m_nbytes + 1;
/*
tmp.m_data = (uint8_t*)realloc(tmp.m_data, m_nbytes + 2);
for (int i = 0; i < lsh; i++)
tmp.m_data[tmp.m_nbytes + 1 + i] = 0; // Clear the expanded buffer
tmp.m_nbytes = m_nbytes;
*/
/*
printf("\nperforming modulus\n");
printf("\n a %% b where a = "); Print();
printf("\n and b = "); tmp.Print();
printf("\n");
*/
lsh <<= 3;
tmp <<= lsh;
/*
for (int i = 0; i < lsh; i++)
tmp <<= 1;
printf("\n b << %i = ", int(lsh)); tmp.Print();
*/
lsh++;
while (lsh--) {
if (*this >= tmp) {
*this -= tmp;
//printf("\n now a is "); Print();
}
tmp >>= 1;
// printf("\n now b is "); tmp.Print();
}
Normalize(); // normalize the result
// tmp.Normalize(); // fix any damage we did to "bot->nbytes"
}
}
// numbers may be normalized but will otherwse be unmodified
static BigNumber Multiply(const BigNumber& a, const BigNumber& b)
{
if (!(a >= b))
{
// printf(" swapping args\n");
// printf(" a : "); a.Print();
// printf(" b : "); b.Print();
BigNumber r = Multiply(b, a);
// printf(" r : "); r.Print();
// printf("\n");
return r;
}
// printf(" no-swap args\n");
// printf(" a : "); a.Print();
// printf(" b : "); b.Print();
// Work out expected output size
size_t i = a.m_nbytes + b.m_nbytes + 1;
// printf(" i : %i \n", i);
// Setup a destination value
BigNumber d;
free(d.m_data);
d.m_nbytes = i;
d.m_data = (uint8_t*)malloc((uint32_t)i + 1 + 1024);
memset(d.m_data, 0, i + 1 + 1024 - 100);
// printf(" d0 : "); d.Print();
// printf("\n");
//printf("mul output size: %i\n", (int)d.m_nbytes);
uint8_t *dst = d.m_data;
uint8_t *buf = (uint8_t*)malloc((uint32_t)b.m_nbytes + 2 + 1024);
memset(buf, 0, b.m_nbytes + 2 );
uint8_t c, o;
size_t j = 0;
do {
i = 0;
c = 0;
//uint8_t c1 = 0;
buf[0] = 0;
// basically component wise buf = a * b
do {
uint32_t v16 = ((uint32_t)a.m_data[j] * (uint32_t)b.m_data[i]) + (uint32_t)buf[i];
//c1 = v16 >> 8;
buf[i+0] = v16 & 0xFF;
if (v16 >= 0x10000)
printf("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n!\n");
buf[i+1] = v16 >> 8;
} while (i++ != b.m_nbytes );
i = 0;
//printf(".%i(%i).", (int)j, (int)a.m_nbytes );
// TODO: this looks like 'add' operation
// basically dst += buf
do {
/*
uint32_t tmp = dst[i];
tmp += c;
tmp += buf[i];
dst[i] = tmp & 0xFF;
c = tmp >> 8;
//tmp = tmp & 0xFF;
//dst[i] += buf[i];
*/
o = dst[i+j];
dst[i+j] += c;
c = 0;
if (o > dst[i+j])
c = 1;
o = dst[i+j];
dst[i+j] += buf[i];
if (o > dst[i+j])
c = 1; // TODO: should be c += 1 instead ?
} while (i++ != b.m_nbytes + 1);
//dst++;
} while (j++ != a.m_nbytes );
free(buf);
// printf("\n d1 : "); d.Print();
d.Normalize();
// printf("\n d2 : "); d.Print();
return d;
}
// big number multiply algorithms
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
// base ^ exponent % modulus
static BigNumber ExpMod(const BigNumber& base, const BigNumber& exponent, const BigNumber& modulus)
{
BigNumber dst("01");
BigNumber a = base;
BigNumber b = exponent;
BigNumber t;
a %= modulus;
b %= modulus;
// printf("\n now dst is "); dst.Print();
while (!!b)
{
// what is happening to b looks correct and working as expected
//printf("\n now exponent is "); b.Print();
if (b.m_data[0] & 1)
{
//printf(" !bit! ");
// printf("\n dst is "); dst.Print();
// printf("\n a is "); a.Print();
// printf("\n m is "); modulus.Print();
t = Multiply(dst, a);
// printf("\n t is "); t.Print();
t %= modulus;
dst = t;
// printf("\n now dst is "); dst.Print();
//printf("\n now dst is "); dst.Print();
}
t = Multiply(a, a);
//printf(".\n");
t %= modulus;
a = t;
b >>= 1;
}
//printf("\n now a is "); dst.Print();
return dst;
}
void Normalize()
{
// Assumes at index 0 is the low order data, and at m_nbytes index is the high order data
while (m_nbytes && !m_data[m_nbytes])
m_nbytes--;
}
void Print() const
{
// This is printing in reverse order, low order to high order
for (int i = 0; i <= m_nbytes; i++)
printf("%02X ", m_data[m_nbytes - i]);
}
std::string ToString()
{
const char hexChrs[] = "0123456789ABCDEF";
std::string str = "0x";
for (int i = 0; i <= m_nbytes; i++)
{
str += hexChrs[(m_data[m_nbytes - i] >> 4) & 0xF];
str += hexChrs[(m_data[m_nbytes - i] >> 0) & 0xF];
}
return str;
}
std::vector<uint8_t> Data()
{
//return std::vector<uint8_t>(m_data, m_data + m_nbytes);
std::vector<uint8_t> r;
for (int i = 0; i <= m_nbytes; i++)
r.push_back(m_data[m_nbytes - i]);
return r;
}
private:
void Init(const char* str)
{
int len = strlen(str);
int byteLen = len / 2;
m_nbytes = byteLen - 1; // ? implicitly +1 ?
m_data = (uint8_t*)malloc((int)byteLen+2);
for (int i = 0; i < byteLen; i++)
{
int val = 0;
// TODO: how to tell if the endian order is right?
val = HexCharToInt(str[((len-2) - i * 2) + 1]);
val |= HexCharToInt(str[((len-2) - i * 2)]) << 4;
//val = hexCharToInt(str[(len - 1) - i * 2]);
//val |= hexCharToInt(str[(len - 2) - i * 2]) << 4;
m_data[i] = val; // This is assuming the string is in reverse order, low order to high order
}
}
static unsigned int HexCharToInt(unsigned char ch)
{
unsigned int c = tolower(ch);
return (c >= 'a') ? c - 'a' + 10 : c - '0';
}
size_t m_nbytes; // num bytes used - 1 (for 2048-bit value this will be 255)
uint8_t* m_data;
};
#endif // BIGNUM_CLASS_H