show: make highlight legible
[debiancodesearch.git] / grpcutil / grpcutil.go
blob1d69bbd59fe7e991c94173517821e3f5cc82d522
1 // Encapsulates common RPC server setup.
2 package grpcutil
4 import (
5 "crypto/tls"
6 "crypto/x509"
7 "flag"
8 "fmt"
9 "io/ioutil"
10 "net"
11 "net/http"
12 "strings"
14 "github.com/Debian/dcs/internal/addrfd"
15 "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
16 "golang.org/x/net/http2"
17 "golang.org/x/net/trace"
18 "google.golang.org/grpc"
19 "google.golang.org/grpc/credentials"
20 "google.golang.org/grpc/reflection"
23 var (
24 requireClientAuth = flag.Bool("tls_require_client_auth",
25 true,
26 "Require TLS Client Authentication")
29 func init() {
30 // Disable grpc tracing until
31 // https://github.com/grpc/grpc-go/issues/695 is fixed.
32 grpc.EnableTracing = false
35 // grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
36 // connections or otherHandler otherwise. Copied from cockroachdb.
37 func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {
38 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39 // This is a partial recreation of gRPC's internal checks:
40 // https://github.com/grpc/grpc-go/blob/7834b974e55fbf85a5b01afb5821391c71084efd/transport/handler_server.go#L61
41 if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
42 grpcServer.ServeHTTP(w, r)
43 } else {
44 otherHandler.ServeHTTP(w, r)
49 func DialTLS(addr, certFile, keyFile string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
50 cert, err := tls.LoadX509KeyPair(certFile, keyFile)
51 if err != nil {
52 return nil, err
54 roots := x509.NewCertPool()
55 contents, err := ioutil.ReadFile(certFile)
56 if err != nil {
57 return nil, err
59 if !roots.AppendCertsFromPEM(contents) {
60 return nil, fmt.Errorf("Could not parse %q as PEM file (contents: %q)", certFile, contents)
62 auth := credentials.NewTLS(&tls.Config{
63 RootCAs: roots,
64 Certificates: []tls.Certificate{cert}})
66 return grpc.Dial(addr,
67 append([]grpc.DialOption{
68 grpc.WithTransportCredentials(auth),
69 grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor()),
70 grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor()),
71 }, opts...)...)
74 func ListenAndServeTLS(addr, certFile, keyFile string, register func(s *grpc.Server)) error {
75 ln, err := net.Listen("tcp", addr)
76 if err != nil {
77 return err
80 auth, err := credentials.NewServerTLSFromFile(certFile, keyFile)
81 if err != nil {
82 return err
85 s := grpc.NewServer(
86 grpc.Creds(auth),
87 grpc.StreamInterceptor(grpc_opentracing.StreamServerInterceptor()),
88 grpc.UnaryInterceptor(grpc_opentracing.UnaryServerInterceptor()))
90 register(s)
91 reflection.Register(s)
93 srv := http.Server{
94 Addr: addr,
95 Handler: grpcHandlerFunc(s, http.DefaultServeMux),
97 if err := http2.ConfigureServer(&srv, nil); err != nil {
98 return err
100 roots := x509.NewCertPool()
101 contents, err := ioutil.ReadFile(certFile)
102 if err != nil {
103 return err
105 if !roots.AppendCertsFromPEM(contents) {
106 return fmt.Errorf("Could not parse %q as PEM file (contents: %q)", certFile, contents)
109 if *requireClientAuth {
110 srv.TLSConfig.ClientCAs = roots
111 srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
112 trace.AuthRequest = func(req *http.Request) (bool, bool) {
113 return true, true
116 srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
117 srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
118 addrfd.MustWrite(ln.Addr().String())
119 return srv.Serve(tls.NewListener(ln, srv.TLSConfig))