//------------------------------------------------------------------------------
// File: HostResolver.cc
// Author: Georgios Bitzes - CERN
//------------------------------------------------------------------------------
/************************************************************************
* qclient - A simple redis C++ client with support for redirects *
* Copyright (C) 2019 CERN/Switzerland *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see .*
************************************************************************/
#include "qclient/network/HostResolver.hh"
#include "qclient/GlobalInterceptor.hh"
#include "qclient/Logger.hh"
#include "qclient/Status.hh"
#include "qclient/GlobalInterceptor.hh"
#include
#include
#define SSTR(message) static_cast(std::ostringstream().flush() << message).str()
namespace qclient {
//------------------------------------------------------------------------------
// Protocol type as string
//------------------------------------------------------------------------------
std::string protocolTypeToString(ProtocolType prot) {
switch(prot) {
case ProtocolType::kIPv4: {
return "IPv4";
}
case ProtocolType::kIPv6: {
return "IPv6";
}
}
return "unknown protocol";
}
//------------------------------------------------------------------------------
// Socket type as string
//------------------------------------------------------------------------------
std::string socketTypeToString(SocketType sock) {
switch(sock) {
case SocketType::kStream: {
return "stream";
}
case SocketType::kDatagram: {
return "datagram";
}
}
return "unknown socket";
}
//------------------------------------------------------------------------------
// Empty constructor
//------------------------------------------------------------------------------
ServiceEndpoint::ServiceEndpoint() {}
//------------------------------------------------------------------------------
// Constructor
//------------------------------------------------------------------------------
ServiceEndpoint::ServiceEndpoint(ProtocolType protocol, SocketType socket,
const std::vector addr, const std::string &original)
: protocolType(protocol), socketType(socket), address(addr),
originalHostname(original) { }
//------------------------------------------------------------------------------
// Constructor, taking the IP address as text and a port, not sockaddr bytes
//------------------------------------------------------------------------------
ServiceEndpoint::ServiceEndpoint(ProtocolType protocol, SocketType socket,
const std::string &addr, int port, const std::string &original)
: protocolType(protocol), socketType(socket), originalHostname(original) {
if(protocolType == ProtocolType::kIPv4) {
struct sockaddr_in out;
memset(&out, 0, sizeof(struct sockaddr_in));
out.sin_family = AF_INET;
out.sin_port = htons(port);
inet_pton(AF_INET, addr.c_str(), &(out.sin_addr));
address.resize(sizeof(struct sockaddr_in));
memcpy(address.data(), &out, sizeof(struct sockaddr_in));
}
else if(protocolType == ProtocolType::kIPv6) {
struct sockaddr_in6 out;
memset(&out, 0, sizeof(sockaddr_in6));
out.sin6_family = AF_INET6;
out.sin6_port = htons(port);
inet_pton(AF_INET6, addr.c_str(), &(out.sin6_addr));
address.resize(sizeof(struct sockaddr_in6));
memcpy(address.data(), &out, sizeof(struct sockaddr_in6));
}
}
//------------------------------------------------------------------------------
// Get stored protocol type
//------------------------------------------------------------------------------
ProtocolType ServiceEndpoint::getProtocolType() const {
return protocolType;
}
//------------------------------------------------------------------------------
// Get socket type
//------------------------------------------------------------------------------
SocketType ServiceEndpoint::getSocketType() const {
return socketType;
}
//------------------------------------------------------------------------------
// Get raw address bytes (the form ::connect expects)
//------------------------------------------------------------------------------
const std::vector& ServiceEndpoint::getAddressBytes() const {
return address;
}
//----------------------------------------------------------------------------
// Get printable address (ie 127.0.0.1) that a human would expect
//----------------------------------------------------------------------------
std::string ServiceEndpoint::getPrintableAddress() const {
char buffer[INET6_ADDRSTRLEN];
switch(protocolType) {
case ProtocolType::kIPv4: {
const struct sockaddr_in* sockaddr = (const struct sockaddr_in*)(address.data());
inet_ntop(AF_INET, &(sockaddr->sin_addr), buffer, INET6_ADDRSTRLEN);
break;
}
case ProtocolType::kIPv6: {
const struct sockaddr_in6* sockaddr = (const struct sockaddr_in6*)(address.data());
inet_ntop(AF_INET6, &(sockaddr->sin6_addr), buffer, INET6_ADDRSTRLEN);
break;
}
}
return buffer;
}
//------------------------------------------------------------------------------
// Get service port number
//------------------------------------------------------------------------------
uint16_t ServiceEndpoint::getPort() const {
switch(protocolType) {
case ProtocolType::kIPv4: {
const struct sockaddr_in* sockaddr = (const struct sockaddr_in*)(address.data());
return htons(sockaddr->sin_port);
}
case ProtocolType::kIPv6: {
const struct sockaddr_in6* sockaddr = (const struct sockaddr_in6*)(address.data());
return ntohs(sockaddr->sin6_port);
}
}
return 0; // should never happen
}
//------------------------------------------------------------------------------
// Describe as a string
//----------------------------------------------- ------------------------------
std::string ServiceEndpoint::getString() const {
std::ostringstream ss;
ss << "[" << getPrintableAddress() << "]" << ":" << getPort() << " (" << protocolTypeToString(protocolType) << "," <<
socketTypeToString(socketType) << " resolved from " << originalHostname << ")";
return ss.str();
}
//----------------------------------------------------------------------------
// Get ai_family to pass to ::connect
//----------------------------------------------------------------------------
int ServiceEndpoint::getAiFamily() const {
switch(protocolType) {
case ProtocolType::kIPv4: {
return AF_INET;
}
case ProtocolType::kIPv6: {
return AF_INET6;
}
}
return 0;
}
//------------------------------------------------------------------------------
// Get ai_socktype to pass to ::socket
//------------------------------------------------------------------------------
int ServiceEndpoint::getAiSocktype() const {
switch(socketType) {
case SocketType::kStream: {
return SOCK_STREAM;
}
case SocketType::kDatagram: {
return SOCK_DGRAM;
}
}
return 0;
}
//------------------------------------------------------------------------------
// Get ai_protocol to pass to ::socket
//------------------------------------------------------------------------------
int ServiceEndpoint::getAiProtocol() const {
switch(socketType) {
case SocketType::kStream: {
return IPPROTO_TCP;
}
case SocketType::kDatagram: {
return IPPROTO_UDP;
}
}
return 0;
}
//------------------------------------------------------------------------------
// Recover original hostname, the one we passed to HostResolver
//------------------------------------------------------------------------------
std::string ServiceEndpoint::getOriginalHostname() const {
return originalHostname;
}
//------------------------------------------------------------------------------
// Equality operator
//------------------------------------------------------------------------------
bool ServiceEndpoint::operator==(const ServiceEndpoint& other) const {
return protocolType == other.protocolType &&
socketType == other.socketType &&
address == other.address &&
originalHostname == other.originalHostname;
}
//------------------------------------------------------------------------------
// Constructor
//------------------------------------------------------------------------------
HostResolver::HostResolver(Logger *log) : logger(log) { }
//------------------------------------------------------------------------------
// Resolve, while taking into account intercepts as well
//------------------------------------------------------------------------------
std::vector HostResolver::resolve(const std::string &host, int port, Status &st) {
Endpoint translated = GlobalInterceptor::translate(Endpoint(host, port));
return resolveNoIntercept(translated.getHost(), translated.getPort(), st);
}
//------------------------------------------------------------------------------
// Main resolve function: How many service endpoints match the given
// hostname and port pair?
//------------------------------------------------------------------------------
std::vector HostResolver::resolveNoIntercept(const std::string &host, int port, Status &st) {
if(!fakeMap.empty()) {
return resolveFake(host, port, st);
}
std::vector output;
struct addrinfo hints, *servinfo, *p;
int rv;
memset(&hints, 0, sizeof hints);
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
if ((rv = getaddrinfo(host.c_str(), std::to_string(port).c_str(),
&hints, &servinfo)) != 0) {
st = Status(rv, SSTR("error when resolving '" << host << "': " << gai_strerror(rv)));
return output;
}
//----------------------------------------------------------------------------
// getaddrinfo was successful: Loop through all results, build list of
// service endpoints
//----------------------------------------------------------------------------
for (p = servinfo; p != NULL; p = p->ai_next) {
std::vector addr(p->ai_addrlen);
memcpy(addr.data(), p->ai_addr, addr.size());
ProtocolType protocolType = ProtocolType::kIPv4;
if(p->ai_family == AF_INET) {
protocolType = ProtocolType::kIPv4;
}
else if(p->ai_family == AF_INET6) {
protocolType = ProtocolType::kIPv6;
}
else {
QCLIENT_LOG(logger, LogLevel::kWarn, "Encountered unknown network family during resolution of " << host << ":" << port << " - neither IPv4, nor IPv6!");
continue;
}
SocketType socketType = SocketType::kStream;
if(p->ai_socktype == SOCK_STREAM) {
socketType = SocketType::kStream;
}
else if(p->ai_socktype == SOCK_DGRAM) {
socketType = SocketType::kDatagram;
}
else {
QCLIENT_LOG(logger, LogLevel::kWarn, "Encountered unknown socket type during resolution of " << host << ":" << port << " - neither stream, nor datagram!");
continue;
}
output.emplace_back(protocolType, socketType, addr, host);
}
freeaddrinfo(servinfo);
st = Status();
return output;
}
//----------------------------------------------------------------------------
// Feed fake data - once you call this, _all_ responses will be faked
//----------------------------------------------------------------------------
void HostResolver::feedFake(const std::string &host, int port, const std::vector &out) {
std::lock_guard lock(mtx);
fakeMap[std::pair(host, port)] = out;
}
//------------------------------------------------------------------------------
// Resolve function that only returns fake data
//------------------------------------------------------------------------------
std::vector HostResolver::resolveFake(const std::string &host, int port,
Status &st) {
std::lock_guard lock(mtx);
auto it = fakeMap.find(std::pair(host, port));
if(it != fakeMap.end()) {
st = Status();
return it->second;
}
st = Status(ENOENT, "Unable to resolve");
return {};
}
}