//
// bmath.cxx
//
// Copyright (C) 1996-7 by Leonard Janke (janke@unixg.ubc.ca)
 
#include <linteger/bmath.hxx>
#include <lmisc/lmisc.hxx>
#include <iostream.h>
#include <iomanip.h>

const int BMath::bitsPerUInt=sizeof(unsigned int)*8;

const int BMath::SmallPrimesAvailable=303;
const unsigned int BMath::SmallPrime[303]= {
  2u,  3u,   5u,    7u,  11u,  13u,  17u,  19u,
  23u,  29u,  31u,  37u,  41u,  43u,  47u,  53u,
  59u,  61u,  67u,  71u,  73u,  79u,  83u,  89u,
  97u, 101u, 103u, 107u, 109u, 113u, 127u, 131u,

  137u, 139u, 149u, 151u, 157u, 163u, 167u, 173u,
  179u, 181u, 191u, 193u, 197u, 199u, 211u, 223u,
  227u, 229u, 233u, 239u, 241u, 251u, 257u, 263u,
  269u, 271u, 277u, 281u, 283u, 293u, 307u, 311u,
  
  313u, 317u, 331u, 337u, 347u, 349u, 353u, 359u,
  367u, 373u, 379u, 383u, 389u, 397u, 401u, 409u,
  419u, 421u, 431u, 433u, 439u, 443u, 449u, 457u,
  461u, 463u, 467u, 479u, 487u, 491u, 499u, 503u,
  
  509u, 521u, 523u, 541u, 547u, 557u, 563u, 569u,
  571u, 577u, 587u, 593u, 599u, 601u, 607u, 613u,
  617u, 619u, 631u, 641u, 643u, 647u, 653u, 659u,
  661u, 673u, 677u, 683u, 691u, 701u, 709u, 719u,
  
  727u, 733u, 739u, 743u, 751u, 757u, 761u, 769u,
  773u, 787u, 797u, 809u, 811u, 821u, 823u, 827u,
  829u, 839u, 853u, 857u, 859u, 863u, 877u, 881u,
  883u, 887u, 907u, 911u, 919u, 929u, 937u, 941u,
  
  947u, 953u, 967u, 971u, 977u, 983u, 991u, 997u,
  1009u, 1013u, 1019u, 1021u, 1031u, 1033u, 1039u, 1049u, 
  1051u, 1061u, 1063u, 1069u, 1087u, 1091u, 1093u, 1097u, 
  1103u, 1109u, 1117u, 1123u, 1129u, 1151u, 1153u, 1163u, 
  
  1171u, 1181u, 1187u, 1193u, 1201u, 1213u, 1217u, 1223u,
  1229u, 1231u, 1237u, 1249u, 1259u, 1277u, 1279u, 1283u,
  1289u, 1291u, 1297u, 1301u, 1303u, 1307u, 1319u, 1321u,
  1327u, 1361u, 1367u, 1373u, 1381u, 1399u, 1409u, 1423u,
  
  1427u, 1429u, 1433u, 1439u, 1447u, 1451u, 1453u, 1459u,
  1471u, 1481u, 1483u, 1487u, 1489u, 1493u, 1499u, 1511u,
  1523u, 1531u, 1543u, 1549u, 1553u, 1559u, 1567u, 1571u,
  1579u, 1583u, 1597u, 1601u, 1607u, 1609u, 1613u, 1619u,
  
  1621u, 1627u, 1637u, 1657u, 1663u, 1667u, 1669u, 1693u,
  1697u, 1699u, 1709u, 1721u, 1723u, 1733u, 1741u, 1747u,
  1753u, 1759u, 1777u, 1783u, 1787u, 1789u, 1801u, 1811u,
  1823u, 1831u, 1847u, 1861u, 1867u, 1871u, 1873u, 1877u,
  
  1879u, 1889u, 1901u, 1907u, 1913u, 1931u, 1933u, 1949u,
  1951u, 1973u, 1979u, 1987u, 1993u, 1997u, 1999u };

char BMath::Add(const unsigned int* x, 
		const int digitsX, 
		const unsigned int* y, 
		const int digitsY,
		unsigned int* z)
{
  LMisc::MemCopy(z,x,digitsX-digitsY);

  char intermediateCarry=BasicAdd(x+digitsX-digitsY,y,z+digitsX-digitsY,
				  digitsY);
  char carry;

  if ( intermediateCarry && (digitsX-digitsY) )
    carry=Increment(z,digitsX-digitsY);
  else 
    carry=intermediateCarry;

  return carry;
}

void BMath::Subtract(const unsigned int* x, 
		     const int digitsX, 
		     const unsigned int* y,
		     const int digitsY,
		     unsigned int* z) 
{
  LMisc::MemCopy(z,x,digitsX-digitsY);

  char intermediateBorrow=BasicSubtract(x+digitsX-digitsY,y,z+digitsX-digitsY,
				   digitsY);

  if ( intermediateBorrow )
    RippleDecrement(z,digitsX-digitsY);
}

void BMath::Multiply(const unsigned int* x, 
		     int digitsX, 
		     const unsigned int* y, 
		     int digitsY, 
		     unsigned int* z) 
{
  // Basic recursive Multiplication described in Knuth
  // z needs to be zeroed for this to work


  if ( digitsX > digitsY )
    {
      LC_Swap(x,y);
      LC_Swap(digitsX,digitsY);
    }

  // can assume digitsY >= digitsX  now

  if ( digitsX == 1 ) 
    {
      BasicMultiply(y,x[0],z,digitsY);
      return;
    }

  // digitsX != 1

  if ( digitsX%2 == 0 )
    {
      if ( digitsX==digitsY )
	{
	  if ( digitsX==2 )
	    {
	      MultDouble(x,y,z);
	      return;
	    }

	  // digitsX != 2

	  unsigned int* firstTerm=new unsigned int[4*digitsX]; 
	  unsigned int* middleTerm=firstTerm+digitsX;
	  unsigned int* absu1minusu0=middleTerm+digitsX;
	  unsigned int* absv0minusv1=absu1minusu0+digitsX/2;
	  unsigned int* lastTerm=absv0minusv1+digitsX/2; 
	  LMisc::MemZero(firstTerm,4*digitsX);

	  Multiply(x+digitsX/2,digitsX/2,y+digitsX/2,digitsX/2,lastTerm);
	  BasicAdd(lastTerm,z+digitsX,z+digitsX,digitsX);
	  BasicAdd(lastTerm,z+digitsX/2,z+digitsX/2,digitsX);

	  // first term

	  Multiply(x,digitsX/2,y,digitsX/2,firstTerm);
	  RippleAdd(firstTerm,z+digitsX/2,z+digitsX/2,digitsX);
	  BasicAdd(firstTerm,z,z,digitsX);

	  // middle term

	  char cmp1=GreaterThanOrEqualTo(x,x+digitsX/2,digitsX/2);
	  char cmp2=GreaterThanOrEqualTo(y+digitsX/2,y,digitsX/2);
	  char cmp=cmp1^cmp2;

	  if ( cmp1 )
	    RippleSubtract(x,x+digitsX/2,absu1minusu0,digitsX/2);
	  else
	    RippleSubtract(x+digitsX/2,x,absu1minusu0,digitsX/2);

	  if ( cmp2 )
	    RippleSubtract(y+digitsX/2,y,absv0minusv1,digitsX/2);
	  else
	    RippleSubtract(y,y+digitsX/2,absv0minusv1,digitsX/2);

	  Multiply(absu1minusu0,digitsX/2,absv0minusv1,digitsX/2,middleTerm);
	  if ( cmp )
	    RippleSubtract(z+digitsX/2,middleTerm,z+digitsX/2,digitsX);
	  else
	    RippleAdd(middleTerm,z+digitsX/2,z+digitsX/2,digitsX);

	  delete[] firstTerm;

	  return;
	}

      // digitsX != digitsY

      unsigned int* temp1=new unsigned int[digitsY];
      LMisc::MemZero(temp1,digitsY);
      Multiply(x,digitsX,y+digitsY-digitsX,digitsX,z+digitsY-digitsX);
      Multiply(x,digitsX,y,digitsY-digitsX,temp1);
      BasicAdd(temp1,z,z,digitsY);
      delete[] temp1;

      return;
    }

  // digitsX % 2 == 1

  Multiply(x+1,digitsX-1,y+digitsY-digitsX+1,digitsX-1,z+digitsY-digitsX+2);

  unsigned int* temp2=new unsigned int[digitsY];
  LMisc::MemZero(temp2,digitsY);

  Multiply(x+1,digitsX-1,y,digitsY-digitsX+1,temp2);
  RippleAdd(z+1,temp2,z+1,digitsY);
  delete[] temp2;

  BasicMultiply(y,x[0],z,digitsY);
}

void BMath::Square(const unsigned int* x, 
		   const int digitsX, 
		   unsigned int* y)
{
  // Recursive Squaring 
  // same algorithm as multiplying 
  // but a few steps can be eliminated
  //
  // y needs to be zeroed for this method to work!

  if ( digitsX ==1 ) 
    {
      BasicMultiply(x,x[0],y,1); // use a special method here?
      return;
    }

  // digits != 1

  if ( digitsX%2 == 0 )
    {
      if ( digitsX==2 )
	{
	  SquareDouble(x,y);
	  return;
	}

      // digits != 2
      
      unsigned int* firstTerm=new unsigned int[4*digitsX]; 
      unsigned int* middleTerm=firstTerm+digitsX;
      unsigned int* absu1minusu0=middleTerm+digitsX;
      unsigned int* lastTerm=absu1minusu0+digitsX;
      LMisc::MemZero(firstTerm,4*digitsX);

      // last term

      Square(x+digitsX/2,digitsX/2,lastTerm);
      BasicAdd(lastTerm,y+digitsX,y+digitsX,digitsX);
      BasicAdd(lastTerm,y+digitsX/2,y+digitsX/2,digitsX);

      // first term

      Square(x,digitsX/2,firstTerm);
      RippleAdd(firstTerm,y+digitsX/2,y+digitsX/2,digitsX);
      BasicAdd(firstTerm,y,y,digitsX);

      // middle term

      char cmp=GreaterThanOrEqualTo(x,x+digitsX/2,digitsX/2);

      if ( cmp )
        RippleSubtract(x,x+digitsX/2,absu1minusu0,digitsX/2);
      else
        RippleSubtract(x+digitsX/2,x,absu1minusu0,digitsX/2);

      Square(absu1minusu0,digitsX/2,middleTerm);
      RippleSubtract(y+digitsX/2,middleTerm,y+digitsX/2,digitsX);

      delete[] firstTerm;
      return;
    }

  // digitsX % 2 == 1

  Square(x+1,digitsX-1,y+2);
  BasicMultiply(x,x[0],y,digitsX);
  BasicMultiply(x+1,x[0],y+1,digitsX-1);
}

void BMath::Divide(const unsigned int* dividend, 
		   const int dividendDigits,
		   const unsigned int* divisor, 
		   const int divisorDigits,
		   unsigned int*& quotient, 
		   unsigned int*& remainder)
{
  // this algorithm by Leonard Janke with a hint or two from Knuth

  char normalizationFactor=31-char(BSR(divisor[0]));

  // remainder=normalized dividend

  remainder=new unsigned int[dividendDigits+1];
  LMisc::MemCopy(remainder+1,dividend,dividendDigits);
  remainder[0]=0u;
  ShortShiftLeft(remainder,dividendDigits+1,normalizationFactor);

  // padded divisor = 0u|divisor for use in comparisons
  
  unsigned int* paddedDivisor=new unsigned int[divisorDigits+1];
  LMisc::MemCopy(paddedDivisor+1,divisor,divisorDigits);
  paddedDivisor[0]=0u;

  // normalize padded divisor

  ShortShiftLeft(paddedDivisor,divisorDigits+1,normalizationFactor);

  // create quotient

  quotient=new unsigned int[dividendDigits-divisorDigits+1];

  // main

  const unsigned int d1=paddedDivisor[1]+1;

  unsigned int highWord;
  unsigned int lowWord;
  unsigned int q;
  unsigned int r;          // Can I use this somehow?

  unsigned int* multBuf=new unsigned int[divisorDigits+1];

  for (int i=0; i<dividendDigits-divisorDigits+1; i++)
    {
      highWord=remainder[i];
      lowWord=remainder[i+1];

      if ( !d1 )
	q=highWord;
      else if ( d1==highWord)  
	q=0xffffffffu;
      else
	BasicDivide(highWord,lowWord,d1,q,r);

      // clear multBuf

      LMisc::MemZero(multBuf,divisorDigits+1);

      BasicMultiply(paddedDivisor+1,q,multBuf,divisorDigits);
      RippleSubtract(remainder+i,multBuf,remainder+i,divisorDigits+1);

      quotient[i]=q;

      if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,divisorDigits+1) )
	{
	  quotient[i]++;
	  RippleSubtract(remainder+i,paddedDivisor,remainder+i,
			 divisorDigits+1);

	  if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,
				    divisorDigits+1) )
	    {
	      quotient[i]++;
	      RippleSubtract(remainder+i,paddedDivisor,remainder+i,
			     divisorDigits+1);

	      if ( GreaterThanOrEqualTo(remainder+i,paddedDivisor,
					divisorDigits+1) )
		abort(); // oh, oh! this is supposed to be impossible! 
	    }
	}

    }
  // clean up

  delete[] paddedDivisor;
  delete[] multBuf;

  // Undo Normalization 
  ShortShiftRight(remainder,dividendDigits+1,normalizationFactor);
}

int BMath::BSF(const unsigned int* x, 
	       const int digits)
{
  int scanPos(digits-1);

  while ( scanPos >=0 && x[scanPos] == 0u )
    scanPos--;
  
  if ( scanPos == -1 )
    return -1;

  return (digits-1-scanPos)*bitsPerUInt+BMath::BSF(x[scanPos]);
}

int BMath::BSR(const unsigned int* x, 
	       const int digits)
{
  int scanPos=0;

  while ( scanPos != digits && x[scanPos] == 0u)
      scanPos++;
  
  if ( scanPos==digits )
    return -1;

  return (digits-1-scanPos)*bitsPerUInt+BMath::BSR(x[scanPos]);
}


void BMath::ShiftRight(unsigned int*& x, 
		       int& digits, 
		       const int distance) 
{
  int originalDigits=digits;
  int lostDigits=distance/bitsPerUInt;
  char bitPos=char(BMath::BSR(x[0]));
  int phantomDigit=0;

  if ( (bitPos-distance%bitsPerUInt) < 0 )
      phantomDigit=1;

  if ( (digits-lostDigits-phantomDigit) > 0 )
    {
      unsigned int* newX=x+lostDigits;

      LMisc::MemMove(newX,x,digits);
      BMath::ShortShiftRight(newX,digits,char(distance%bitsPerUInt));

      digits-=(lostDigits+phantomDigit);
    }
  else
    {
      x[digits-1]=0u;
      digits=1;
    }

  if ( digits != originalDigits )
    {
      unsigned int* temp=new unsigned int[digits];
      LMisc::MemCopy(temp,x+originalDigits-digits,digits);
      delete[] x;
      x=temp;
    }
}

void BMath::ShiftLeft(unsigned int*& x, 
		      int& digits, 
		      const int distance) 
{
  char bitPos=char(BMath::BSR(x[0]));

  int newDigits=distance/bitsPerUInt;

  int overflow=0;

  if ( (bitPos+distance% bitsPerUInt) >= bitsPerUInt )
    overflow=1;

  if (newDigits+overflow)
    {
      unsigned int* temp=new unsigned int[digits+newDigits+overflow];
      LMisc::MemCopy(temp+newDigits+overflow,x,digits);
      delete[] x;
      x=temp;
    }

  unsigned int* oldX=x+newDigits+overflow;

  LMisc::MemMove(x+overflow,oldX,digits);
  if ( overflow )
    x[0]=0u;

  LMisc::MemZero(x+overflow+digits,newDigits);

  BMath::ShortShiftLeft(x,overflow+digits,char(distance%bitsPerUInt));

  digits+=(newDigits+overflow);
}
