//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_BARRIER
#define _CUDA_BARRIER

#include "std/barrier"

// Forward-declare CUtensorMap for use in cp_async_bulk_tensor_* PTX wrapping
// functions. These functions take a pointer to CUtensorMap, so do not need to
// know its size. This type is defined in cuda.h (driver API) as:
//
//     typedef struct CUtensorMap_st {  [ .. snip .. ] } CUtensorMap;
//
// We need to forward-declare both CUtensorMap_st (the struct) and CUtensorMap
// (the typedef):
struct CUtensorMap_st;
typedef struct CUtensorMap_st CUtensorMap;

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL

// Experimental exposure of TMA PTX:
//
// - cp_async_bulk_global_to_shared
// - cp_async_bulk_shared_to_global
// - cp_async_bulk_tensor_{1,2,3,4,5}d_global_to_shared
// - cp_async_bulk_tensor_{1,2,3,4,5}d_shared_to_global
// - fence_proxy_async_shared_cta
// - cp_async_bulk_commit_group
// - cp_async_bulk_wait_group_read<0, …, 7>

// These PTX wrappers are only available when the code is compiled compute
// capability 9.0 and above. The check for (!defined(__CUDA_MINIMUM_ARCH__)) is
// necessary to prevent cudafe from ripping out the device functions before
// device compilation begins.
#ifdef __cccl_lib_experimental_ctk12_cp_async_exposure

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_global_to_shared(void *__dest, const void *__src, _CUDA_VSTD::uint32_t __size, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    _LIBCUDACXX_DEBUG_ASSERT(__size % 16 == 0,   "Size must be multiple of 16.");
    _LIBCUDACXX_DEBUG_ASSERT(__isShared(__dest), "Destination must be shared memory address.");
    _LIBCUDACXX_DEBUG_ASSERT(__isGlobal(__src),  "Source must be global memory address.");

    asm volatile(
        "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__src))),
          "r"(__size),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}


// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_shared_to_global(void *__dest, const void * __src, _CUDA_VSTD::uint32_t __size)
{
    _LIBCUDACXX_DEBUG_ASSERT(__size % 16 == 0,   "Size must be multiple of 16.");
    _LIBCUDACXX_DEBUG_ASSERT(__isGlobal(__dest), "Destination must be global memory address.");
    _LIBCUDACXX_DEBUG_ASSERT(__isShared(__src),  "Source must be shared memory address.");

    asm volatile(
        "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n"
        :
        : "l"(static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__dest))),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src))),
          "r"(__size)
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_1d_global_to_shared(
    void *__dest, const CUtensorMap *__tensor_map , int __c0, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    asm volatile(
        "cp.async.bulk.tensor.1d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
        "[%0], [%1, {%2}], [%3];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(__tensor_map),
          "r"(__c0),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_2d_global_to_shared(
    void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    asm volatile(
        "cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
        "[%0], [%1, {%2, %3}], [%4];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_3d_global_to_shared(
    void *__dest, const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    asm volatile(
        "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
        "[%0], [%1, {%2, %3, %4}], [%5];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_4d_global_to_shared(
    void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, int __c2, int __c3, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    asm volatile(
        "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
        "[%0], [%1, {%2, %3, %4, %5}], [%6];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(__c3),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_5d_global_to_shared(
    void *__dest, const CUtensorMap *__tensor_map , int __c0, int __c1, int __c2, int __c3, int __c4, ::cuda::barrier<::cuda::thread_scope_block> &__bar)
{
    asm volatile(
        "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
        "[%0], [%1, {%2, %3, %4, %5, %6}], [%7];\n"
        :
        : "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__dest))),
          "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(__c3),
          "r"(__c4),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(::cuda::device::barrier_native_handle(__bar))))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_1d_shared_to_global(
    const CUtensorMap *__tensor_map, int __c0, const void *__src)
{
    asm volatile(
        "cp.async.bulk.tensor.1d.global.shared::cta.tile.bulk_group "
        "[%0, {%1}], [%2];\n"
        :
        : "l"(__tensor_map),
          "r"(__c0),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src)))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_2d_shared_to_global(
    const CUtensorMap *__tensor_map, int __c0, int __c1, const void *__src)
{
    asm volatile(
        "cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group "
        "[%0, {%1, %2}], [%3];\n"
        :
        : "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src)))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_3d_shared_to_global(
    const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, const void *__src)
{
    asm volatile(
        "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group "
        "[%0, {%1, %2, %3}], [%4];\n"
        :
        : "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src)))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_4d_shared_to_global(
    const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, int __c3, const void *__src)
{
    asm volatile(
        "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group "
        "[%0, {%1, %2, %3, %4}], [%5];\n"
        :
        : "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(__c3),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src)))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_tensor_5d_shared_to_global(
    const CUtensorMap *__tensor_map, int __c0, int __c1, int __c2, int __c3, int __c4, const void *__src)
{
    asm volatile(
        "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group "
        "[%0, {%1, %2, %3, %4, %5}], [%6];\n"
        :
        : "l"(__tensor_map),
          "r"(__c0),
          "r"(__c1),
          "r"(__c2),
          "r"(__c3),
          "r"(__c4),
          "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__src)))
        : "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar
inline _LIBCUDACXX_DEVICE
void fence_proxy_async_shared_cta() {
    asm volatile("fence.proxy.async.shared::cta; \n":::"memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_commit_group()
{
    asm volatile("cp.async.bulk.commit_group;\n" ::: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <int n_prior>
inline _LIBCUDACXX_DEVICE
void cp_async_bulk_wait_group_read()
{
  static_assert(n_prior <= 63, "cp_async_bulk_wait_group_read: waiting for more than 63 groups is not supported.");
  asm volatile("cp.async.bulk.wait_group.read %0; \n"
               :
               : "n"(n_prior)
               : "memory");
}

#endif // __cccl_lib_experimental_ctk12_cp_async_exposure

_LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE_EXPERIMENTAL

#endif // _CUDA_BARRIER
