diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 6d1da788e..340d462c1 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "fmt" "io" "net/http" @@ -22,6 +23,10 @@ type PluginConfig struct { SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` LogLevel *string `yaml:"log_level"` + Client *http.Client `yaml:"-"` + CertPath string `yaml:"cert_path"` + KeyPath string `yaml:"key_path"` + CAPath string `yaml:"ca_cert_path"` } type HTTPPlugin struct { @@ -35,6 +40,64 @@ var logger hclog.Logger = hclog.New(&hclog.LoggerOptions{ JSONFormat: true, }) +func getCertPool(caPath string) (*x509.CertPool, error) { + cp, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system CA certificates: %w", err) + } + + if cp == nil { + cp = x509.NewCertPool() + } + + if caPath == "" { + return cp, nil + } + + logger.Info(fmt.Sprintf("Using CA cert '%s'", caPath)) + + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to load CA certificate '%s': %w", caPath, err) + } + + cp.AppendCertsFromPEM(caCert) + + return cp, nil +} + +func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) { + var client *http.Client + + caCertPool, err := getCertPool(caPath) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + InsecureSkipVerify: tlsVerify, + } + + if certPath != "" && keyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath)) + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err) + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + return client, err +} + func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { if _, ok := s.PluginConfigByName[notification.Name]; !ok { return nil, fmt.Errorf("invalid plugin config name %s", notification.Name) @@ -46,13 +109,6 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific } logger.Info(fmt.Sprintf("received signal for %s config", notification.Name)) - client := http.Client{} - - if cfg.SkipTLSVerification { - client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } request, err := http.NewRequest(cfg.Method, cfg.URL, bytes.NewReader([]byte(notification.Text))) if err != nil { @@ -63,7 +119,7 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific request.Header.Add(headerName, headerValue) } logger.Debug(fmt.Sprintf("making HTTP %s call to %s with body %s", cfg.Method, cfg.URL, notification.Text)) - resp, err := client.Do(request.WithContext(ctx)) + resp, err := cfg.Client.Do(request.WithContext(ctx)) if err != nil { logger.Error(fmt.Sprintf("Failed to make HTTP request : %s", err)) return nil, err @@ -88,6 +144,13 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { d := PluginConfig{} err := yaml.Unmarshal(config.Config, &d) + if err != nil { + return nil, err + } + d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath) + if err != nil { + return nil, err + } s.PluginConfigByName[d.Name] = d logger.Debug(fmt.Sprintf("HTTP plugin '%s' use URL '%s'", d.Name, d.URL)) return &protobufs.Empty{}, err