diff --git a/src/async_client.cpp b/src/async_client.cpp index f1e3226..a6ad6e1 100644 --- a/src/async_client.cpp +++ b/src/async_client.cpp @@ -35,101 +35,70 @@ namespace mqtt { // Constructors async_client::async_client(const string& serverURI, const string& clientId, - const string& persistDir) - : async_client(serverURI, clientId, 0, persistDir) + const string& persistDir, + ipersistence_encoder* encoder /*=nullptr*/) + : async_client(serverURI, clientId, 0, persistDir, encoder) { } async_client::async_client(const string& serverURI, const string& clientId, - iclient_persistence* persistence /*=nullptr*/) - : async_client(serverURI, clientId, 0, persistence) + iclient_persistence* persistence /*=nullptr*/, + ipersistence_encoder* encoder /*=nullptr*/) + : async_client(serverURI, clientId, 0, persistence, encoder) { } async_client::async_client(const string& serverURI, const string& clientId, - int maxBufferedMessages, const string& persistDir) - : serverURI_(serverURI), clientId_(clientId), mqttVersion_(MQTTVERSION_DEFAULT), - persist_(nullptr), userCallback_(nullptr) + int maxBufferedMessages, const string& persistDir, + ipersistence_encoder* encoder /*=nullptr*/) + : serverURI_(serverURI), clientId_(clientId), + mqttVersion_(MQTTVERSION_DEFAULT), userCallback_(nullptr) { - create_options opts; - - if (maxBufferedMessages != 0) { - opts.set_send_while_disconnected(true); - opts.set_max_buffered_messages(maxBufferedMessages); - } + create_options opts(maxBufferedMessages); int rc = MQTTAsync_createWithOptions(&cli_, serverURI.c_str(), clientId.c_str(), MQTTCLIENT_PERSISTENCE_DEFAULT, const_cast(persistDir.c_str()), &opts.opts_); - if (rc != 0) throw exception(rc); + + persistence_encoder(encoder); } async_client::async_client(const string& serverURI, const string& clientId, - int maxBufferedMessages, iclient_persistence* persistence /*=nullptr*/) - : serverURI_(serverURI), clientId_(clientId), mqttVersion_(MQTTVERSION_DEFAULT), - persist_(nullptr), userCallback_(nullptr) + int maxBufferedMessages, + iclient_persistence* persistence /*=nullptr*/, + ipersistence_encoder* encoder /*=nullptr*/) + : async_client(serverURI, clientId, + create_options(maxBufferedMessages), + persistence, encoder) { - create_options opts; - - if (maxBufferedMessages != 0) { - opts.set_send_while_disconnected(true); - opts.set_max_buffered_messages(maxBufferedMessages); - } - - int rc = MQTTASYNC_SUCCESS; - - if (!persistence) { - rc = MQTTAsync_createWithOptions(&cli_, serverURI.c_str(), clientId.c_str(), - MQTTCLIENT_PERSISTENCE_NONE, nullptr, - &opts.opts_); - } - else { - persist_.reset(new MQTTClient_persistence { - persistence, - &iclient_persistence::persistence_open, - &iclient_persistence::persistence_close, - &iclient_persistence::persistence_put, - &iclient_persistence::persistence_get, - &iclient_persistence::persistence_remove, - &iclient_persistence::persistence_keys, - &iclient_persistence::persistence_clear, - &iclient_persistence::persistence_containskey - }); - - rc = MQTTAsync_createWithOptions(&cli_, serverURI.c_str(), clientId.c_str(), - MQTTCLIENT_PERSISTENCE_USER, persist_.get(), - &opts.opts_); - } - if (rc != 0) - throw exception(rc); } - async_client::async_client(const string& serverURI, const string& clientId, const create_options& opts, - const string& persistDir) + const string& persistDir, + ipersistence_encoder* encoder /*=nullptr*/) : serverURI_(serverURI), clientId_(clientId), - mqttVersion_(opts.opts_.MQTTVersion), - persist_(nullptr), userCallback_(nullptr) + mqttVersion_(opts.opts_.MQTTVersion), userCallback_(nullptr) { int rc = MQTTAsync_createWithOptions(&cli_, serverURI.c_str(), clientId.c_str(), MQTTCLIENT_PERSISTENCE_DEFAULT, const_cast(persistDir.c_str()), const_cast(&opts.opts_)); - if (rc != 0) throw exception(rc); + + persistence_encoder(encoder); } async_client::async_client(const string& serverURI, const string& clientId, const create_options& opts, - iclient_persistence* persistence /*=nullptr*/) + iclient_persistence* persistence /*=nullptr*/, + ipersistence_encoder* encoder /*=nullptr*/) : serverURI_(serverURI), clientId_(clientId), - mqttVersion_(opts.opts_.MQTTVersion), - persist_(nullptr), userCallback_(nullptr) + mqttVersion_(opts.opts_.MQTTVersion), userCallback_(nullptr) { int rc = MQTTASYNC_SUCCESS; @@ -154,6 +123,9 @@ async_client::async_client(const string& serverURI, const string& clientId, rc = MQTTAsync_createWithOptions(&cli_, serverURI.c_str(), clientId.c_str(), MQTTCLIENT_PERSISTENCE_USER, persist_.get(), const_cast(&opts.opts_)); + + if (rc == 0) + persistence_encoder(encoder); } if (rc != 0) throw exception(rc); @@ -350,6 +322,13 @@ void async_client::remove_token(token* tok) } } +void async_client::persistence_encoder(ipersistence_encoder* encoder) +{ + if (encoder && cli_) { + MQTTAsync_setBeforePersistenceWrite(cli_, encoder, &ipersistence_encoder::before_write); + MQTTAsync_setAfterPersistenceRead(cli_, encoder, &ipersistence_encoder::after_read); + } +} // -------------------------------------------------------------------------- // Callback management diff --git a/src/create_options.cpp b/src/create_options.cpp index 45747f3..9506fd5 100644 --- a/src/create_options.cpp +++ b/src/create_options.cpp @@ -24,6 +24,15 @@ namespace mqtt { const MQTTAsync_createOptions create_options::DFLT_C_STRUCT = MQTTAsync_createOptions_initializer5; + +create_options::create_options(int maxBufferedMessages) : create_options() +{ + if (maxBufferedMessages != 0) { + opts_.sendWhileDisconnected = to_int(true); + opts_.maxBufferedMessages = maxBufferedMessages; + } +} + ///////////////////////////////////////////////////////////////////////////// } // end namespace mqtt diff --git a/src/iclient_persistence.cpp b/src/iclient_persistence.cpp index 95465a3..057b4d7 100644 --- a/src/iclient_persistence.cpp +++ b/src/iclient_persistence.cpp @@ -145,6 +145,58 @@ int iclient_persistence::persistence_containskey(void* handle, char* key) return MQTTCLIENT_PERSISTENCE_ERROR; } +///////////////////////////////////////////////////////////////////////////// +// Encoder + +int ipersistence_encoder::before_write(void* context, int nbuf, char* bufs[], int buflens[]) +{ + try { + if (context && nbuf > 0 && bufs && buflens) { + std::vector vec; + auto n = size_t(nbuf); + vec.reserve(n); + + for (size_t i=0; i(context)->encode(&vec[0], n); + + for (size_t i=0; i(vec[i].data()); + } + buflens[i] = vec[i].size(); + } + return MQTTASYNC_SUCCESS; + } + } + catch (...) {} + + return MQTTCLIENT_PERSISTENCE_ERROR; +} + +int ipersistence_encoder::after_read(void* context, char** buf, int* buflen) +{ + try { + if (context && buf && *buf && buflen && *buflen > 0) { + string_view sv(*buf, *buflen); + + static_cast(context)->decode(sv); + + if (*buf != sv.data()) { + MQTTAsync_free(*buf); + *buf = const_cast(sv.data()); + } + *buflen = sv.size(); + return MQTTASYNC_SUCCESS; + } + } + catch (...) {} + + return MQTTCLIENT_PERSISTENCE_ERROR; +} + ///////////////////////////////////////////////////////////////////////////// // end namespace mqtt } diff --git a/src/mqtt/async_client.h b/src/mqtt/async_client.h index 87c0218..79f323f 100644 --- a/src/mqtt/async_client.h +++ b/src/mqtt/async_client.h @@ -154,6 +154,9 @@ private: throw exception(rc); } + /** Installs a persistence encoder/decoder */ + void persistence_encoder(ipersistence_encoder* encoder); + public: /** * Create an async_client that can be used to communicate with an MQTT @@ -164,10 +167,13 @@ public: * @param clientId a client identifier that is unique on the server * being connected to * @param persistDir The directory to use for persistence data + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, - const string& persistDir); + const string& persistDir, + ipersistence_encoder* encoder=nullptr); + /** * Create an async_client that can be used to communicate with an MQTT * server. @@ -179,10 +185,12 @@ public: * being connected to * @param persistence The user persistence structure. If this is null, * then no persistence is used. + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, - iclient_persistence* persistence=nullptr); + iclient_persistence* persistence=nullptr, + ipersistence_encoder* encoder=nullptr); /** * Create an async_client that can be used to communicate with an MQTT * server, which allows for off-line message buffering. @@ -194,10 +202,12 @@ public: * @param maxBufferedMessages the maximum number of messages allowed to * be buffered while not connected * @param persistDir The directory to use for persistence data + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, - int maxBufferedMessages, const string& persistDir); + int maxBufferedMessages, const string& persistDir, + ipersistence_encoder* encoder=nullptr); /** * Create an async_client that can be used to communicate with an MQTT * server, which allows for off-line message buffering. @@ -211,11 +221,13 @@ public: * be buffered while not connected * @param persistence The user persistence structure. If this is null, * then no persistence is used. + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, - int maxBufferedMessages, iclient_persistence* persistence=nullptr); - + int maxBufferedMessages, + iclient_persistence* persistence=nullptr, + ipersistence_encoder* encoder=nullptr); /** * Create an async_client that can be used to communicate with an MQTT * server, which allows for off-line message buffering. @@ -226,10 +238,12 @@ public: * being connected to * @param opts The create options * @param persistDir The directory to use for persistence data + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, - const create_options& opts, const string& persistDir); + const create_options& opts, const string& persistDir, + ipersistence_encoder* encoder=nullptr); /** * Create an async_client that can be used to communicate with an MQTT * server, which allows for off-line message buffering. @@ -243,11 +257,13 @@ public: * be buffered while not connected * @param persistence The user persistence structure. If this is null, * then no persistence is used. + * @param encoder An object to encode and decode the persistence data. * @throw exception if an argument is invalid */ async_client(const string& serverURI, const string& clientId, const create_options& opts, - iclient_persistence* persistence=nullptr); + iclient_persistence* persistence=nullptr, + ipersistence_encoder* encoder=nullptr); /** * Destructor */ diff --git a/src/mqtt/create_options.h b/src/mqtt/create_options.h index 9d61024..3a5743a 100644 --- a/src/mqtt/create_options.h +++ b/src/mqtt/create_options.h @@ -56,6 +56,12 @@ public: * Default set of client create options. */ create_options() : opts_(DFLT_C_STRUCT) {} + /** + * Default create options, but with off-line buffering enabled. + * @param maxBufferedMessages the maximum number of messages allowed to + * be buffered while not connected + */ + explicit create_options(int maxBufferedMessages); /** * Gets whether the client will accept message to publish while * disconnected. diff --git a/src/mqtt/iclient_persistence.h b/src/mqtt/iclient_persistence.h index 417ff63..1fa25b9 100644 --- a/src/mqtt/iclient_persistence.h +++ b/src/mqtt/iclient_persistence.h @@ -32,6 +32,24 @@ namespace mqtt { +/** + * Allocate memory for use with user persistence. + * + * @param n The number of bytes for the buffer. + * @return A pointer to the allocated memory + */ +inline char* persistence_malloc(size_t n) { + return static_cast(MQTTAsync_malloc(n)); +} + +/** + * Frees memory allocated with @ref persistence_malloc + * @param p Pointer to a buffer obtained by persistence_malloc. + */ +inline void persistence_free(char* p) { + MQTTAsync_free(p); +} + ///////////////////////////////////////////////////////////////////////////// /** @@ -128,8 +146,56 @@ using iclient_persistence_ptr = iclient_persistence::ptr_t; /** Smart/shared pointer to a persistence client */ using const_iclient_persistence_ptr = iclient_persistence::const_ptr_t; +///////////////////////////////////////////////////////////////////////////// + +/** + * Interface for objects to encode and decode data going into the + * persistence store. + * + * This is typically used to encrypt the data before writing to + * persistence, and then decrypt it when reading it back from persistence. + */ +class ipersistence_encoder +{ + friend class async_client; + + /** Callbacks from the C library */ + static int before_write(void* context, int bufcount, char* buffers[], int buflens[]); + static int after_read(void* context, char** buffer, int* buflen); + +public: + /** + * Virtual destructor. + */ + virtual ~ipersistence_encoder() {} + /** + * Callback to let the application encode data before writing it to + * persistence. + * + * This is called just prior to writing the data to persistence. If a + * buffer needs to grow, the application can call @ref + * persistence_malloc to get a new buffer, and then update the pointer + * and side of the buffer. It *should not* free the old buffer. That is + * done automatically. + * + * @param bufs The data buffers that need to be encoded. + * @param n The number of buffers + */ + virtual void encode(string_view bufs[], size_t n) =0; + /** + * Callback to let the application decode data after it is retrieved + * from persistence. + * + * @param buffers The data buffers that need to be decoded. + * @param n The number of buffers + */ + virtual void decode(string_view& buf) =0; +}; + + ///////////////////////////////////////////////////////////////////////////// // end namespace mqtt } #endif // __mqtt_iclient_persistence_h + diff --git a/src/samples/data_publish.cpp b/src/samples/data_publish.cpp index aece8bd..05c5764 100644 --- a/src/samples/data_publish.cpp +++ b/src/samples/data_publish.cpp @@ -60,6 +60,7 @@ using namespace std; using namespace std::chrono; const std::string DFLT_ADDRESS { "tcp://localhost:1883" }; +const std::string CLIENT_ID { "paho-cpp-data-publish" }; const string TOPIC { "data/rand" }; const int QOS = 1; @@ -72,11 +73,45 @@ const string PERSIST_DIR { "data-persist" }; ///////////////////////////////////////////////////////////////////////////// +class persistence_encoder : virtual public mqtt::ipersistence_encoder +{ + /** + * Callback to let the application encode data before writing it to + * persistence. + */ + void encode(mqtt::string_view bufs[], size_t n) override { + cout << "Encoding " << n << " buffers" << endl; + auto sz = bufs[0].size(); + auto buf = mqtt::persistence_malloc(sz+6); + strcpy(buf, "bubba"); + memcpy(buf+6, bufs[0].data(), sz); + bufs[0] = mqtt::string_view(buf, n+6); + } + /** + * Callback to let the application decode data after it is retrieved + * from persistence. + * + * @param buffers The data buffers that need to be decoded. + * @param n The number of buffers + */ + void decode(mqtt::string_view& buf) override { + cout << "Decoding buffer: " << buf.data() << endl; + auto n = buf.size(); + auto newBuf = mqtt::persistence_malloc(n-6); + memcpy(newBuf, buf.data(), n-6); + buf = mqtt::string_view(newBuf, n-6); + } +}; + +///////////////////////////////////////////////////////////////////////////// + int main(int argc, char* argv[]) { string address = (argc > 1) ? string(argv[1]) : DFLT_ADDRESS; - mqtt::async_client cli(address, "", MAX_BUFFERED_MSGS, PERSIST_DIR); + persistence_encoder encoder; + mqtt::async_client cli(address, CLIENT_ID, MAX_BUFFERED_MSGS, + PERSIST_DIR, &encoder); mqtt::connect_options connOpts; connOpts.set_keep_alive_interval(MAX_BUFFERED_MSGS * PERIOD);