#include <string.h>
#include "sha_256_hash.h"

/*
__inline static void reverse_bytes_long( u8* bytes, u64 x )
{
	bytes[ 7 ] =   x & 0xFF;		 bytes[ 6 ] = ( x >> 8 )  & 0xFF;
	bytes[ 5 ] = ( x >> 16 ) & 0xFF; bytes[ 4 ] = ( x >> 24 ) & 0xFF;
	bytes[ 3 ] = ( x >> 32 ) & 0xFF; bytes[ 2 ] = ( x >> 40 ) & 0xFF;
	bytes[ 1 ] = ( x >> 48 ) & 0xFF; bytes[ 0 ] = ( x >> 56 ) & 0xFF;
}
*/

static const unsigned int test_endian = 1;
#define is_bigendian() ( (*(unsigned char*)&test_endian) == 0 )

__inline static void reverse_bytes_int( uint8_t* bytes, uint32_t x )
{
	bytes[ 3 ] =   x & 0xFF;	 bytes[ 2 ] = ( x >> 8 )  & 0xFF;
	bytes[ 1 ] = ( x >> 16 ) & 0xFF; bytes[ 0 ] = ( x >> 24 ) & 0xFF;
}


const uint32_t SHA_256_ROUND_CONST [ 64 ] =
{
	0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
	0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
	0xE49B69C1, 0xEFBE4786, 0x0FC19DC6, 0x240CA1CC, 0x2DE92C6F, 0x4A7484AA, 0x5CB0A9DC, 0x76F988DA,
	0x983E5152, 0xA831C66D, 0xB00327C8, 0xBF597FC7, 0xC6E00BF3, 0xD5A79147, 0x06CA6351, 0x14292967,
	0x27B70A85, 0x2E1B2138, 0x4D2C6DFC, 0x53380D13, 0x650A7354, 0x766A0ABB, 0x81C2C92E, 0x92722C85,
	0xA2BFE8A1, 0xA81A664B, 0xC24B8B70, 0xC76C51A3, 0xD192E819, 0xD6990624, 0xF40E3585, 0x106AA070,
	0x19A4C116, 0x1E376C08, 0x2748774C, 0x34B0BCB5, 0x391C0CB3, 0x4ED8AA4A, 0x5B9CCA4F, 0x682E6FF3,
	0x748F82EE, 0x78A5636F, 0x84C87814, 0x8CC70208, 0x90BEFFFA, 0xA4506CEB, 0xBEF9A3F7, 0xC67178F2
};

const uint32_t SHA_256_INIT_CONST [ SHA256_HASH_LENGTH / 4 ] =
{
	0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19
};

void loadBlockSHA256( SHA_256_Hash* h );
void processBlockSHA256( SHA_256_Hash* h );
void padMessageSHA256( SHA_256_Hash* h );

__inline static uint32_t sha_256_lf1( uint32_t x, uint32_t y, uint32_t z )
{
	return ( x & y ) ^ ( ( ~x ) & z );
}

__inline static uint32_t sha_256_lf2( uint32_t x, uint32_t y, uint32_t z )
{
	return ( x & y ) ^ ( x & z ) ^ ( y & z );
}

__inline static uint32_t sha_256_rotr( uint32_t x, uint32_t n )
{
	return ( x >> n ) | ( x << ( 32 - n ) );
}

__inline static uint32_t sha_256_ss0( uint32_t x )
{
	return sha_256_rotr( x, 2 ) ^ sha_256_rotr( x, 13 ) ^ sha_256_rotr( x, 22 );
}

__inline static uint32_t sha_256_ss1( uint32_t x )
{
	return sha_256_rotr( x, 6 ) ^ sha_256_rotr( x, 11 ) ^ sha_256_rotr( x, 25 );
}

__inline static uint32_t sha_256_sg0( uint32_t x )
{
	return sha_256_rotr( x, 7 ) ^ sha_256_rotr( x, 18 ) ^ ( x >> 3 );
}

__inline static uint32_t sha_256_sg1( uint32_t x )
{
	return sha_256_rotr( x, 17 ) ^ sha_256_rotr( x, 19 ) ^ ( x >> 10 );
}

void initSHA256Hash( SHA_256_Hash* h )
{
	h->m_hash[ 0 ] = SHA_256_INIT_CONST[ 0 ];
	h->m_hash[ 1 ] = SHA_256_INIT_CONST[ 1 ];
	h->m_hash[ 2 ] = SHA_256_INIT_CONST[ 2 ];
	h->m_hash[ 3 ] = SHA_256_INIT_CONST[ 3 ];
	h->m_hash[ 4 ] = SHA_256_INIT_CONST[ 4 ];
	h->m_hash[ 5 ] = SHA_256_INIT_CONST[ 5 ];
	h->m_hash[ 6 ] = SHA_256_INIT_CONST[ 6 ];
	h->m_hash[ 7 ] = SHA_256_INIT_CONST[ 7 ];

	h->m_lengthHigh = 0;
	h->m_lengthLow = 0;
	h->m_messageBlockIndex = 0;
  	memset( h->m_messageBlock, 0, SHA256_MESSAGE_BLOCK_LENGTH * sizeof( uint8_t ) );
	memset( h->m_hashBlock, 0, SHA256_HASH_BLOCK_LENGTH * sizeof( uint32_t ) );
}

void loadBlockSHA256( SHA_256_Hash* h )
{
	uint32_t i = 0;
	for( i = 0; i < 16; ++i )
	{
		h->m_hashBlock[ i ]  = h->m_messageBlock[ ( i << 2 )     ] << 24;
		h->m_hashBlock[ i ] |= h->m_messageBlock[ ( i << 2 ) + 1 ] << 16;
		h->m_hashBlock[ i ] |= h->m_messageBlock[ ( i << 2 ) + 2 ] <<  8;
		h->m_hashBlock[ i ] |= h->m_messageBlock[ ( i << 2 ) + 3 ];
	}

	for( i = 16; i < SHA256_HASH_BLOCK_LENGTH; ++i )
	{
		h->m_hashBlock[ i ] = sha_256_sg1( h->m_hashBlock[ i - 2 ] ) + h->m_hashBlock[ i - 7 ] +
								sha_256_sg0( h->m_hashBlock[ i - 15 ] ) + h->m_hashBlock[ i - 16 ];
	}
}

void processBlockSHA256( SHA_256_Hash* h )
{
	uint32_t A = h->m_hash[ 0 ];
	uint32_t B = h->m_hash[ 1 ];
	uint32_t C = h->m_hash[ 2 ];
	uint32_t D = h->m_hash[ 3 ];
	uint32_t E = h->m_hash[ 4 ];
	uint32_t F = h->m_hash[ 5 ];
	uint32_t G = h->m_hash[ 6 ];
	uint32_t H = h->m_hash[ 7 ];
	uint32_t T1 = 0, T2 = 0, ind = 0;

	loadBlockSHA256( h );
	for( ind = 0; ind < SHA256_HASH_BLOCK_LENGTH; ++ind )
	{
		T1 = H + sha_256_ss1( E ) + sha_256_lf1( E, F, G ) + SHA_256_ROUND_CONST[ ind ] +
				h->m_hashBlock[ ind ];
		T2 = sha_256_ss0( A ) + sha_256_lf2( A, B, C );
		H = G; G = F; F = E; E = D + T1; D = C; C = B; B = A; A = T1 + T2;
	}

	h->m_hash[ 0 ] += A;
	h->m_hash[ 1 ] += B;
	h->m_hash[ 2 ] += C;
	h->m_hash[ 3 ] += D;
	h->m_hash[ 4 ] += E;
	h->m_hash[ 5 ] += F;
	h->m_hash[ 6 ] += G;
	h->m_hash[ 7 ] += H;
	h->m_messageBlockIndex = 0;
}

void updateSHA256Hash( SHA_256_Hash* h, const uint8_t* message, unsigned int length )
{
	while( length-- )
	{
		h->m_messageBlock[ h->m_messageBlockIndex++ ] = *message;
		h->m_lengthLow += 8;
		if( !h->m_lengthLow )
		{
			h->m_lengthHigh++;
		}
		if( h->m_messageBlockIndex == SHA256_MESSAGE_BLOCK_LENGTH )
		{
			processBlockSHA256( h );
		}
		message++;
	}
}

void padMessageSHA256( SHA_256_Hash* h )
{
	uint32_t lengthRemain = 0;
	h->m_messageBlock[ h->m_messageBlockIndex++ ] = 0x80;

	if( h->m_messageBlockIndex > 56 )
	{
		lengthRemain = SHA256_MESSAGE_BLOCK_LENGTH - h->m_messageBlockIndex;
		memset( h->m_messageBlock + h->m_messageBlockIndex, 0, lengthRemain );
		processBlockSHA256( h );
		memset( h->m_messageBlock, 0, 56 );
	}
	else
	{
		lengthRemain = 56 - h->m_messageBlockIndex;
		memset( h->m_messageBlock + h->m_messageBlockIndex, 0, lengthRemain );
	}

	h->m_messageBlock[ 56 ] = h->m_lengthHigh >> 24;
	h->m_messageBlock[ 57 ] = h->m_lengthHigh >> 16;
	h->m_messageBlock[ 58 ] = h->m_lengthHigh >>  8;
	h->m_messageBlock[ 59 ] = h->m_lengthHigh;
	h->m_messageBlock[ 60 ] = h->m_lengthLow  >> 24;
	h->m_messageBlock[ 61 ] = h->m_lengthLow  >> 16;
	h->m_messageBlock[ 62 ] = h->m_lengthLow  >>  8;
	h->m_messageBlock[ 63 ] = h->m_lengthLow;
	processBlockSHA256( h );
}

void getResultAndResetSHA256Hash( SHA_256_Hash* h, uint8_t* hash )
{
	padMessageSHA256( h );
	
	if( !is_bigendian() )
	{
		uint8_t rev_hash[ 4 ];
		int i = 0;
		for( i = 0; i < ( SHA256_HASH_LENGTH / 4 ); ++i )
		{
			reverse_bytes_int( rev_hash, h->m_hash[ i ] );
			memcpy( hash + i * 4, rev_hash, 4 );
		}
	}
	else
	{
		memcpy( hash, h->m_hash, SHA256_HASH_LENGTH );
	}
	
	initSHA256Hash( h );
}
