Fix sending of leftover packets.
[dnstt.git] / dnstt-server / main.go
blob539a4f45be40dbd85b925bf4d02e6d062d4a350e
1 // dnstt-server is the server end of a DNS tunnel.
2 //
3 // Usage:
4 // dnstt-server -gen-key [-privkey-file PRIVKEYFILE] [-pubkey-file PUBKEYFILE]
5 // dnstt-server -udp ADDR [-privkey PRIVKEY|-privkey-file PRIVKEYFILE] DOMAIN UPSTREAMADDR
6 //
7 // Example:
8 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
9 // dnstt-server -udp 127.0.0.1:5300 -privkey-file server.key t.example.com 127.0.0.1:8000
11 // To generate a persistent server private key, first run with the -gen-key
12 // option. By default the generated private and public keys are printed to
13 // standard output. To save them to files instead, use the -privkey-file and
14 // -pubkey-file options.
15 // dnstt-server -gen-key
16 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
18 // You can give the server's private key as a file or as a hex string.
19 // -privkey-file server.key
20 // -privkey 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
22 // The -udp option controls the address that will listen for incoming DNS
23 // queries.
25 // The -mtu option controls the maximum size of response UDP payloads.
26 // Queries that do not advertise requestor support for responses of at least
27 // this size at least this size will be responded to with a FORMERR. The default
28 // value is maxUDPPayload.
30 // DOMAIN is the root of the DNS zone reserved for the tunnel. See README for
31 // instructions on setting it up.
33 // UPSTREAMADDR is the TCP address to which incoming tunnelled streams will be
34 // forwarded.
35 package main
37 import (
38 "bytes"
39 "encoding/base32"
40 "encoding/binary"
41 "flag"
42 "fmt"
43 "io"
44 "io/ioutil"
45 "log"
46 "net"
47 "os"
48 "sync"
49 "time"
51 "github.com/xtaci/kcp-go/v5"
52 "github.com/xtaci/smux"
53 "www.bamsoftware.com/git/dnstt.git/dns"
54 "www.bamsoftware.com/git/dnstt.git/noise"
55 "www.bamsoftware.com/git/dnstt.git/turbotunnel"
58 const (
59 // smux streams will be closed after this much time without receiving data.
60 idleTimeout = 10 * time.Minute
62 // How to set the TTL field in Answer resource records.
63 responseTTL = 60
65 // How long we may wait for downstream data before sending an empty
66 // response. If another query comes in while we are waiting, we'll send
67 // an empty response anyway and restart the delay timer for the next
68 // response.
70 // This number should be less than 2 seconds, which in 2019 was reported
71 // to be the query timeout of the Quad9 DoH server.
72 // https://dnsencryption.info/imc19-doe.html Section 4.2, Finding 2.4
73 maxResponseDelay = 1 * time.Second
76 var (
77 // We don't send UDP payloads larger than this, in an attempt to avoid
78 // network-layer fragmentation. 1280 is the minimum IPv6 MTU, 40 bytes
79 // is the size of an IPv6 header (though without any extension headers),
80 // and 8 bytes is the size of a UDP header.
82 // Control this value with the -mtu command-line option.
84 // https://dnsflagday.net/2020/#message-size-considerations
85 // "An EDNS buffer size of 1232 bytes will avoid fragmentation on nearly
86 // all current networks."
88 // On 2020-04-19, the Quad9 resolver was seen to have a UDP payload size
89 // of 1232. Cloudflare's was 1452, and Google's was 4096.
90 maxUDPPayload = 1280 - 40 - 8
93 // base32Encoding is a base32 encoding without padding.
94 var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
96 // generateKeypair generates a private key and the corresponding public key. If
97 // privkeyFilename and pubkeyFilename are respectively empty, it prints the
98 // corresponding key to standard output; otherwise it saves the key to the given
99 // file name. The private key is saved with mode 0400 and the public key is
100 // saved with 0666 (before umask). In case of any error, it attempts to delete
101 // any files it has created before returning.
102 func generateKeypair(privkeyFilename, pubkeyFilename string) (err error) {
103 // Filenames to delete in case of error (avoid leaving partially written
104 // files).
105 var toDelete []string
106 defer func() {
107 for _, filename := range toDelete {
108 fmt.Fprintf(os.Stderr, "deleting partially written file %s\n", filename)
109 if closeErr := os.Remove(filename); closeErr != nil {
110 fmt.Fprintf(os.Stderr, "cannot remove %s: %v\n", filename, closeErr)
111 if err == nil {
112 err = closeErr
118 privkey, pubkey, err := noise.GenerateKeypair()
119 if err != nil {
120 return err
123 if privkeyFilename != "" {
124 // Save the privkey to a file.
125 f, err := os.OpenFile(privkeyFilename, os.O_RDWR|os.O_CREATE, 0400)
126 if err != nil {
127 return err
129 toDelete = append(toDelete, privkeyFilename)
130 err = noise.WriteKey(f, privkey)
131 if err2 := f.Close(); err == nil {
132 err = err2
134 if err != nil {
135 return err
139 if pubkeyFilename != "" {
140 // Save the pubkey to a file.
141 f, err := os.Create(pubkeyFilename)
142 if err != nil {
143 return err
145 toDelete = append(toDelete, pubkeyFilename)
146 err = noise.WriteKey(f, pubkey)
147 if err2 := f.Close(); err == nil {
148 err = err2
150 if err != nil {
151 return err
155 // All good, allow the written files to remain.
156 toDelete = nil
158 if privkeyFilename != "" {
159 fmt.Printf("privkey written to %s\n", privkeyFilename)
160 } else {
161 fmt.Printf("privkey %x\n", privkey)
163 if pubkeyFilename != "" {
164 fmt.Printf("pubkey written to %s\n", pubkeyFilename)
165 } else {
166 fmt.Printf("pubkey %x\n", pubkey)
169 return nil
172 // readKeyFromFile reads a key from a named file.
173 func readKeyFromFile(filename string) ([]byte, error) {
174 f, err := os.Open(filename)
175 if err != nil {
176 return nil, err
178 defer f.Close()
179 return noise.ReadKey(f)
182 // handleStream bidirectionally connects a client stream with a TCP socket
183 // addressed by upstream.
184 func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error {
185 conn, err := net.DialTCP("tcp", nil, upstream)
186 if err != nil {
187 return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
189 defer conn.Close()
191 var wg sync.WaitGroup
192 wg.Add(2)
193 go func() {
194 defer wg.Done()
195 _, err := io.Copy(stream, conn)
196 if err == io.EOF {
197 // smux Stream.Write may return io.EOF.
198 err = nil
200 if err != nil {
201 log.Printf("stream %08x:%d copy streamā†upstream: %v\n", conv, stream.ID(), err)
203 conn.CloseRead()
204 stream.Close()
206 go func() {
207 defer wg.Done()
208 _, err := io.Copy(conn, stream)
209 if err == io.EOF {
210 // smux Stream.WriteTo may return io.EOF.
211 err = nil
213 if err != nil && err != io.ErrClosedPipe {
214 log.Printf("stream %08x:%d copy upstreamā†stream: %v\n", conv, stream.ID(), err)
216 conn.CloseWrite()
218 wg.Wait()
220 return nil
223 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
224 // then awaits smux streams. It passes each stream to handleStream.
225 func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.TCPAddr) error {
226 // Put a Noise channel on top of the KCP conn.
227 rw, err := noise.NewServer(conn, privkey, pubkey)
228 if err != nil {
229 return err
232 // Put an smux session on top of the encrypted Noise channel.
233 smuxConfig := smux.DefaultConfig()
234 smuxConfig.Version = 2
235 smuxConfig.KeepAliveTimeout = idleTimeout
236 sess, err := smux.Server(rw, smuxConfig)
237 if err != nil {
238 return err
241 for {
242 stream, err := sess.AcceptStream()
243 if err != nil {
244 if err, ok := err.(net.Error); ok && err.Temporary() {
245 continue
247 return err
249 log.Printf("begin stream %08x:%d", conn.GetConv(), stream.ID())
250 go func() {
251 defer func() {
252 log.Printf("end stream %08x:%d", conn.GetConv(), stream.ID())
253 stream.Close()
255 err := handleStream(stream, upstream, conn.GetConv())
256 if err != nil {
257 log.Printf("stream %08x:%d handleStream: %v\n", conn.GetConv(), stream.ID(), err)
263 // acceptSessions listens for incoming KCP connections and passes them to
264 // acceptStreams.
265 func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream *net.TCPAddr) error {
266 for {
267 conn, err := ln.AcceptKCP()
268 if err != nil {
269 if err, ok := err.(net.Error); ok && err.Temporary() {
270 continue
272 return err
274 log.Printf("begin session %08x", conn.GetConv())
275 // Permit coalescing the payloads of consecutive sends.
276 conn.SetStreamMode(true)
277 // Disable the dynamic congestion window (limit only by the
278 // maximum of local and remote static windows).
279 conn.SetNoDelay(
280 0, // default nodelay
281 0, // default interval
282 0, // default resend
283 1, // nc=1 => congestion window off
285 if rc := conn.SetMtu(mtu); !rc {
286 panic(rc)
288 go func() {
289 defer func() {
290 log.Printf("end session %08x", conn.GetConv())
291 conn.Close()
293 err := acceptStreams(conn, privkey, pubkey, upstream)
294 if err != nil {
295 log.Printf("session %08x acceptStreams: %v\n", conn.GetConv(), err)
301 // nextPacket reads the next length-prefixed packet from r, ignoring padding. It
302 // returns a nil error only when a packet was read successfully. It returns
303 // io.EOF only when there were 0 bytes remaining to read from r. It returns
304 // io.ErrUnexpectedEOF when EOF occurs in the middle of an encoded packet.
306 // The prefixing scheme is as follows. A length prefix L < 0xe0 means a data
307 // packet of L bytes. A length prefix L >= 0xe0 means padding of L - 0xe0 bytes
308 // (not counting the length of the length prefix itself).
309 func nextPacket(r *bytes.Reader) ([]byte, error) {
310 // Convert io.EOF to io.ErrUnexpectedEOF.
311 eof := func(err error) error {
312 if err == io.EOF {
313 err = io.ErrUnexpectedEOF
315 return err
318 for {
319 prefix, err := r.ReadByte()
320 if err != nil {
321 // We may return a real io.EOF only here.
322 return nil, err
324 if prefix >= 224 {
325 paddingLen := prefix - 224
326 _, err := io.CopyN(ioutil.Discard, r, int64(paddingLen))
327 if err != nil {
328 return nil, eof(err)
330 } else {
331 p := make([]byte, int(prefix))
332 _, err = io.ReadFull(r, p)
333 return p, eof(err)
338 // responseFor constructs a response dns.Message that is appropriate for query.
339 // Along with the dns.Message, it returns the query's decoded data payload. If
340 // the returned dns.Message is nil, it means that there should be no response to
341 // this query. If the returned dns.Message has an Rcode() of dns.RcodeNoError,
342 // the message is a candidate for for carrying downstream data in a TXT record.
343 func responseFor(query *dns.Message, domain dns.Name) (*dns.Message, []byte) {
344 resp := &dns.Message{
345 ID: query.ID,
346 Flags: 0x8000, // QR = 1, RCODE = no error
347 Question: query.Question,
350 if query.Flags&0x8000 != 0 {
351 // QR != 0, this is not a query. Don't even send a response.
352 return nil, nil
355 // Check for EDNS(0) support. Include our own OPT RR only if we receive
356 // one from the requestor.
357 // https://tools.ietf.org/html/rfc6891#section-6.1.1
358 // "Lack of presence of an OPT record in a request MUST be taken as an
359 // indication that the requestor does not implement any part of this
360 // specification and that the responder MUST NOT include an OPT record
361 // in its response."
362 payloadSize := 0
363 for _, rr := range query.Additional {
364 if rr.Type != dns.RRTypeOPT {
365 continue
367 if len(resp.Additional) != 0 {
368 // https://tools.ietf.org/html/rfc6891#section-6.1.1
369 // "If a query message with more than one OPT RR is
370 // received, a FORMERR (RCODE=1) MUST be returned."
371 resp.Flags |= dns.RcodeFormatError
372 log.Printf("FORMERR: more than one OPT RR")
373 return resp, nil
375 resp.Additional = append(resp.Additional, dns.RR{
376 Name: dns.Name{},
377 Type: dns.RRTypeOPT,
378 Class: 4096, // responder's UDP payload size
379 TTL: 0,
380 Data: []byte{},
382 additional := &resp.Additional[0]
384 version := (rr.TTL >> 16) & 0xff
385 if version != 0 {
386 // https://tools.ietf.org/html/rfc6891#section-6.1.1
387 // "If a responder does not implement the VERSION level
388 // of the request, then it MUST respond with
389 // RCODE=BADVERS."
390 resp.Flags |= dns.ExtendedRcodeBadVers & 0xf
391 additional.TTL = (dns.ExtendedRcodeBadVers >> 4) << 24
392 log.Printf("BADVERS: EDNS version %d != 0", version)
393 return resp, nil
396 payloadSize = int(rr.Class)
398 if payloadSize < 512 {
399 // https://tools.ietf.org/html/rfc6891#section-6.1.1 "Values
400 // lower than 512 MUST be treated as equal to 512."
401 payloadSize = 512
403 // We will return RcodeFormatError if payloadSize is too small, but
404 // first, check the name in order to set the AA bit properly.
406 // There must be exactly one question.
407 if len(query.Question) != 1 {
408 resp.Flags |= dns.RcodeFormatError
409 log.Printf("FORMERR: too many questions (%d)", len(query.Question))
410 return resp, nil
412 question := query.Question[0]
413 // Check the name to see if it ends in our chosen domain, and extract
414 // all that comes before the domain if it does. If it does not, we will
415 // return RcodeNameError below, but prefer to return RcodeFormatError
416 // for payload size if that applies as well.
417 prefix, ok := question.Name.TrimSuffix(domain)
418 if !ok {
419 // Not a name we are authoritative for.
420 resp.Flags |= dns.RcodeNameError
421 log.Printf("NXDOMAIN: not authoritative for %s", question.Name)
422 return resp, nil
424 resp.Flags |= 0x0400 // AA = 1
426 if query.Opcode() != 0 {
427 // We don't support OPCODE != QUERY.
428 resp.Flags |= dns.RcodeNotImplemented
429 log.Printf("NOTIMPL: unrecognized OPCODE %d", query.Opcode())
430 return resp, nil
433 if question.Type != dns.RRTypeTXT {
434 // We only support QTYPE == TXT.
435 resp.Flags |= dns.RcodeNameError
436 // No log message here; it's common for recursive resolvers to
437 // send NS or A queries when the client only asked for a TXT. I
438 // suspect this is related to QNAME minimization, but I'm not
439 // sure. https://tools.ietf.org/html/rfc7816
440 // log.Printf("NXDOMAIN: QTYPE %d != TXT", question.Type)
441 return resp, nil
444 encoded := bytes.ToUpper(bytes.Join(prefix, nil))
445 payload := make([]byte, base32Encoding.DecodedLen(len(encoded)))
446 n, err := base32Encoding.Decode(payload, encoded)
447 if err != nil {
448 // Base32 error, make like the name doesn't exist.
449 resp.Flags |= dns.RcodeNameError
450 log.Printf("NXDOMAIN: base32 decoding: %v", err)
451 return resp, nil
453 payload = payload[:n]
455 // We require clients to support EDNS(0) with a minimum payload size;
456 // otherwise we would have to set a small KCP MTU (only around 200
457 // bytes). https://tools.ietf.org/html/rfc6891#section-7 "If there is a
458 // problem with processing the OPT record itself, such as an option
459 // value that is badly formatted or that includes out-of-range values, a
460 // FORMERR MUST be returned."
461 if payloadSize < maxUDPPayload {
462 resp.Flags |= dns.RcodeFormatError
463 log.Printf("FORMERR: requestor payload size %d is too small (minimum %d)", payloadSize, maxUDPPayload)
464 return resp, nil
467 return resp, payload
470 // record represents a DNS message appropriate for a response to a previously
471 // received query, along with metadata necessary for sending the response.
472 // recvLoop sends instances of record to sendLoop via a channel. sendLoop
473 // receives instances of record and may fill in the message's Answer section
474 // before sending it.
475 type record struct {
476 Resp *dns.Message
477 Addr net.Addr
478 ClientID turbotunnel.ClientID
481 // recvLoop repeatedly calls dnsConn.ReadFrom, extracts the packets contained in
482 // the incoming DNS queries, and puts them on ttConn's incoming queue. Whenever
483 // a query calls for a response, constructs a partial response and passes it to
484 // sendLoop over ch.
485 func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- *record) error {
486 for {
487 var buf [4096]byte
488 n, addr, err := dnsConn.ReadFrom(buf[:])
489 if err != nil {
490 if err, ok := err.(net.Error); ok && err.Temporary() {
491 log.Printf("ReadFrom temporary error: %v", err)
492 continue
494 return err
497 // Got a UDP packet. Try to parse it as a DNS message.
498 query, err := dns.MessageFromWireFormat(buf[:n])
499 if err != nil {
500 log.Printf("cannot parse DNS query: %v", err)
501 continue
504 resp, payload := responseFor(&query, domain)
505 // Extract the ClientID from the payload.
506 var clientID turbotunnel.ClientID
507 n = copy(clientID[:], payload)
508 payload = payload[n:]
509 if n == len(clientID) {
510 // Discard padding and pull out the packets contained in
511 // the payload.
512 r := bytes.NewReader(payload)
513 for {
514 p, err := nextPacket(r)
515 if err != nil {
516 break
518 // Feed the incoming packet to KCP.
519 ttConn.QueueIncoming(p, clientID)
521 } else {
522 // Payload is not long enough to contain a ClientID.
523 if resp != nil && resp.Rcode() == dns.RcodeNoError {
524 resp.Flags |= dns.RcodeNameError
525 log.Printf("NXDOMAIN: %d bytes are too short to contain a ClientID", n)
528 // If a response is called for, pass it to sendLoop via the channel.
529 if resp != nil {
530 select {
531 case ch <- &record{resp, addr, clientID}:
532 default:
538 // sendLoop repeatedly receives records from ch. Those that represent an error
539 // response, it sends on the network immediately. Those that represent a
540 // response capable of carrying data, it packs full of as many packets as will
541 // fit while keeping the total size under maxEncodedPayload, then sends it.
542 func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan *record, maxEncodedPayload int) error {
543 var nextRec *record
544 for {
545 rec := nextRec
546 nextRec = nil
548 if rec == nil {
549 var ok bool
550 rec, ok = <-ch
551 if !ok {
552 break
556 if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 {
557 // If it's a non-error response, we can fill the Answer
558 // section with downstream packets.
560 // Any changes to how responses are built need to happen
561 // also in computeMaxEncodedPayload.
562 rec.Resp.Answer = []dns.RR{
564 Name: rec.Resp.Question[0].Name,
565 Type: rec.Resp.Question[0].Type,
566 Class: rec.Resp.Question[0].Class,
567 TTL: responseTTL,
568 Data: nil, // will be filled in below
572 var payload bytes.Buffer
573 limit := maxEncodedPayload
574 // We loop and bundle as many packets from OutgoingQueue
575 // into the response as will fit. Any packet that would
576 // overflow the capacity of the DNS response, we stash
577 // to be bundled into a future response.
578 timer := time.NewTimer(maxResponseDelay)
579 loop:
580 for {
581 var p []byte
582 select {
583 // Check the nextRec, timer, and stash cases
584 // before considering the OutgoingQueue case.
585 // Only if all these cases fail do we enter the
586 // default arm, where they are checked again in
587 // addition to OutgoingQueue.
588 case nextRec = <-ch:
589 // If there's another response waiting
590 // to be sent, wait no longer for a
591 // payload for this one.
592 break loop
593 case <-timer.C:
594 break loop
595 case p = <-ttConn.Unstash(rec.ClientID):
596 default:
597 select {
598 case nextRec = <-ch:
599 break loop
600 case <-timer.C:
601 break loop
602 case p = <-ttConn.Unstash(rec.ClientID):
603 case p = <-ttConn.OutgoingQueue(rec.ClientID):
606 // We wait for the first packet in a bundle
607 // only. The second and later packets must be
608 // immediately available or they will be omitted
609 // from this bundle.
610 timer.Reset(0)
612 limit -= 2 + len(p)
613 if payload.Len() == 0 {
614 // No packet length check for the first
615 // packet; if it's too large, we allow
616 // it to be truncated and dropped by the
617 // receiver.
618 } else if limit < 0 {
619 // Stash this packet to send in the next
620 // response.
621 ttConn.Stash(p, rec.ClientID)
622 break loop
624 if int(uint16(len(p))) != len(p) {
625 panic(len(p))
627 binary.Write(&payload, binary.BigEndian, uint16(len(p)))
628 payload.Write(p)
630 timer.Stop()
632 rec.Resp.Answer[0].Data = dns.EncodeRDataTXT(payload.Bytes())
635 buf, err := rec.Resp.WireFormat()
636 if err != nil {
637 log.Printf("resp WireFormat: %v", err)
638 continue
640 // Truncate if necessary.
641 // https://tools.ietf.org/html/rfc1035#section-4.1.1
642 if len(buf) > maxUDPPayload {
643 log.Printf("truncating response of %d bytes to max of %d", len(buf), maxUDPPayload)
644 buf = buf[:maxUDPPayload]
645 buf[2] |= 0x02 // TC = 1
648 // Now we actually send the message as a UDP packet.
649 _, err = dnsConn.WriteTo(buf, rec.Addr)
650 if err != nil {
651 if err, ok := err.(net.Error); ok && err.Temporary() {
652 log.Printf("WriteTo temporary error: %v", err)
653 continue
655 return err
658 return nil
661 // computeMaxEncodedPayload computes the maximum amount of downstream TXT RR
662 // data that keep the overall response size less than maxUDPPayload, in the
663 // worst case when the response answers a query that has a maximum-length name
664 // in its Question section. Returns 0 in the case that no amount of data makes
665 // the overall response size small enough.
667 // This function needs to be kept in sync with sendLoop with regard to how it
668 // builds candidate responses.
669 func computeMaxEncodedPayload(limit int) int {
670 // 64+64+64+62 octets, needs to be base32-decodable.
671 maxLengthName, err := dns.NewName([][]byte{
672 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
673 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
674 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
675 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
677 if err != nil {
678 panic(err)
680 if len(maxLengthName.String())+2 != 255 {
681 panic(fmt.Sprintf("max-length name is %d octets, should be %d %s",
682 len(maxLengthName.String())+2, 255, maxLengthName))
685 queryLimit := uint16(limit)
686 if int(queryLimit) != limit {
687 queryLimit = 0xffff
689 query := &dns.Message{
690 Question: []dns.Question{
692 Name: maxLengthName,
693 Type: dns.RRTypeTXT,
694 Class: dns.RRTypeTXT,
697 // EDNS(0)
698 Additional: []dns.RR{
700 Name: dns.Name{},
701 Type: dns.RRTypeOPT,
702 Class: queryLimit, // requestor's UDP payload size
703 TTL: 0, // extended RCODE and flags
704 Data: []byte{},
708 resp, _ := responseFor(query, dns.Name([][]byte{}))
709 // As in sendLoop.
710 resp.Answer = []dns.RR{
712 Name: query.Question[0].Name,
713 Type: query.Question[0].Type,
714 Class: query.Question[0].Class,
715 TTL: responseTTL,
716 Data: nil, // will be filled in below
720 // Binary search to find the maximum payload length that does not result
721 // in a wire-format message whose length exceeds the limit.
722 low := 0
723 high := 32768
724 for low+1 < high {
725 mid := (low + high) / 2
726 resp.Answer[0].Data = dns.EncodeRDataTXT(make([]byte, mid))
727 buf, err := resp.WireFormat()
728 if err != nil {
729 panic(err)
731 if len(buf) <= limit {
732 low = mid
733 } else {
734 high = mid
738 return low
741 func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net.PacketConn) error {
742 defer dnsConn.Close()
744 log.Printf("pubkey %x", pubkey)
746 // We have a variable amount of room in which to encode downstream
747 // packets in each response, because each response must contain the
748 // query's Question section, which is of variable length. But we cannot
749 // give dynamic packet size limits to KCP; the best we can do is set a
750 // global maximum which no packet will exceed. We choose that maximum to
751 // keep the UDP payload size under maxUDPPayload, even in the worst case
752 // of a maximum-length name in the query's Question section.
753 maxEncodedPayload := computeMaxEncodedPayload(maxUDPPayload)
754 // 2 bytes accounts for a packet length prefix.
755 mtu := maxEncodedPayload - 2
756 if mtu < 80 {
757 if mtu < 0 {
758 mtu = 0
760 return fmt.Errorf("maximum UDP payload size of %d leaves only %d bytes for payload", maxUDPPayload, mtu)
762 log.Printf("effective MTU %d\n", mtu)
764 // Start up the virtual PacketConn for turbotunnel.
765 ttConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
766 ln, err := kcp.ServeConn(nil, 0, 0, ttConn)
767 if err != nil {
768 return fmt.Errorf("opening KCP listener: %v", err)
770 defer ln.Close()
771 go func() {
772 err := acceptSessions(ln, privkey, pubkey, mtu, upstream.(*net.TCPAddr))
773 if err != nil {
774 log.Printf("acceptSessions: %v\n", err)
778 ch := make(chan *record, 100)
779 defer close(ch)
781 go func() {
782 err := sendLoop(dnsConn, ttConn, ch, maxEncodedPayload)
783 if err != nil {
784 log.Printf("sendLoop: %v", err)
788 return recvLoop(domain, dnsConn, ttConn, ch)
791 func main() {
792 var genKey bool
793 var privkeyFilename string
794 var privkeyString string
795 var pubkeyFilename string
796 var udpAddr string
798 flag.Usage = func() {
799 fmt.Fprintf(flag.CommandLine.Output(), `Usage:
800 %[1]s -gen-key -privkey-file PRIVKEYFILE -pubkey-file PUBKEYFILE
801 %[1]s -udp ADDR -privkey-file PRIVKEYFILE DOMAIN UPSTREAMADDR
803 Example:
804 %[1]s -gen-key -privkey-file server.key -pubkey-file server.pub
805 %[1]s -udp 127.0.0.1:5300 -privkey-file server.key t.example.com 127.0.0.1:8000
807 `, os.Args[0])
808 flag.PrintDefaults()
810 flag.BoolVar(&genKey, "gen-key", false, "generate a server keypair; print to stdout or save to files")
811 flag.IntVar(&maxUDPPayload, "mtu", maxUDPPayload, "maximum size of DNS responses")
812 flag.StringVar(&privkeyString, "privkey", "", fmt.Sprintf("server private key (%d hex digits)", noise.KeyLen*2))
813 flag.StringVar(&privkeyFilename, "privkey-file", "", "read server private key from file (with -gen-key, write to file)")
814 flag.StringVar(&pubkeyFilename, "pubkey-file", "", "with -gen-key, write server public key to file")
815 flag.StringVar(&udpAddr, "udp", "", "UDP address to listen on (required)")
816 flag.Parse()
818 log.SetFlags(log.LstdFlags | log.LUTC)
820 if genKey {
821 // -gen-key mode.
822 if flag.NArg() != 0 || privkeyString != "" || udpAddr != "" {
823 flag.Usage()
824 os.Exit(1)
826 if err := generateKeypair(privkeyFilename, pubkeyFilename); err != nil {
827 fmt.Fprintf(os.Stderr, "cannot generate keypair: %v\n", err)
828 os.Exit(1)
830 } else {
831 // Ordinary server mode.
832 if flag.NArg() != 2 {
833 flag.Usage()
834 os.Exit(1)
836 domain, err := dns.ParseName(flag.Arg(0))
837 if err != nil {
838 fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
839 os.Exit(1)
841 upstream, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
842 if err != nil {
843 fmt.Fprintf(os.Stderr, "cannot resolve %+q: %v\n", flag.Arg(1), err)
844 os.Exit(1)
847 if udpAddr == "" {
848 fmt.Fprintf(os.Stderr, "the -udp option is required\n")
849 os.Exit(1)
851 dnsConn, err := net.ListenPacket("udp", udpAddr)
852 if err != nil {
853 fmt.Fprintf(os.Stderr, "opening UDP listener: %v\n", err)
854 os.Exit(1)
857 if pubkeyFilename != "" {
858 fmt.Fprintf(os.Stderr, "-pubkey-file may only be used with -gen-key\n")
859 os.Exit(1)
862 var privkey []byte
863 if privkeyFilename != "" && privkeyString != "" {
864 fmt.Fprintf(os.Stderr, "only one of -privkey and -privkey-file may be used\n")
865 os.Exit(1)
866 } else if privkeyFilename != "" {
867 var err error
868 privkey, err = readKeyFromFile(privkeyFilename)
869 if err != nil {
870 fmt.Fprintf(os.Stderr, "cannot read privkey from file: %v\n", err)
871 os.Exit(1)
873 } else if privkeyString != "" {
874 var err error
875 privkey, err = noise.DecodeKey(privkeyString)
876 if err != nil {
877 fmt.Fprintf(os.Stderr, "privkey format error: %v\n", err)
878 os.Exit(1)
881 if len(privkey) == 0 {
882 log.Println("generating a temporary one-time keypair")
883 log.Println("use the -privkey or -privkey-file option for a persistent server keypair")
884 var err error
885 privkey, _, err = noise.GenerateKeypair()
886 if err != nil {
887 fmt.Fprintln(os.Stderr, err)
888 os.Exit(1)
891 pubkey := noise.PubkeyFromPrivkey(privkey)
893 err = run(privkey, pubkey, domain, upstream, dnsConn)
894 if err != nil {
895 log.Fatal(err)