API reference guide#
This document provides information about rocWMMA functions, data types, and other programming constructs.
Synchronous API#
In general, rocWMMA API functions ( load_matrix_sync, store_matrix_sync, mma_sync ) are assumed to be synchronous when
used in the context of global memory.
When using these functions in the context of shared memory (e.g. LDS memory), additional explicit workgroup synchronization (synchronize_workgroup)
may be required due to the nature of this memory usage.
Supported GPU architectures#
List of supported CDNA architectures (wave64):
gfx908
gfx90a
gfx942
gfx950
Note
gfx9 = gfx908, gfx90a, gfx942, gfx950
List of supported RDNA architectures (wave32):
gfx1100
gfx1101
gfx1102
gfx1200
gfx1201
Note
gfx11 = gfx1100, gfx1101, gfx1102 gfx12 = gfx1200, gfx1201
Supported data types#
rocWMMA mixed precision multiply-accumulate operations support the following data type combinations.
Data Types <Ti / To / Tc> = <Input type / Output Type / Compute Type>, where:
Input Type = Matrix A / B
Output Type = Matrix C / D
Compute Type = Math / accumulation type
i8 = 8-bit precision integer
f8 = 8-bit precision floating point
bf8 = 8-bit precision brain floating point
f16 = half-precision floating point
bf16 = half-precision brain floating point
f32 = single-precision floating point
i32 = 32-bit precision integer
xf32 = single-precision tensor floating point
f64 = double-precision floating point
Note
f16 represents equivalent support for both _Float16 and __half types.
Current f8 support is NANOO (optimized) format.
Ti / To / Tc |
BlockM |
BlockN |
BlockK Range* (Powers of 2) |
CDNA Support |
RDNA Support |
|---|---|---|---|---|---|
bf8 / f32 / f32 |
16 |
16 |
32+ |
gfx940, gfx950 |
gfx12 |
32 |
32 |
16+ |
- |
||
f8 / f32 / f32 |
16 |
16 |
32+ |
gfx940, gfx950 |
gfx12 |
32 |
32 |
16+ |
- |
||
i8 / i32 / i32 |
16 |
16 |
16 |
gfx908, gfx90a |
gfx11, gfx12 |
32 |
gfx940, gfx950 |
- |
|||
64+ |
gfx950 |
- |
|||
32 |
32 |
8 |
gfx908, gfx90a |
- |
|
16 |
gfx940, gfx950 |
- |
|||
32+ |
gfx950 |
- |
|||
i8 / i8 / i32 |
16 |
16 |
16 |
gfx908, gfx90a |
gfx11, gfx12 |
32 |
gfx940, gfx950 |
- |
|||
64+ |
gfx950 |
- |
|||
32 |
32 |
8 |
gfx908, gfx90a |
- |
|
16 |
gfx940, gfx950 |
- |
|||
32+ |
gfx950 |
- |
|||
f16 / f32 / f32 |
16 |
16 |
16 |
gfx9 |
gfx11, gfx12 |
32+ |
gfx950 |
- |
|||
32 |
32 |
8 |
gfx9 |
- |
|
16+ |
gfx950 |
- |
|||
f16 / f16 / f32 |
16 |
16 |
16 |
gfx9 |
gfx11, gfx12 |
32+ |
gfx950 |
- |
|||
32 |
32 |
8 |
gfx9 |
- |
|
16+ |
gfx950 |
- |
|||
f16 / f16 / f16** |
16 |
16 |
16 |
gfx9 |
gfx11, gfx12 |
32+ |
gfx950 |
- |
|||
32 |
32 |
8 |
gfx9 |
- |
|
16+ |
gfx950 |
- |
|||
bf16 / f32 / f32 |
16 |
16 |
8 |
gfx908 |
- |
16 |
gfx90a, gfx942, gfx950 |
gfx11, gfx12 |
|||
32+ |
gfx950 |
- |
|||
32 |
32 |
4+ |
gfx908 |
- |
|
8 |
gfx90a, gfx942, gfx950 |
- |
|||
16+ |
gfx950 |
- |
|||
bf16 / bf16 / f32 |
16 |
16 |
8 |
gfx908 |
- |
16 |
gfx90a, gfx942, gfx950 |
gfx11, gfx12 |
|||
32+ |
gfx950 |
- |
|||
32 |
32 |
4+ |
gfx908 |
- |
|
8 |
gfx90a, gfx942, gfx950 |
- |
|||
16+ |
gfx950 |
- |
|||
bf16 / bf16 / bf16** |
16 |
16 |
8 |
gfx908 |
- |
16 |
gfx90a, gfx942, gfx950 |
gfx11, gfx12 |
|||
32+ |
gfx950 |
- |
|||
32 |
32 |
4+ |
gfx908 |
- |
|
8 |
gfx90a, gfx942, gfx950 |
- |
|||
16+ |
gfx950 |
- |
|||
f32 / f32 / f32 |
16 |
16 |
4+ |
gfx9 |
- |
32 |
32 |
2+ |
gfx9 |
- |
|
xf32 / xf32 / xf32 |
16 |
16 |
8+ |
gfx942 |
- |
32 |
32 |
4+ |
|||
f64 / f64 / f64 |
16 |
16 |
4+ |
gfx90a, gfx942, gfx950 |
- |
Note
* = BlockK range lists the minimum possible value. Other values in the range are powers of 2 larger than the minimum. Practical BlockK values are usually 32 and smaller.
** = CDNA architectures matrix unit accumulation is natively 32-bit precision and is converted to the desired type.
Note
rocWMMA supports partial fragment sizes where FragMNK may be smaller than the BlockMNK sizes listed in the table above. These fragments are internally padded to nearest supported BlockMNK sizes.
Supported matrix layouts#
(N = col major, T = row major)
LayoutA |
LayoutB |
Layout C |
LayoutD |
|---|---|---|---|
N |
N |
N |
N |
N |
N |
T |
T |
N |
T |
N |
N |
N |
T |
T |
T |
T |
N |
N |
N |
T |
N |
T |
T |
T |
T |
N |
N |
T |
T |
T |
T |
Supported thread block sizes#
rocWMMA generally supports and tests up to 4 wavefronts per thread block. The X dimension is expected to be a multiple of the wave size and will be scaled as such.
TBlock_X |
TBlock_Y |
|---|---|
WaveSize |
1 |
WaveSize |
2 |
WaveSize |
4 |
WaveSize*2 |
1 |
WaveSize*2 |
2 |
WaveSize*4 |
1 |
Note
WaveSize (RDNA) = 32
WaveSize (CDNA) = 64
Using rocWMMA API#
This section describes how to use the rocWMMA library API.
rocWMMA datatypes#
matrix_a#
-
struct matrix_a#
Meta-tag indicating data context is input Matrix A.
matrix_b#
-
struct matrix_b#
Meta-tag indicating data context is input Matrix B.
accumulator#
-
struct accumulator#
Meta-tag indicating data context is Accumulator (also used as Matrix C / D).
row_major#
-
struct row_major#
Meta-tag indicating 2D in-memory data layout as row major.
col_major#
-
struct col_major#
Meta-tag indicating 2D in-memory data layout as column major.
default_schedule#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::default_schedule” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
coop_row_major_2d#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::coop_row_major_2d” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
coop_col_major_2d#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::coop_col_major_2d” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
coop_row_slice_2d#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::coop_row_slice_2d” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
coop_col_slice_2d#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::coop_col_slice_2d” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
single#
Warning
doxygenstruct: Cannot find class “rocwmma::fragment_scheduler::single” in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml
fragment#
-
template<typename MatrixT, uint32_t FragM, uint32_t FragN, uint32_t FragK, typename DataT, typename DataLayoutT = void, typename Scheduler = fragment_scheduler::default_schedule>
class fragment# rocWMMA fragment class. This is the primary object used in block-wise decomposition of the matrix multiply-accumulate (mma) problem space. In general, fragment data is associated with a matrix context (matrix_a, matrix_b or accumulator), a block size (BlockM/N/K), a datatype (e.g. single-precision float, etc.) and an in-memory 2D layout (e.g. row_major or col_major). These fragment properties are used to define how data is handled and stored locally, and to drive API implementations for loading / storing, mma and transforms. Fragment abstractions are designed to promote a simple wavefront programming model, which can accelerate development time. Internal thread-level details are handled by rocWMMA which frees the user to focus on wavefront block-wise decomposition. Written purely in device code, the programmer can use this object in their own device kernels.
Note
Fragments are stored in packed registers, however vector elements have no guaranteed order or locality.
- Template Parameters:
Public Types
Public Functions
-
inline DataT &operator[](uint32_t index)#
- Parameters:
index – Element index
- Returns:
Mutable unpacked element accessor at given index
Public Members
Public Static Functions
-
static inline constexpr uint32_t height()#
- Returns:
The geometric height of fragment
-
static inline constexpr uint32_t width()#
- Returns:
The geometric width of fragment
-
static inline constexpr uint32_t blockDim()#
- Returns:
The leading block dimension (non-K)
-
static inline constexpr uint32_t kDim()#
- Returns:
The k dimension
-
static inline constexpr uint32_t size()#
- Returns:
The size of the unpacked elements vector
-
struct Traits#
rocWMMA enumeration#
layout_t#
rocWMMA API functions#
-
template<typename FragT, typename DataT>
void rocwmma::fill_fragment(FragT &frag, DataT value)# Fills the entire fragment with the desired value.
- Parameters:
frag – Fragment of type MatrixT with its associated block sizes, data type and layout
value – Fill value of type DataT
- Template Parameters:
FragT – Opaque fragment type
DataT – Datatype
Warning
doxygenfunction: Unable to resolve function “rocwmma::load_matrix_sync” with arguments (fragment<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayoutT>&, const DataT*, uint32_t) in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml. Potential matches:
- template<typename FragT, typename DataT> void load_matrix_sync(FragT &frag, const DataT *data, uint32_t ldm)
- template<typename FragT, typename DataT> void load_matrix_sync(FragT &frag, const DataT *data, uint32_t ldm, layout_t layout)
Warning
doxygenfunction: Unable to resolve function “rocwmma::load_matrix_sync” with arguments (fragment<MatrixT, BlockM, BlockN, BlockK, DataT>&, const DataT*, uint32_t, layout_t) in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml. Potential matches:
- template<typename FragT, typename DataT> void load_matrix_sync(FragT &frag, const DataT *data, uint32_t ldm)
- template<typename FragT, typename DataT> void load_matrix_sync(FragT &frag, const DataT *data, uint32_t ldm, layout_t layout)
Warning
doxygenfunction: Unable to resolve function “rocwmma::store_matrix_sync” with arguments (DataT*, fragment<MatrixT, BlockM, BlockN, BlockK, DataT, DataLayoutT> const&, uint32_t) in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml. Potential matches:
- template<typename FragT, typename DataT> void store_matrix_sync(DataT *data, FragT const &frag, uint32_t ldm)
- template<typename FragT, typename DataT> void store_matrix_sync(DataT *data, FragT const &frag, uint32_t ldm, layout_t layout)
Warning
doxygenfunction: Unable to resolve function “rocwmma::store_matrix_sync” with arguments (DataT*, fragment<MatrixT, BlockM, BlockN, BlockK, DataT> const&, uint32_t, layout_t) in doxygen xml output for project “rocWMMA 2.0.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/rocwmma/checkouts/stable/docs/doxygen/xml. Potential matches:
- template<typename FragT, typename DataT> void store_matrix_sync(DataT *data, FragT const &frag, uint32_t ldm)
- template<typename FragT, typename DataT> void store_matrix_sync(DataT *data, FragT const &frag, uint32_t ldm, layout_t layout)
-
template<typename FragA, typename FragB, typename FragAccumIn, typename FragAccumOut>
void rocwmma::mma_sync(FragAccumOut &d, FragA const &a, FragB const &b, FragAccumIn &c)# Performs the Multiply-Accumulate operation on the fragments A, B, C and D (D = A * B + C)
Note
Frag c = d is valid
- Parameters:
d – Accumulator output D
a – Input fragment A
b – Input fragment B
c – Input accumulator fragment C
- Template Parameters:
FragA – Opaque fragment type for matrix A data
FragB – Opaque fragment type for matrix A data
FragAccumIn – Opaque fragment type for input accumulation data
FragAccumOut – Opaque fragment type for output accumulation data
Warning
doxygenfunction: Unable to resolve function “rocwmma::synchronize_workgroup” with arguments “None”. Candidate function could not be parsed. Parsing error is Error when parsing function declaration. If the function has no return type: Error in declarator or parameters-and-qualifiers Invalid C++ declaration: Expecting “(” in parameters-and-qualifiers. [error at 15] ROCWMMA_INLINE void synchronize_workgroup () —————^ If the function has a return type: Error in declarator or parameters-and-qualifiers If pointer to member declarator: Invalid C++ declaration: Expected identifier in nested name, got keyword: void [error at 19] ROCWMMA_INLINE void synchronize_workgroup () ——————-^ If declarator-id: Invalid C++ declaration: Expected identifier in nested name, got keyword: void [error at 19] ROCWMMA_INLINE void synchronize_workgroup () ——————-^
rocWMMA transforms API functions#
-
template<typename FragT>
static inline T rocwmma::apply_transpose(FragT &&frag)# Applies the transpose transform the input fragment. Transpose is defined as orthogonal matrix and data layout. E.g. T(fragment<matrix_a, BlockM, BlockN, BlockK, DataT, row_major>) = fragment<matrix_b, BlockN, BlockM, BlockK, DataT, col_major>
- Parameters:
frag – Fragment of type MatrixT with its associated block sizes, data type and layout
- Template Parameters:
FragT – The incoming fragment type
- Returns:
Transposed (orthogonal) fragment
-
template<typename DataLayoutT, typename FragT>
static inline T rocwmma::apply_data_layout(FragT &&frag)# Transforms the input fragment to have the desired data layout.
- Parameters:
frag – Fragment of type MatrixT with its associated block sizes, data type and layout
- Template Parameters:
DataLayoutT – The desired fragment data layout to apply
FragT – The incoming fragment type
- Returns:
Fragment with transformed data layout
-
template<typename DstFragT, typename FragT>
static inline T rocwmma::apply_fragment(FragT &&frag)# Transforms the input fragment to the target fragment type. This could include changing matrix context and/or changing data layout, as long as there is a path from the source register layout to the destination register layout.
- Parameters:
frag – Source fragment of type MatrixT with its associated block sizes, data type and layout
- Template Parameters:
DstFragT – The target fragment type to transform to
FragT – The source incoming fragment type
- Returns:
Target fragment after transformation
-
template<typename FragT>
static inline T rocwmma::to_register_file(FragT &&frag)# Transforms the input fragment to a “register file” fragment type. Register contents are directly mapped to a 2D matrix space represented by [RegCount x WaveSize]. This transform is a geometry reinterpretation.
- Parameters:
frag – Source fragment of type MatrixT with its associated block sizes, data type and layout
- Template Parameters:
FragT – The source incoming fragment type
- Returns:
Target fragment after transformation
-
template<typename DstFragT, typename FragT>
static inline T rocwmma::from_register_file(FragT &&frag)# Transforms the “register file” fragment type to a target fragment type. Register contents are directly mapped to a 2D matrix space represented by [RegCount x WaveSize]. This transform is a geometry reinterpretation.
- Parameters:
frag – Source fragment of type MatrixT with its associated block sizes, data type and layout
- Template Parameters:
DstFragT – The target frag to transform to
FragT – The source incoming fragment type as register file
- Returns:
Fragment after transformation
Sample programs#
See a sample code for calling rocWMMA functions load_matrix_sync, store_matrix_sync, fill_fragment, and mma_sync here.
For more such sample programs, refer to the Samples directory.
Emulation tests#
The emulation test is a smaller test suite specifically designed for emulators. It comprises a selection of test cases from the full ROCWMM test set, allowing for significantly faster execution on emulated platforms. Despite its concise nature, the emulation test supports smoke, regression, and extended modes.
For example, run a smoke test.
rtest.py --install_dir <build_dir> --emulation smoke