1 #ifndef jb_tde_conjugate_and_multiply_hpp 2 #define jb_tde_conjugate_and_multiply_hpp 6 #include <boost/compute/algorithm/copy.hpp> 7 #include <boost/compute/buffer.hpp> 8 #include <boost/compute/container/vector.hpp> 19 return std::string(
"-DTYPENAME_MACRO=float2");
22 return std::string(
"__jaybeams_conjugate_and_multiply_float");
29 return std::string(
"-DTYPENAME_MACRO=double2");
32 return std::string(
"__jaybeams_conjugate_and_multiply_double");
36 template <
typename precision_t>
37 boost::compute::kernel
40 auto cache = boost::compute::program_cache::get_global_cache(context);
41 auto program = cache->get_or_build(
42 traits::program_name(), traits::flags(),
44 return program.create_kernel(
"conjugate_and_multiply");
47 template <
typename InputIterator,
typename OutputIterator>
49 InputIterator a_start, InputIterator a_end, InputIterator b_start,
50 InputIterator b_end, OutputIterator output,
51 boost::compute::command_queue& queue,
52 boost::compute::wait_list
const& wait = boost::compute::wait_list()) {
54 typename std::iterator_traits<InputIterator>::value_type input_value_type;
55 typedef typename std::iterator_traits<OutputIterator>::value_type
60 namespace bc = boost::compute;
61 namespace bcdetail = boost::compute::detail;
64 std::is_same<input_value_type, output_value_type>::value,
65 "jb::td::conjugate_and_multiply() input and output value types" 66 " must be identical");
68 std::is_same<std::complex<precision_type>, output_value_type>::value,
69 "jb::td::conjugate_and_multiply() value type must be an instance" 72 bc::is_device_iterator<InputIterator>::value,
73 "jb::td::conjugate_and_multiply() input range must be" 74 " a device container");
76 bc::is_device_iterator<OutputIterator>::value,
77 "jb::td::conjugate_and_multiply() output range must be" 78 " a device container");
80 std::size_t a_count = bcdetail::iterator_range_size(a_start, a_end);
82 return bc::future<OutputIterator>();
84 std::size_t b_count = bcdetail::iterator_range_size(b_start, b_end);
85 if (b_count != a_count) {
86 throw std::invalid_argument(
87 "jb::td::conjugate_and_multiply() mismatched range sizes");
90 bc::buffer
const& a_buffer = a_start.get_buffer();
91 std::size_t a_offset = a_start.get_index();
92 bc::buffer
const& b_buffer = b_start.get_buffer();
93 bc::buffer
const& dst_buffer = output.get_buffer();
96 conjugate_and_multiply_kernel<precision_type>(queue.get_context());
97 kernel.set_arg(0, dst_buffer);
98 kernel.set_arg(1, a_buffer);
99 kernel.set_arg(2, b_buffer);
100 kernel.set_arg(3, cl_uint(a_count));
103 queue.enqueue_1d_range_kernel(kernel, a_offset, a_count, 0, wait);
105 return bc::make_future(
106 bcdetail::iterator_plus_distance(output, a_count), event);
112 #endif // jb_tde_conjugate_and_multiply_hpp static std::string program_name()
static std::string flags()
boost::compute::kernel conjugate_and_multiply_kernel(boost::compute::context context)
static std::string flags()
char const conjugate_and_multiply_kernel_source[]
Contains the code for the kernels used in computing the argmax.
static std::string program_name()
boost::compute::future< OutputIterator > conjugate_and_multiply(InputIterator a_start, InputIterator a_end, InputIterator b_start, InputIterator b_end, OutputIterator output, boost::compute::command_queue &queue, boost::compute::wait_list const &wait=boost::compute::wait_list())
The top-level namespace for the JayBeams library.