Remove an obsolete comment.
[champa.git] / champa-server / main.go
blob2828284d674bc56b54ff87894455ea2df2fa7c28
1 package main
3 import (
4 "bytes"
5 "encoding/base64"
6 "errors"
7 "flag"
8 "fmt"
9 "io"
10 "log"
11 "net"
12 "net/http"
13 "os"
14 "path"
15 "strings"
16 "sync"
17 "time"
19 "github.com/xtaci/kcp-go/v5"
20 "github.com/xtaci/smux"
21 "www.bamsoftware.com/git/champa.git/armor"
22 "www.bamsoftware.com/git/champa.git/encapsulation"
23 "www.bamsoftware.com/git/champa.git/noise"
24 "www.bamsoftware.com/git/champa.git/turbotunnel"
27 const (
28 // smux streams will be closed after this much time without receiving data.
29 idleTimeout = 2 * time.Minute
31 // How long we may wait for downstream data before sending an empty
32 // response.
33 maxResponseDelay = 100 * time.Millisecond
35 // How long to wait for a TCP connection to upstream to be established.
36 upstreamDialTimeout = 30 * time.Second
38 // net/http Server.ReadTimeout, the maximum time allowed to read an
39 // entire request, including the body. Because we are likely to be
40 // proxying through an AMP cache, we expect requests to be small, with
41 // no streaming body.
42 serverReadTimeout = 10 * time.Second
43 // net/http Server.WriteTimeout, the maximum time allowed to write an
44 // entire response, including the body. Because we are likely to be
45 // proxying through an AMP cache, our responses are limited in size and
46 // not streaming.
47 serverWriteTimeout = 20 * time.Second
48 // net/http Server.IdleTimeout, how long to keep a keep-alive HTTP
49 // connection open, awaiting another request.
50 serverIdleTimeout = idleTimeout
53 // handleStream bidirectionally connects a client stream with a TCP socket
54 // addressed by upstream.
55 func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
56 dialer := net.Dialer{
57 Timeout: upstreamDialTimeout,
59 upstreamConn, err := dialer.Dial("tcp", upstream)
60 if err != nil {
61 return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
63 defer upstreamConn.Close()
64 upstreamTCPConn := upstreamConn.(*net.TCPConn)
66 var wg sync.WaitGroup
67 wg.Add(2)
68 go func() {
69 defer wg.Done()
70 _, err := io.Copy(stream, upstreamTCPConn)
71 if err == io.EOF {
72 // smux Stream.Write may return io.EOF.
73 err = nil
75 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
76 log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
78 upstreamTCPConn.CloseRead()
79 stream.Close()
80 }()
81 go func() {
82 defer wg.Done()
83 _, err := io.Copy(upstreamTCPConn, stream)
84 if err == io.EOF {
85 // smux Stream.WriteTo may return io.EOF.
86 err = nil
88 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
89 log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
91 upstreamTCPConn.CloseWrite()
92 }()
93 wg.Wait()
95 return nil
98 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
99 // then awaits smux streams. It passes each stream to handleStream.
100 func acceptStreams(conn *kcp.UDPSession, upstream string) error {
101 // Put an smux session on top of the KCP connection.
102 smuxConfig := smux.DefaultConfig()
103 smuxConfig.Version = 2
104 smuxConfig.KeepAliveTimeout = idleTimeout
105 smuxConfig.MaxReceiveBuffer = 16 * 1024 * 1024 // default is 4 * 1024 * 1024
106 smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 // default is 65536
107 sess, err := smux.Server(conn, smuxConfig)
108 if err != nil {
109 return err
111 defer sess.Close()
113 for {
114 stream, err := sess.AcceptStream()
115 if err != nil {
116 if err, ok := err.(net.Error); ok && err.Temporary() {
117 continue
119 if err == io.ErrClosedPipe {
120 // We don't want to report this error.
121 err = nil
123 return err
125 log.Printf("begin stream %08x:%d", conn.GetConv(), stream.ID())
126 go func() {
127 defer func() {
128 log.Printf("end stream %08x:%d", conn.GetConv(), stream.ID())
129 stream.Close()
131 err := handleStream(stream, upstream, conn.GetConv())
132 if err != nil {
133 log.Printf("stream %08x:%d handleStream: %v", conn.GetConv(), stream.ID(), err)
139 // acceptSessions listens for incoming KCP connections and passes them to
140 // acceptStreams.
141 func acceptSessions(ln *kcp.Listener, upstream string) error {
142 for {
143 conn, err := ln.AcceptKCP()
144 if err != nil {
145 if err, ok := err.(net.Error); ok && err.Temporary() {
146 continue
148 return err
150 log.Printf("begin session %08x", conn.GetConv())
151 // Permit coalescing the payloads of consecutive sends.
152 conn.SetStreamMode(true)
153 // Disable the dynamic congestion window (limit only by the
154 // maximum of local and remote static windows).
155 conn.SetNoDelay(
156 0, // default nodelay
157 0, // default interval
158 0, // default resend
159 1, // nc=1 => congestion window off
161 conn.SetWindowSize(1024, 1024) // Default is 32, 32.
162 go func() {
163 defer func() {
164 log.Printf("end session %08x", conn.GetConv())
165 conn.Close()
167 err := acceptStreams(conn, upstream)
168 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
169 log.Printf("session %08x acceptStreams: %v", conn.GetConv(), err)
175 type Handler struct {
176 pconn *turbotunnel.QueuePacketConn
179 // decodeRequest extracts a ClientID and a payload from an incoming HTTP
180 // request. In case of a decoding failure, the returned payload slice will be
181 // nil. The payload is always non-nil after a successful decoding, even if the
182 // payload is empty.
183 func decodeRequest(req *http.Request) (turbotunnel.ClientID, []byte) {
184 // Check the version indicator of the incoming client–server protocol.
185 switch {
186 case strings.HasPrefix(req.URL.Path, "/0"):
187 // Version "0"'s payload is base64-encoded, using the URL-safe
188 // alphabet without padding, in the final path component
189 // (earlier path components are ignored).
190 _, encoded := path.Split(req.URL.Path[2:]) // Remove "/0" prefix.
191 decoded, err := base64.RawURLEncoding.DecodeString(encoded)
192 if err != nil {
193 return turbotunnel.ClientID{}, nil
195 var clientID turbotunnel.ClientID
196 n := copy(clientID[:], decoded)
197 if n != len(clientID) {
198 return turbotunnel.ClientID{}, nil
200 payload := decoded[n:]
201 return clientID, payload
202 default:
203 return turbotunnel.ClientID{}, nil
207 func (handler *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
208 const maxPayloadLength = 5000
210 if req.Method != "GET" {
211 http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
212 return
215 rw.Header().Set("Content-Type", "text/html")
216 // Attempt to hint to an AMP cache not to waste resources caching this
217 // document. "The Google AMP Cache considers any document fresh for at
218 // least 15 seconds."
219 // https://developers.google.com/amp/cache/overview#google-amp-cache-updates
220 rw.Header().Set("Cache-Control", "max-age=15")
221 rw.WriteHeader(http.StatusOK)
223 enc, err := armor.NewEncoder(rw)
224 if err != nil {
225 log.Printf("armor.NewEncoder: %v", err)
226 return
228 defer enc.Close()
230 clientID, payload := decodeRequest(req)
231 if payload == nil {
232 // Could not decode the client request. We do not even have a
233 // meaningful clientID or nonce. This may be a result of the
234 // client deliberately sending a short request for traffic
235 // shaping purposes. Send back a dummy, though still
236 // AMP-compatible, response.
237 // TODO: random padding.
238 return
241 // Read incoming packets from the payload.
242 r := bytes.NewReader(payload)
243 for {
244 p, err := encapsulation.ReadData(r)
245 if err != nil {
246 break
248 handler.pconn.QueueIncoming(p, clientID)
251 limit := maxPayloadLength
252 // We loop and bundle as many outgoing packets as will fit, up to
253 // maxPayloadLength. We wait up to maxResponseDelay for the first
254 // available packet; after that we only include whatever packets are
255 // immediately available.
256 timer := time.NewTimer(maxResponseDelay)
257 defer timer.Stop()
258 first := true
259 for {
260 var p []byte
261 unstash := handler.pconn.Unstash(clientID)
262 outgoing := handler.pconn.OutgoingQueue(clientID)
263 // Prioritize taking a packet first from the stash, then from
264 // the outgoing queue, then finally check for expiration of the
265 // timer. (We continue to bundle packets even after the timer
266 // expires, as long as the packets are immediately available.)
267 select {
268 case p = <-unstash:
269 default:
270 select {
271 case p = <-unstash:
272 case p = <-outgoing:
273 default:
274 select {
275 case p = <-unstash:
276 case p = <-outgoing:
277 case <-timer.C:
281 // We wait for the first packet only. Later packets must be
282 // immediately available.
283 timer.Reset(0)
285 if len(p) == 0 {
286 // Timer expired, we are done bundling packets into this
287 // response.
288 break
291 limit -= len(p)
292 if !first && limit < 0 {
293 // This packet doesn't fit in the payload size limit.
294 // Stash it so that it will be first in line for the
295 // next response.
296 handler.pconn.Stash(p, clientID)
297 break
299 first = false
301 // Write the packet to the AMP response.
302 _, err := encapsulation.WriteData(enc, p)
303 if err != nil {
304 log.Printf("encapsulation.WriteData: %v", err)
305 break
307 if rw, ok := rw.(http.Flusher); ok {
308 rw.Flush()
313 // noiseLoop is the Noise interface between an external noiseConn, which sends
314 // and receives encrypted Noise messages, and an internal plainConn, which sends
315 // and receives normal plaintext packets. This function tracks the state of
316 // Noise handshakes and a map of ongoing sessions, proxies packets between the
317 // connections while a session is active, and removes session from the map when
318 // they are finished.
319 func noiseLoop(noiseConn net.PacketConn, plainConn *turbotunnel.QueuePacketConn, privkey []byte) error {
320 sessions := make(map[turbotunnel.ClientID]*noise.Session)
321 var sessionsLock sync.RWMutex
323 for {
324 msgType, msg, addr, err := noise.ReadMessageFrom(noiseConn)
325 if err != nil {
326 if err, ok := err.(net.Error); ok && err.Temporary() {
327 continue
329 return err
332 sessionsLock.RLock()
333 sess := sessions[addr.(turbotunnel.ClientID)]
334 sessionsLock.RUnlock()
336 switch msgType {
337 // If the msgType of the incoming Noise message is
338 // MsgTypeHandshakeInit, send back a MsgTypeHandshakeResp and
339 // begin a new session for addr.
340 case noise.MsgTypeHandshakeInit:
341 if sess != nil {
342 // Already have a session for this addr.
343 continue
346 // Send back a MsgTypeHandshakeResp to permit the
347 // initiator to complete the Noise handshake.
348 p := []byte{noise.MsgTypeHandshakeResp}
349 sess, p, err := noise.AcceptHandshake(p, msg, privkey)
350 if err != nil {
351 log.Printf("AcceptHandshake: %v", err)
352 continue
354 _, err = noiseConn.WriteTo(p, addr)
355 if err != nil {
356 if err, ok := err.(net.Error); ok && err.Temporary() {
357 continue
359 return err
362 // We have enough information at this point to start a
363 // session. Store it in the map.
364 sessionsLock.Lock()
365 sessions[addr.(turbotunnel.ClientID)] = sess
366 sessionsLock.Unlock()
368 // Start a goroutine for sending to the peer on this
369 // session. Reading from the peer is handled in the
370 // MsgTypeTransport case in the top-level switch.
371 go func() {
372 defer func() {
373 sessionsLock.Lock()
374 delete(sessions, addr.(turbotunnel.ClientID))
375 sessionsLock.Unlock()
377 for p := range plainConn.OutgoingQueue(addr) {
378 buf := []byte{noise.MsgTypeTransport}
379 buf, err := sess.Encrypt(buf, p)
380 if err != nil {
381 log.Printf("Encrypt: %v", err)
382 break
384 _, err = noiseConn.WriteTo(buf, addr)
385 if err != nil {
386 log.Printf("WriteTo: %v", err)
387 if err, ok := err.(net.Error); ok && err.Temporary() {
388 continue
390 break
395 // If the msgType of the incoming Noise message is
396 // MsgTypeTransport, decrypt the message and queue the contents
397 // with plainConn.
398 case noise.MsgTypeTransport:
399 if sess == nil {
400 // No session yet for this addr.
401 continue
403 p, err := sess.Decrypt(nil, msg)
404 if err != nil {
405 log.Printf("Decrypt: %v", err)
406 continue
408 plainConn.QueueIncoming(p, addr)
410 default:
411 log.Printf("unknown msgType %d", msgType)
416 func run(listen, upstream string, privkey []byte) error {
417 done := make(chan error, 10)
419 // noiseConn is the packet interface that communicates with the AMP/HTTP
420 // Handler; it deals in encrypted Noise messages. plainConn is the
421 // packet interface that communicates with KCP. noiseLoop sits in the
422 // middle, handling Noise handshakes and sessions, and
423 // encrypting/decrypting between the two net.PacketConns.
424 noiseConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
425 plainConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
426 defer noiseConn.Close()
427 defer plainConn.Close()
428 go func() {
429 err := noiseLoop(noiseConn, plainConn, privkey)
430 done <- fmt.Errorf("noiseLoop: %w", err)
433 ln, err := kcp.ServeConn(nil, 0, 0, plainConn)
434 if err != nil {
435 return fmt.Errorf("opening KCP listener: %v", err)
437 defer ln.Close()
438 go func() {
439 err := acceptSessions(ln, upstream)
440 done <- fmt.Errorf("acceptSessions: %w", err)
443 handler := &Handler{
444 pconn: noiseConn,
446 server := &http.Server{
447 Addr: listen,
448 Handler: handler,
449 ReadTimeout: serverReadTimeout,
450 WriteTimeout: serverWriteTimeout,
451 IdleTimeout: serverIdleTimeout,
452 // The default MaxHeaderBytes is plenty for our purposes.
454 defer server.Close()
455 go func() {
456 err := server.ListenAndServe()
457 done <- fmt.Errorf("ListenAndServe: %w", err)
460 // The goroutines are expected to run forever. Return the first error
461 // from any of them.
462 return <-done
465 func main() {
466 var genKey bool
467 var privkeyFilename string
468 var privkeyString string
469 var pubkeyFilename string
471 flag.Usage = func() {
472 fmt.Fprintf(flag.CommandLine.Output(), `Usage:
473 %[1]s -gen-key -privkey-file PRIVKEYFILE -pubkey-file PUBKEYFILE
474 %[1]s -privkey-file PRIVKEYFILE LISTENADDR UPSTREAMADDR
476 Example:
477 %[1]s -gen-key -privkey-file server.key -pubkey-file server.pub
478 %[1]s -privkey-file server.key 127.0.0.1:8080 127.0.0.1:7001
480 `, os.Args[0])
481 flag.PrintDefaults()
483 flag.BoolVar(&genKey, "gen-key", false, "generate a server keypair; print to stdout or save to files")
484 flag.StringVar(&privkeyString, "privkey", "", fmt.Sprintf("server private key (%d hex digits)", noise.KeyLen*2))
485 flag.StringVar(&privkeyFilename, "privkey-file", "", "read server private key from file (with -gen-key, write to file)")
486 flag.StringVar(&pubkeyFilename, "pubkey-file", "", "with -gen-key, write server public key to file")
487 flag.Parse()
489 log.SetFlags(log.LstdFlags | log.LUTC)
491 if genKey {
492 // -gen-key mode.
494 if flag.NArg() != 0 || privkeyString != "" {
495 flag.Usage()
496 os.Exit(1)
498 if err := generateKeypair(privkeyFilename, pubkeyFilename); err != nil {
499 fmt.Fprintf(os.Stderr, "cannot generate keypair: %v\n", err)
500 os.Exit(1)
502 } else {
503 // Ordinary server mode.
505 if flag.NArg() != 2 {
506 flag.Usage()
507 os.Exit(1)
509 listen := flag.Arg(0)
510 upstream := flag.Arg(1)
511 // We keep upstream as a string in order to eventually pass it to
512 // net.Dial in handleStream. But we do a preliminary resolution of the
513 // name here, in order to exit with a quick error at startup if the
514 // address cannot be parsed or resolved.
516 upstreamTCPAddr, err := net.ResolveTCPAddr("tcp", upstream)
517 if err == nil && upstreamTCPAddr.IP == nil {
518 err = fmt.Errorf("missing host in address")
520 if err != nil {
521 fmt.Fprintf(os.Stderr, "cannot parse upstream address: %v\n", err)
522 os.Exit(1)
526 var privkey []byte
527 if privkeyFilename != "" && privkeyString != "" {
528 fmt.Fprintf(os.Stderr, "only one of -privkey and -privkey-file may be used\n")
529 os.Exit(1)
530 } else if privkeyFilename != "" {
531 var err error
532 privkey, err = readKeyFromFile(privkeyFilename)
533 if err != nil {
534 fmt.Fprintf(os.Stderr, "cannot read privkey from file: %v\n", err)
535 os.Exit(1)
537 } else if privkeyString != "" {
538 var err error
539 privkey, err = noise.DecodeKey(privkeyString)
540 if err != nil {
541 fmt.Fprintf(os.Stderr, "privkey format error: %v\n", err)
542 os.Exit(1)
544 } else {
545 log.Println("generating a temporary one-time keypair")
546 log.Println("use the -privkey or -privkey-file option for a persistent server keypair")
547 var err error
548 privkey, err = noise.GeneratePrivkey()
549 if err != nil {
550 fmt.Fprintln(os.Stderr, err)
551 os.Exit(1)
553 log.Printf("pubkey %x", noise.PubkeyFromPrivkey(privkey))
556 err := run(listen, upstream, privkey)
557 if err != nil {
558 log.Fatal(err)