Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Always Encrypted support #637

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 80 additions & 0 deletions README.md
Expand Up @@ -28,6 +28,12 @@ Other supported formats are listed below.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `columnEncryption` - Set to "true" if you want to use [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15)
* `keyStoreAuthentication`
* `pfx` - Use a PFX file as a key store to authenticate and perform Always Encrypted operations, used when `columnEncryption` is enabled
* `keyStoreLocation` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled
* `keyStoreSecret` - The password of the key store file provided in `keyStoreLocation`, used when `columnEncryption` is enabled


### Connection parameters for ODBC and ADO style connection strings:

Expand Down Expand Up @@ -126,6 +132,80 @@ Where `tokenProvider` is a function that returns a fresh access token or an erro
actually trigger the retrieval of a token, this happens when the first statment is issued and a connection
is created.


### Always Encrypted support (preview)

`go-mssql` supports a client-side decryption of the column encrypted values for those databases
that are using the [Always Encrypted](https://docs.microsoft.com/en-us/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver15)
feature.

To start using the feature, you have to use the following parameters in your DSN:

* `columnEncryption=true`
* `keyStoreAuthentication=pfx` - Only `pfx` is supported at the moment
* `keyStoreLocation=/path/to/your/keystore.pfx` - The location of the key store file (e.g: `./resources/test/always-encrypted/ae-1.pfx`), used when `columnEncryption` is enabled
* `keyStoreSecret=secret` - The password of your keystore (`keyStoreLocation`)

#### Usage

Using the Always Encrypted feature should be transparent in the driver:
```go
query := url.Values{}
query.Add("database", "dbname")
query.Add("columnEncryption", "true")
query.Add("keyStoreAuthentication", "pfx")
query.Add("keyStoreLocation", "./resources/test/always-encrypted/ae-1.pfx")
query.Add("keyStoreSecret", "password")


hostname := "172.20.0.2"
port:= 1433

u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword("sa", "superSecurePassword_"),
Host: fmt.Sprintf("%s:%d", hostname, port),
RawQuery: query.Encode(),
}

db, err := sql.Open("sqlserver", u.String())
if err != nil {
logrus.Fatalf("unable to open db: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example, use return fmt.Errorf("unable to open db: %w", err) throughout.

}
rows, err := db.Query("SELECT id, ssn FROM [dbo].[cid]")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove [ and ].

if err != nil {
logrus.Fatalf("unable to perform query: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

}

for ; rows.Next(); {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Go, there are three for loop variants: 0 argument (loop forever), 1 argument (test condition), and 3 argument.

You want the 1 argument version here: for rows.Next() {

var dest struct {
Id int
SSN string
}
err = rows.Scan(&dest.Id, &dest.SSN)
if err != nil {
logrus.Fatalf("unable to scan into struct: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

}
fmt.Printf("%d, %s\n", dest.Id, dest.SSN)
}
```

The code above, when used against an Always Encrypted column, returns
the following:

```
1, 12345
2, 00000
```

If `columnEncryption` is set to false, the result will be similar to the following:
```
1, B��v��3O뗇��a�R��o�l��U�
�iE�#wOS�T횡5�R��1�i_n/Q��oLPBy��kL���8'/�
2, �ކ��?�Y
Ѕ���i_n��-g|����v��2����x�Q)y�p�x��O��9������r��Bt�L�"N����.N]Rc
```

## Executing Stored Procedures

To run a stored procedure, set the query text to the procedure name:
Expand Down
1 change: 1 addition & 0 deletions accesstokenconnector_test.go
Expand Up @@ -23,6 +23,7 @@ func TestNewAccessTokenConnector(t *testing.T) {
args args
want func(driver.Connector) error
wantErr bool

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra line that was added.

}{
{
name: "Happy path",
Expand Down
104 changes: 104 additions & 0 deletions always_encrypted_test.go
@@ -0,0 +1,104 @@
package mssql

import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. We don't need an assert package.

"time"
)

type testAEStruct struct {
Id int
SSN string
Date time.Time
Float *float64
Money *float64
}

func TestAlwaysEncrypted(t *testing.T) {
conn := open(t)
defer conn.Close()
rows, err := conn.Query("SELECT id, ssn, secure_date, secure_float, secure_money FROM [dbo].[cid]")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move defer rows.Close after error check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is dbo.cid coming from? If this is a unit test, it should be created first as part of the test.

defer rows.Close()

if err != nil {
t.Fatalf("unable to query db: %s", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verb should be %w, not %s.

}

var dest testAEStruct

expectedValues := []string{
"12345 ",
"00000 ",
"041-64-841",
"009-34-870",
"517-04-462",
"158-16-318",
"136-01-843",
}

secureFloat := []float64{
1.0,
453.32,
}

secureMoney := []float64{
40333.95,
8284323.0,
}


expectedSecureFloat := []*float64 {
&secureFloat[0],
&secureFloat[1],
nil,
nil,
nil,
nil,
nil,
}

expectedSecureMoney := []*float64 {
&secureMoney[0],
&secureMoney[1],
nil,
nil,
nil,
nil,
nil,
}

expectedDate := time.Date(2021, 02, 11, 0, 0, 0, 0, time.UTC)
expectedIdx := 0

for rows.Next() {
err = rows.Scan(&dest.Id, &dest.SSN, &dest.Date, &dest.Float, &dest.Money)
fmt.Printf("col: %+v", dest)
if dest.Float != nil {
fmt.Printf("\t%f", *dest.Float)
}

if dest.Money != nil {
fmt.Printf("\t%f", *dest.Money)
}
fmt.Printf("\n")

assert.Equal(t, expectedValues[expectedIdx], dest.SSN)
assert.Equal(t, expectedDate, dest.Date.UTC())
checkNilandValue(t, expectedSecureFloat, expectedIdx, dest.Float)
checkNilandValue(t, expectedSecureMoney, expectedIdx, dest.Money)


expectedIdx++
assert.Nil(t, err)
}
}

func checkNilandValue(t *testing.T, expectedArr []*float64, expectedIdx int, res *float64) {
if expectedArr[expectedIdx] == nil {
assert.Nil(t, res)
} else {
assert.NotNil(t, res)
assert.Equal(t, *expectedArr[expectedIdx], *res)
}
}
38 changes: 38 additions & 0 deletions buf.go
Expand Up @@ -269,3 +269,41 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
r.rpos += copied
return
}

type sqlIdentifier struct {
serverName string
databaseName string
schemaName string
objectName string
}

func (r *tdsBuffer) sqlIdentifier() sqlIdentifier {
numParts := int(r.byte())
if numParts < 1 || numParts >= 5 {
panic("invalid sqlIdentifier: numparts is not between 1 and 4")
}

parts := make([]string, numParts)

for i := range parts {
parts[i] = r.UsVarChar()
}

sqlID := sqlIdentifier{
objectName: parts[0],
}

if numParts >= 2 {
sqlID.schemaName = parts[1]
}

if numParts >= 3{
sqlID.databaseName = parts[2]
}

if numParts == 4 {
sqlID.serverName = parts[3]
}

return sqlID
}
29 changes: 29 additions & 0 deletions cek.go
@@ -0,0 +1,29 @@
package mssql

type cekTable struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is cek mean? A comment or full name would be good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Column Encryption Key, this is used everywhere in Microsoft's documentation. Not sure if it's a good idea to make it longer here since cek is mentioned everywhere in the codebase now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then above type cekTable struct {, put

// Column Encryption Key table that ...

entries []cekTableEntry
}

type encryptionKeyInfo struct {
encryptedKey []byte
databaseID int
cekID int
cekVersion int
cekMdVersion []byte
keyPath string
keyStoreName string
algorithmName string
}

type cekTableEntry struct {
databaseID int
keyId int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeyID

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should this be exported?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nevermind, you were referring to the ID part 😅

keyVersion int
mdVersion []byte
valueCount int
cekValues []encryptionKeyInfo
}

func newCekTable(size uint16) cekTable {
return cekTable{entries: make([]cekTableEntry, size)}
}
78 changes: 77 additions & 1 deletion conn_str.go
Expand Up @@ -39,11 +39,21 @@ type connectParams struct {
packetSize uint16
fedAuthLibrary int
fedAuthADALWorkflow byte
columnEncryption bool
keyStoreAuthentication KeyStoreAuthentication
keyStoreLocation string
keyStoreSecret string
}

// default packet size for TDS buffer
const defaultPacketSize = 4096

type KeyStoreAuthentication string

const (
PFXKeystoreAuth = "pfx"
)

func parseConnectParams(dsn string) (connectParams, error) {
p := connectParams{
fedAuthLibrary: fedAuthLibraryReserved,
Expand Down Expand Up @@ -169,6 +179,54 @@ func parseConnectParams(dsn string) (connectParams, error) {
} else {
p.trustServerCertificate = true
}

columnEncryption, ok := params["columnencryption"]
if ok {
if strings.EqualFold(columnEncryption, "true") {
p.columnEncryption = true
} else {
var err error
p.columnEncryption, err = strconv.ParseBool(columnEncryption)
if err != nil {
f := "invalid columnEncryption '%s': %s"
return p, fmt.Errorf(f, columnEncryption, err.Error())
}
}
} else {
p.columnEncryption = false
}

ksAuth, ok := params["keystoreauthentication"]
if ok {
var authMethod KeyStoreAuthentication
switch strings.ToLower(ksAuth) {
case "pfx":
authMethod = PFXKeystoreAuth
default:
return p, fmt.Errorf("invalid keystotreAuthentication '%s'", ksAuth)
}
p.keyStoreAuthentication = authMethod
}

ksLocation, ok := params["keystorelocation"]
if ok {
if ksLocation == "" {
return p, fmt.Errorf("invalid keystore location provided: '%s'", ksLocation)
}

_, err := os.Stat(ksLocation)
if err != nil {
return p, fmt.Errorf("unable to find keystore %s: %v", ksLocation, err)
}

p.keyStoreLocation = ksLocation
}

ksSecret, ok := params["keystoresecret"]
if ok {
p.keyStoreSecret = ksSecret
}

trust, ok := params["trustservercertificate"]
if ok {
var err error
Expand Down Expand Up @@ -248,6 +306,23 @@ func (p connectParams) toUrl() *url.URL {
if p.logFlags != 0 {
q.Add("log", strconv.FormatUint(p.logFlags, 10))
}

if p.columnEncryption {
q.Add("columnEncryption", "true")
}

if p.keyStoreAuthentication != "" {
q.Add("keyStoreAuthentication", string(p.keyStoreAuthentication))
}

if p.keyStoreLocation != "" {
q.Add("keyStoreLocation", p.keyStoreLocation)
}

if p.keyStoreSecret != "" {
q.Add("keyStoreSecret", p.keyStoreSecret)
}

res := url.URL{
Scheme: "sqlserver",
Host: p.host,
Expand All @@ -256,6 +331,7 @@ func (p connectParams) toUrl() *url.URL {
if p.instance != "" {
res.Path = p.instance
}

if len(q) > 0 {
res.RawQuery = q.Encode()
}
Expand All @@ -274,7 +350,7 @@ func splitConnectionString(dsn string) (res map[string]string) {
if len(name) == 0 {
continue
}
var value string = ""
var value = ""
if len(lst) > 1 {
value = strings.TrimSpace(lst[1])
}
Expand Down