172 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			HLSL
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			HLSL
		
	
	
	
	
	
| #ifndef THREADING_SM6_IMPL
 | |
| #define THREADING_SM6_IMPL
 | |
| 
 | |
| namespace Threading
 | |
| {
 | |
|     // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those.
 | |
|     // Support for vector types is currently not there but can be added as needed (and this comment removed).
 | |
|     groupshared uint g_Scratch[THREADING_BLOCK_SIZE];
 | |
| 
 | |
|     uint Wave::GetIndex() { return indexW; }
 | |
| 
 | |
|     void Wave::Init(uint groupIndex)
 | |
|     {
 | |
|         indexG = groupIndex;
 | |
|         indexW = indexG / GetLaneCount();
 | |
|     }
 | |
| 
 | |
|     // Note: The HLSL intrinsics should be correctly replaced by console-specific intrinsics by our API library.
 | |
|     #define DEFINE_API_FOR_TYPE(TYPE)                                                     \
 | |
|         bool Wave::AllEqual(TYPE v)                 { return WaveActiveAllEqual(v);     } \
 | |
|         TYPE Wave::Product(TYPE v)                  { return WaveActiveProduct(v);      } \
 | |
|         TYPE Wave::Sum(TYPE v)                      { return WaveActiveSum(v);          } \
 | |
|         TYPE Wave::Max(TYPE v)                      { return WaveActiveMax(v);          } \
 | |
|         TYPE Wave::Min(TYPE v)                      { return WaveActiveMin(v);          } \
 | |
|         TYPE Wave::InclusivePrefixSum (TYPE v)      { return WavePrefixSum(v) + v;      } \
 | |
|         TYPE Wave::InclusivePrefixProduct (TYPE v)  { return WavePrefixProduct(v) * v;  } \
 | |
|         TYPE Wave::PrefixSum(TYPE v)                { return WavePrefixSum(v);          } \
 | |
|         TYPE Wave::PrefixProduct(TYPE v)            { return WavePrefixProduct(v);      } \
 | |
|         TYPE Wave::ReadLaneAt(TYPE v, uint i)       { return WaveReadLaneAt(v, i);      } \
 | |
|         TYPE Wave::ReadLaneFirst(TYPE v)            { return WaveReadLaneFirst(v);      } \
 | |
| 
 | |
|     // Currently just support scalars.
 | |
|     DEFINE_API_FOR_TYPE(uint)
 | |
|     DEFINE_API_FOR_TYPE(int)
 | |
|     DEFINE_API_FOR_TYPE(float)
 | |
| 
 | |
|     // The following intrinsics need only be declared once.
 | |
|     uint  Wave::GetLaneCount()          { return WaveGetLaneCount();     }
 | |
|     uint  Wave::GetLaneIndex()          { return WaveGetLaneIndex();     }
 | |
|     bool  Wave::IsFirstLane()           { return WaveIsFirstLane();      }
 | |
|     bool  Wave::AllTrue(bool v)         { return WaveActiveAllTrue(v);   }
 | |
|     bool  Wave::AnyTrue(bool v)         { return WaveActiveAnyTrue(v);   }
 | |
|     uint4 Wave::Ballot(bool v)          { return WaveActiveBallot(v);    }
 | |
|     uint  Wave::CountBits(bool v)       { return WaveActiveCountBits(v); }
 | |
|     uint  Wave::PrefixCountBits(bool v) { return WavePrefixCountBits(v); }
 | |
|     uint  Wave::And(uint v)             { return WaveActiveBitAnd(v);    }
 | |
|     uint  Wave::Or (uint v)             { return WaveActiveBitOr(v);     }
 | |
|     uint  Wave::Xor(uint v)             { return WaveActiveBitXor(v);    }
 | |
| 
 | |
| #define EMULATED_GROUP_REDUCE(TYPE, OP) \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     g_Scratch[groupIndex] = asuint(v); \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     [unroll] \
 | |
|     for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
 | |
|     { \
 | |
|         if (groupIndex < s) \
 | |
|             g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \
 | |
|         GroupMemoryBarrierWithGroupSync(); \
 | |
|     } \
 | |
|     return as##TYPE(g_Scratch[0]); \
 | |
| 
 | |
| #define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     g_Scratch[groupIndex] = asuint(v); \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     [unroll] \
 | |
|     for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
 | |
|     { \
 | |
|         if (groupIndex < s) \
 | |
|             g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \
 | |
|         GroupMemoryBarrierWithGroupSync(); \
 | |
|     } \
 | |
|     return as##TYPE(g_Scratch[0]); \
 | |
| 
 | |
| #define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     g_Scratch[groupIndex] = asuint(v); \
 | |
|     GroupMemoryBarrierWithGroupSync(); \
 | |
|     [unroll] \
 | |
|     for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \
 | |
|     { \
 | |
|         TYPE nv = FILL_VALUE; \
 | |
|         if (groupIndex >= s) \
 | |
|         { \
 | |
|             nv = as##TYPE(g_Scratch[groupIndex - s]); \
 | |
|         } \
 | |
|         nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \
 | |
|         GroupMemoryBarrierWithGroupSync(); \
 | |
|         g_Scratch[groupIndex] = asuint(nv); \
 | |
|         GroupMemoryBarrierWithGroupSync(); \
 | |
|     } \
 | |
|     TYPE result = FILL_VALUE; \
 | |
|     if (groupIndex > 0u) \
 | |
|         result = as##TYPE(g_Scratch[groupIndex - 1]); \
 | |
|     return result; \
 | |
| 
 | |
|     uint Group::GetWaveCount()
 | |
|     {
 | |
|         return THREADING_BLOCK_SIZE / WaveGetLaneCount();
 | |
|     }
 | |
| 
 | |
|     #define DEFINE_API_FOR_TYPE_GROUP(TYPE)                                                                                                                                                       \
 | |
|         bool Group::AllEqual(TYPE v)                  { return AllTrue(ReadThreadFirst(v) == v);                                                                                                } \
 | |
|         TYPE Group::Product(TYPE v)                   { EMULATED_GROUP_REDUCE(TYPE, *)                                                                                                          } \
 | |
|         TYPE Group::Sum(TYPE v)                       { EMULATED_GROUP_REDUCE(TYPE, +)                                                                                                          } \
 | |
|         TYPE Group::Max(TYPE v)                       { EMULATED_GROUP_REDUCE_CMP(TYPE, max)                                                                                                    } \
 | |
|         TYPE Group::Min(TYPE v)                       { EMULATED_GROUP_REDUCE_CMP(TYPE, min)                                                                                                    } \
 | |
|         TYPE Group::InclusivePrefixSum (TYPE v)       { return PrefixSum(v) + v;                                                                                                                } \
 | |
|         TYPE Group::InclusivePrefixProduct (TYPE v)   { return PrefixProduct(v) * v;                                                                                                            } \
 | |
|         TYPE Group::PrefixSum (TYPE v)                { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0)                                                                                                 } \
 | |
|         TYPE Group::PrefixProduct (TYPE v)            { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1)                                                                                                 } \
 | |
|         TYPE Group::ReadThreadAt(TYPE v, uint i)      { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \
 | |
|         TYPE Group::ReadThreadFirst(TYPE v)           { return ReadThreadAt(v, 0u);                                                                                                             } \
 | |
|         TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i);                                                                                                              } \
 | |
| 
 | |
|     // Currently just support scalars.
 | |
|     DEFINE_API_FOR_TYPE_GROUP(uint)
 | |
|     DEFINE_API_FOR_TYPE_GROUP(int)
 | |
|     DEFINE_API_FOR_TYPE_GROUP(float)
 | |
| 
 | |
|     // The following emulated functions need only be declared once.
 | |
|     uint  Group::GetThreadCount()        { return THREADING_BLOCK_SIZE;   }
 | |
|     uint  Group::GetThreadIndex()        { return groupIndex;             }
 | |
|     bool  Group::IsFirstThread()         { return groupIndex == 0u;       }
 | |
|     bool  Group::AllTrue(bool v)         { return And(v) != 0u;           }
 | |
|     bool  Group::AnyTrue(bool v)         { return Or (v) != 0u;           }
 | |
|     uint  Group::PrefixCountBits(bool v) { return PrefixSum((uint)v);     }
 | |
|     uint  Group::And(uint v)             { EMULATED_GROUP_REDUCE(uint, &) }
 | |
|     uint  Group::Or (uint v)             { EMULATED_GROUP_REDUCE(uint, |) }
 | |
|     uint  Group::Xor(uint v)             { EMULATED_GROUP_REDUCE(uint, ^) }
 | |
| 
 | |
|     GroupBallot Group::Ballot(bool v)
 | |
|     {
 | |
|         uint indexDw = groupIndex % 32u;
 | |
|         uint offsetDw = (groupIndex / 32u) * 32u;
 | |
|         uint indexScratch = offsetDw + indexDw;
 | |
| 
 | |
|         GroupMemoryBarrierWithGroupSync();
 | |
| 
 | |
|         g_Scratch[groupIndex] = v << indexDw;
 | |
| 
 | |
|         GroupMemoryBarrierWithGroupSync();
 | |
| 
 | |
|         [unroll]
 | |
|         for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u)
 | |
|         {
 | |
|             if (indexDw < s)
 | |
|                 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
 | |
| 
 | |
|             GroupMemoryBarrierWithGroupSync();
 | |
|         }
 | |
| 
 | |
|         GroupBallot ballot = (GroupBallot)0;
 | |
| 
 | |
|         // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference
 | |
|         [unroll]
 | |
|         for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
 | |
|         {
 | |
|             ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32];
 | |
|         }
 | |
| 
 | |
|         return ballot;
 | |
|     }
 | |
| 
 | |
|     uint Group::CountBits(bool v)
 | |
|     {
 | |
|         return Ballot(v).CountBits();
 | |
|     }
 | |
| }
 | |
| #endif
 |