Skip to content

Commit

Permalink
Merge pull request #1249 from cloudflare/nicky/new-db-accessor
Browse files Browse the repository at this point in the history
add db accessor to get unexpired certs by labels, add DB tests back to CI
  • Loading branch information
nickysemenza authored Oct 4, 2022
2 parents d4be5f5 + e0c522a commit 079aed0
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 11 deletions.
41 changes: 38 additions & 3 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,45 @@ jobs:
strategy:
matrix:
go: ["1.18", "1.19"]
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
# Provide the password for postgres
env:
POSTGRES_DB: postgres_db
POSTGRES_PASSWORD: ""
POSTGRES_HOST_AUTH_METHOD: trust # allow no password
POSTGRES_PORT: 5432
POSTGRES_USER: postgres
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
mysql:
image: mysql
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_ROOT_PASSWORD: ""
ports:
- 3306:3306
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3

env:
GOFLAGS: "-mod=vendor"
GODEBUG: "x509sha1=1"
BUILD_TAGS: "postgresql"
PGHOST: localhost
MYSQL_HOST: 127.0.0.1
steps:
- run: psql -c 'create database certdb_development;' -U postgres;
- run: mysql -e 'create database certdb_development;' -u root;
- run: mysql -e 'SET global sql_mode = 0;' -u root;
- uses: actions/checkout@v2

- name: Set up Go
Expand All @@ -24,11 +59,11 @@ jobs:

- name: Build
run: go build -v ./...

- run: make bin/goose;
- run: ./bin/goose -path certdb/pg up;
- run: ./bin/goose -path certdb/mysql up;
- name: Test
run: ./test.sh
# todo: these Actions tests still need to be updated to run the database tests
# that used to run in travis
- uses: codecov/codecov-action@v3

golangci:
Expand Down
1 change: 1 addition & 0 deletions certdb/certdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type Accessor interface {
GetCertificate(serial, aki string) ([]CertificateRecord, error)
GetUnexpiredCertificates() ([]CertificateRecord, error)
GetRevokedAndUnexpiredCertificates() ([]CertificateRecord, error)
GetUnexpiredCertificatesByLabel(labels []string) (crs []CertificateRecord, err error)
GetRevokedAndUnexpiredCertificatesByLabel(label string) ([]CertificateRecord, error)
GetRevokedAndUnexpiredCertificatesByLabelSelectColumns(label string) ([]CertificateRecord, error)
RevokeCertificate(serial, aki string, reasonCode int) error
Expand Down
2 changes: 1 addition & 1 deletion certdb/pg/dbconf.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
development:
driver: postgres
open: dbname=certdb_development sslmode=disable
open: dbname=certdb_development sslmode=disable user=postgres

test:
driver: postgres
Expand Down
25 changes: 25 additions & 0 deletions certdb/sql/database_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ type Accessor struct {
db *sqlx.DB
}

var _ certdb.Accessor = &Accessor{}

func wrapSQLError(err error) error {
if err != nil {
return cferr.Wrap(cferr.CertStoreError, cferr.Unknown, err)
Expand Down Expand Up @@ -176,6 +178,29 @@ func (d *Accessor) GetUnexpiredCertificates() (crs []certdb.CertificateRecord, e
return crs, nil
}

// GetUnexpiredCertificatesByLabel gets all unexpired certificate from db that have the provided label.
func (d *Accessor) GetUnexpiredCertificatesByLabel(labels []string) (crs []certdb.CertificateRecord, err error) {
err = d.checkDB()
if err != nil {
return nil, err
}

query, args, err := sqlx.In(
fmt.Sprintf(`SELECT %s FROM certificates WHERE CURRENT_TIMESTAMP < expiry AND ca_label IN (?)`,
sqlstruct.Columns(certdb.CertificateRecord{}),
), labels)
if err != nil {
return nil, wrapSQLError(err)
}

err = d.db.Select(&crs, d.db.Rebind(query), args...)
if err != nil {
return nil, wrapSQLError(err)
}

return crs, nil
}

// GetRevokedAndUnexpiredCertificates gets all revoked and unexpired certificate from db (for CRLs).
func (d *Accessor) GetRevokedAndUnexpiredCertificates() (crs []certdb.CertificateRecord, err error) {
err = d.checkDB()
Expand Down
21 changes: 15 additions & 6 deletions certdb/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ func testInsertCertificateAndGetUnexpiredCertificate(ta TestAccessor, t *testing

expiry := time.Now().Add(time.Minute)
want := certdb.CertificateRecord{
PEM: "fake cert data",
Serial: "fake serial 2",
AKI: fakeAKI,
Status: "good",
Reason: 0,
Expiry: expiry,
PEM: "fake cert data",
Serial: "fake serial 2",
AKI: fakeAKI,
Status: "good",
Reason: 0,
Expiry: expiry,
CALabel: "foo",
}

if err := ta.Accessor.InsertCertificate(want); err != nil {
Expand Down Expand Up @@ -153,6 +154,14 @@ func testInsertCertificateAndGetUnexpiredCertificate(ta TestAccessor, t *testing
if len(unexpired) != 1 {
t.Error("Should have 1 unexpired certificate record:", len(unexpired))
}

unexpiredFiltered, err := ta.Accessor.GetUnexpiredCertificatesByLabel([]string{"foo"})
require.NoError(t, err)
require.Len(t, unexpiredFiltered, 1)
unexpiredFiltered, err = ta.Accessor.GetUnexpiredCertificatesByLabel([]string{"bar"})
require.NoError(t, err)
require.Len(t, unexpiredFiltered, 0)

}
func testInsertCertificateAndGetUnexpiredCertificateNullCommonName(ta TestAccessor, t *testing.T) {
ta.Truncate()
Expand Down
2 changes: 1 addition & 1 deletion certdb/testdb/testdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func MySQLDB() *sqlx.DB {

// PostgreSQLDB returns a PostgreSQL db instance for certdb testing.
func PostgreSQLDB() *sqlx.DB {
connStr := "dbname=certdb_development sslmode=disable"
connStr := "dbname=certdb_development sslmode=disable user=postgres"

if dbURL := os.Getenv("DATABASE_URL"); dbURL != "" {
connStr = dbURL
Expand Down

0 comments on commit 079aed0

Please sign in to comment.