#### [SOLVED] Modular Exponentiation for high numbers in C++

So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.

So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,

``````3 ^ 2 % 5 =
9 % 5 =
4
``````

here is the code I have been using for this modular exponentiation:

``````unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
unsigned test;
for(test = 1; pow; pow >>= 1)
{
if (pow & 1)
test = (test * num) % mod;
num = (num * num) % mod;
}

return test;

}
``````

As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:

(2 ^ 168277) % 673109

now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.

on the reverse side, arguments such as

4000111222 ^ 3 % 1608

also evaluate incorrectly, for much the same reason.

Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1)) #### @abkds 2014-09-09 05:39:59

`LL` is for `long long int`

``````LL power_mod(LL a, LL k) {
if (k == 0)
return 1;
LL temp = power(a, k/2);
LL res;

res = ( ( temp % P ) * (temp % P) ) % P;
if (k % 2 == 1)
res = ((a % P) * (res % P)) % P;
return res;
}
``````

Use the above recursive function for finding the mod exp of the number. This will not result in overflow because it calculates in a bottom up manner.

Sample test run for : `a = 2` and `k = 168277` shows output to be 518358 which is correct and the function runs in `O(log(k))` time; #### @ShowLove 2013-10-03 21:19:48

``````    package playTime;

public class play {

public static long count = 0;
public static long binSlots = 10;
public static long y = 645;
public static long finalValue = 1;
public static long x = 11;

public static void main(String[] args){

int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};

x = BME(x, count, binArray);

System.out.print("\nfinal value:"+finalValue);

}

public static long BME(long x, long count, int[] binArray){

if(count == binSlots){
return finalValue;
}

if(binArray[(int) count] == 1){
finalValue = finalValue*x%y;
}

x = (x*x)%y;
System.out.print("Array("+binArray[(int) count]+") "
+"x("+x+")" +" finalVal("+              finalValue + ")\n");

count++;

return BME(x, count,binArray);
}

}
`````` #### @ShowLove 2013-10-03 21:20:32

that was code i wrote in java very quickly. The example i used was 11^644mod 645. = 1. we know the binary of 645 is 1010000100. I kind of cheated and hard coded the variables but it works fine. #### @ShowLove 2013-10-03 21:22:51

output was Array(0) x(121) finalVal(1) Array(0) x(451) finalVal(1) Array(1) x(226) finalVal(451) Array(0) x(121) finalVal(451) Array(0) x(451) finalVal(451) Array(0) x(226) finalVal(451) Array(0) x(121) finalVal(451) Array(1) x(451) finalVal(391) Array(0) x(226) finalVal(391) Array(1) x(121) finalVal(1) final value:1 #### @clinux 2010-07-15 07:06:53

I wrote something for this recently for RSA in C++, bit messy though.

``````#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
digits.push_back(0);
negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sum_n_carry = 0;
int n = (int)a.digits.size();
if (n < (int)b.digits.size()) {
n = b.digits.size();
}
c.digits.resize(n);
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size()) {
a_digit = a.digits[i];
}
if (i < (int)b.digits.size()) {
b_digit = b.digits[i];
}
sum_n_carry += a_digit + b_digit;
c.digits[i] = (sum_n_carry & 0xFFFF);
sum_n_carry >>= 16;
}
if (sum_n_carry != 0) {
putCarryInfront(c, sum_n_carry);
}
while (c.digits.size() > 1 && c.digits.back() == 0) {
c.digits.pop_back();
}
//std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sub_n_borrow = 0;
int n = a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
c.digits.resize(n);
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
sub_n_borrow += a_digit - b_digit;
if (sub_n_borrow >= 0) {
c.digits[i] = sub_n_borrow;
sub_n_borrow = 0;
} else {
c.digits[i] = 0x10000 + sub_n_borrow;
sub_n_borrow = -1;
}
}
while (c.digits.size() > 1 && c.digits.back() == 0) {
c.digits.pop_back();
}
//std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
int n = (int)a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
//std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
for (int i = n-1; i >= 0; --i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
if (a_digit < b_digit) {
//std::cout << "-1" << std::endl;
return -1;
} else if (a_digit > b_digit) {
//std::cout << "+1" << std::endl;
return +1;
}
}
//std::cout << "0" << std::endl;
return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
unsigned int mult_n_carry = 0;
c.digits.clear();
c.digits.resize(a.digits.size());
for (int i = 0; i < (int)a.digits.size(); ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = b;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
mult_n_carry += a_digit * b_digit;
c.digits[i] = (mult_n_carry & 0xFFFF);
mult_n_carry >>= 16;
}
if (mult_n_carry != 0) {
putCarryInfront(c, mult_n_carry);
}
//std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
b.digits.resize(a.digits.size() + times);
for (int i = 0; i < times; ++i) {
b.digits[i] = 0;
}
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i + times] = a.digits[i];
}
}

void BigInteger::shiftRight(BigInteger& a) {
//std::cout << "shr " << a.toString() << " == ";
for (int i = 0; i < (int)a.digits.size(); ++i) {
a.digits[i] >>= 1;
if (i+1 < (int)a.digits.size()) {
if ((a.digits[i+1] & 0x1) != 0) {
a.digits[i] |= 0x8000;
}
}
}
//std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
bool lastBit = false;
for (int i = 0; i < (int)a.digits.size(); ++i) {
bool bit = (a.digits[i] & 0x8000) != 0;
a.digits[i] <<= 1;
if (lastBit)
a.digits[i] |= 1;
lastBit = bit;
}
if (lastBit) {
a.digits.push_back(1);
}
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
BigInteger b;
b.negative = a.negative;
b.digits.resize(a.digits.size() + 1);
b.digits[a.digits.size()] = carry;
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i] = a.digits[i];
}
a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
c.digits.clear();
c.digits.push_back(0);
BigInteger two("2");
BigInteger e = b;
BigInteger f("1");
BigInteger g = a;
BigInteger one("1");
while (cmpWithoutSign(g, e) >= 0) {
shiftLeft(e);
shiftLeft(f);
}
shiftRight(e);
shiftRight(f);
while (cmpWithoutSign(g, b) >= 0) {
g -= e;
c += f;
while (cmpWithoutSign(g, e) < 0) {
shiftRight(e);
shiftRight(f);
}
}
e = c;
e *= b;
f = a;
f -= e;
d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
digits = other.digits;
negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
digits.push_back(0);
negative = false;
BigInteger ten;
ten.digits = 10;
const char* c = other;
bool make_negative = false;
if (*c == '-') {
make_negative = true;
++c;
}
while (*c != 0) {
BigInteger digit;
digit.digits = *c - '0';
*this *= ten;
*this += digit;
++c;
}
negative = make_negative;
}

bool BigInteger::isOdd() const {
return (digits & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
if (this == &other) // handle self assignment
return *this;
digits = other.digits;
negative = other.negative;
return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
BigInteger result;
if (negative) {
if (other.negative) {
result.negative = true;
} else {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = false;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = true;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
result.digits.clear();
result.digits.push_back(0);
}
}
} else {
if (other.negative) {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = true;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = false;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
result.digits.clear();
result.digits.push_back(0);
}
} else {
result.negative = false;
}
}
negative = result.negative;
digits.swap(result.digits);
return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
BigInteger neg_other = other;
neg_other.negative = !neg_other.negative;
return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
BigInteger result;
for (int i = 0; i < (int)digits.size(); ++i) {
BigInteger mult;
multByDigitWithoutSign(mult, other, digits[i]);
BigInteger shift;
shiftLeftByBase(shift, mult, i);
}
if (negative != other.negative) {
result.negative = true;
} else {
result.negative = false;
}
//std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
negative = result.negative;
digits.swap(result.digits);
return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
BigInteger result, tmp;
divideWithoutSign(result, tmp, *this, other);
result.negative = (negative != other.negative);
negative = result.negative;
digits.swap(result.digits);
return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
BigInteger c, d;
divideWithoutSign(c, d, *this, other);
*this = d;
return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
if (negative) {
if (other.negative) {
return cmpWithoutSign(*this, other) < 0;
} else {
return false;
}
} else {
if (other.negative) {
return true;
} else {
return cmpWithoutSign(*this, other) > 0;
}
}
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
BigInteger zero("0");
BigInteger one("1");
BigInteger e = exponent;
BigInteger base = *this;
*this = one;
while (cmpWithoutSign(e, zero) != 0) {
//std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
if (e.isOdd()) {
*this *= base;
*this %= modulus;
}
shiftRight(e);
base *= BigInteger(base);
base %= modulus;
}
return *this;
}

std::string BigInteger::toString() const {
std::ostringstream os;
if (negative)
os << "-";
BigInteger tmp = *this;
BigInteger zero("0");
BigInteger ten("10");
tmp.negative = false;
std::stack<char> s;
while (cmpWithoutSign(tmp, zero) != 0) {
BigInteger tmp2, tmp3;
divideWithoutSign(tmp2, tmp3, tmp, ten);
s.push((char)(tmp3.digits + '0'));
tmp = tmp2;
}
while (!s.empty()) {
os << s.top();
s.pop();
}
/*
for (int i = digits.size()-1; i >= 0; --i) {
os << digits[i];
if (i != 0) {
os << ",";
}
}
*/
return os.str();
``````

And an example usage.

``````BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);
``````

Its fast too, and has unlimited number of digits. #### @darius 2016-06-23 15:18:27

thanks for sharing! A question, is digit a std::vector<unsigned short> ? #### @clinux 2017-09-30 18:54:36

Yes, but working in base 65536 under the hood, not base 10. #### @dirkgently 2010-02-05 12:22:56

Two things:

• Are you using the appropriate data type? In other words, does UINT_MAX allow you to have 673109 as an argument?

No, it does not, since at one point you have Your code does not work because at one point you have `num = 2^16` and the `num = ...` causes overflow. Use a bigger data type to hold this intermediate value.

• How about taking modulo at every possible overflow oppertunity such as:

`test = ((test % mod) * (num % mod)) % mod;`

Edit:

``````unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
unsigned long long test;
unsigned long long n = num;
for(test = 1; pow; pow >>= 1)
{
if (pow & 1)
test = ((test % mod) * (n % mod)) % mod;
n = ((n % mod) * (n % mod)) % mod;
}

return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

/* (2 ^ 168277) % 673109 */
printf("%u\n", mod_pow(2, 168277, 673109));
return 0;
}
`````` #### @Steve Jessop 2010-02-05 12:51:16

Exponentiation by squaring still "works" for modulo exponentiation. Your problem isn't that `2 ^ 168277` is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.

So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):

• Perform the actual squaring in a 64-bit unsigned integer.
• Reduce modulo 673109 at each step to get back within the 32-bit range, as you do.

Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.

There's an example on slide 22 here: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.

Your other example, `4000111222 ^ 3 % 1608` would work in your current code if you just reduce `4000111222` modulo `1608` before you start. `1608` is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int. #### @Axel Magnuson 2010-02-05 14:01:54

thanks man, that did the trick. Just out of curiosity, do you know any methods that wouldn't require using a larger memory size? I'm sure they would come in handy. #### @Steve Jessop 2010-02-05 16:03:08

Not that I know of. You need to multiply together two numbers up to 673108, mod 673109. Obviously you could break things up and do long multiplication with smaller "digits", say 2^10. But as soon as you're implementing multiplication and division in software, you might as well just implement it for the special case of multiplying together two 32 bit values to give a 64 bit result, then dividing to extract a 32 bit remainder. There might be some hard-core optimisations that do the bare minimum you need, but I don't know them and faking a 64 bit int in C++ isn't that hard. #### @Alexander Poluektov 2010-02-05 12:25:38

You could use following identity:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

Try using it straightforward way and incrementally improve.

``````    if (pow & 1)
test = ((test % mod) * (num % mod)) % mod;
num = ((num % mod) * (num % mod)) % mod;
`````` #### @Axel Magnuson 2010-02-05 12:33:52

thank you both for your suggestions, but by the nature of the algorithm both test and num will always be less than mod, so: { (test % mod) = test } and { (num % mod) = test } therefore the identity can't help me because the function fails even when num and test are less than mod. Also, unsigned ints allow me to have 673109 as an argument. UINT_MAX = 4 294 967 295 for my computer.

### [SOLVED] How can I profile C++ code running on Linux?

• 2008-12-17 20:29:24
• Gabriel Isenberg
• 494073 View
• 1741 Score
• Tags:   c++ unix profiling

### [SOLVED] The Definitive C++ Book Guide and List

• 2008-12-23 05:23:56
• grepsedawk
• 2249975 View
• 4246 Score
• Tags:   c++ c++-faq