File MbedTLSClient.cpp
File List > arduino > libraries > common > WiFiClient > MbedTLSClient.cpp
Go to the documentation of this file.
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
#if LT_ARD_HAS_WIFI && LT_HAS_MBEDTLS
#include "MbedTLSClient.h"
#include <WiFi.h>
extern "C" {
#include <mbedtls/debug.h>
#include <mbedtls/net.h>
#include <mbedtls/pk.h>
#include <mbedtls/platform.h>
#include <mbedtls/sha256.h>
#include <mbedtls/ssl.h>
#if LT_HAS_FREERTOS
#include <FreeRTOS.h>
#endif
} // extern "C"
#define _clientKeyC ((mbedtls_pk_context *)_clientKey)
MbedTLSClient::MbedTLSClient() : WiFiClient() {
init(); // ensure the context is zero filled
}
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {
init(); // ensure the context is zero filled
}
MbedTLSClient::~MbedTLSClient() {
LT_VM(CLIENT, "~MbedTLSClient()");
stop();
}
void MbedTLSClient::stop() {
if (!_sslCtx)
return;
LT_VM(SSL, "Stopping SSL");
if (_sslCfg->ca_chain) {
mbedtls_x509_crt_free(_caCert);
}
if (_sslCfg->key_cert) {
mbedtls_x509_crt_free(_clientCert);
mbedtls_pk_free(_clientKeyC);
}
mbedtls_ssl_free(_sslCtx);
mbedtls_ssl_config_free(_sslCfg);
free(_sslCtx);
free(_sslCfg);
free(_caCert);
free(_clientCert);
free(_clientKey);
_sslCtx = NULL;
LT_HEAP_I();
}
void MbedTLSClient::init() {
if (!_sslCtx) {
_sslCtx = (mbedtls_ssl_context *)malloc(sizeof(mbedtls_ssl_context));
_sslCfg = (mbedtls_ssl_config *)malloc(sizeof(mbedtls_ssl_config));
_caCert = (mbedtls_x509_crt *)malloc(sizeof(mbedtls_x509_crt));
_clientCert = (mbedtls_x509_crt *)malloc(sizeof(mbedtls_x509_crt));
_clientKey = (mbedtls_pk_context *)malloc(sizeof(mbedtls_pk_context));
}
// Realtek AmbZ: init platform here to ensure HW crypto is initialized in ssl_init
mbedtls_platform_set_calloc_free(calloc, free);
mbedtls_ssl_init(_sslCtx);
mbedtls_ssl_config_init(_sslCfg);
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
return connect(ipToString(ip).c_str(), port, timeout);
}
int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) {
if (_pskIdentStr && _pskStr)
return connect(host, port, timeout, NULL, NULL, NULL, _pskIdentStr, _pskStr) == 0;
return connect(host, port, timeout, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL) == 0;
}
int MbedTLSClient::connect(
IPAddress ip,
uint16_t port,
const char *rootCABuf,
const char *clientCert,
const char *clientKey
) {
return connect(ipToString(ip).c_str(), port, 0, rootCABuf, clientCert, clientKey, NULL, NULL) == 0;
}
int MbedTLSClient::connect(
const char *host,
uint16_t port,
const char *rootCABuf,
const char *clientCert,
const char *clientKey
) {
return connect(host, port, 0, rootCABuf, clientCert, clientKey, NULL, NULL) == 0;
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) {
return connect(ipToString(ip).c_str(), port, 0, NULL, NULL, NULL, pskIdent, psk) == 0;
}
int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) {
return connect(host, port, 0, NULL, NULL, NULL, pskIdent, psk) == 0;
}
static int ssl_random(void *data, unsigned char *output, size_t len) {
lt_rand_bytes((uint8_t *)output, len);
return 0;
}
void debug_cb(void *ctx, int level, const char *file, int line, const char *str) {
// do not print the trailing \n
uint16_t len = strlen(str);
char *msg = (char *)str;
msg[len - 1] = '\0';
LT_IM(SSL, "%04d: |%d| %s", line, level, msg);
}
int MbedTLSClient::connect(
const char *host,
uint16_t port,
int32_t timeout,
const char *rootCABuf,
const char *clientCert,
const char *clientKey,
const char *pskIdent,
const char *psk
) {
LT_HEAP_I();
if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA)
return -1;
if (timeout <= 0)
timeout = _timeout; // use default when -1 passed as timeout
IPAddress addr = WiFi.hostByName(host);
if (!(uint32_t)addr)
return -1;
int ret = WiFiClient::connect(addr, port, timeout);
if (ret < 0) {
LT_EM(SSL, "SSL socket failed");
return ret;
}
char *uid = "lt-ssl"; // TODO
LT_VM(SSL, "Init SSL");
init();
LT_HEAP_I();
// mbedtls_debug_set_threshold(4);
// mbedtls_ssl_conf_dbg(&_sslCfg, debug_cb, NULL);
ret = mbedtls_ssl_config_defaults(
_sslCfg,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT
);
LT_RET_NZ(ret);
#ifdef MBEDTLS_SSL_ALPN
if (_alpnProtocols) {
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, _alpnProtocols);
LT_RET_NZ(ret);
}
#endif
if (_insecure) {
mbedtls_ssl_conf_authmode(_sslCfg, MBEDTLS_SSL_VERIFY_NONE);
} else if (rootCABuf) {
mbedtls_x509_crt_init(_caCert);
mbedtls_ssl_conf_authmode(_sslCfg, MBEDTLS_SSL_VERIFY_REQUIRED);
ret = mbedtls_x509_crt_parse(_caCert, (const unsigned char *)rootCABuf, strlen(rootCABuf) + 1);
mbedtls_ssl_conf_ca_chain(_sslCfg, _caCert, NULL);
if (ret < 0) {
mbedtls_x509_crt_free(_caCert);
LT_RET(ret);
}
} else if (_useRootCA) {
return -1; // not implemented
} else if (pskIdent && psk) {
#ifdef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED
uint16_t len = strlen(psk);
if ((len & 1) != 0 || len > 2 * MBEDTLS_PSK_MAX_LEN) {
LT_EM(SSL, "PSK length invalid");
return -1;
}
unsigned char pskBin[MBEDTLS_PSK_MAX_LEN] = {};
for (uint8_t i = 0; i < len; i++) {
uint8_t c = psk[i];
c |= 0b00100000; // make lowercase
c -= '0' * (c >= '0' && c <= '9');
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
if (c > 0xf)
return -1;
pskBin[i / 2] |= c << (4 * ((i & 1) ^ 1));
}
ret = mbedtls_ssl_conf_psk(_sslCfg, pskBin, len / 2, (const unsigned char *)pskIdent, strlen(pskIdent));
LT_RET_NZ(ret);
#else
return -1;
#endif
} else {
return -1;
}
if (!_insecure && clientCert && clientKey) {
mbedtls_x509_crt_init(_clientCert);
mbedtls_pk_init(_clientKeyC);
LT_VM(SSL, "Loading client cert");
ret = mbedtls_x509_crt_parse(_clientCert, (const unsigned char *)clientCert, strlen(clientCert) + 1);
if (ret < 0) {
mbedtls_x509_crt_free(_clientCert);
LT_RET(ret);
}
LT_VM(SSL, "Loading private key");
ret = mbedtls_pk_parse_key(_clientKeyC, (const unsigned char *)clientKey, strlen(clientKey) + 1, NULL, 0);
if (ret < 0) {
mbedtls_x509_crt_free(_clientCert);
LT_RET(ret);
}
mbedtls_ssl_conf_own_cert(_sslCfg, _clientCert, _clientKeyC);
}
LT_VM(SSL, "Setting TLS hostname");
ret = mbedtls_ssl_set_hostname(_sslCtx, host);
LT_RET_NZ(ret);
mbedtls_ssl_conf_rng(_sslCfg, ssl_random, NULL);
ret = mbedtls_ssl_setup(_sslCtx, _sslCfg);
LT_RET_NZ(ret);
_sockTls = fd();
mbedtls_ssl_set_bio(_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL);
mbedtls_net_set_nonblock((mbedtls_net_context *)&_sockTls);
LT_HEAP_I();
LT_VM(SSL, "SSL handshake");
if (_handshakeTimeout == 0)
_handshakeTimeout = timeout;
unsigned long start = millis();
while (ret = mbedtls_ssl_handshake(_sslCtx)) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
LT_RET(ret);
}
if ((millis() - start) > _handshakeTimeout) {
LT_EM(SSL, "SSL handshake timeout");
return -1;
}
delay(2);
}
LT_HEAP_I();
if (clientCert && clientKey) {
LT_DM(
SSL,
"Protocol %s, ciphersuite %s",
mbedtls_ssl_get_version(_sslCtx),
mbedtls_ssl_get_ciphersuite(_sslCtx)
);
ret = mbedtls_ssl_get_record_expansion(_sslCtx);
if (ret >= 0)
LT_DM(SSL, "Record expansion: %d", ret);
else {
LT_WM(SSL, "Record expansion unknown");
}
}
LT_VM(SSL, "Verifying certificate");
ret = mbedtls_ssl_get_verify_result(_sslCtx);
if (ret) {
char buf[512];
memset(buf, 0, sizeof(buf));
mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", ret);
LT_EM(SSL, "Failed to verify peer certificate! Verification info: %s", buf);
return ret;
}
if (rootCABuf)
mbedtls_x509_crt_free(_caCert);
if (clientCert)
mbedtls_x509_crt_free(_clientCert);
if (clientKey != NULL)
mbedtls_pk_free(_clientKeyC);
return 0; // OK
}
size_t MbedTLSClient::write(const uint8_t *buf, size_t size) {
int ret = -1;
while ((ret = mbedtls_ssl_write(_sslCtx, buf, size)) <= 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
LT_RET(ret);
}
delay(2);
}
return ret;
}
int MbedTLSClient::available() {
bool peeked = _peeked >= 0;
if (!connected())
return peeked;
int ret = mbedtls_ssl_read(_sslCtx, NULL, 0);
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
stop();
return peeked ? peeked : ret;
}
return mbedtls_ssl_get_bytes_avail(_sslCtx) + peeked;
}
int MbedTLSClient::read(uint8_t *buf, size_t size) {
bool peeked = false;
int toRead = available();
if ((!buf && size) || toRead <= 0)
return -1;
if (!size)
return 0;
if (_peeked >= 0) {
buf[0] = _peeked;
_peeked = -1;
size--;
toRead--;
if (!size || !toRead)
return 1;
buf++;
peeked = true;
}
int ret = mbedtls_ssl_read(_sslCtx, buf, size);
if (ret < 0) {
stop();
return peeked ? peeked : ret;
}
return ret + peeked;
}
int MbedTLSClient::peek() {
if (_peeked >= 0)
return _peeked;
_peeked = timedRead();
return _peeked;
}
void MbedTLSClient::flush() {}
int MbedTLSClient::lastError(char *buf, const size_t size) {
return 0; // TODO (?)
}
void MbedTLSClient::setInsecure() {
_caCertStr = NULL;
_clientCertStr = NULL;
_clientKeyStr = NULL;
_pskIdentStr = NULL;
_pskStr = NULL;
_insecure = true;
}
// TODO only allocate _caCert, _clientCert and _clientKey when one
// of the following functions is used
void MbedTLSClient::setPreSharedKey(const char *pskIdent, const char *psk) {
_pskIdentStr = pskIdent;
_pskStr = psk;
}
void MbedTLSClient::setCACert(const char *rootCA) {
_caCertStr = rootCA;
}
void MbedTLSClient::setCertificate(const char *clientCA) {
_clientCertStr = clientCA;
}
void MbedTLSClient::setPrivateKey(const char *privateKey) {
_clientKeyStr = privateKey;
}
char *streamToStr(Stream &stream, size_t size) {
char *buf = (char *)malloc(size + 1);
if (!buf)
return NULL;
if (size != stream.readBytes(buf, size)) {
free(buf);
return NULL;
}
buf[size] = '\0';
return buf;
}
bool MbedTLSClient::loadCACert(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_caCertStr = str;
return true;
}
return false;
}
bool MbedTLSClient::loadCertificate(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_clientCertStr = str;
return true;
}
return false;
}
bool MbedTLSClient::loadPrivateKey(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_clientKeyStr = str;
return true;
}
return false;
}
bool MbedTLSClient::verify(const char *fingerprint, const char *domainName) {
uint8_t fpLocal[32] = {};
uint16_t len = strlen(fingerprint);
uint8_t byte = 0;
for (uint8_t i = 0; i < len; i++) {
uint8_t c = fingerprint[i];
while ((c == ' ' || c == ':') && i < len) {
c = fingerprint[++i];
}
c |= 0b00100000; // make lowercase
c -= '0' * (c >= '0' && c <= '9');
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
if (c > 0xf)
return -1;
fpLocal[byte / 2] |= c << (4 * ((byte & 1) ^ 1));
byte++;
if (byte >= 64)
break;
}
uint8_t fpRemote[32];
if (!getFingerprintSHA256(fpRemote))
return false;
if (memcmp(fpLocal, fpRemote, 32)) {
LT_DM(SSL, "Fingerprints don't match");
return false;
}
if (!domainName)
return true;
// TODO domain name verification
return true;
}
void MbedTLSClient::setHandshakeTimeout(unsigned long handshakeTimeout) {
_handshakeTimeout = handshakeTimeout * 1000;
}
void MbedTLSClient::setAlpnProtocols(const char **alpnProtocols) {
_alpnProtocols = alpnProtocols;
}
bool MbedTLSClient::getFingerprintSHA256(uint8_t result[32]) {
const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(_sslCtx);
if (!cert) {
LT_EM(SSL, "Failed to get peer certificate");
return false;
}
mbedtls_sha256_context shaCtx;
mbedtls_sha256_init(&shaCtx);
mbedtls_sha256_starts(&shaCtx, false);
mbedtls_sha256_update(&shaCtx, cert->raw.p, cert->raw.len);
mbedtls_sha256_finish(&shaCtx, result);
return true;
}
#endif // LT_ARD_HAS_WIFI && LT_HAS_MBEDTLS