summaryrefslogtreecommitdiffstats
path: root/apt-pkg/contrib/srvrec.cc
blob: 3eb5f1d4c6749d1a762137e53005d4399f0db136 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
// -*- mode: cpp; mode: fold -*-
// Description								/*{{{*/
/* ######################################################################

   SRV record support

   ##################################################################### */
									/*}}}*/
#include <config.h>

#include <netdb.h>

#include <arpa/nameser.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <resolv.h>
#include <time.h>

#include <algorithm>
#include <memory>
#include <tuple>

#include <apt-pkg/configuration.h>
#include <apt-pkg/error.h>
#include <apt-pkg/strutl.h>

#include "srvrec.h"

bool SrvRec::operator==(SrvRec const &other) const
{
   return (std::tie(target, priority, weight, port) ==
           std::tie(other.target, other.priority, other.weight, other.port));
}

bool GetSrvRecords(std::string host, int port, std::vector<SrvRec> &Result)
{
   // try SRV only for hostnames, not for IP addresses
   {
      struct in_addr addr4;
      struct in6_addr addr6;
      if (inet_pton(AF_INET, host.c_str(), &addr4) == 1 ||
	  inet_pton(AF_INET6, host.c_str(), &addr6) == 1)
	 return true;
   }

   std::string target;
   int res;
   struct servent s_ent_buf;
   struct servent *s_ent = nullptr;
   std::vector<char> buf(1024);

   res = getservbyport_r(htons(port), "tcp", &s_ent_buf, buf.data(), buf.size(), &s_ent);
   if (res != 0 || s_ent == nullptr)
      return false;

   strprintf(target, "_%s._tcp.%s", s_ent->s_name, host.c_str());
   return GetSrvRecords(target, Result);
}

bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
{
   unsigned char answer[PACKETSZ];
   int answer_len, compressed_name_len;
   int answer_count;
#if __RES >= 19991006
   struct __res_state res;

   if (res_ninit(&res) != 0)
      return _error->Errno("res_init", "Failed to init resolver");

   // Close on return
   std::shared_ptr<void> guard(&res, res_nclose);

   answer_len = res_nquery(&res, name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
#else
   if (res_init() != 0)
      return _error->Errno("res_init", "Failed to init resolver");

   answer_len = res_query(name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
#endif //__RES >= 19991006
   if (answer_len == -1)
      return false;
   if (answer_len < (int)sizeof(HEADER))
      return _error->Warning("Not enough data from res_query (%i)", answer_len);

   // check the header
   HEADER *header = (HEADER*)answer;
   if (header->rcode != NOERROR)
      return _error->Warning("res_query returned rcode %i", header->rcode);
   answer_count = ntohs(header->ancount);
   if (answer_count <= 0)
      return _error->Warning("res_query returned no answers (%i) ", answer_count);

   // skip the header
   compressed_name_len = dn_skipname(answer+sizeof(HEADER), answer+answer_len);
   if(compressed_name_len < 0)
      return _error->Warning("dn_skipname failed %i", compressed_name_len);

   // pt points to the first answer record, go over all of them now
   unsigned char *pt = answer+sizeof(HEADER)+compressed_name_len+QFIXEDSZ;
   while ((int)Result.size() < answer_count && pt < answer+answer_len)
   {
      u_int16_t type, klass, priority, weight, port, dlen;
      char buf[MAXDNAME];

      compressed_name_len = dn_skipname(pt, answer+answer_len);
      if (compressed_name_len < 0)
         return _error->Warning("dn_skipname failed (2): %i",
                                compressed_name_len);
      pt += compressed_name_len;
      if (((answer+answer_len) - pt) < 16)
         return _error->Warning("packet too short");

      // extract the data out of the result buffer
      #define extract_u16(target, p) target = *p++ << 8; target |= *p++;

      extract_u16(type, pt);
      if(type != T_SRV)
         return _error->Warning("Unexpected type excepted %x != %x",
                                T_SRV, type);
      extract_u16(klass, pt);
      if(klass != C_IN)
         return _error->Warning("Unexpected class excepted %x != %x",
                                C_IN, klass);
      pt += 4;  // ttl
      extract_u16(dlen, pt);
      extract_u16(priority, pt);
      extract_u16(weight, pt);
      extract_u16(port, pt);

      #undef extract_u16

      compressed_name_len = dn_expand(answer, answer+answer_len, pt, buf, sizeof(buf));
      if(compressed_name_len < 0)
         return _error->Warning("dn_expand failed %i", compressed_name_len);
      pt += compressed_name_len;

      // add it to our class
      Result.emplace_back(buf, priority, weight, port);
   }

   // implement load balancing as specified in RFC-2782

   // sort them by priority
   std::stable_sort(Result.begin(), Result.end());

   if (_config->FindB("Debug::Acquire::SrvRecs", false))
      for(auto const &R : Result)
	 std::cerr << "SrvRecs: got " << R.target
		   << " prio: " << R.priority
		   << " weight: " << R.weight
		   << '\n';

   return true;
}

SrvRec PopFromSrvRecs(std::vector<SrvRec> &Recs)
{
   // FIXME: instead of the simplistic shuffle below use the algorithm
   //        described in rfc2782 (with weights)
   //        and figure out how the weights need to be adjusted if
   //        a host refuses connections

#if 0  // all code below is only needed for the weight adjusted selection 
   // assign random number ranges
   int prev_weight = 0;
   int prev_priority = 0;
   for(std::vector<SrvRec>::iterator I = Result.begin();
      I != Result.end(); ++I)
   {
      if(prev_priority != I->priority)
         prev_weight = 0;
      I->random_number_range_start = prev_weight;
      I->random_number_range_end = prev_weight + I->weight;
      prev_weight = I->random_number_range_end;
      prev_priority = I->priority;

      if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
         std::cerr << "SrvRecs: got " << I->target
                   << " prio: " << I->priority
                   << " weight: " << I->weight
                   << std::endl;
   }

   // go over the code in reverse order and note the max random range
   int max = 0;
   prev_priority = 0;
   for(std::vector<SrvRec>::iterator I = Result.end();
      I != Result.begin(); --I)
   {
      if(prev_priority != I->priority)
         max = I->random_number_range_end;
      I->random_number_range_max = max;
   }
#endif

   // shuffle in a very simplistic way for now (equal weights)
   std::vector<SrvRec>::iterator I = Recs.begin();
   std::vector<SrvRec>::iterator const J = std::find_if(Recs.begin(), Recs.end(),
	 [&I](SrvRec const &J) { return I->priority != J.priority; });

   // clock seems random enough.
   I += std::max(static_cast<clock_t>(0), clock()) % std::distance(I, J);
   SrvRec const selected = std::move(*I);
   Recs.erase(I);

   if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
      std::cerr << "PopFromSrvRecs: selecting " << selected.target << std::endl;

   return selected;
}