C++代码设计:向Java借鉴Builder模式塈OpenCL内核代码编译

版权声明:本文为博主原创文章,转载请注明源地址。 https://blog.csdn.net/10km/article/details/50786063

Builder模式

所谓的builder模式是指在设计Java代码时,当方法调用的参数过多的时候,可以用builder模式将所有参数封装在一个类中,然后将这个类的实例做为参数传递给方法。这样以来方法只需要接收一个类参数,就能获取所有想要的参数,尤其是对于多个类似方法,都需要差不多相同的参数的情况下,这种设计就更加有效率,可以减少方法调用的复杂度,减少出错的机会,如果你还不懂什么叫builder模式,这篇文章介绍的很详细:《Java方法参数太多怎么办—Part3—Builder模式》 Builder模式传递参数在Java代码中应用挺广泛,下面是HttpClient中RequestConfig参数类的代码,封装了用于Http请求的16个参数,就是典型的builder模式,所有的Http请求方法都会用到这个类中的参数。

/*
 * ====================================================================
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 * ====================================================================
 *
 * This software consists of voluntary contributions made by many
 * individuals on behalf of the Apache Software Foundation.  For more
 * information on the Apache Software Foundation, please see
 * <http://www.apache.org/>.
 *
 */

package org.apache.http.client.config;

import java.net.InetAddress;
import java.util.Collection;

import org.apache.http.HttpHost;
import org.apache.http.annotation.Immutable;

/**
 *  Immutable class encapsulating request configuration items.
 *  The default setting for stale connection checking changed
 *  to false, and the feature was deprecated starting with version 4.4.
 */
@Immutable
public class RequestConfig implements Cloneable {

    public static final RequestConfig DEFAULT = new Builder().build();

    private final boolean expectContinueEnabled;
    private final HttpHost proxy;
    private final InetAddress localAddress;
    private final boolean staleConnectionCheckEnabled;
    private final String cookieSpec;
    private final boolean redirectsEnabled;
    private final boolean relativeRedirectsAllowed;
    private final boolean circularRedirectsAllowed;
    private final int maxRedirects;
    private final boolean authenticationEnabled;
    private final Collection<String> targetPreferredAuthSchemes;
    private final Collection<String> proxyPreferredAuthSchemes;
    private final int connectionRequestTimeout;
    private final int connectTimeout;
    private final int socketTimeout;
    private final boolean decompressionEnabled;

    RequestConfig(
            final boolean expectContinueEnabled,
            final HttpHost proxy,
            final InetAddress localAddress,
            final boolean staleConnectionCheckEnabled,
            final String cookieSpec,
            final boolean redirectsEnabled,
            final boolean relativeRedirectsAllowed,
            final boolean circularRedirectsAllowed,
            final int maxRedirects,
            final boolean authenticationEnabled,
            final Collection<String> targetPreferredAuthSchemes,
            final Collection<String> proxyPreferredAuthSchemes,
            final int connectionRequestTimeout,
            final int connectTimeout,
            final int socketTimeout,
            final boolean decompressionEnabled) {
        super();
        this.expectContinueEnabled = expectContinueEnabled;
        this.proxy = proxy;
        this.localAddress = localAddress;
        this.staleConnectionCheckEnabled = staleConnectionCheckEnabled;
        this.cookieSpec = cookieSpec;
        this.redirectsEnabled = redirectsEnabled;
        this.relativeRedirectsAllowed = relativeRedirectsAllowed;
        this.circularRedirectsAllowed = circularRedirectsAllowed;
        this.maxRedirects = maxRedirects;
        this.authenticationEnabled = authenticationEnabled;
        this.targetPreferredAuthSchemes = targetPreferredAuthSchemes;
        this.proxyPreferredAuthSchemes = proxyPreferredAuthSchemes;
        this.connectionRequestTimeout = connectionRequestTimeout;
        this.connectTimeout = connectTimeout;
        this.socketTimeout = socketTimeout;
        this.decompressionEnabled = decompressionEnabled;
    }

    /**
     * Determines whether the 'Expect: 100-Continue' handshake is enabled
     * for entity enclosing methods. The purpose of the 'Expect: 100-Continue'
     * handshake is to allow a client that is sending a request message with
     * a request body to determine if the origin server is willing to
     * accept the request (based on the request headers) before the client
     * sends the request body.
     * <p>
     * The use of the 'Expect: 100-continue' handshake can result in
     * a noticeable performance improvement for entity enclosing requests
     * (such as POST and PUT) that require the target server's
     * authentication.
     * </p>
     * <p>
     * 'Expect: 100-continue' handshake should be used with caution, as it
     * may cause problems with HTTP servers and proxies that do not support
     * HTTP/1.1 protocol.
     * </p>
     * <p>
     * Default: {@code false}
     * </p>
     */
    public boolean isExpectContinueEnabled() {
        return expectContinueEnabled;
    }

    /**
     * Returns HTTP proxy to be used for request execution.
     * <p>
     * Default: {@code null}
     * </p>
     */
    public HttpHost getProxy() {
        return proxy;
    }

    /**
     * Returns local address to be used for request execution.
     * <p>
     * On machines with multiple network interfaces, this parameter
     * can be used to select the network interface from which the
     * connection originates.
     * </p>
     * <p>
     * Default: {@code null}
     * </p>
     */
    public InetAddress getLocalAddress() {
        return localAddress;
    }

    /**
     * Determines whether stale connection check is to be used. The stale
     * connection check can cause up to 30 millisecond overhead per request and
     * should be used only when appropriate. For performance critical
     * operations this check should be disabled.
     * <p>
     * Default: {@code false} since 4.4
     * </p>
     *
     * @deprecated (4.4) Use {@link
     *   org.apache.http.impl.conn.PoolingHttpClientConnectionManager#getValidateAfterInactivity()}
     */
    @Deprecated
    public boolean isStaleConnectionCheckEnabled() {
        return staleConnectionCheckEnabled;
    }

    /**
     * Determines the name of the cookie specification to be used for HTTP state
     * management.
     * <p>
     * Default: {@code null}
     * </p>
     */
    public String getCookieSpec() {
        return cookieSpec;
    }

    /**
     * Determines whether redirects should be handled automatically.
     * <p>
     * Default: {@code true}
     * </p>
     */
    public boolean isRedirectsEnabled() {
        return redirectsEnabled;
    }

    /**
     * Determines whether relative redirects should be rejected. HTTP specification
     * requires the location value be an absolute URI.
     * <p>
     * Default: {@code true}
     * </p>
     */
    public boolean isRelativeRedirectsAllowed() {
        return relativeRedirectsAllowed;
    }

    /**
     * Determines whether circular redirects (redirects to the same location) should
     * be allowed. The HTTP spec is not sufficiently clear whether circular redirects
     * are permitted, therefore optionally they can be enabled
     * <p>
     * Default: {@code false}
     * </p>
     */
    public boolean isCircularRedirectsAllowed() {
        return circularRedirectsAllowed;
    }

    /**
     * Returns the maximum number of redirects to be followed. The limit on number
     * of redirects is intended to prevent infinite loops.
     * <p>
     * Default: {@code 50}
     * </p>
     */
    public int getMaxRedirects() {
        return maxRedirects;
    }

    /**
     * Determines whether authentication should be handled automatically.
     * <p>
     * Default: {@code true}
     * </p>
     */
    public boolean isAuthenticationEnabled() {
        return authenticationEnabled;
    }

    /**
     * Determines the order of preference for supported authentication schemes
     * when authenticating with the target host.
     * <p>
     * Default: {@code null}
     * </p>
     */
    public Collection<String> getTargetPreferredAuthSchemes() {
        return targetPreferredAuthSchemes;
    }

    /**
     * Determines the order of preference for supported authentication schemes
     * when authenticating with the proxy host.
     * <p>
     * Default: {@code null}
     * </p>
     */
    public Collection<String> getProxyPreferredAuthSchemes() {
        return proxyPreferredAuthSchemes;
    }

    /**
     * Returns the timeout in milliseconds used when requesting a connection
     * from the connection manager. A timeout value of zero is interpreted
     * as an infinite timeout.
     * <p>
     * A timeout value of zero is interpreted as an infinite timeout.
     * A negative value is interpreted as undefined (system default).
     * </p>
     * <p>
     * Default: {@code -1}
     * </p>
     */
    public int getConnectionRequestTimeout() {
        return connectionRequestTimeout;
    }

    /**
     * Determines the timeout in milliseconds until a connection is established.
     * A timeout value of zero is interpreted as an infinite timeout.
     * <p>
     * A timeout value of zero is interpreted as an infinite timeout.
     * A negative value is interpreted as undefined (system default).
     * </p>
     * <p>
     * Default: {@code -1}
     * </p>
     */
    public int getConnectTimeout() {
        return connectTimeout;
    }

    /**
     * Defines the socket timeout ({@code SO_TIMEOUT}) in milliseconds,
     * which is the timeout for waiting for data  or, put differently,
     * a maximum period inactivity between two consecutive data packets).
     * <p>
     * A timeout value of zero is interpreted as an infinite timeout.
     * A negative value is interpreted as undefined (system default).
     * </p>
     * <p>
     * Default: {@code -1}
     * </p>
     */
    public int getSocketTimeout() {
        return socketTimeout;
    }

    /**
     * Determines whether compressed entities should be decompressed automatically.
     * <p>
     * Default: {@code true}
     * </p>
     *
     * @since 4.4
     */
    public boolean isDecompressionEnabled() {
        return decompressionEnabled;
    }

    @Override
    protected RequestConfig clone() throws CloneNotSupportedException {
        return (RequestConfig) super.clone();
    }

    @Override
    public String toString() {
        final StringBuilder builder = new StringBuilder();
        builder.append("[");
        builder.append("expectContinueEnabled=").append(expectContinueEnabled);
        builder.append(", proxy=").append(proxy);
        builder.append(", localAddress=").append(localAddress);
        builder.append(", cookieSpec=").append(cookieSpec);
        builder.append(", redirectsEnabled=").append(redirectsEnabled);
        builder.append(", relativeRedirectsAllowed=").append(relativeRedirectsAllowed);
        builder.append(", maxRedirects=").append(maxRedirects);
        builder.append(", circularRedirectsAllowed=").append(circularRedirectsAllowed);
        builder.append(", authenticationEnabled=").append(authenticationEnabled);
        builder.append(", targetPreferredAuthSchemes=").append(targetPreferredAuthSchemes);
        builder.append(", proxyPreferredAuthSchemes=").append(proxyPreferredAuthSchemes);
        builder.append(", connectionRequestTimeout=").append(connectionRequestTimeout);
        builder.append(", connectTimeout=").append(connectTimeout);
        builder.append(", socketTimeout=").append(socketTimeout);
        builder.append(", decompressionEnabled=").append(decompressionEnabled);
        builder.append("]");
        return builder.toString();
    }

    public static RequestConfig.Builder custom() {
        return new Builder();
    }

    @SuppressWarnings("deprecation")
    public static RequestConfig.Builder copy(final RequestConfig config) {
        return new Builder()
            .setExpectContinueEnabled(config.isExpectContinueEnabled())
            .setProxy(config.getProxy())
            .setLocalAddress(config.getLocalAddress())
            .setStaleConnectionCheckEnabled(config.isStaleConnectionCheckEnabled())
            .setCookieSpec(config.getCookieSpec())
            .setRedirectsEnabled(config.isRedirectsEnabled())
            .setRelativeRedirectsAllowed(config.isRelativeRedirectsAllowed())
            .setCircularRedirectsAllowed(config.isCircularRedirectsAllowed())
            .setMaxRedirects(config.getMaxRedirects())
            .setAuthenticationEnabled(config.isAuthenticationEnabled())
            .setTargetPreferredAuthSchemes(config.getTargetPreferredAuthSchemes())
            .setProxyPreferredAuthSchemes(config.getProxyPreferredAuthSchemes())
            .setConnectionRequestTimeout(config.getConnectionRequestTimeout())
            .setConnectTimeout(config.getConnectTimeout())
            .setSocketTimeout(config.getSocketTimeout())
            .setDecompressionEnabled(config.isDecompressionEnabled());
    }

    public static class Builder {

        private boolean expectContinueEnabled;
        private HttpHost proxy;
        private InetAddress localAddress;
        private boolean staleConnectionCheckEnabled;
        private String cookieSpec;
        private boolean redirectsEnabled;
        private boolean relativeRedirectsAllowed;
        private boolean circularRedirectsAllowed;
        private int maxRedirects;
        private boolean authenticationEnabled;
        private Collection<String> targetPreferredAuthSchemes;
        private Collection<String> proxyPreferredAuthSchemes;
        private int connectionRequestTimeout;
        private int connectTimeout;
        private int socketTimeout;
        private boolean decompressionEnabled;

        Builder() {
            super();
            this.staleConnectionCheckEnabled = false;
            this.redirectsEnabled = true;
            this.maxRedirects = 50;
            this.relativeRedirectsAllowed = true;
            this.authenticationEnabled = true;
            this.connectionRequestTimeout = -1;
            this.connectTimeout = -1;
            this.socketTimeout = -1;
            this.decompressionEnabled = true;
        }

        public Builder setExpectContinueEnabled(final boolean expectContinueEnabled) {
            this.expectContinueEnabled = expectContinueEnabled;
            return this;
        }

        public Builder setProxy(final HttpHost proxy) {
            this.proxy = proxy;
            return this;
        }

        public Builder setLocalAddress(final InetAddress localAddress) {
            this.localAddress = localAddress;
            return this;
        }

        /**
         * @deprecated (4.4) Use {@link
         *   org.apache.http.impl.conn.PoolingHttpClientConnectionManager#setValidateAfterInactivity(int)}
         */
        @Deprecated
        public Builder setStaleConnectionCheckEnabled(final boolean staleConnectionCheckEnabled) {
            this.staleConnectionCheckEnabled = staleConnectionCheckEnabled;
            return this;
        }

        public Builder setCookieSpec(final String cookieSpec) {
            this.cookieSpec = cookieSpec;
            return this;
        }

        public Builder setRedirectsEnabled(final boolean redirectsEnabled) {
            this.redirectsEnabled = redirectsEnabled;
            return this;
        }

        public Builder setRelativeRedirectsAllowed(final boolean relativeRedirectsAllowed) {
            this.relativeRedirectsAllowed = relativeRedirectsAllowed;
            return this;
        }

        public Builder setCircularRedirectsAllowed(final boolean circularRedirectsAllowed) {
            this.circularRedirectsAllowed = circularRedirectsAllowed;
            return this;
        }

        public Builder setMaxRedirects(final int maxRedirects) {
            this.maxRedirects = maxRedirects;
            return this;
        }

        public Builder setAuthenticationEnabled(final boolean authenticationEnabled) {
            this.authenticationEnabled = authenticationEnabled;
            return this;
        }

        public Builder setTargetPreferredAuthSchemes(final Collection<String> targetPreferredAuthSchemes) {
            this.targetPreferredAuthSchemes = targetPreferredAuthSchemes;
            return this;
        }

        public Builder setProxyPreferredAuthSchemes(final Collection<String> proxyPreferredAuthSchemes) {
            this.proxyPreferredAuthSchemes = proxyPreferredAuthSchemes;
            return this;
        }

        public Builder setConnectionRequestTimeout(final int connectionRequestTimeout) {
            this.connectionRequestTimeout = connectionRequestTimeout;
            return this;
        }

        public Builder setConnectTimeout(final int connectTimeout) {
            this.connectTimeout = connectTimeout;
            return this;
        }

        public Builder setSocketTimeout(final int socketTimeout) {
            this.socketTimeout = socketTimeout;
            return this;
        }

        public Builder setDecompressionEnabled(final boolean decompressionEnabled) {
            this.decompressionEnabled = decompressionEnabled;
            return this;
        }

        public RequestConfig build() {
            return new RequestConfig(
                    expectContinueEnabled,
                    proxy,
                    localAddress,
                    staleConnectionCheckEnabled,
                    cookieSpec,
                    redirectsEnabled,
                    relativeRedirectsAllowed,
                    circularRedirectsAllowed,
                    maxRedirects,
                    authenticationEnabled,
                    targetPreferredAuthSchemes,
                    proxyPreferredAuthSchemes,
                    connectionRequestTimeout,
                    connectTimeout,
                    socketTimeout,
                    decompressionEnabled);
        }

    }

}

向Java借鉴

C++的函数定义可以为提供参数缺省值,这是比Java方便的优点,可以因此比Java少定义一些重载函数,但 C++的重构能力远不如Java,同一个函数具备多个重载函数版本时,代码维护的困难还是比Java更大。所以这种情况下借用Java的Builder模式封装参数的办法对C++来说代码收益就显得更大。

OpenCL实例说明

下面以我最近涉及的OpenCL相关开发工作为例,说说我的困扰。

OpenCL开发中,需要对OpenCL设备(GPU/CPU)进行内核编程(C99语言,这不在本文件讨论的范围),所以会写一些C代码,就是所谓的kernel代码,如果要想要在OpenCL设备上执行kernel,首先要调用OpenCL的函数编译这些代码,将它们编译成可执行的程序(Exceutable Program),然后通过Program创建kernel,然后才能执行kernel 所以在OpenCL C++接口(cl.hpp)中定义了cl::Program,cl::Kernel类 以我们的主机平台(windows/linux….)上的开发经验,我们知道,要将一个C/C++代码编译成目标文件(exe OR 动态库),需要经历complie,link两个阶段,complie阶段编译所有的C/C++为obj,link阶段将所有的obj连接生成目标文件,其实编译kernel也是一样一样的啊。。。 所以,cl::Program对应就有build,complie函数,以及cl::linkProgram函数 其中build函数包含了complie/link,用于将单个源码编译成可执行程序; 下面是cl::Program类的主要构造函数和方法的定义(摘自cl.hpp)。

//cl::Program构造函数
Program(const STRING_CLASS& source,bool build = false, cl_int* err = NULL);
//cl::Program构造函数
Program(const Context& context,const STRING_CLASS& source,bool build = false,cl_int* err = NULL);
//cl::Program构造函数
Program(const Context& context,const Sources& sources,cl_int* err = NULL)
//cl::Program成员函数 将一个source 源码编译成一个可以可执行程序(Executable Program)
cl_int build(
    const VECTOR_CLASS<Device>& devices,
    const char* options = NULL,
    void (CL_CALLBACK * notifyFptr)(cl_program, void *) = NULL,
    void* data = NULL) const;
//cl::Program成员函数 将一个source 源码编译成obj (不可执行的Program)用于link
cl_int compile(
    const char* options = NULL,
    void (CL_CALLBACK * notifyFptr)(cl_program, void *) = NULL,
    void* data = NULL) const;
//全局函数 将两个cl::Program编译成一个可以可执行程序(Executable Program)       
inline Program linkProgram(
    Program input1,
    Program input2,
    const char* options = NULL,
    void (CL_CALLBACK * notifyFptr)(cl_program, void *) = NULL,
    void* data = NULL,
    cl_int* err = NULL);   
// link函数 将多个cl::Program编译成一个可以可执行程序(Executable Program)       
inline Program linkProgram(
    VECTOR_CLASS<Program> inputPrograms,
    const char* options = NULL,
    void (CL_CALLBACK * notifyFptr)(cl_program, void *) = NULL,
    void* data = NULL,
    cl_int* err = NULL)           

从上面的代码中可以看出,构造一个cl::Program并编译成可Exceutable Program主要可能提供的基本参数有:

const Context& context //设备上下文件对象,本项目中必须提供
const VECTOR_CLASS<Device>& devices // 设备对象列表,缺省值为空数组
const STRING_CLASS& source // 源码 
const STRING_CLASS& source_name // 源码的名字,用于编译出错时的问题跟踪,缺省值"Unknow_name"
const char* options // 编译选项 缺省为nullptr

其他参数如notifyFptr,data ,err,在本项目中都使用缺省参数

这其中: 内核源码可能是一个字符串,也可能来自本地文件,所以它的类型可能是一个代表源码的字符串,也可能是代表文件名的字符串 另外一个可执行程序可以由一个源码生成,也可以是由多个源码编译后连接生成,所以源码可以有多个。多个源码的情况应该用std::vector<std::string>来描述 设备对象列表允许不提供,所以需要有缺省参数 编译选项允许不提供,所以需要有缺省参数 内核代码编译时,也有不少的编译选项,但有两个最基本的编译选项-D -I -I 当源码中有#include其他文件时,需要在options中用-I 指定#include文件的搜索路径 -D 可以为内核源码提供宏定义 以下描述来自opencl官网:clBuildProgram

Preprocessor Options These options control the OpenCL C preprocessor which is run on each program source before actual compilation. -D options are processed in the order they are given in the options argument to clBuildProgram or or clCompileProgram. -D name Predefine name as a macro, with definition 1. -D name=definition The contents of definition are tokenized and processed as if they appeared during translation phase three in a `#define’ directive. In particular, the definition will be truncated by embedded newline characters. -I dir Add the directory dir to the list of directories to be searched for header files.

如果使用传统的方式,要提供一组编译内核源的函数,且满足上述要求,需要定义如下的函数:

//////从单个源码生成cl::Program////////////////////////
cl::Program createProgram(const cl::Context& context,
    const std::string &source, 
    const std::string& source_name=Unknow_Name,// 定义源码名字,Unknow_Name为全局常量
    const char* options=nullptr,
    const std::vector<cl::Device>& devices=Empty_Devices_Vector //Empty_Devices_Vector 为全局常量
    );
cl::Program createProgram(const cl::Context& context,
    const char*file, 
    const std::string& source_name=Unknow_Name,
    const char* options=nullptr,
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);
//////从多个源码生成cl::Program////////////////////////    
cl::Program createProgram(const cl::Context& context,
    const std::vector<std::pair<std::string,std::string>> &sources, //定义多个源码及对应的源码,vector元素类型为std::pair类,pair.first为源码名字,pairt.second为源码
    const char* options=nullptr,
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);
cl::Program createProgram(const cl::Context& context,
    const std::vector<cl::Device>& devices, 
    const std::vector<std::string> &source_files, //定义多个源码文件,每个源码的source_name为文件名本身
    const char* options=nullptr,
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);

/////////////基于上述4个函数定义增加source_root,define_str,include_str  参数////////////
//////从单个源码生成cl::Program////////////////////////
cl::Program createProgram(const cl::Context& context,
    const std::string &source, 
    const std::string& source_name=Unknow_Name,// 定义源码名字,Unknow_Name为全局常量
    const char* define_str =nullptr, // 编译选项中的-D选项
    const char* include_str =nullptr,// 编译选项中的-I选项
    const char* other_options=nullptr,// 编译选项中的其他选项
    const std::vector<cl::Device>& devices=Empty_Devices_Vector //Empty_Devices_Vector 为全局常量
    );
cl::Program createProgram(const cl::Context& context,
    const char*file, 
    const std::string& source_name=Unknow_Name,
    const char* define_str =nullptr, // 编译选项中的-D选项
    const char* include_str =nullptr,// 编译选项中的-I选项
    const char* other_options=nullptr,// 编译选项中的其他选项
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);
//////从多个源码生成cl::Program////////////////////////    
cl::Program createProgram(const cl::Context& context,
    const std::vector<std::pair<std::string,std::string>> &sources, //定义多个源码及对应的源码,vector元素类型为std::pair类,pair.first为源码名字,pairt.second为源码
    const char* source_root=nullptr,// 源码的根目录位置
    const char* define_str =nullptr, // 编译选项中的-D选项
    const char* include_str =nullptr,// 编译选项中的-I选项
    const char* other_options=nullptr,// 编译选项中的其他选项
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);
cl::Program createProgram(const cl::Context& context,
    const std::vector<cl::Device>& devices, 
    const std::vector<std::string> &source_files, //定义多个源码文件,每个源码的source_name为文件名本身
    const char* source_root=nullptr,// 源码的根目录位置
    const char* define_str =nullptr, // 编译选项中的-D选项
    const char* include_str =nullptr,// 编译选项中的-I选项
    const char* other_options=nullptr,// 编译选项中的其他选项
    const std::vector<cl::Device>& devices=Empty_Devices_Vector);

好吧,看到上述8个函数定义,你有没头大的感觉?反正我当初写这么多函数的时候花了一天时间,已经头大了,虽然每个函数都不只有几行,但相似内容太多,非常容易搞错,维护起来甚是麻烦,如果未来要加入更多的参数(比如前面忽略的notifyFptr,data ,err参数),这代码真是没办法改了。

build_param封装所有参数

叔可忍,婶不可忍呐,写完上面这些代码我已经快崩溃了,第二天,痛定思痛,我想到了以前写Java代码时用到的builder模式。决定对重写上面的代码,将所有编译内核所需要的参数封装到build_param类中。

/* 内核程序编译参数类 */
struct build_param{
    // 内核源码描述类型 pair.first为源码名字,pairt.second为源码
    using source_info_type =std::pair<std::string,std::string>;
    cl::Context context; // 设备上下文
    std::vector<cl::Device> devices; // 指定编译的目标设备列表
    std::vector<source_info_type> sources; // 要编译的源码清单
    std::string options; // 所有编译选项
    class builder{
    private:
        const cl::Context _context;
        std::vector<cl::Device> _devices;
        std::vector<source_info_type> _sources;
        std::string _source_root=Empty_String;
        std::vector<std::string> _source_files;
        std::string _options=Empty_String;          
    public:
        // builder 构造函数
        builder(const cl::Context &context) :_context(context) {}
        // 设置_devices参数,提供目标设备列表清单
        builder &set_devices(std::initializer_list<cl::Device>devices){
            this->_devices=devices;
            return *this;
        }
        // 设置_devices参数,提供目标设备列表清单
        builder &set_devices(const std::vector<cl::Device>&devices){
            this->_devices=devices;
            return *this;
        }
        // 设置添加一组源码对象
        builder &add_sources(std::initializer_list<source_info_type>sources){
            this->_sources.insert(this->_sources.end(),sources);
            return *this;
        }
        // 设置添加一个源码对象
        builder &add_source(const source_info_type &source){
            return add_sources({source});
        }
        // 设置添加一个源码对象,source_name可以不提供
        builder &add_source(const std::string &source,const std::string &source_name=Unknow_Name){
            return add_source({source,source_name});
        }
        /* 指定所有源码的根路径
         * 同时调用add_include将根路径指定为#include搜索路径
         * */
        builder &set_source_root(const std::string &root){
            if(_source_root.empty()&&!root.empty()){
                _source_root=root;
                add_include(root);
            }
            return *this;
        }
        /* 加入一组source源码文件 */
        builder &add_source_files(std::initializer_list<std::string>sources){
            this->_source_files.insert(this->_source_files.end(),sources);
            return *this;
        }
        /* 加入一个source源码文件 */
        builder &add_source_file(const std::string &source) {
            return add_source_files({source});
        }
        /* 添加预定义宏 -D */
        builder &add_define(std::string def){
            throw_if(def.empty(),"def is empty")
            _options+="-D "+def+" ";
            return *this;
        }
        /* 添加头文件搜索路径 -I */
        builder &add_include(std::string dir){
            auto p=gdface::trim(dir);
            throw_if(p.empty(),"dir is empty")
            if(gdface::has_space(p))
                p="\""+p+"\"";
            _options+="-I "+p+" ";
            return *this;
        }
        /* 添加通用编译选项 */
        builder &add_options(std::string opt){
            if(!opt.empty())
                _options+=opt+" ";
            return *this;
        }
        /* 生成build_param对象 */
        build_param build(){
            if(_source_root.empty()){
                add_include(gdface::getcwd()); //将当前路径加入#include文件搜索路径
            }
            // 将源码文件转为std::string
            for(auto file:_source_files){
                throw_if(file.empty(),"the argument 'file' is empty")
                this->_sources.emplace_back(file,gdface::load_string(gdface::path_concate(_source_root,file).data()));
            }
            // 返回build_param对象
            return {_context,_devices,_sources,_options};
        }
    };
    /* 静态成员函数 根据cl::Context构造builder对象 */
    static builder custom(const cl::Context &context){
        return builder(context);
    }
    // 允许缺省构造
    build_param() = default;
    // 允许复制构造
    build_param(const build_param &)=default;
    // 允许移动构造
    build_param(build_param &&)=default;
    // 允许赋值操作符
    build_param&operator=(const build_param &)=default;
    build_param(const cl::Context &context,
            const std::vector<cl::Device> &devices,
            const std::vector<source_info_type> &sources,
            const std::string options):context(context),devices(devices),sources(sources),options(options){}
};

有了build_param类,编译内核程序的函数定义就简化成了下面这样。

/* (通用)编译源码源码生成可执行程序  */
cl::Program buildExecutableProgram(const build_param& param){
    // 源码数目为1时,直接调用buildSource编译成可执行程序
    if(1==param.sources.size()){
        return buildSource({param.context,param.devices,{param.sources[0]},param.options});
    }
    return buildMultiFilesProgram(param);
}

buildExecutableProgram会自动根据source的个数来决定是否调用buildSource还是调用buildMultiFilesProgram。

上面buildSourcebuildMultiFilesProgram函数的实现是这样滴,

/* 编译一个源码, 参数complie_only指定是否只编译成obj */
cl::Program buildSource(const build_param& param, bool complie_only = false) {
    try {
        throw_if(1 != param.sources.size(), "size of build_param::sources must be 1")
        cl::Program program(param.context, param.sources[0].second);
        try {
#ifdef CL_VERSION_1_2
            if (complie_only) {
                show_on_buildstart(_DEF_STRING(complie), param.options.data(), param.sources[0]);
                program.compile(param.options.data());
            } else
#else
            throw_exception_if(face_cl_build_exception,complie_only,"unsupported complie under version OpenCL 1.2")
#endif
            {
                show_on_buildstart(_DEF_STRING(build), param.options.data(), param.sources[0]);
                if (param.devices.size())
                    program.build(param.devices, param.options.data());
                else
                    program.build(param.options.data());
            }
            show_on_buildend(program, param.sources[0]);
            return std::move(program);
        }
#ifdef CL_VERSION_2_0
        //当OpenCL版本为2.0以上时,编译错误抛出cl::BuildError异常
        catch (cl::BuildError &e) {
            auto log = e.getBuildLog();
            showBuildLog(log, param.sources[0].first);
            throw face_cl_build_exception(SOURCE_AT, e);
        }
#else
        //当OpenCL版本为1.1,1.2时,编译错误抛出cl::Error异常
        catch (cl::Error& e) {
            auto log = cl_utilits::getBuildInfo<CL_PROGRAM_BUILD_LOG>(program);
            showBuildLog(log, param.sources[0].first);
            throw face_cl_build_exception(SOURCE_AT, e, log);
        }
#endif
    } catch (cl::Error& e) {
        throw face_cl_build_exception(SOURCE_AT, e);
    }
}

/* 编译连接多个源码生成可执行程序  */
cl::Program cl_utilits::buildMultiFilesProgram(const build_param& param) {
    // 源码数目为0时抛出异常
    throw_if(0==param.sources.size(),"size of build_param::sources must not be 0")
    std::vector<cl::Program> objs;
    // 编译每一个源码为obj
    for (auto source : param.sources) {
        objs.emplace_back(compileSource({ param.context,param.devices,{ source },param.options }));
    }
    return cl::linkProgram(objs, param.options.data());// 连接obj生成可执行程序
}

有了build_param封装所有参数,未来即使加入更多的参数,只需要在build_param和相关的函数中增加相应的代码,而不需要修改所有的函数接口定义,可维护性也相应提高了。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励