Botan  1.11.15
src/lib/tls/tls_handshake_io.cpp
Go to the documentation of this file.
00001 /*
00002 * TLS Handshake IO
00003 * (C) 2012,2014 Jack Lloyd
00004 *
00005 * Botan is released under the Simplified BSD License (see license.txt)
00006 */
00007 
00008 #include <botan/internal/tls_handshake_io.h>
00009 #include <botan/internal/tls_messages.h>
00010 #include <botan/internal/tls_record.h>
00011 #include <botan/internal/tls_seq_numbers.h>
00012 #include <botan/exceptn.h>
00013 #include <chrono>
00014 
00015 namespace Botan {
00016 
00017 namespace TLS {
00018 
00019 namespace {
00020 
00021 inline size_t load_be24(const byte q[3])
00022    {
00023    return make_u32bit(0,
00024                       q[0],
00025                       q[1],
00026                       q[2]);
00027    }
00028 
00029 void store_be24(byte out[3], size_t val)
00030    {
00031    out[0] = get_byte<u32bit>(1, val);
00032    out[1] = get_byte<u32bit>(2, val);
00033    out[2] = get_byte<u32bit>(3, val);
00034    }
00035 
00036 }
00037 
00038 Protocol_Version Stream_Handshake_IO::initial_record_version() const
00039    {
00040    return Protocol_Version::TLS_V10;
00041    }
00042 
00043 void Stream_Handshake_IO::add_record(const std::vector<byte>& record,
00044                                      Record_Type record_type, u64bit)
00045    {
00046    if(record_type == HANDSHAKE)
00047       {
00048       m_queue.insert(m_queue.end(), record.begin(), record.end());
00049       }
00050    else if(record_type == CHANGE_CIPHER_SPEC)
00051       {
00052       if(record.size() != 1 || record[0] != 1)
00053          throw Decoding_Error("Invalid ChangeCipherSpec");
00054 
00055       // Pretend it's a regular handshake message of zero length
00056       const byte ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
00057       m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
00058       }
00059    else
00060       throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
00061    }
00062 
00063 std::pair<Handshake_Type, std::vector<byte>>
00064 Stream_Handshake_IO::get_next_record(bool)
00065    {
00066    if(m_queue.size() >= 4)
00067       {
00068       const size_t length = make_u32bit(0, m_queue[1], m_queue[2], m_queue[3]);
00069 
00070       if(m_queue.size() >= length + 4)
00071          {
00072          Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
00073 
00074          std::vector<byte> contents(m_queue.begin() + 4,
00075                                     m_queue.begin() + 4 + length);
00076 
00077          m_queue.erase(m_queue.begin(), m_queue.begin() + 4 + length);
00078 
00079          return std::make_pair(type, contents);
00080          }
00081       }
00082 
00083    return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
00084    }
00085 
00086 std::vector<byte>
00087 Stream_Handshake_IO::format(const std::vector<byte>& msg,
00088                             Handshake_Type type) const
00089    {
00090    std::vector<byte> send_buf(4 + msg.size());
00091 
00092    const size_t buf_size = msg.size();
00093 
00094    send_buf[0] = type;
00095 
00096    store_be24(&send_buf[1], buf_size);
00097 
00098    copy_mem(&send_buf[4], &msg[0], msg.size());
00099 
00100    return send_buf;
00101    }
00102 
00103 std::vector<byte> Stream_Handshake_IO::send(const Handshake_Message& msg)
00104    {
00105    const std::vector<byte> msg_bits = msg.serialize();
00106 
00107    if(msg.type() == HANDSHAKE_CCS)
00108       {
00109       m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
00110       return std::vector<byte>(); // not included in handshake hashes
00111       }
00112 
00113    const std::vector<byte> buf = format(msg_bits, msg.type());
00114    m_send_hs(HANDSHAKE, buf);
00115    return buf;
00116    }
00117 
00118 Protocol_Version Datagram_Handshake_IO::initial_record_version() const
00119    {
00120    return Protocol_Version::DTLS_V10;
00121    }
00122 
00123 namespace {
00124 
00125 // 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1
00126 const u64bit INITIAL_TIMEOUT = 1*1000;
00127 const u64bit MAXIMUM_TIMEOUT = 60*1000;
00128 
00129 u64bit steady_clock_ms()
00130    {
00131    return std::chrono::duration_cast<std::chrono::milliseconds>(
00132       std::chrono::steady_clock::now().time_since_epoch()).count();
00133    }
00134 
00135 }
00136 
00137 bool Datagram_Handshake_IO::timeout_check()
00138    {
00139    if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
00140       {
00141       /*
00142       If we haven't written anything yet obviously no timeout.
00143       Also no timeout possible if we are mid-flight,
00144       */
00145       return false;
00146       }
00147 
00148    const u64bit ms_since_write = steady_clock_ms() - m_last_write;
00149 
00150    if(ms_since_write < m_next_timeout)
00151       return false;
00152 
00153    std::vector<u16bit> flight;
00154    if(m_flights.size() == 1)
00155       flight = m_flights.at(0); // lost initial client hello
00156    else
00157       flight = m_flights.at(m_flights.size() - 2);
00158 
00159    BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
00160 
00161    u16bit epoch = m_flight_data[flight[0]].epoch;
00162 
00163    for(auto msg_seq : flight)
00164       {
00165       auto& msg = m_flight_data[msg_seq];
00166 
00167       if(msg.epoch != epoch)
00168          {
00169          // Epoch gap: insert the CCS
00170          std::vector<byte> ccs(1, 1);
00171          m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
00172          }
00173 
00174       send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
00175       epoch = msg.epoch;
00176       }
00177 
00178    m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT);
00179    return true;
00180    }
00181 
00182 void Datagram_Handshake_IO::add_record(const std::vector<byte>& record,
00183                                        Record_Type record_type,
00184                                        u64bit record_sequence)
00185    {
00186    const u16bit epoch = static_cast<u16bit>(record_sequence >> 48);
00187 
00188    if(record_type == CHANGE_CIPHER_SPEC)
00189       {
00190       // TODO: check this is otherwise empty
00191       m_ccs_epochs.insert(epoch);
00192       return;
00193       }
00194 
00195    const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
00196 
00197    const byte* record_bits = &record[0];
00198    size_t record_size = record.size();
00199 
00200    while(record_size)
00201       {
00202       if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
00203          return; // completely bogus? at least degenerate/weird
00204 
00205       const byte msg_type = record_bits[0];
00206       const size_t msg_len = load_be24(&record_bits[1]);
00207       const u16bit message_seq = load_be<u16bit>(&record_bits[4], 0);
00208       const size_t fragment_offset = load_be24(&record_bits[6]);
00209       const size_t fragment_length = load_be24(&record_bits[9]);
00210 
00211       const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
00212 
00213       if(record_size < total_size)
00214          throw Decoding_Error("Bad lengths in DTLS header");
00215 
00216       if(message_seq >= m_in_message_seq)
00217          {
00218          m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
00219                                               fragment_length,
00220                                               fragment_offset,
00221                                               epoch,
00222                                               msg_type,
00223                                               msg_len);
00224          }
00225       else
00226          {
00227          // TODO: detect retransmitted flight
00228          }
00229 
00230       record_bits += total_size;
00231       record_size -= total_size;
00232       }
00233    }
00234 
00235 std::pair<Handshake_Type, std::vector<byte>>
00236 Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
00237    {
00238    // Expecting a message means the last flight is concluded
00239    if(!m_flights.rbegin()->empty())
00240       m_flights.push_back(std::vector<u16bit>());
00241 
00242    if(expecting_ccs)
00243       {
00244       if(!m_messages.empty())
00245          {
00246          const u16bit current_epoch = m_messages.begin()->second.epoch();
00247 
00248          if(m_ccs_epochs.count(current_epoch))
00249             return std::make_pair(HANDSHAKE_CCS, std::vector<byte>());
00250          }
00251 
00252       return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
00253       }
00254 
00255    auto i = m_messages.find(m_in_message_seq);
00256 
00257    if(i == m_messages.end() || !i->second.complete())
00258       return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
00259 
00260    m_in_message_seq += 1;
00261 
00262    return i->second.message();
00263    }
00264 
00265 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
00266    const byte fragment[],
00267    size_t fragment_length,
00268    size_t fragment_offset,
00269    u16bit epoch,
00270    byte msg_type,
00271    size_t msg_length)
00272    {
00273    if(complete())
00274       return; // already have entire message, ignore this
00275 
00276    if(m_msg_type == HANDSHAKE_NONE)
00277       {
00278       m_epoch = epoch;
00279       m_msg_type = msg_type;
00280       m_msg_length = msg_length;
00281       }
00282 
00283    if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
00284       throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
00285 
00286    if(fragment_offset > m_msg_length)
00287       throw Decoding_Error("Fragment offset past end of message");
00288 
00289    if(fragment_offset + fragment_length > m_msg_length)
00290       throw Decoding_Error("Fragment overlaps past end of message");
00291 
00292    if(fragment_offset == 0 && fragment_length == m_msg_length)
00293       {
00294       m_fragments.clear();
00295       m_message.assign(fragment, fragment+fragment_length);
00296       }
00297    else
00298       {
00299       /*
00300       * FIXME. This is a pretty lame way to do defragmentation, huge
00301       * overhead with a tree node per byte.
00302       *
00303       * Also should confirm that all overlaps have no changes,
00304       * otherwise we expose ourselves to the classic fingerprinting
00305       * and IDS evasion attacks on IP fragmentation.
00306       */
00307       for(size_t i = 0; i != fragment_length; ++i)
00308          m_fragments[fragment_offset+i] = fragment[i];
00309 
00310       if(m_fragments.size() == m_msg_length)
00311          {
00312          m_message.resize(m_msg_length);
00313          for(size_t i = 0; i != m_msg_length; ++i)
00314             m_message[i] = m_fragments[i];
00315          m_fragments.clear();
00316          }
00317       }
00318    }
00319 
00320 bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
00321    {
00322    return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
00323    }
00324 
00325 std::pair<Handshake_Type, std::vector<byte>>
00326 Datagram_Handshake_IO::Handshake_Reassembly::message() const
00327    {
00328    if(!complete())
00329       throw Internal_Error("Datagram_Handshake_IO - message not complete");
00330 
00331    return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
00332    }
00333 
00334 std::vector<byte>
00335 Datagram_Handshake_IO::format_fragment(const byte fragment[],
00336                                        size_t frag_len,
00337                                        u16bit frag_offset,
00338                                        u16bit msg_len,
00339                                        Handshake_Type type,
00340                                        u16bit msg_sequence) const
00341    {
00342    std::vector<byte> send_buf(12 + frag_len);
00343 
00344    send_buf[0] = type;
00345 
00346    store_be24(&send_buf[1], msg_len);
00347 
00348    store_be(msg_sequence, &send_buf[4]);
00349 
00350    store_be24(&send_buf[6], frag_offset);
00351    store_be24(&send_buf[9], frag_len);
00352 
00353    copy_mem(&send_buf[12], &fragment[0], frag_len);
00354 
00355    return send_buf;
00356    }
00357 
00358 std::vector<byte>
00359 Datagram_Handshake_IO::format_w_seq(const std::vector<byte>& msg,
00360                                     Handshake_Type type,
00361                                     u16bit msg_sequence) const
00362    {
00363    return format_fragment(&msg[0], msg.size(), 0, msg.size(), type, msg_sequence);
00364    }
00365 
00366 std::vector<byte>
00367 Datagram_Handshake_IO::format(const std::vector<byte>& msg,
00368                               Handshake_Type type) const
00369    {
00370    return format_w_seq(msg, type, m_in_message_seq - 1);
00371    }
00372 
00373 namespace {
00374 
00375 size_t split_for_mtu(size_t mtu, size_t msg_size)
00376    {
00377    const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
00378 
00379    const size_t parts = (msg_size + mtu) / mtu;
00380 
00381    if(parts + DTLS_HEADERS_SIZE > mtu)
00382       return parts + 1;
00383 
00384    return parts;
00385    }
00386 
00387 }
00388 
00389 std::vector<byte>
00390 Datagram_Handshake_IO::send(const Handshake_Message& msg)
00391    {
00392    const std::vector<byte> msg_bits = msg.serialize();
00393    const u16bit epoch = m_seqs.current_write_epoch();
00394    const Handshake_Type msg_type = msg.type();
00395 
00396    if(msg_type == HANDSHAKE_CCS)
00397       {
00398       m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
00399       return std::vector<byte>(); // not included in handshake hashes
00400       }
00401 
00402    // Note: not saving CCS, instead we know it was there due to change in epoch
00403    m_flights.rbegin()->push_back(m_out_message_seq);
00404    m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
00405 
00406    m_out_message_seq += 1;
00407    m_last_write = steady_clock_ms();
00408    m_next_timeout = INITIAL_TIMEOUT;
00409 
00410    return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
00411    }
00412 
00413 std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq,
00414                                                       u16bit epoch,
00415                                                       Handshake_Type msg_type,
00416                                                       const std::vector<byte>& msg_bits)
00417    {
00418    const std::vector<byte> no_fragment =
00419       format_w_seq(msg_bits, msg_type, msg_seq);
00420 
00421    if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
00422       m_send_hs(epoch, HANDSHAKE, no_fragment);
00423    else
00424       {
00425       const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
00426 
00427       const size_t parts_size = (msg_bits.size() + parts) / parts;
00428 
00429       size_t frag_offset = 0;
00430 
00431       while(frag_offset != msg_bits.size())
00432          {
00433          const size_t frag_len =
00434             std::min<size_t>(msg_bits.size() - frag_offset,
00435                              parts_size);
00436 
00437          m_send_hs(epoch,
00438                    HANDSHAKE,
00439                    format_fragment(&msg_bits[frag_offset],
00440                                    frag_len,
00441                                    frag_offset,
00442                                    msg_bits.size(),
00443                                    msg_type,
00444                                    msg_seq));
00445 
00446          frag_offset += frag_len;
00447          }
00448       }
00449 
00450    return no_fragment;
00451    }
00452 
00453 }
00454 }