File size: 2,695 Bytes
be11144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#pragma once

#include <string>

#include <thrust/detail/static_assert.h>
#undef THRUST_STATIC_ASSERT
#undef THRUST_STATIC_ASSERT_MSG

#define THRUST_STATIC_ASSERT(B) unittest::assert_static((B), __FILE__, __LINE__);
#define THRUST_STATIC_ASSERT_MSG(B, msg) unittest::assert_static((B), __FILE__, __LINE__);

namespace unittest
{
    __host__ __device__
    void assert_static(bool condition, const char * filename, int lineno);
}

#include <thrust/device_new.h>
#include <thrust/device_delete.h>

#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA

#define ASSERT_STATIC_ASSERT(X) \
    { \
        bool triggered = false; \
        typedef unittest::static_assert_exception ex_t; \
        thrust::device_ptr<ex_t> device_ptr = thrust::device_new<ex_t>(); \
        ex_t* raw_ptr = thrust::raw_pointer_cast(device_ptr); \
        ::cudaMemcpyToSymbol(unittest::detail::device_exception, &raw_ptr, sizeof(ex_t*)); \
        try { X; } catch (ex_t) { triggered = true; } \
        if (!triggered) { \
            triggered = static_cast<ex_t>(*device_ptr).triggered; \
        } \
        thrust::device_free(device_ptr); \
        raw_ptr = NULL; \
        ::cudaMemcpyToSymbol(unittest::detail::device_exception, &raw_ptr, sizeof(ex_t*)); \
        if (!triggered) { unittest::UnitTestFailure f; f << "[" << __FILE__ << ":" << __LINE__ << "] did not trigger a THRUST_STATIC_ASSERT"; throw f; } \
    }

#else

#define ASSERT_STATIC_ASSERT(X) \
    { \
        bool triggered = false; \
        typedef unittest::static_assert_exception ex_t; \
        try { X; } catch (ex_t) { triggered = true; } \
        if (!triggered) { unittest::UnitTestFailure f; f << "[" << __FILE__ << ":" << __LINE__ << "] did not trigger a THRUST_STATIC_ASSERT"; throw f; } \
    }

#endif

namespace unittest
{
    class static_assert_exception
    {
    public:
        __host__ __device__
        static_assert_exception() : triggered(false)
        {
        }

        __host__ __device__
        static_assert_exception(const char * filename, int lineno)
            : triggered(true), filename(filename), lineno(lineno)
        {
        }

        bool triggered;
        const char * filename;
        int lineno;
    };

    namespace detail
    {
#ifdef __clang__
        __attribute__((used))
#endif
        __device__ static static_assert_exception* device_exception = NULL;
    }

    __host__ __device__
    void assert_static(bool condition, const char * filename, int lineno)
    {
        if (!condition)
        {
            static_assert_exception ex(filename, lineno);

#ifdef __CUDA_ARCH__
            *detail::device_exception = ex;
#else
            throw ex;
#endif
        }
    }
}