Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls-expire(api): add TLS cert expiration check #1059

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 35 additions & 4 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -106,28 +107,42 @@ func listen(network, address string) {
}
}

func tlsListen(network, address, certFile, keyFile string) {
var cert tls.Certificate
func LoadCertificate(certFile, keyFile string) (tls.Certificate, error) {
var err error
var cert tls.Certificate
if strings.IndexByte(certFile, '\n') < 0 && strings.IndexByte(keyFile, '\n') < 0 {
// check if file path
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
} else {
// if text file content
cert, err = tls.X509KeyPair([]byte(certFile), []byte(keyFile))
}

return cert, err
}

func tlsListen(network, address, certFile, keyFile string) {
log.Trace().Str("address", address).Msg("[api] tls listen")
cert, err := LoadCertificate(certFile, keyFile)
if err != nil {
log.Error().Err(err).Caller().Send()
return
}

ln, err := net.Listen(network, address)
if err != nil {
log.Error().Err(err).Msg("[api] tls listen")
return
}

log.Info().Str("addr", address).Msg("[api] tls listen")
certInfo, err := x509.ParseCertificate(cert.Certificate[0])

if err != nil {
log.Error().Err(err).Caller().Send()
return
}

tlsExpire := certInfo.NotAfter
checkCertExpiration(tlsExpire, address)

server := &http.Server{
Handler: Handler,
Expand All @@ -139,6 +154,22 @@ func tlsListen(network, address, certFile, keyFile string) {
}
}

// checkCertExpiration logs the certificate expiration status.
func checkCertExpiration(expirationTime time.Time, address string) (int, time.Duration) {
now := time.Now()
switch {
case now.Unix()-expirationTime.Unix() > 0 && now.Unix()-expirationTime.Unix() < int64(time.Hour.Seconds()*24):
log.Warn().Str("ExpireDate", expirationTime.Local().String()).Str("listen addr", address).Msg("[api] tls cert will expire today")
return 1, time.Until(expirationTime)
case expirationTime.Before(now):
log.Error().Str("ExpireDate", expirationTime.Local().String()).Str("listen addr", address).Msg("[api] tls cert expired")
return -1, time.Until(expirationTime)
default:
log.Info().Str("ExpireDate", expirationTime.Local().String()).Str("listen addr", address).Msg("[api] tls")
return 0, time.Until(expirationTime)
}
}

var Port int

const (
Expand Down
59 changes: 59 additions & 0 deletions internal/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package api

import (
"testing"
"time"
)

func TestCheckCertExpiration(t *testing.T) {
// Define a structure for test cases
tests := []struct {
name string
expirationTime time.Time // Input expiration time
expectedCode int // Expected status code returned by the function
expectedDuration time.Duration // Expected duration until expiration
}{
{
name: "Expired Certificate",
expirationTime: time.Now().Add(-48 * time.Hour), // 2 days ago
expectedCode: -1,
expectedDuration: -48 * time.Hour, // Negative duration indicating past expiration
},
{
name: "Expiring Today",
expirationTime: time.Now().Add(-12 * time.Hour), // 12 hours from now
expectedCode: 1,
expectedDuration: -12 * time.Hour, // Positive, less than 24 hours
},
{
name: "Valid Certificate",
expirationTime: time.Now().Add(48 * time.Hour), // 2 days from now
expectedCode: 0,
expectedDuration: 48 * time.Hour, // Positive, more than 24 hours
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, duration := checkCertExpiration(tc.expirationTime, "localhost")

if code != tc.expectedCode {
t.Errorf("Expected status code %d, got %d", tc.expectedCode, code)
}

// Since exact duration comparison can be flaky due to execution time, check if the duration is within a reasonable threshold
if !durationApproxEquals(duration, tc.expectedDuration, 5*time.Second) {
t.Errorf("Expected duration %v, got %v", tc.expectedDuration, duration)
}
})
}
}

// durationApproxEquals checks if two durations are approximately equal within a given threshold.
func durationApproxEquals(a, b, threshold time.Duration) bool {
diff := a - b
if diff < 0 {
diff = -diff
}
return diff <= threshold
}