using System;
using System.Diagnostics;
using Unity.Burst;
using Unity.Mathematics;
namespace Unity.Collections
{
    /// 
    /// A type that uses Huffman encoding to encode values in a lossless manner.
    /// 
    /// 
    /// This type puts values into a manageable number of power-of-two-sized buckets.
    /// It codes the bucket index with Huffman, and uses several raw bits that correspond
    /// to the size of the bucket to code the position in the bucket.
    ///
    /// For example, if you want to send a 32-bit integer over the network, it's
    /// impractical to create a Huffman tree that encompasses every value the integer
    /// can take because it requires a tree with 2^32 leaves. This type manages that situation.
    ///
    /// The buckets are small, around 0, and become progressively larger as the data moves away from zero.
    /// Because most data is deltas against predictions, most values are small and most of the redundancy
    /// is in the error's size and not in the values of that size we end up hitting.
    ///
    /// The context is as a sub-model that has its own statistics and uses its own Huffman tree.
    /// When using the context to read and write a specific value, the context must always be the same.
    /// The benefit of using multiple contexts is that it allows you to separate the statistics of things that have
    /// different expected distributions, which leads to more precise statistics, which again yields better compression.
    /// More contexts does, however, result in a marginal cost of a slightly larger model.
    /// 
    [GenerateTestsForBurstCompatibility]
    public unsafe struct StreamCompressionModel
    {
        internal static readonly byte[] k_BucketSizes =
        {
            0, 0, 1, 2, 3, 4, 6, 8, 10, 12, 15, 18, 21, 24, 27, 32
        };
        internal static readonly uint[] k_BucketOffsets =
        {
            0, 1, 2, 4, 8, 16, 32, 96, 352, 1376, 5472, 38240, 300384, 2397536, 19174752, 153392480
        };
        internal static readonly int[] k_FirstBucketCandidate =
        {
            // 0   1   2   3   4   5   6   7   8   9   10  11  12  13  14  15  16  17  18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
            15, 15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10, 10, 10, 9, 9, 8, 8, 7, 7, 6, 5, 4, 3, 2, 1, 1, 0
        };
        internal static readonly byte[] k_DefaultModelData = { 16, // 16 symbols
                                                               2, 3, 3, 3,   4, 4, 4, 5,     5, 5, 6, 6,     6, 6, 6, 6,
                                                               0, 0 }; // no contexts
        internal const int k_AlphabetSize = 16;
        internal const int k_MaxHuffmanSymbolLength = 6;
        internal const int k_MaxContexts = 1;
        byte m_Initialized;
        static class SharedStaticCompressionModel
        {
            internal static readonly SharedStatic Default = SharedStatic.GetOrCreate();
        }
        /// 
        /// A shared singleton instance of , this instance is initialized using
        /// hardcoded bucket parameters and model.
        /// 
        public static StreamCompressionModel Default {
            get
            {
                if (SharedStaticCompressionModel.Default.Data.m_Initialized == 1)
                {
                    return SharedStaticCompressionModel.Default.Data;
                }
                Initialize();
                SharedStaticCompressionModel.Default.Data.m_Initialized = 1;
                return SharedStaticCompressionModel.Default.Data;
            }
        }
        static void Initialize()
        {
            for (int i = 0; i < k_AlphabetSize; ++i)
            {
                SharedStaticCompressionModel.Default.Data.bucketSizes[i] = k_BucketSizes[i];
                SharedStaticCompressionModel.Default.Data.bucketOffsets[i] = k_BucketOffsets[i];
            }
            var modelData = new NativeArray(k_DefaultModelData.Length, Allocator.Temp);
            for (var index = 0; index < k_DefaultModelData.Length; index++)
            {
                modelData[index] = k_DefaultModelData[index];
            }
            //int numContexts = NetworkConfig.maxContexts;
            int numContexts = 1;
            var symbolLengths = new NativeArray(numContexts * k_AlphabetSize, Allocator.Temp);
            int readOffset = 0;
            {
                // default model
                int defaultModelAlphabetSize = modelData[readOffset++];
                CheckAlphabetSize(defaultModelAlphabetSize);
                for (int i = 0; i < k_AlphabetSize; i++)
                {
                    byte length = modelData[readOffset++];
                    for (int context = 0; context < numContexts; context++)
                    {
                        symbolLengths[numContexts * context + i] = length;
                    }
                }
                // other models
                int numModels = modelData[readOffset] | (modelData[readOffset + 1] << 8);
                readOffset += 2;
                for (int model = 0; model < numModels; model++)
                {
                    int context = modelData[readOffset] | (modelData[readOffset + 1] << 8);
                    readOffset += 2;
                    int modelAlphabetSize = modelData[readOffset++];
                    CheckAlphabetSize(modelAlphabetSize);
                    for (int i = 0; i < k_AlphabetSize; i++)
                    {
                        byte length = modelData[readOffset++];
                        symbolLengths[numContexts * context + i] = length;
                    }
                }
            }
            // generate tables
            var tmpSymbolLengths = new NativeArray(k_AlphabetSize, Allocator.Temp);
            var tmpSymbolDecodeTable = new NativeArray(1 << k_MaxHuffmanSymbolLength, Allocator.Temp);
            var symbolCodes = new NativeArray(k_AlphabetSize, Allocator.Temp);
            for (int context = 0; context < numContexts; context++)
            {
                for (int i = 0; i < k_AlphabetSize; i++)
                    tmpSymbolLengths[i] = symbolLengths[numContexts * context + i];
                GenerateHuffmanCodes(symbolCodes, 0, tmpSymbolLengths, 0, k_AlphabetSize, k_MaxHuffmanSymbolLength);
                GenerateHuffmanDecodeTable(tmpSymbolDecodeTable, 0, tmpSymbolLengths, symbolCodes, k_AlphabetSize, k_MaxHuffmanSymbolLength);
                for (int i = 0; i < k_AlphabetSize; i++)
                {
                    SharedStaticCompressionModel.Default.Data.encodeTable[context * k_AlphabetSize + i] = (ushort)((symbolCodes[i] << 8) | symbolLengths[numContexts * context + i]);
                }
                for (int i = 0; i < (1 << k_MaxHuffmanSymbolLength); i++)
                {
                    SharedStaticCompressionModel.Default.Data.decodeTable[context * (1 << k_MaxHuffmanSymbolLength) + i] = tmpSymbolDecodeTable[i];
                }
            }
        }
        static void GenerateHuffmanCodes(NativeArray symbolCodes, int symbolCodesOffset, NativeArray symbolLengths, int symbolLengthsOffset, int alphabetSize, int maxCodeLength)
        {
            CheckAlphabetAndMaxCodeLength(alphabetSize, maxCodeLength);
            var lengthCounts = new NativeArray(maxCodeLength + 1, Allocator.Temp);
            var symbolList = new NativeArray((maxCodeLength + 1) * alphabetSize, Allocator.Temp);
            //byte[] symbol_list[(MAX_HUFFMAN_CODE_LENGTH + 1u) * MAX_NUM_HUFFMAN_SYMBOLS];
            for (int symbol = 0; symbol < alphabetSize; symbol++)
            {
                int symbolLength = symbolLengths[symbol + symbolLengthsOffset];
                CheckExceedMaxCodeLength(symbolLength, maxCodeLength);
                symbolList[(maxCodeLength + 1) * symbolLength + lengthCounts[symbolLength]++] = (byte)symbol;
            }
            uint nextCodeWord = 0;
            for (int length = 1; length <= maxCodeLength; length++)
            {
                int length_count = lengthCounts[length];
                for (int i = 0; i < length_count; i++)
                {
                    int symbol = symbolList[(maxCodeLength + 1) * length + i];
                    CheckSymbolLength(symbolLengths, symbolLengthsOffset, symbol, length);
                    symbolCodes[symbol + symbolCodesOffset] = (byte)ReverseBits(nextCodeWord++, length);
                }
                nextCodeWord <<= 1;
            }
        }
        static uint ReverseBits(uint value, int num_bits)
        {
            value = ((value & 0x55555555u) << 1) | ((value & 0xAAAAAAAAu) >> 1);
            value = ((value & 0x33333333u) << 2) | ((value & 0xCCCCCCCCu) >> 2);
            value = ((value & 0x0F0F0F0Fu) << 4) | ((value & 0xF0F0F0F0u) >> 4);
            value = ((value & 0x00FF00FFu) << 8) | ((value & 0xFF00FF00u) >> 8);
            value = (value << 16) | (value >> 16);
            return value >> (32 - num_bits);
        }
        // decode table entries: (symbol << 8) | length
        static void GenerateHuffmanDecodeTable(NativeArray decodeTable, int decodeTableOffset, NativeArray symbolLengths, NativeArray symbolCodes, int alphabetSize, int maxCodeLength)
        {
            CheckAlphabetAndMaxCodeLength(alphabetSize, maxCodeLength);
            uint maxCode = 1u << maxCodeLength;
            for (int symbol = 0; symbol < alphabetSize; symbol++)
            {
                int length = symbolLengths[symbol];
                CheckExceedMaxCodeLength(length, maxCodeLength);
                if (length > 0)
                {
                    uint code = symbolCodes[symbol];
                    uint step = 1u << length;
                    do
                    {
                        decodeTable[(int)(decodeTableOffset + code)] = (ushort)(symbol << 8 | length);
                        code += step;
                    }
                    while (code < maxCode);
                }
            }
        }
        /// 
        /// Bucket n starts at bucketOffsets[n] and ends at bucketOffsets[n] + (1 << bucketSizes[n]).
        /// (code << 8) | length
        /// 
        internal fixed ushort encodeTable[k_MaxContexts * k_AlphabetSize];
        /// 
        /// Bucket n starts at bucketOffsets[n] and ends at bucketOffsets[n] + (1 << bucketSizes[n]).
        /// (symbol << 8) | length
        /// 
        internal fixed ushort decodeTable[k_MaxContexts * (1 << k_MaxHuffmanSymbolLength)];
        /// 
        /// Specifies the sizes of the buckets in bits, so a bucket of n bits has 2^n values.
        /// 
        internal fixed byte bucketSizes[k_AlphabetSize];
        /// 
        /// Specifies the starting positions of the bucket.
        /// 
        internal fixed uint bucketOffsets[k_AlphabetSize];
        /// 
        /// Calculates the bucket index into the  where the specified value should be written.
        /// 
        /// A 4-byte unsigned integer value to find a bucket for.
        /// The bucket index where to put the value.
        readonly public int CalculateBucket(uint value)
        {
            int bucketIndex = k_FirstBucketCandidate[math.lzcnt(value)];
            if (bucketIndex + 1 < k_AlphabetSize && value >= bucketOffsets[bucketIndex + 1])
                bucketIndex++;
            return bucketIndex;
        }
        /// 
        /// The compressed size in bits of the given unsigned integer value
        /// 
        /// the unsigned int value you want to compress
        /// 
        public readonly int GetCompressedSizeInBits(uint value)
        {
            int bucket = CalculateBucket(value);
            int bits = bucketSizes[bucket];
            ushort encodeEntry = encodeTable[bucket];
            return (encodeEntry & 0xff) + bits;
        }
        [Conditional("ENABLE_UNITY_COLLECTIONS_CHECKS"), Conditional("UNITY_DOTS_DEBUG")]
        static void CheckAlphabetSize(int alphabetSize)
        {
            if (alphabetSize != k_AlphabetSize)
            {
                throw new InvalidOperationException("The alphabet size of compression models must be " + k_AlphabetSize);
            }
        }
        [Conditional("ENABLE_UNITY_COLLECTIONS_CHECKS"), Conditional("UNITY_DOTS_DEBUG")]
        static void CheckSymbolLength(NativeArray symbolLengths, int symbolLengthsOffset, int symbol, int length)
        {
            if (symbolLengths[symbol + symbolLengthsOffset] != length)
                throw new InvalidOperationException("Incorrect symbol length");
        }
        [Conditional("ENABLE_UNITY_COLLECTIONS_CHECKS"), Conditional("UNITY_DOTS_DEBUG")]
        static void CheckAlphabetAndMaxCodeLength(int alphabetSize, int maxCodeLength)
        {
            if (alphabetSize > 256 || maxCodeLength > 8)
                throw new InvalidOperationException("Can only generate huffman codes up to alphabet size 256 and maximum code length 8");
        }
        [Conditional("ENABLE_UNITY_COLLECTIONS_CHECKS"), Conditional("UNITY_DOTS_DEBUG")]
        static void CheckExceedMaxCodeLength(int length, int maxCodeLength)
        {
            if (length > maxCodeLength)
                throw new InvalidOperationException("Maximum code length exceeded");
        }
    }
}