diff --git a/Makefile b/Makefile index 4b0aa0a..c28271c 100644 --- a/Makefile +++ b/Makefile @@ -36,4 +36,17 @@ curl-bad: .PHONY: hurl hurl: - @hurl --variable api_host='http://localhost:8080' hurl/*.hurl + @hurl --verbose --error-format=long --variable api_host='http://localhost:8080' hurl/*.hurl + +.PHONY: coverage clean-cover +clean-cover: + rm -f cover.out cover.html + +coverage: cover.out cover.html + +cover.out: + go test ./... -coverprofile cover.out + +cover.html: cover.out + go tool cover -html=cover.out -o cover.html + diff --git a/api/api.go b/api/api.go index 0da90a1..c96e42b 100644 --- a/api/api.go +++ b/api/api.go @@ -11,12 +11,17 @@ import ( "github.com/Ajnasz/sekret.link/api/middlewares" "github.com/Ajnasz/sekret.link/internal/api" + "github.com/Ajnasz/sekret.link/internal/hasher" "github.com/Ajnasz/sekret.link/internal/models" "github.com/Ajnasz/sekret.link/internal/parsers" "github.com/Ajnasz/sekret.link/internal/services" "github.com/Ajnasz/sekret.link/internal/views" ) +func newAESEncrypter(b []byte) services.Encrypter { + return services.NewAESEncrypter(b) +} + // HandlerConfig configuration for http handlers type HandlerConfig struct { ExpireSeconds int @@ -26,18 +31,29 @@ type HandlerConfig struct { DB *sql.DB } -// NewSecretHandler creates a SecretHandler instance -func NewSecretHandler(config HandlerConfig) SecretHandler { - return SecretHandler{config} -} - // SecretHandler is an http.Handler implementation which handles requests to // encode or decode the post body type SecretHandler struct { config HandlerConfig } +// NewSecretHandler creates a SecretHandler instance +func NewSecretHandler(config HandlerConfig) SecretHandler { + return SecretHandler{config: config} +} + // POST method handler +// This method is responsible for creating a new entry +// url: / +// query: +// - expire: the expiration time of the entry +// - maxReads: the maximum number of reads for the entry +// +// method: POST +// response: 200 OK +// response: 400 Bad Request +// response: 500 Internal Server Error +// response: 413 Payload Too Large func (s SecretHandler) Post(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" && r.URL.Path != "" { http.Error(w, "Not found", http.StatusNotFound) @@ -45,12 +61,9 @@ func (s SecretHandler) Post(w http.ResponseWriter, r *http.Request) { return } - encrypter := func(b []byte) services.Encrypter { - return services.NewAESEncrypter(b) - } - parser := parsers.NewCreateEntryParser(s.config.MaxExpireSeconds) - entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(s.config.DB, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), newAESEncrypter) + entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, newAESEncrypter, keyManager) view := views.NewEntryCreateView(s.config.WebExternalURL) createHandler := api.NewCreateHandler( @@ -64,13 +77,10 @@ func (s SecretHandler) Post(w http.ResponseWriter, r *http.Request) { // GET method handler func (s SecretHandler) Get(w http.ResponseWriter, r *http.Request) { - encrypter := func(b []byte) services.Encrypter { - return services.NewAESEncrypter(b) - } - view := views.NewEntryReadView() parser := parsers.NewGetEntryParser() - entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(s.config.DB, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), newAESEncrypter) + entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, newAESEncrypter, keyManager) getHandler := api.NewGetHandler( parser, entryManager, @@ -81,11 +91,8 @@ func (s SecretHandler) Get(w http.ResponseWriter, r *http.Request) { // DELETE method handler func (s SecretHandler) Delete(w http.ResponseWriter, r *http.Request) { - encrypter := func(b []byte) services.Encrypter { - return services.NewAESEncrypter(b) - } - - entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(s.config.DB, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), newAESEncrypter) + entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, newAESEncrypter, keyManager) view := views.NewEntryDeleteView() deleteHandler := api.NewDeleteHandler(entryManager, view) deleteHandler.Handle(w, r) @@ -97,6 +104,32 @@ func (s SecretHandler) Options(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +// GenerateEncryptionKey provides a way to generate a new encryption key for an existing entry +// This allows to share the same entry with multiple users without sharing the encryption key +// url: /key/{uuid}/{key} +// - uuid: the uuid of the entry +// - key: the key of the entry +// query: +// - expire: the expiration time of the new key +// - maxReads: the maximum number of reads for the new key +// +// method: GET +// response: 200 OK +func (s SecretHandler) GenerateEncryptionKey(w http.ResponseWriter, r *http.Request) { + keyManager := services.NewEntryKeyManager(s.config.DB, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), newAESEncrypter) + entryManager := services.NewEntryManager(s.config.DB, &models.EntryModel{}, newAESEncrypter, keyManager) + view := views.NewGenerateEntryKeyView(s.config.WebExternalURL) + parser := parsers.NewGenerateEntryKeyParser(s.config.MaxExpireSeconds) + getHandler := api.NewGenerateEntryKeyHandler( + parser, + entryManager, + view, + ) + + getHandler.Handle(w, r) + +} + // NotFound handler func (s SecretHandler) NotFound(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not found", http.StatusNotFound) @@ -128,6 +161,7 @@ func clearApiRoot(apiRoot string) string { } func (s SecretHandler) RegisterHandlers(mux *http.ServeMux, apiRoot string) { + apiRoot = clearApiRoot(apiRoot) mux.Handle( fmt.Sprintf("GET %s", path.Join("/", apiRoot, "{uuid}", "{key}")), http.StripPrefix( @@ -138,7 +172,7 @@ func (s SecretHandler) RegisterHandlers(mux *http.ServeMux, apiRoot string) { ), ) mux.Handle( - fmt.Sprintf("POST %s", clearApiRoot(apiRoot)), + fmt.Sprintf("POST %s", apiRoot), http.StripPrefix( path.Join("/", apiRoot), middlewares.SetupLogging( @@ -158,7 +192,7 @@ func (s SecretHandler) RegisterHandlers(mux *http.ServeMux, apiRoot string) { ) mux.Handle( - fmt.Sprintf("OPTIONS %s", clearApiRoot(apiRoot)), + fmt.Sprintf("OPTIONS %s", apiRoot), http.StripPrefix( apiRoot, middlewares.SetupLogging( @@ -167,6 +201,17 @@ func (s SecretHandler) RegisterHandlers(mux *http.ServeMux, apiRoot string) { ), ) + // TODO + mux.Handle( + fmt.Sprintf("GET %s", path.Join(apiRoot, "key", "{uuid}", "{key}")), + http.StripPrefix( + apiRoot, + middlewares.SetupLogging( + middlewares.SetupHeaders(http.HandlerFunc(s.GenerateEncryptionKey)), + ), + ), + ) + mux.Handle("/", middlewares.SetupLogging(middlewares.SetupHeaders(http.HandlerFunc(s.NotFound)))) } diff --git a/api/api_test.go b/api/api_test.go index b1b3d25..393488c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -17,6 +17,7 @@ import ( "testing" "time" + "github.com/Ajnasz/sekret.link/internal/hasher" "github.com/Ajnasz/sekret.link/internal/models" "github.com/Ajnasz/sekret.link/internal/services" "github.com/Ajnasz/sekret.link/internal/test/durable" @@ -91,7 +92,9 @@ func TestCreateEntry(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) entry, err := entryManager.ReadEntry(ctx, savedUUID, key) if err != nil { @@ -209,7 +212,8 @@ func TestCreateEntryJSON(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) entry, err := entryManager.ReadEntry(ctx, encode.UUID, key) if err != nil { @@ -299,7 +303,8 @@ func TestCreateEntryForm(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) entry, err := entryManager.ReadEntry(ctx, savedUUID, key) if err != nil { @@ -387,7 +392,9 @@ func TestGetEntry(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) meta, encKey, err := entryManager.CreateEntry(ctx, []byte(testCase.Value), 1, time.Second*10) if err != nil { @@ -442,7 +449,9 @@ func TestGetEntryJSON(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) meta, encKey, err := entryManager.CreateEntry(ctx, []byte(testCase.Value), 1, time.Second*10) if err != nil { t.Error(err) @@ -577,7 +586,8 @@ func TestCreateEntryWithExpiration(t *testing.T) { encrypter := func(b []byte) services.Encrypter { return services.NewAESEncrypter(b) } - entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter) + keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter) + entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager) entry, err := entryManager.ReadEntry(ctx, savedUUID, decodedKey) if err != nil { @@ -641,7 +651,6 @@ func TestCreateEntryWithMaxReads(t *testing.T) { model := &models.EntryModel{} savedUUID := resp.Header.Get("x-entry-uuid") - fmt.Println("savedUUID", savedUUID) if err != nil { t.Fatal(err) diff --git a/cmd/prepare/main.go b/cmd/prepare/main.go index 306420a..7ecc507 100644 --- a/cmd/prepare/main.go +++ b/cmd/prepare/main.go @@ -17,7 +17,6 @@ func prepareDatabase(ctx context.Context) error { flag.StringVar(&postgresDB, "postgresDB", "", "Connection string for postgresql database backend") flag.Parse() - fmt.Println(postgresDB) db, err := durable.OpenDatabaseClient(context.Background(), config.GetConnectionString(postgresDB)) if err != nil { return err diff --git a/hurl/createnewkey.hurl b/hurl/createnewkey.hurl new file mode 100644 index 0000000..cd88a73 --- /dev/null +++ b/hurl/createnewkey.hurl @@ -0,0 +1,55 @@ +# Create a new entry +POST {{api_host}}/api/?maxReads=3 +content-type: application/json +{ + "name": "John Doe", + "email": "john.do@acheron.space" +} + +HTTP 200 +[Captures] +entry_uuid: header "x-entry-uuid" +entry_key: header "x-entry-key" +entry_expire: header "x-entry-expire" +entry_delete_key: header "x-entry-delete-key" + + +# Generate a new key for the entry +GET {{api_host}}/api/key/{{entry_uuid}}/{{entry_key}} + +HTTP 200 +[Captures] +entry_key2: header "x-entry-key" +entry_expire2: header "x-entry-expire" + +GET {{api_host}}/api/key/{{entry_uuid}}/{{entry_key}} + +HTTP 200 +[Captures] +entry_key3: header "x-entry-key" +entry_expire3: header "x-entry-expire" + +# Retrieve the entry with key 2 +GET {{api_host}}/api/{{entry_uuid}}/{{entry_key2}} + +HTTP 200 +[Asserts] +{ + "name": "John Doe", + "email": "john.do@acheron.space" +} + +# Retrieve the entry with key 3 +GET {{api_host}}/api/{{entry_uuid}}/{{entry_key3}} + +HTTP 200 +[Asserts] +{ + "name": "John Doe", + "email": "john.do@acheron.space" +} + +# # Should not be able to retrieve the entry again +# GET {{api_host}}/api/{{entry_uuid}}/{{entry_key2}} +# +# HTTP 404 diff --git a/hurl/createread.hurl b/hurl/createread.hurl index ba007bb..707d140 100644 --- a/hurl/createread.hurl +++ b/hurl/createread.hurl @@ -23,3 +23,9 @@ HTTP 200 "name": "John Doe", "email": "john.do@acheron.space" } + + +# Should not be able to retrieve the entry again +GET {{api_host}}/api/{{entry_uuid}}/{{entry_key}} + +HTTP 404 diff --git a/internal/api/createentry.go b/internal/api/createentry.go index 5b737ae..ea95b05 100644 --- a/internal/api/createentry.go +++ b/internal/api/createentry.go @@ -9,6 +9,7 @@ import ( "github.com/Ajnasz/sekret.link/internal/parsers" "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" ) // CreateEntryParser is an interface for parsing the create entry request @@ -32,7 +33,7 @@ type CreateHandler struct { maxDataSize int64 parser CreateEntryParser entryManager CreateEntryManager - view CreateEntryView + view views.View[views.EntryCreatedResponse] } // NewCreateHandler creates a new CreateHandler @@ -40,7 +41,7 @@ func NewCreateHandler( maxDataSize int64, parser CreateEntryParser, entryManager CreateEntryManager, - view CreateEntryView, + view views.View[views.EntryCreatedResponse], ) CreateHandler { return CreateHandler{ maxDataSize: maxDataSize, @@ -67,7 +68,9 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error { return err } - c.view.RenderEntryCreated(w, r, entry, hex.EncodeToString(key)) + viewData := views.BuildCreatedResponse(entry, hex.EncodeToString(key)) + + c.view.Render(w, r, viewData) return nil } @@ -76,6 +79,6 @@ func (c CreateHandler) Handle(w http.ResponseWriter, r *http.Request) { err := c.handle(w, r) if err != nil { - c.view.RenderCreateEntryErrorResponse(w, r, err) + c.view.RenderError(w, r, err) } } diff --git a/internal/api/createentry_test.go b/internal/api/createentry_test.go index e34eb6c..e2754bd 100644 --- a/internal/api/createentry_test.go +++ b/internal/api/createentry_test.go @@ -11,6 +11,7 @@ import ( "github.com/Ajnasz/sekret.link/internal/parsers" "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -44,11 +45,11 @@ type MockEntryView struct { mock.Mock } -func (m *MockEntryView) RenderEntryCreated(w http.ResponseWriter, r *http.Request, entry *services.EntryMeta, key string) { - m.Called(w, r, entry, key) +func (m *MockEntryView) Render(w http.ResponseWriter, r *http.Request, data views.EntryCreatedResponse) { + m.Called(w, r, data) } -func (m *MockEntryView) RenderCreateEntryErrorResponse(w http.ResponseWriter, r *http.Request, err error) { +func (m *MockEntryView) RenderError(w http.ResponseWriter, r *http.Request, err error) { m.Called(w, r, err) } @@ -64,7 +65,7 @@ func Test_CreateEntryHandle(t *testing.T) { parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil) entryManager.On("CreateEntry", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, []byte("key"), nil) - view.On("RenderEntryCreated", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + view.On("Render", mock.Anything, mock.Anything, mock.Anything).Return() handler := NewCreateHandler(10, parser, entryManager, view) @@ -88,7 +89,7 @@ func Test_CreateEntryHandleParserError(t *testing.T) { response := httptest.NewRecorder() parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, errors.New("error")) - view.On("RenderCreateEntryErrorResponse", mock.Anything, mock.Anything, mock.Anything).Return() + view.On("RenderError", mock.Anything, mock.Anything, mock.Anything).Return() handler := NewCreateHandler(10, parser, entryManager, view) @@ -112,7 +113,7 @@ func Test_CreateEntryHandleError(t *testing.T) { parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil) entryManager.On("CreateEntry", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, []byte("key"), errors.New("error")) - view.On("RenderCreateEntryErrorResponse", mock.Anything, mock.Anything, mock.Anything).Return() + view.On("RenderError", mock.Anything, mock.Anything, mock.Anything).Return() handler := NewCreateHandler(10, parser, entryManager, view) diff --git a/internal/api/deleteentry.go b/internal/api/deleteentry.go index dba6db5..7d6c7ec 100644 --- a/internal/api/deleteentry.go +++ b/internal/api/deleteentry.go @@ -2,10 +2,10 @@ package api import ( "context" - "fmt" "net/http" "github.com/Ajnasz/sekret.link/internal/parsers" + "github.com/Ajnasz/sekret.link/internal/views" ) // DeleteEntryManager is the interface for deleting an entry @@ -22,11 +22,11 @@ type DeleteEntryView interface { // DeleteHandler is the handler for deleting an entry type DeleteHandler struct { entryManager DeleteEntryManager - view DeleteEntryView + view views.View[views.DeleteEntryResponse] } // NewDeleteHandler creates a new DeleteHandler instance -func NewDeleteHandler(entryManager DeleteEntryManager, view DeleteEntryView) DeleteHandler { +func NewDeleteHandler(entryManager DeleteEntryManager, view views.View[views.DeleteEntryResponse]) DeleteHandler { return DeleteHandler{ entryManager: entryManager, view: view, @@ -38,7 +38,6 @@ func (d DeleteHandler) handle(w http.ResponseWriter, r *http.Request) error { UUID, _, deleteKey, err := parsers.ParseDeleteEntryPath(r.URL.Path) if err != nil { - fmt.Println("parse error", err) return err } @@ -48,13 +47,13 @@ func (d DeleteHandler) handle(w http.ResponseWriter, r *http.Request) error { return err } - d.view.RenderDeleteEntry(w, r) + d.view.Render(w, r, views.DeleteEntryResponse{}) return nil } // Handle handles the delete request func (d DeleteHandler) Handle(w http.ResponseWriter, r *http.Request) { if err := d.handle(w, r); err != nil { - d.view.RenderDeleteEntryError(w, r, err) + d.view.RenderError(w, r, err) } } diff --git a/internal/api/deleteentry_test.go b/internal/api/deleteentry_test.go index a0cecd8..e4f7a8e 100644 --- a/internal/api/deleteentry_test.go +++ b/internal/api/deleteentry_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/Ajnasz/sekret.link/internal/parsers" + "github.com/Ajnasz/sekret.link/internal/views" "github.com/stretchr/testify/mock" ) @@ -24,11 +25,11 @@ type MockDeleteEntryView struct { mock.Mock } -func (m *MockDeleteEntryView) RenderDeleteEntry(w http.ResponseWriter, r *http.Request) { +func (m *MockDeleteEntryView) Render(w http.ResponseWriter, r *http.Request, data views.DeleteEntryResponse) { m.Called(w, r) } -func (m *MockDeleteEntryView) RenderDeleteEntryError(w http.ResponseWriter, r *http.Request, err error) { +func (m *MockDeleteEntryView) RenderError(w http.ResponseWriter, r *http.Request, err error) { m.Called(w, r, err) } @@ -37,7 +38,7 @@ func Test_DeleteHandle(t *testing.T) { view := new(MockDeleteEntryView) entryManager.On("DeleteEntry", mock.Anything, "40e7d7d6-db0d-11ee-b9ee-1340bdbad9b2", "delete-key").Return(nil) - view.On("RenderDeleteEntry", mock.Anything, mock.Anything).Return() + view.On("Render", mock.Anything, mock.Anything).Return() handler := NewDeleteHandler(entryManager, view) @@ -55,7 +56,7 @@ func Test_DeleteHandle_InvalidUUID(t *testing.T) { entryManager := new(MockDeleteEntryManager) view := new(MockDeleteEntryView) - view.On("RenderDeleteEntryError", mock.Anything, mock.Anything, mock.MatchedBy(func(err error) bool { + view.On("RenderError", mock.Anything, mock.Anything, mock.MatchedBy(func(err error) bool { return errors.Is(err, parsers.ErrInvalidUUID) })).Return() diff --git a/internal/api/generateentrykey.go b/internal/api/generateentrykey.go new file mode 100644 index 0000000..0116f83 --- /dev/null +++ b/internal/api/generateentrykey.go @@ -0,0 +1,67 @@ +package api + +import ( + "context" + "encoding/hex" + "net/http" + + "github.com/Ajnasz/sekret.link/internal/parsers" + "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" +) + +type GenerateEntryKeyView interface { + RenderGenerateEntryKey(w http.ResponseWriter, r *http.Request, entry views.GenerateEntryKeyResponseData) + RenderGenerateEntryKeyError(w http.ResponseWriter, r *http.Request, err error) +} + +type GenerateEntryKeyManager interface { + GenerateEntryKey(ctx context.Context, UUID string, key []byte) (*services.EntryKeyData, error) +} + +type GenerateEntryKeyHandler struct { + entryManager GenerateEntryKeyManager + view views.View[views.GenerateEntryKeyResponseData] + parser parsers.Parser[parsers.GenerateEntryKeyRequestData] +} + +func NewGenerateEntryKeyHandler( + parser parsers.Parser[parsers.GenerateEntryKeyRequestData], + entryManager GenerateEntryKeyManager, + view views.View[views.GenerateEntryKeyResponseData], +) GenerateEntryKeyHandler { + return GenerateEntryKeyHandler{ + view: view, + parser: parser, + entryManager: entryManager, + } +} + +func (g GenerateEntryKeyHandler) handle(w http.ResponseWriter, r *http.Request) error { + request, err := g.parser.Parse(r) + + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key) + if err != nil { + return err + } + + g.view.Render(w, r, views.GenerateEntryKeyResponseData{ + UUID: request.UUID, + Key: hex.EncodeToString(entry.KEK), + Expire: entry.Expire, + }) + return nil +} + +func (g GenerateEntryKeyHandler) Handle(w http.ResponseWriter, r *http.Request) { + if err := g.handle(w, r); err != nil { + g.view.RenderError(w, r, err) + } +} diff --git a/internal/api/generateentrykey_test.go b/internal/api/generateentrykey_test.go new file mode 100644 index 0000000..02dfdf1 --- /dev/null +++ b/internal/api/generateentrykey_test.go @@ -0,0 +1,114 @@ +package api + +import ( + "context" + "encoding/hex" + "net/http" + "testing" + "time" + + "github.com/Ajnasz/sekret.link/internal/parsers" + "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockGenerateEntryKeyView struct { + mock.Mock +} + +func (m *MockGenerateEntryKeyView) Render(w http.ResponseWriter, r *http.Request, entry views.GenerateEntryKeyResponseData) { + m.Called(w, r, entry) +} + +func (m *MockGenerateEntryKeyView) RenderError(w http.ResponseWriter, r *http.Request, err error) { + m.Called(w, r, err) +} + +type MockGenerateEntryKeyManager struct { + mock.Mock +} + +func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, UUID string, key []byte) (*services.EntryKeyData, error) { + args := m.Called(ctx, UUID, key) + return args.Get(0).(*services.EntryKeyData), args.Error(2) +} + +type MockGenerateEntryKeyParser struct { + mock.Mock +} + +func (g *MockGenerateEntryKeyParser) Parse(u *http.Request) (parsers.GenerateEntryKeyRequestData, error) { + args := g.Called(u) + return args.Get(0).(parsers.GenerateEntryKeyRequestData), args.Error(1) +} + +func TestGenerateEntryKey_Handle(t *testing.T) { + viewMock := new(MockGenerateEntryKeyView) + parserMock := new(MockGenerateEntryKeyParser) + managerMock := new(MockGenerateEntryKeyManager) + + handler := NewGenerateEntryKeyHandler(parserMock, managerMock, viewMock) + + newKey := []byte{18, 18, 18, 18, 174, 173, 15} + expire := time.Now().Add(time.Hour * 24) + + viewMock.On("Render", mock.Anything, mock.Anything, views.GenerateEntryKeyResponseData{ + UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", + Key: hex.EncodeToString(newKey), + Expire: expire, + }).Return() + parserMock.On("Parse", mock.Anything).Return(parsers.GenerateEntryKeyRequestData{ + UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", + Key: []byte{18, 18, 18, 18, 174, 173, 15}, + }, nil) + + managerMock.On("GenerateEntryKey", mock.Anything, "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", []byte{18, 18, 18, 18, 174, 173, 15}).Return(&services.EntryKeyData{ + Expire: expire, + EntryUUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", + KEK: newKey, + }, newKey, nil) + + handler.Handle(nil, nil) + managerMock.AssertExpectations(t) + parserMock.AssertExpectations(t) + viewMock.AssertExpectations(t) +} + +func TestGenerateEntryKey_HandleParseError(t *testing.T) { + viewMock := new(MockGenerateEntryKeyView) + parserMock := new(MockGenerateEntryKeyParser) + managerMock := new(MockGenerateEntryKeyManager) + + handler := NewGenerateEntryKeyHandler(parserMock, managerMock, viewMock) + + parserMock.On("Parse", mock.Anything).Return(parsers.GenerateEntryKeyRequestData{}, assert.AnError) + + viewMock.On("RenderError", mock.Anything, mock.Anything, mock.Anything).Return() + handler.Handle(nil, nil) + managerMock.AssertExpectations(t) + parserMock.AssertExpectations(t) + viewMock.AssertExpectations(t) +} + +func TestGenerateEntryKey_HandleManagerError(t *testing.T) { + viewMock := new(MockGenerateEntryKeyView) + parserMock := new(MockGenerateEntryKeyParser) + managerMock := new(MockGenerateEntryKeyManager) + + handler := NewGenerateEntryKeyHandler(parserMock, managerMock, viewMock) + + viewMock.On("RenderError", mock.Anything, mock.Anything, mock.Anything).Return() + parserMock.On("Parse", mock.Anything).Return(parsers.GenerateEntryKeyRequestData{ + UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", + Key: []byte{18, 18, 18, 18, 174, 173, 15}, + }, nil) + + managerMock.On("GenerateEntryKey", mock.Anything, "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", []byte{18, 18, 18, 18, 174, 173, 15}).Return(&services.EntryKeyData{}, []byte{}, assert.AnError) + + handler.Handle(nil, nil) + managerMock.AssertExpectations(t) + parserMock.AssertExpectations(t) + viewMock.AssertExpectations(t) +} diff --git a/internal/api/getentry.go b/internal/api/getentry.go index 1091611..7b43f3f 100644 --- a/internal/api/getentry.go +++ b/internal/api/getentry.go @@ -7,6 +7,7 @@ import ( "github.com/Ajnasz/sekret.link/internal/parsers" "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" ) // GetEntryManager is the interface for getting an entry @@ -26,7 +27,7 @@ var ErrInvalidKeyError = errors.New("invalid key") // GetHandler is the handler for getting an entry type GetHandler struct { entryManager GetEntryManager - view GetEntryView + view views.View[views.EntryReadResponse] parser parsers.Parser[parsers.GetEntryRequestData] } @@ -34,7 +35,7 @@ type GetHandler struct { func NewGetHandler( parser parsers.Parser[parsers.GetEntryRequestData], entryManager GetEntryManager, - view GetEntryView, + view views.View[views.EntryReadResponse], ) GetHandler { return GetHandler{ view: view, @@ -58,7 +59,7 @@ func (g GetHandler) handle(w http.ResponseWriter, r *http.Request) error { return err } - g.view.RenderReadEntry(w, r, entry, request.KeyString) + g.view.Render(w, r, views.BuildEntryReadResponse(*entry, request.KeyString)) return nil } @@ -67,6 +68,6 @@ func (g GetHandler) handle(w http.ResponseWriter, r *http.Request) error { func (g GetHandler) Handle(w http.ResponseWriter, r *http.Request) { err := g.handle(w, r) if err != nil { - g.view.RenderReadEntryError(w, r, err) + g.view.RenderError(w, r, err) } } diff --git a/internal/api/getentry_test.go b/internal/api/getentry_test.go index 74ad875..93237b2 100644 --- a/internal/api/getentry_test.go +++ b/internal/api/getentry_test.go @@ -9,6 +9,7 @@ import ( "github.com/Ajnasz/sekret.link/internal/parsers" "github.com/Ajnasz/sekret.link/internal/services" + "github.com/Ajnasz/sekret.link/internal/views" "github.com/stretchr/testify/mock" ) @@ -16,11 +17,11 @@ type MockGetEntryView struct { mock.Mock } -func (m *MockGetEntryView) RenderReadEntry(w http.ResponseWriter, r *http.Request, entry *services.Entry, key string) { - m.Called(w, r, entry, key) +func (m *MockGetEntryView) Render(w http.ResponseWriter, r *http.Request, entry views.EntryReadResponse) { + m.Called(w, r, entry) } -func (m *MockGetEntryView) RenderReadEntryError(w http.ResponseWriter, r *http.Request, err error) { +func (m *MockGetEntryView) RenderError(w http.ResponseWriter, r *http.Request, err error) { m.Called(w, r, err) } @@ -43,14 +44,13 @@ func (g *GetEntryManagerMock) ReadEntry(ctx context.Context, UUID string, key [] } func TestGetHandle(t *testing.T) { - viewMock := new(MockGetEntryView) parserMock := new(GetEntryParserMock) managerMock := new(GetEntryManagerMock) handler := NewGetHandler(parserMock, managerMock, viewMock) - viewMock.On("RenderReadEntry", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + viewMock.On("Render", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() // viewMock.On("RenderReadEntryError", mock.Anything, mock.Anything, mock.Anything).Return() parserMock.On("Parse", mock.Anything).Return(parsers.GetEntryRequestData{ UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", diff --git a/internal/api/getentrykey_test.go b/internal/api/getentrykey_test.go new file mode 100644 index 0000000..778f64e --- /dev/null +++ b/internal/api/getentrykey_test.go @@ -0,0 +1 @@ +package api diff --git a/internal/hasher/hasher.go b/internal/hasher/hasher.go new file mode 100644 index 0000000..7165670 --- /dev/null +++ b/internal/hasher/hasher.go @@ -0,0 +1,31 @@ +package hasher + +import "crypto/sha256" + +type Hasher interface { + Hash(data []byte) []byte +} + +type Sha256Hasher struct{} + +func NewSHA256Hasher() *Sha256Hasher { + return &Sha256Hasher{} +} + +func (h *Sha256Hasher) Hash(data []byte) []byte { + hasher := sha256.New() + hasher.Write(data) + return hasher.Sum(nil) +} + +func Compare(k, k2 []byte) bool { + if len(k) != len(k2) { + return false + } + for i := range k { + if k[i] != k2[i] { + return false + } + } + return true +} diff --git a/internal/hasher/hasher_test.go b/internal/hasher/hasher_test.go new file mode 100644 index 0000000..c06f16f --- /dev/null +++ b/internal/hasher/hasher_test.go @@ -0,0 +1,10 @@ +package hasher + +import "fmt" + +func ExampleSha256Hasher_Hash() { + hashGenerator := NewSHA256Hasher() + hash := hashGenerator.Hash([]byte("test")) + fmt.Println(hash) + // Output: [159 134 208 129 136 76 125 101 154 47 234 160 197 90 208 21 163 191 79 27 43 11 130 44 209 93 108 21 176 240 10 8] +} diff --git a/internal/key/key.go b/internal/key/key.go index 162aa80..b575376 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -10,9 +10,10 @@ import ( // Key object which already has a generated key var ErrorKeyAlreadyGenerated = errors.New("key already generated") var ErrorKeyGenerateFailed = errors.New("Key generation failed") +var ErrorInvalidKey = errors.New("invalid key") // SizeAES256 the byte size required for aes 256 encoding -const SizeAES256 uint = 32 +const SizeAES256 int = 32 // NewKey creates a Key struct func NewKey() *Key { @@ -36,7 +37,7 @@ type Key struct { b64 string } -func (k Key) generateRandomBytes(size uint) ([]byte, error) { +func (k Key) generateRandomBytes(size int) ([]byte, error) { bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { return nil, err @@ -61,12 +62,22 @@ func (k *Key) Generate() error { } // Get returns the key -func (k Key) Get() []byte { +func (k *Key) Get() []byte { return k.key } +func (k *Key) Set(key []byte) error { + if len(key) != SizeAES256 { + return ErrorInvalidKey + } + + k.key = key + + return nil +} + // ToHex Converts the key to hex string -func (k Key) ToHex() string { +func (k *Key) ToHex() string { if k.hex == "" { k.hex = hex.EncodeToString(k.key) } @@ -74,6 +85,6 @@ func (k Key) ToHex() string { return k.hex } -func (k Key) String() string { +func (k *Key) String() string { return k.ToHex() } diff --git a/internal/key/key_test.go b/internal/key/key_test.go index a448bcd..e51bf89 100644 --- a/internal/key/key_test.go +++ b/internal/key/key_test.go @@ -10,18 +10,18 @@ func TestNewKey(t *testing.T) { err := k.Generate() if err != nil { - t.Errorf("Generate returned error on first call") + t.Fatal("Generate returned error on first call") } err = k.Generate() if err != ErrorKeyAlreadyGenerated { - t.Errorf("Generate returned error %s, expected %s", err, ErrorKeyAlreadyGenerated) + t.Fatalf("Generate returned error %s, expected %s", err, ErrorKeyAlreadyGenerated) } bytesKey := k.Get() if len(bytesKey) != 32 { - t.Errorf("Expected k.Get() return a 32 length byte slice") + t.Fatalf("Expected k.Get() return a 32 length byte slice") } hexStr := k.ToHex() @@ -32,11 +32,11 @@ func TestNewKey(t *testing.T) { } if !isHex { - t.Errorf("expected %s to match hex string regexp", hexStr) + t.Fatalf("expected %s to match hex string regexp", hexStr) } str := k.String() if str != hexStr { - t.Errorf("Stringer interface expected to return hex value: %s, but got %s", hexStr, str) + t.Fatalf("Stringer interface expected to return hex value: %s, but got %s", hexStr, str) } } diff --git a/internal/keywrapper/keywrapper.go b/internal/keywrapper/keywrapper.go new file mode 100644 index 0000000..35ee9de --- /dev/null +++ b/internal/keywrapper/keywrapper.go @@ -0,0 +1,56 @@ +package services + +import ( + "github.com/Ajnasz/sekret.link/internal/key" + "github.com/Ajnasz/sekret.link/internal/services" +) + +// KeyWrapper is a simple interface to wrap and unwrap keys +// rfc3394 +type KeyWrapper interface { + Wrap(dek *key.Key) ([]byte, *key.Key, error) + Unwrap(kek *key.Key, encrypted []byte) (*key.Key, error) +} + +// AesKeyWrapper is a simple key wrapper that uses AES to wrap and unwrap keys +type AesKeyWrapper struct{} + +// NewAesKeyWrapper creates a new AesKeyWrapper +func NewAesKeyWrapper() *AesKeyWrapper { + return &AesKeyWrapper{} +} + +// Wrap will wrap the Data Encryption Key (DEK) with the Key Encryption Key (KEK) +func (*AesKeyWrapper) Wrap(dek *key.Key) ([]byte, *key.Key, error) { + kek := key.NewKey() + err := kek.Generate() + if err != nil { + return nil, nil, err + } + + encrypter := services.NewAESEncrypter(kek.Get()) + + encrypted, err := encrypter.Encrypt(dek.Get()) + if err != nil { + return nil, nil, err + } + + return encrypted, kek, nil +} + +// Unwrap will unwrap the Data Encryption Key (DEK) with the Key Encryption Key (KEK) +func (*AesKeyWrapper) Unwrap(kek *key.Key, encrypted []byte) (*key.Key, error) { + decrypter := services.NewAESEncrypter(kek.Get()) + + decrypted, err := decrypter.Decrypt(encrypted) + if err != nil { + return nil, err + } + + k := key.NewKey() + if err := k.Set(decrypted); err != nil { + return nil, err + } + + return k, nil +} diff --git a/internal/keywrapper/keywrapper_test.go b/internal/keywrapper/keywrapper_test.go new file mode 100644 index 0000000..b093b50 --- /dev/null +++ b/internal/keywrapper/keywrapper_test.go @@ -0,0 +1,43 @@ +package services + +import ( + "testing" + + "github.com/Ajnasz/sekret.link/internal/key" +) + +func Test_KeyWrap(t *testing.T) { + k := key.NewKey() + + err := k.Generate() + if err != nil { + t.Fatal(err) + } + + keyWrapper := NewAesKeyWrapper() + + encrypted, kek, err := keyWrapper.Wrap(k) + if err != nil { + t.Fatal(err) + } + + badKek := key.NewKey() + err = badKek.Generate() + if err != nil { + t.Fatal(err) + } + + decrypted, err := keyWrapper.Unwrap(kek, encrypted) + if err != nil { + t.Fatal(err) + } + + if string(decrypted.Get()) != string(k.Get()) { + t.Errorf("key unwrap failed, expected %s, got %s", k.Get(), decrypted.Get()) + } + + _, err = keyWrapper.Unwrap(badKek, encrypted) + if err == nil { + t.Error("Expected error, got nil") + } +} diff --git a/internal/models/entry.go b/internal/models/entry.go index 8ffccbb..24ab42b 100644 --- a/internal/models/entry.go +++ b/internal/models/entry.go @@ -15,17 +15,10 @@ import ( var ErrEntryNotFound = errors.New("entry not found") var ErrInvalidKey = errors.New("invalid key") +var ErrCreateEntry = errors.New("failed to create entry") -// uuid uuid PRIMARY KEY, -// data BYTEA, -// remaining_reads SMALLINT DEFAULT 1, -// delete_key CHAR(256) NOT NULL, -// created TIMESTAMPTZ, -// accessed TIMESTAMPTZ, -// expire TIMESTAMPTZ -type Entry struct { +type EntryMeta struct { UUID string - Data []byte RemainingReads int DeleteKey string Created time.Time @@ -33,13 +26,16 @@ type Entry struct { Expire time.Time } -type EntryMeta struct { - UUID string - RemainingReads int - DeleteKey string - Created time.Time - Accessed sql.NullTime - Expire time.Time +// uuid uuid PRIMARY KEY, +// data BYTEA, +// remaining_reads SMALLINT DEFAULT 1, +// delete_key CHAR(256) NOT NULL, +// created TIMESTAMPTZ, +// accessed TIMESTAMPTZ, +// expire TIMESTAMPTZ +type Entry struct { + EntryMeta + Data []byte } type EntryModel struct { @@ -55,13 +51,26 @@ func (e *EntryModel) getDeleteKey() (string, error) { // CreateEntry creates a new entry into the database func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, error) { - now := time.Now() deleteKey, err := e.getDeleteKey() if err != nil { - return nil, err + return nil, errors.Join(err, ErrCreateEntry) } - _, err = tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey) + now := time.Now() + res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey) + + if err != nil { + return nil, errors.Join(err, ErrCreateEntry) + } + + rows, err := res.RowsAffected() + if err != nil { + return nil, errors.Join(err, ErrCreateEntry) + } + + if rows != 1 { + return nil, ErrCreateEntry + } return &EntryMeta{ UUID: uuid, @@ -72,7 +81,7 @@ func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, d }, err } -func (e *EntryModel) UpdateAccessed(ctx context.Context, tx *sql.Tx, uuid string) error { +func (e *EntryModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error { _, err := tx.ExecContext(ctx, "UPDATE entries SET accessed = NOW(), remaining_reads = remaining_reads - 1 WHERE uuid = $1 AND remaining_reads > 0", uuid) return err } diff --git a/internal/models/entry_test.go b/internal/models/entry_test.go index 21931fe..3779b3b 100644 --- a/internal/models/entry_test.go +++ b/internal/models/entry_test.go @@ -65,7 +65,7 @@ func Test_EntryModel_CreateEntry(t *testing.T) { } } -func Test_EntryModel_UpdateAccessed(t *testing.T) { +func Test_EntryModel_Use(t *testing.T) { ctx := context.Background() db, err := durable.TestConnection(ctx) if err != nil { @@ -91,7 +91,7 @@ func Test_EntryModel_UpdateAccessed(t *testing.T) { t.Fatal(err) } - if err := model.UpdateAccessed(ctx, tx, uid); err != nil { + if err := model.Use(ctx, tx, uid); err != nil { t.Fatal(errors.Join(err, errors.New("failed to access entry"))) } diff --git a/internal/models/entrykey.go b/internal/models/entrykey.go new file mode 100644 index 0000000..df834f6 --- /dev/null +++ b/internal/models/entrykey.go @@ -0,0 +1,117 @@ +package models + +import ( + "context" + "database/sql" + "time" +) + +type EntryKey struct { + UUID string + EntryUUID string + EncryptedKey []byte + KeyHash []byte + Created time.Time + Expire sql.NullTime + RemainingReads sql.NullInt16 +} + +type EntryKeyModel struct{} + +func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*EntryKey, error) { + + now := time.Now() + res := tx.QueryRowContext(ctx, ` + INSERT INTO entry_key (uuid, entry_uuid, encrypted_key, key_hash, created) + VALUES (gen_random_uuid(), $1, $2, $3, $4) RETURNING uuid, created; + `, entryUUID, encryptedKey, hash, now) + + var uid string + var created time.Time + + err := res.Scan(&uid, &created) + + if err != nil { + return nil, err + + } + + return &EntryKey{ + UUID: uid, + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: now, + }, err +} + +func (e *EntryKeyModel) Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]EntryKey, error) { + rows, err := tx.QueryContext(ctx, ` + SELECT uuid, entry_uuid, encrypted_key, key_hash, created, expire, remaining_reads + FROM entry_key + WHERE entry_uuid = $1 + AND (expire IS NULL OR expire > NOW()); + `, entryUUID) + + if err != nil { + return nil, err + } + + defer rows.Close() + + var entryKeys []EntryKey + + for rows.Next() { + var ek EntryKey + err := rows.Scan(&ek.UUID, &ek.EntryUUID, &ek.EncryptedKey, &ek.KeyHash, &ek.Created, &ek.Expire, &ek.RemainingReads) + if err != nil { + return nil, err + } + + entryKeys = append(entryKeys, ek) + } + + return entryKeys, nil + +} + +// GetByUUID returns the entry key by its UUID +func (e *EntryKeyModel) Delete(ctx context.Context, tx *sql.Tx, uuid string) error { + _, err := tx.ExecContext(ctx, ` + DELETE FROM entry_key + WHERE uuid = $1 + `, uuid) + + return err +} + +// SetExpire sets the expire time for the entry key +func (e *EntryKeyModel) SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error { + _, err := tx.ExecContext(ctx, ` + UPDATE entry_key + SET expire = $1 + WHERE uuid = $2 + `, expire, uuid) + + return err +} + +func (e *EntryKeyModel) SetMaxReads(ctx context.Context, tx *sql.Tx, uuid string, maxReads int) error { + _, err := tx.ExecContext(ctx, ` + UPDATE entry_key + SET remaining_reads = $1 + WHERE uuid = $2 + `, maxReads, uuid) + + return err +} + +func (e *EntryKeyModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error { + _, err := tx.ExecContext(ctx, ` + UPDATE entry_key + SET remaining_reads = remaining_reads - 1 + WHERE uuid = $1 AND remaining_reads IS NOT NULL + `, uuid) + + return err +} diff --git a/internal/models/entrykey_test.go b/internal/models/entrykey_test.go new file mode 100644 index 0000000..d862e73 --- /dev/null +++ b/internal/models/entrykey_test.go @@ -0,0 +1,458 @@ +package models + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/Ajnasz/sekret.link/internal/test/durable" + "github.com/google/uuid" +) + +func getTestDbTx(ctx context.Context) (*sql.DB, *sql.Tx, error) { + db, err := durable.TestConnection(ctx) + + if err != nil { + return nil, nil, err + } + + tx, err := db.Begin() + + if err != nil { + defer db.Close() + return nil, nil, err + } + + return db, tx, nil +} + +func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error) { + uid := uuid.New().String() + + entryModel := &EntryModel{} + + _, err := entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600) + + if err != nil { + return "", "", err + } + + model := &EntryKeyModel{} + + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx")) + + if err != nil { + return "", "", err + } + + return uid, entryKey.UUID, nil +} + +func Test_EntryKeyModel_Create(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := db.Close(); err != nil { + t.Errorf("close failed: %v", err) + } + }() + + uid := uuid.New().String() + + entryModel := &EntryModel{} + _, err = entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600) + if err != nil { + t.Fatal(err) + } + + model := &EntryKeyModel{} + + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke")) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + if entryKey.UUID == "" { + t.Error("expected uuid to be set") + } + + if entryKey.Created.IsZero() { + t.Error("expected created to be set") + } + + if entryKey.EntryUUID != uid { + t.Errorf("expected %s got %s", uid, entryKey.EntryUUID) + } + + if entryKey.EncryptedKey == nil { + t.Error("expected encrypted data to be set") + } + + if entryKey.KeyHash == nil { + t.Error("expected encrypted key to be set") + } + +} + +func Test_EntryKeyModel_Get(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Errorf("close failed: %v", err) + } + }() + + uid := uuid.New().String() + + entryModel := &EntryModel{} + _, err = entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600) + if err != nil { + tx.Rollback() + t.Fatal(err) + } + + model := &EntryKeyModel{} + + for i := 0; i < 10; i++ { + _, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i))) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Error(err) + } + t.Fatal(err) + } + + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + + entryKeys, err := model.Get(ctx, tx, uid) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Error(err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + if len(entryKeys) != 10 { + t.Fatalf("expected 1 got %d", len(entryKeys)) + } + + if entryKeys[0].EntryUUID != uid { + t.Errorf("expected %s got %s", uid, entryKeys[0].EntryUUID) + } + + if entryKeys[0].EncryptedKey == nil { + t.Error("expected encrypted data to be set") + } + + if entryKeys[0].KeyHash == nil { + t.Error("expected encrypted key to be set") + } +} + +func Test_EntryKeyModel_Get_Empty(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Fatal(err) + } + }() + + model := &EntryKeyModel{} + + entryKeys, err := model.Get(ctx, tx, uuid.New().String()) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Error(err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + if len(entryKeys) != 0 { + if err := tx.Rollback(); err != nil { + t.Error(err) + } + t.Fatalf("expected 0 got %d", len(entryKeys)) + } +} + +func Test_EntryKeyModel_Delete(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + defer db.Close() + + uid, entryKeyUUID, err := createTestEntryKey(ctx, tx) + + model := &EntryKeyModel{} + + err = model.Delete(ctx, tx, entryKeyUUID) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + tx, err = db.Begin() + + if err != nil { + t.Fatal(err) + } + + entryKeys, err := model.Get(ctx, tx, uid) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + if len(entryKeys) != 0 { + t.Fatalf("expected 0 got %d", len(entryKeys)) + } +} + +func Test_EntryKeyModel_Delete_Empty(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + defer db.Close() + + model := &EntryKeyModel{} + err = model.Delete(ctx, tx, uuid.New().String()) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } +} + +func Test_EntryKeyModel_SetExpire(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := db.Close(); err != nil { + t.Fatal(err) + } + }() + + model := &EntryKeyModel{} + + uid, entryKeyUUID, err := createTestEntryKey(ctx, tx) + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + err = model.SetExpire(ctx, tx, entryKeyUUID, time.Now().Add(time.Hour)) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + tx, err = db.Begin() + + if err != nil { + t.Fatal(err) + } + + entryKeys, err := model.Get(ctx, tx, uid) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("rollback failed: %v", err) + } + + if len(entryKeys) != 1 { + t.Fatalf("expected 1 got %d", len(entryKeys)) + } + + if entryKeys[0].Expire.Time.IsZero() { + t.Error("expected expire to be set") + } + + if entryKeys[0].Expire.Time.Before(time.Now()) { + t.Error("expected expire to be in the future") + } +} + +func Test_EntryKeyModel_SetExpire_Empty(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := db.Close(); err != nil { + t.Errorf("close failed: %v", err) + } + }() + + model := &EntryKeyModel{} + + err = model.SetExpire(ctx, tx, uuid.New().String(), time.Now().Add(time.Hour)) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } +} + +func Test_EntryKeyModel_UseTx(t *testing.T) { + ctx := context.Background() + db, tx, err := getTestDbTx(ctx) + + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := db.Close(); err != nil { + t.Errorf("close failed: %v", err) + } + }() + + uid, entryKeyUUID, err := createTestEntryKey(ctx, tx) + + model := &EntryKeyModel{} + + if err := model.SetMaxReads(ctx, tx, entryKeyUUID, 2); err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + err = model.Use(ctx, tx, entryKeyUUID) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + tx, err = db.Begin() + + if err != nil { + t.Fatal(err) + } + + entryKeys, err := model.Get(ctx, tx, uid) + + if err != nil { + if err := tx.Rollback(); err != nil { + t.Errorf("rollback failed: %v", err) + } + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Errorf("commit failed: %v", err) + } + + if len(entryKeys) != 1 { + t.Fatalf("expected 1 got %d", len(entryKeys)) + } + + if entryKeys[0].RemainingReads.Int16 != 1 { + t.Errorf("expected 1 got %d", entryKeys[0].RemainingReads.Int16) + } +} diff --git a/internal/models/migrate/entry.go b/internal/models/migrate/entry.go new file mode 100644 index 0000000..b80f712 --- /dev/null +++ b/internal/models/migrate/entry.go @@ -0,0 +1,101 @@ +package migrate + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Ajnasz/sekret.link/internal/key" +) + +type EntryMigration struct{} + +func NewEntryMigration() *EntryMigration { + return &EntryMigration{} +} + +func (*EntryMigration) Create(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS entries ( + uuid uuid PRIMARY KEY, + data BYTEA, + remaining_reads SMALLINT DEFAULT 1, + delete_key CHAR(256) NOT NULL, + created TIMESTAMPTZ DEFAULT NOW(), + accessed TIMESTAMPTZ, + expire TIMESTAMPTZ DEFAULT NULL + );`) + + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + return nil +} + +func (e *EntryMigration) Alter(ctx context.Context, tx *sql.Tx) error { + e.addRemainingRead(ctx, tx) + e.addDeleteKey(ctx, tx) + + return nil +} + +func (*EntryMigration) addRemainingRead(ctx context.Context, tx *sql.Tx) error { + alterTable, err := tx.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS remaining_reads SMALLINT DEFAULT 1;") + + if err != nil { + return err + } + + _, err = alterTable.Exec() + + if err != nil { + return fmt.Errorf("failed to add remaining_reads column: %w", err) + } + + return nil +} + +func (*EntryMigration) addDeleteKey(ctx context.Context, db *sql.Tx) error { + alterTable, err := db.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS delete_key CHAR(256);") + + if err != nil { + return fmt.Errorf("failed to add delete_key column: %w", err) + } + + _, err = alterTable.ExecContext(ctx) + + if err != nil { + return fmt.Errorf("failed to add delete_key column: %w", err) + } + + rows, err := db.QueryContext(ctx, "SELECT uuid FROM entries WHERE delete_key IS NULL;") + if err != nil { + return err + } + + for rows.Next() { + var UUID string + if err := rows.Scan(&UUID); err != nil { + return fmt.Errorf("failed to scan UUID: %w", err) + } + + k, err := key.NewGeneratedKey() + if err != nil { + return err + } + + deleteKey := k.ToHex() + + _, err = db.ExecContext(ctx, "UPDATE entries SET delete_key=$2 WHERE uuid=$1", UUID, deleteKey) + if err != nil { + return fmt.Errorf("failed to update delete_key: %w", err) + } + } + _, err = db.ExecContext(ctx, "ALTER TABLE entries ALTER COLUMN delete_key SET NOT NULL;") + if err != nil { + return fmt.Errorf("failed to alter delete_key column: %w", err) + } + + return nil +} diff --git a/internal/models/migrate/entrykey.go b/internal/models/migrate/entrykey.go new file mode 100644 index 0000000..4e97a67 --- /dev/null +++ b/internal/models/migrate/entrykey.go @@ -0,0 +1,40 @@ +package migrate + +import ( + "context" + "database/sql" + "fmt" +) + +type EntryKeyMigration struct{} + +func NewEntryKeyMigration() *EntryKeyMigration { + return &EntryKeyMigration{} +} + +func (e *EntryKeyMigration) Create(ctx context.Context, tx *sql.Tx) error { + query := ` + CREATE TABLE IF NOT EXISTS entry_key ( + uuid UUID PRIMARY KEY, + entry_uuid UUID NOT NULL, + encrypted_key BYTEA NOT NULL, + key_hash BYTEA NOT NULL, + expire TIMESTAMPTZ DEFAULT NULL, + remaining_reads SMALLINT DEFAULT NULL, + accesed TIMESTAMPTZ DEFAULT NULL, + created TIMESTAMPTZ, + FOREIGN KEY (entry_uuid) REFERENCES entries(uuid) ON DELETE CASCADE + ); +` + _, err := tx.ExecContext(ctx, query) + + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + return nil +} + +func (e *EntryKeyMigration) Alter(ctx context.Context, tx *sql.Tx) error { + return nil +} diff --git a/internal/models/migrate/migrate.go b/internal/models/migrate/migrate.go index 9b8326c..aac2fe0 100644 --- a/internal/models/migrate/migrate.go +++ b/internal/models/migrate/migrate.go @@ -5,89 +5,11 @@ import ( "database/sql" "fmt" "sync" - - "github.com/Ajnasz/sekret.link/internal/key" ) -type dbExec func(context.Context, *sql.Tx) error - -func createTable(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS entries ( - uuid uuid PRIMARY KEY, - data BYTEA, - remaining_reads SMALLINT DEFAULT 1, - delete_key CHAR(256) NOT NULL, - created TIMESTAMPTZ, - accessed TIMESTAMPTZ, - expire TIMESTAMPTZ - );`) - - if err != nil { - return fmt.Errorf("failed to create table: %w", err) - } - - return nil -} - -func addRemainingRead(ctx context.Context, tx *sql.Tx) error { - alterTable, err := tx.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS remaining_reads SMALLINT DEFAULT 1;") - - if err != nil { - return err - } - - _, err = alterTable.Exec() - - if err != nil { - return fmt.Errorf("failed to add remaining_reads column: %w", err) - } - - return nil -} - -func addDeleteKey(ctx context.Context, db *sql.Tx) error { - alterTable, err := db.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS delete_key CHAR(256);") - - if err != nil { - return fmt.Errorf("failed to add delete_key column: %w", err) - } - - _, err = alterTable.ExecContext(ctx) - - if err != nil { - return fmt.Errorf("failed to add delete_key column: %w", err) - } - - rows, err := db.QueryContext(ctx, "SELECT uuid FROM entries WHERE delete_key IS NULL;") - if err != nil { - return err - } - - for rows.Next() { - var UUID string - if err := rows.Scan(&UUID); err != nil { - return fmt.Errorf("failed to scan UUID: %w", err) - } - - k, err := key.NewGeneratedKey() - if err != nil { - return err - } - - deleteKey := k.ToHex() - - _, err = db.ExecContext(ctx, "UPDATE entries SET delete_key=$2 WHERE uuid=$1", UUID, deleteKey) - if err != nil { - return fmt.Errorf("failed to update delete_key: %w", err) - } - } - _, err = db.ExecContext(ctx, "ALTER TABLE entries ALTER COLUMN delete_key SET NOT NULL;") - if err != nil { - return fmt.Errorf("failed to alter delete_key column: %w", err) - } - - return nil +type Migrator interface { + Create(context.Context, *sql.Tx) error + Alter(context.Context, *sql.Tx) error } func prepareDatabase(ctx context.Context, db *sql.DB) error { @@ -95,16 +17,21 @@ func prepareDatabase(ctx context.Context, db *sql.DB) error { if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - actions := []dbExec{ - createTable, - addRemainingRead, - addDeleteKey, + + migrations := []Migrator{ + NewEntryMigration(), + NewEntryKeyMigration(), } - for _, action := range actions { - if err := action(ctx, tx); err != nil { + for _, migration := range migrations { + if err := migration.Create(ctx, tx); err != nil { + tx.Rollback() + return fmt.Errorf("failed to create migration: %w", err) + } + + if err := migration.Alter(ctx, tx); err != nil { tx.Rollback() - return fmt.Errorf("failed to execute action: %w", err) + return fmt.Errorf("failed to alter migration: %w", err) } } diff --git a/internal/test/mocks.go b/internal/models/mock.go similarity index 53% rename from internal/test/mocks.go rename to internal/models/mock.go index e7b208f..a34796e 100644 --- a/internal/test/mocks.go +++ b/internal/models/mock.go @@ -1,28 +1,13 @@ -package test +package models import ( "context" "database/sql" "time" - "github.com/Ajnasz/sekret.link/internal/models" "github.com/stretchr/testify/mock" ) -type MockEntryCrypto struct { - mock.Mock -} - -func (m *MockEntryCrypto) Encrypt(data []byte) ([]byte, error) { - args := m.Called(data) - return args.Get(0).([]byte), args.Error(1) -} - -func (m *MockEntryCrypto) Decrypt(data []byte) ([]byte, error) { - args := m.Called(data) - return args.Get(0).([]byte), args.Error(1) -} - type MockEntryModel struct { mock.Mock } @@ -33,17 +18,17 @@ func (m *MockEntryModel) CreateEntry( UUID string, data []byte, remainingReads int, - expire time.Duration) (*models.EntryMeta, error) { + expire time.Duration) (*EntryMeta, error) { args := m.Called(ctx, tx, UUID, data, remainingReads, expire) - return args.Get(0).(*models.EntryMeta), args.Error(1) + return args.Get(0).(*EntryMeta), args.Error(1) } -func (m *MockEntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*models.Entry, error) { +func (m *MockEntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*Entry, error) { args := m.Called(ctx, tx, UUID) - return args.Get(0).(*models.Entry), args.Error(1) + return args.Get(0).(*Entry), args.Error(1) } -func (m *MockEntryModel) UpdateAccessed(ctx context.Context, tx *sql.Tx, UUID string) error { +func (m *MockEntryModel) Use(ctx context.Context, tx *sql.Tx, UUID string) error { args := m.Called(ctx, tx, UUID) return args.Error(0) } diff --git a/internal/parsers/createentry.go b/internal/parsers/createentry.go index 35b554c..59fd871 100644 --- a/internal/parsers/createentry.go +++ b/internal/parsers/createentry.go @@ -4,8 +4,10 @@ import ( "io" "mime" "net/http" - "strconv" "time" + + "github.com/Ajnasz/sekret.link/internal/parsers/expiration" + "github.com/Ajnasz/sekret.link/internal/parsers/maxreads" ) type CreateEntryParser struct { @@ -64,26 +66,12 @@ func getBody(r *http.Request) ([]byte, error) { } func (c CreateEntryParser) calculateExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) { - if expire == "" { - return defaultExpire, nil - } - - userExpire, err := time.ParseDuration(expire) + exp, err := expiration.Parse(expire, defaultExpire, c.maxExpireSeconds) if err != nil { - return 0, err - } - - maxExpire := time.Duration(c.maxExpireSeconds) * time.Second - - if userExpire > maxExpire { return 0, ErrInvalidExpirationDate } - if userExpire <= 0 { - return 0, ErrInvalidExpirationDate - } - - return userExpire, nil + return exp, nil } func (c CreateEntryParser) getSecretExpiration(r *http.Request) (time.Duration, error) { @@ -94,28 +82,15 @@ func (c CreateEntryParser) getSecretExpiration(r *http.Request) (time.Duration, return c.calculateExpiration(expiration, time.Second*time.Duration(c.maxExpireSeconds)) } -func getSecretMaxReads(r *http.Request) (int, error) { +func (c CreateEntryParser) getSecretMaxReads(r *http.Request) (int, error) { r.ParseForm() - const minMaxReadCount int = 1 - val := r.Form.Get("maxReads") - if val == "" { - return minMaxReadCount, nil - } - maxReads, err := strconv.Atoi(val) + reads, err := maxreads.Parse(r.Form.Get("maxReads")) if err != nil { - if _, isNumError := err.(*strconv.NumError); isNumError { - return 0, ErrInvalidMaxRead - } - - return 0, err - } - - if maxReads < minMaxReadCount { return 0, ErrInvalidMaxRead } - return maxReads, nil + return reads, nil } func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, error) { @@ -135,7 +110,7 @@ func (c CreateEntryParser) Parse(r *http.Request) (*CreateEntryRequestData, erro return nil, err } - maxReads, err := getSecretMaxReads(r) + maxReads, err := c.getSecretMaxReads(r) if err != nil { return nil, err diff --git a/internal/parsers/expiration/expiration.go b/internal/parsers/expiration/expiration.go new file mode 100644 index 0000000..4cd225f --- /dev/null +++ b/internal/parsers/expiration/expiration.go @@ -0,0 +1,33 @@ +package expiration + +import ( + "errors" + "time" +) + +// ErrInvalidExpirationDate request parse error happens when the user set +// expiration date is larger than the system maximum expiration date +var ErrInvalidExpirationDate = errors.New("Invalid expiration date") + +func Parse(expire string, defaultExpire time.Duration, maxExpireSeconds int) (time.Duration, error) { + if expire == "" { + return defaultExpire, nil + } + + userExpire, err := time.ParseDuration(expire) + if err != nil { + return 0, err + } + + maxExpire := time.Duration(maxExpireSeconds) * time.Second + + if userExpire > maxExpire { + return 0, ErrInvalidExpirationDate + } + + if userExpire <= 0 { + return 0, ErrInvalidExpirationDate + } + + return userExpire, nil +} diff --git a/internal/parsers/generateentry.go b/internal/parsers/generateentry.go new file mode 100644 index 0000000..9447a52 --- /dev/null +++ b/internal/parsers/generateentry.go @@ -0,0 +1,91 @@ +package parsers + +import ( + "encoding/hex" + "errors" + "net/http" + "time" + + "github.com/Ajnasz/sekret.link/internal/parsers/expiration" + "github.com/Ajnasz/sekret.link/internal/parsers/maxreads" + "github.com/google/uuid" +) + +// GenerateEntryKeyRequestData is the data for the GenerateEntryKey endpoint. +type GenerateEntryKeyRequestData struct { + UUID string + Key []byte + Expiration time.Duration + MaxReads int +} + +// GenerateEntryKeyParser is the http request parser for the GenerateEntryKey endpoint. +type GenerateEntryKeyParser struct { + maxExpireSeconds int +} + +// NewGenerateEntryKeyParser returns a new GenerateEntryKeyParser. +func NewGenerateEntryKeyParser(maxExpireSeconds int) *GenerateEntryKeyParser { + return &GenerateEntryKeyParser{ + maxExpireSeconds: maxExpireSeconds, + } +} + +func (g GenerateEntryKeyParser) calculateExpiration(expire string, defaultExpire time.Duration) (time.Duration, error) { + exp, err := expiration.Parse(expire, defaultExpire, g.maxExpireSeconds) + if err != nil { + return 0, ErrInvalidExpirationDate + } + + return exp, nil +} + +func (g GenerateEntryKeyParser) getSecretExpiration(req *http.Request) (time.Duration, error) { + expiration := req.URL.Query().Get("expire") + + return g.calculateExpiration(expiration, time.Second*time.Duration(g.maxExpireSeconds)) +} + +func (g GenerateEntryKeyParser) getSecretMaxReads(req *http.Request) (int, error) { + maxReads := req.URL.Query().Get("maxReads") + + return maxreads.Parse(maxReads) +} + +// Parse parses the http request for the GenerateEntryKey endpoint. +func (g *GenerateEntryKeyParser) Parse(r *http.Request) (GenerateEntryKeyRequestData, error) { + var reqData GenerateEntryKeyRequestData + uuidFromPath := r.PathValue("uuid") + UUID, err := uuid.Parse(uuidFromPath) + + if err != nil { + return reqData, errors.Join(ErrInvalidUUID, err) + } + + keyString := r.PathValue("key") + if keyString == "" { + return reqData, ErrInvalidKey + } + key, err := hex.DecodeString(keyString) + + if err != nil { + return reqData, errors.Join(ErrInvalidKey, err) + } + + expiration, err := g.getSecretExpiration(r) + if err != nil { + return reqData, err + } + + maxReads, err := g.getSecretMaxReads(r) + if err != nil { + return reqData, err + } + + reqData.UUID = UUID.String() + reqData.Key = key + reqData.Expiration = expiration + reqData.MaxReads = maxReads + + return reqData, nil +} diff --git a/internal/parsers/getentry.go b/internal/parsers/getentry.go index 7a70eb2..0881cec 100644 --- a/internal/parsers/getentry.go +++ b/internal/parsers/getentry.go @@ -3,7 +3,6 @@ package parsers import ( "encoding/hex" "errors" - "fmt" "net/http" "github.com/google/uuid" @@ -25,7 +24,6 @@ func (g GetEntryParser) Parse(req *http.Request) (GetEntryRequestData, error) { var reqData GetEntryRequestData keyString := req.PathValue("key") if keyString == "" { - fmt.Println("EMPTY KEY", req.URL.Path) return reqData, ErrInvalidKey } diff --git a/internal/parsers/maxreads/maxreads.go b/internal/parsers/maxreads/maxreads.go new file mode 100644 index 0000000..8a07480 --- /dev/null +++ b/internal/parsers/maxreads/maxreads.go @@ -0,0 +1,33 @@ +package maxreads + +import ( + "errors" + "strconv" +) + +// ErrInvalidMaxRead request parse error happens when the user maximum read +// number is greater than the system maximum read number +var ErrInvalidMaxRead = errors.New("Invalid max read") + +// Parse returns the maximum number of reads for a secret. +func Parse(val string) (int, error) { + const minMaxReadCount int = 1 + if val == "" { + return minMaxReadCount, nil + } + + maxReads, err := strconv.Atoi(val) + if err != nil { + if _, isNumError := err.(*strconv.NumError); isNumError { + return 0, ErrInvalidMaxRead + } + + return 0, err + } + + if maxReads < minMaxReadCount { + return 0, ErrInvalidMaxRead + } + + return maxReads, nil +} diff --git a/internal/services/encrypter.go b/internal/services/encrypter.go index 9d7231c..35ed9d2 100644 --- a/internal/services/encrypter.go +++ b/internal/services/encrypter.go @@ -34,21 +34,16 @@ func (e *AESEncrypter) Encrypt(data []byte) ([]byte, error) { return nil, err } - //Create a new GCM - https://en.wikipedia.org/wiki/Galois/Counter_Mode - //https://golang.org/pkg/crypto/cipher/#NewGCM aesGCM, err := cipher.NewGCM(block) if err != nil { return nil, err } - //Create a nonce. Nonce should be from GCM nonce := make([]byte, aesGCM.NonceSize()) if _, err = io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } - //Encrypt the data using aesGCM.Seal - //Since we don't want to save the nonce somewhere else in this case, we add it as a prefix to the encrypted data. The first nonce argument in Seal is the prefix. return aesGCM.Seal(nonce, nonce, data, nil), nil } @@ -59,19 +54,15 @@ func (e *AESEncrypter) Decrypt(data []byte) ([]byte, error) { return nil, err } - //Create a new GCM aesGCM, err := cipher.NewGCM(block) if err != nil { return nil, err } - //Get the nonce size nonceSize := aesGCM.NonceSize() - //Extract the nonce from the encrypted data nonce, ciphertext := data[:nonceSize], data[nonceSize:] - //Decrypt the data plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err diff --git a/internal/services/entrykeymanager.go b/internal/services/entrykeymanager.go new file mode 100644 index 0000000..fff2aa2 --- /dev/null +++ b/internal/services/entrykeymanager.go @@ -0,0 +1,237 @@ +package services + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/Ajnasz/sekret.link/internal/hasher" + "github.com/Ajnasz/sekret.link/internal/key" + "github.com/Ajnasz/sekret.link/internal/models" +) + +var ErrEntryKeyNotFound = errors.New("entry key not found") + +type EntryKeyModel interface { + Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*models.EntryKey, error) + Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]models.EntryKey, error) + Delete(ctx context.Context, tx *sql.Tx, uuid string) error + SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error + SetMaxReads(ctx context.Context, tx *sql.Tx, uuid string, maxRead int) error + Use(ctx context.Context, tx *sql.Tx, uuid string) error +} + +type EntryKeyManager struct { + db *sql.DB + tx *sql.Tx + model EntryKeyModel + hasher hasher.Hasher + encrypter EncrypterFactory +} + +func NewEntryKeyManager(db *sql.DB, model EntryKeyModel, hasher hasher.Hasher, encrypter EncrypterFactory) *EntryKeyManager { + return &EntryKeyManager{ + db: db, + model: model, + hasher: hasher, + encrypter: encrypter, + } +} + +func (e *EntryKeyManager) Create(ctx context.Context, entryUUID string, dek []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + + tx, err := e.db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + + entryKey, k, err := e.CreateWithTx(ctx, tx, entryUUID, dek, expire, maxRead) + + if err := tx.Commit(); err != nil { + return nil, nil, err + } + + return entryKey, k, nil +} + +type EntryKey struct { + UUID string + EntryUUID string + EncryptedKey []byte + KeyHash []byte + Created time.Time + Expire time.Time + RemainingReads int +} + +func modelEntryKeyToEntryKey(m *models.EntryKey) *EntryKey { + return &EntryKey{ + UUID: m.UUID, + EntryUUID: m.EntryUUID, + EncryptedKey: m.EncryptedKey, + KeyHash: m.KeyHash, + Created: m.Created, + Expire: m.Expire.Time, + RemainingReads: int(m.RemainingReads.Int16), + } +} + +func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + k, err := key.NewGeneratedKey() + + if err != nil { + return nil, nil, err + } + encrypter := e.encrypter(k.Get()) + encryptedKey, err := encrypter.Encrypt(dek) + if err != nil { + return nil, nil, err + } + + hash := e.hasher.Hash(dek) + entryKey, err := e.model.Create(ctx, tx, entryUUID, encryptedKey, hash) + if err != nil { + return nil, nil, err + } + + if expire != nil { + err := e.model.SetExpire(ctx, tx, entryKey.UUID, *expire) + if err != nil { + return nil, nil, err + } + entryKey.Expire = sql.NullTime{ + Time: *expire, + Valid: true, + } + } + + if maxRead != nil { + err := e.model.SetMaxReads(ctx, tx, entryKey.UUID, *maxRead) + if err != nil { + return nil, nil, err + } + + entryKey.RemainingReads = sql.NullInt16{ + Int16: int16(*maxRead), + Valid: true, + } + } + return modelEntryKeyToEntryKey(entryKey), k, nil +} + +func (e *EntryKeyManager) Delete(ctx context.Context, uuid string) error { + tx, err := e.db.BeginTx(ctx, nil) + if err != nil { + return err + } + + if err := e.model.Delete(ctx, tx, uuid); err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + +func (e *EntryKeyManager) UseTx(ctx context.Context, tx *sql.Tx, entryUUID string) error { + return e.model.Use(ctx, tx, entryUUID) +} + +func (e *EntryKeyManager) findDEK(ctx context.Context, tx *sql.Tx, entryUUID string, key []byte) (dek []byte, entryKey *models.EntryKey, err error) { + entryKeys, err := e.model.Get(ctx, tx, entryUUID) + if err != nil { + return nil, nil, err + } + + crypter := e.encrypter(key) + for _, ek := range entryKeys { + decrypted, err := crypter.Decrypt(ek.EncryptedKey) + if err != nil { + continue + } + + hash := e.hasher.Hash(decrypted) + + if hasher.Compare(hash, ek.KeyHash) { + return decrypted, &ek, nil + } + } + + return nil, nil, ErrEntryKeyNotFound +} + +// GetDEK returns the decrypted data encryption key and the entry key +// if the key is not found it returns ErrEntryKeyNotFound +// if the key is found but the hash does not match it returns an error +func (e *EntryKeyManager) GetDEK(ctx context.Context, entryUUID string, key []byte) (dek []byte, entryKey *EntryKey, err error) { + tx, err := e.db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + + dek, entryKey, err = e.GetDEKTx(ctx, tx, entryUUID, key) + if err != nil { + tx.Rollback() + return nil, nil, err + } + + if err := tx.Commit(); err != nil { + return nil, nil, err + } + + return dek, entryKey, nil +} + +// GetDEKTx returns the decrypted data encryption key and the entry key +// if the key is not found it returns ErrEntryKeyNotFound +// if the key is found but the hash does not match it returns an error +func (e *EntryKeyManager) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID string, key []byte) (dek []byte, entryKey *EntryKey, err error) { + dek, entryKeyModel, err := e.findDEK(ctx, tx, entryUUID, key) + + if err != nil { + return nil, nil, err + } + + if err := validateEntryKey(entryKeyModel); err != nil { + return nil, nil, err + } + + if e.model == nil { + return nil, nil, errors.New("model is nil") + } + + if err := e.model.Use(ctx, tx, entryKeyModel.UUID); err != nil { + return nil, nil, err + } + + return dek, modelEntryKeyToEntryKey(entryKeyModel), nil + +} + +// GenerateEncryptionKey creates a new key for the entry +func (e EntryKeyManager) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + tx, err := e.db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + + dek, _, err := e.findDEK(ctx, tx, entryUUID, existingKey) + + if err != nil { + tx.Rollback() + return nil, nil, err + } + + entryKey, k, err := e.CreateWithTx(ctx, tx, entryUUID, dek, expire, maxRead) + if err != nil { + tx.Rollback() + return nil, nil, err + } + + if err := tx.Commit(); err != nil { + return nil, nil, err + } + + return entryKey, k, nil +} diff --git a/internal/services/entrykeymanager_test.go b/internal/services/entrykeymanager_test.go new file mode 100644 index 0000000..9ac7061 --- /dev/null +++ b/internal/services/entrykeymanager_test.go @@ -0,0 +1,668 @@ +package services + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/Ajnasz/sekret.link/internal/models" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockEntryKeyModel struct { + mock.Mock +} + +func (m *MockEntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*models.EntryKey, error) { + args := m.Called(ctx, tx, entryUUID, encryptedKey, hash) + return args.Get(0).(*models.EntryKey), args.Error(1) +} + +func (m *MockEntryKeyModel) Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]models.EntryKey, error) { + args := m.Called(ctx, tx, entryUUID) + return args.Get(0).([]models.EntryKey), args.Error(1) +} + +func (m *MockEntryKeyModel) Delete(ctx context.Context, tx *sql.Tx, uuid string) error { + args := m.Called(ctx, tx, uuid) + return args.Error(0) +} + +func (m *MockEntryKeyModel) SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error { + args := m.Called(ctx, tx, uuid, expire) + return args.Error(0) +} + +func (m *MockEntryKeyModel) SetMaxReads(ctx context.Context, tx *sql.Tx, uuid string, maxRead int) error { + args := m.Called(ctx, tx, uuid, maxRead) + return args.Error(0) +} + +func (m *MockEntryKeyModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error { + args := m.Called(ctx, tx, uuid) + return args.Error(0) +} + +type MockHasher struct { + mock.Mock +} + +func (m *MockHasher) Hash(data []byte) []byte { + args := m.Called(data) + return args.Get(0).([]byte) +} + +type EncrypterMock struct { + mock.Mock +} + +func (e *EncrypterMock) Encrypt(data []byte) ([]byte, error) { + args := e.Called(data) + return args.Get(0).([]byte), args.Error(1) +} + +func (e *EncrypterMock) Decrypt(data []byte) ([]byte, error) { + args := e.Called(data) + return args.Get(0).([]byte), args.Error(1) +} + +func TestEntryKeyManager_Create(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + dek := []byte("test-dek") + encryptedKey := []byte("test-encrypted-key") + hash := []byte("test-hash") + expire := time.Now() + maxRead := 10 + + encrypter.On("Encrypt", dek).Return(encryptedKey, nil) + hasher.On("Hash", dek).Return(hash) + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + Created: time.Now(), + Expire: sql.NullTime{Time: time.Now(), Valid: false}, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + + model.On("SetExpire", ctx, mock.Anything, "test-uuid", expire).Return(nil) + model.On("SetMaxReads", ctx, mock.Anything, "test-uuid", maxRead).Return(nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + entryKey, key, err := manager.Create(ctx, entryUUID, dek, &expire, &maxRead) + + model.AssertExpectations(t) + encrypter.AssertExpectations(t) + hasher.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "test-uuid", entryKey.UUID) + assert.Equal(t, expire, entryKey.Expire) + assert.Equal(t, maxRead, entryKey.RemainingReads) + assert.NotEmpty(t, key.Get()) +} + +func TestEntryKeyManager_Create_NoExpire(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + dek := []byte("test-dek") + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + hash := []byte("test-hash") + + hasher.On("Hash", dek).Return(hash) + encrypter.On("Encrypt", dek).Return(encryptedKey, nil) + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: sql.NullTime{Time: time.Now(), Valid: false}, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + manager := NewEntryKeyManager(db, model, hasher, crypto) + entryKey, key, err := manager.Create(ctx, entryUUID, dek, nil, nil) + + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + model.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "test-uuid", entryKey.UUID) + fmt.Println("__------------------------", entryKey.Expire) + // assert.False(nil, entryKey.Expire) + assert.NotEmpty(t, key.Get()) +} + +func TestEntryKeyManager_Create_NoMaxRead(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + dek := []byte("test-dek") + encryptedKey := []byte("test-encrypted-key") + hash := []byte("test-hash") + + hasher.On("Hash", dek).Return(hash) + encrypter.On("Encrypt", dek).Return(encryptedKey, nil) + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: sql.NullTime{Time: time.Now(), Valid: false}, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + entryKey, key, err := manager.Create(ctx, entryUUID, dek, nil, nil) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "test-uuid", entryKey.UUID) + // key.Get should not return an empty string + assert.NotEmpty(t, key.Get()) +} + +func TestEntryKeyManager_GetDEK(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + dek := []byte("test-dek") + hash := []byte("test-hash") + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + hasher.On("Hash", dek).Return(hash) + encrypter.On("Decrypt", encryptedKey).Return(dek, nil) + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + }, + }, nil) + model.On("Use", ctx, mock.Anything, "test-uuid").Return(nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, dek, foundDEK) + assert.Equal(t, "test-uuid", entryKey.UUID) +} + +// TestEntryKeyManager_GetDEK_NotFound tests the case when the entry key is not +// found so the function should return an ErrEntryKeyNotFound error +func TestEntryKeyManager_GetDEK_NotFound(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + dek := []byte("test-dek") + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{}, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.Error(t, ErrEntryKeyNotFound, err) + assert.Nil(t, foundDEK) + assert.Nil(t, entryKey) +} + +func TestEntryKeyManager_GetDEK_DecryptError(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-keyke") + dek := []byte("test-dekk") + hash := []byte("test-hashh") + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + + encrypter.On("Decrypt", encryptedKey).Return([]byte{}, assert.AnError) + + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + }, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + assert.Error(t, err) + assert.Nil(t, foundDEK) + assert.Nil(t, entryKey) + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } +} + +func TestEntryManager_InvalidDEK(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + dek := []byte("test-dek") + badDEK := []byte("bad-dek") + hash := []byte("test-hash") + badHash := []byte("bad-hash") + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + encrypter.On("Decrypt", encryptedKey).Return(badDEK, nil) + hasher.On("Hash", badDEK).Return(badHash) + + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + }, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.Error(t, err) + assert.Nil(t, foundDEK) + assert.Nil(t, entryKey) +} + +func TestEntryKeyManager_GenerateEncryptionKey(t *testing.T) { + // reads an existing key from the db, creates a new key, and returns the new key + + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + newEncryptedKey := []byte("new-test-encrypted-key") + dek := []byte("test-dek") + hash := []byte("test-hash") + expire := time.Now() + maxRead := 10 + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + }, + }, nil) + encrypter.On("Decrypt", encryptedKey).Return(dek, nil) + hasher.On("Hash", dek).Return(hash) + + encrypter.On("Encrypt", mock.Anything).Return(newEncryptedKey, nil) + model.On("Create", ctx, mock.Anything, entryUUID, newEncryptedKey, hash).Return(&models.EntryKey{ + UUID: "new-test-uuid", + EntryUUID: entryUUID, + EncryptedKey: newEncryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: sql.NullTime{Time: time.Now(), Valid: false}, + RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + }, nil) + model.On("SetExpire", ctx, mock.Anything, "new-test-uuid", expire).Return(nil) + model.On("SetMaxReads", ctx, mock.Anything, "new-test-uuid", maxRead).Return(nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + + entryKey, key, err := manager.GenerateEncryptionKey(ctx, entryUUID, encryptedKey, &expire, &maxRead) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) + assert.Equal(t, "new-test-uuid", entryKey.UUID) + assert.Equal(t, expire, entryKey.Expire) + assert.Equal(t, maxRead, entryKey.RemainingReads) + assert.NotEmpty(t, key.Get()) +} + +// TestEntryKeyManager_GenerateEncryptionKey_DecryptError tests if the UseTx method correctly calls the model's Use method +func TestEntryKeyManager_UseTx(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + model.On("Use", ctx, mock.Anything, "test-uuid").Return(nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + tx, err := db.Begin() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + tx.Commit() + + err = manager.UseTx(ctx, tx, "test-uuid") + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) +} + +func Test_EntryKeyManager_GetDEKTx_NoRemainingReads(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + dek := []byte("test-dek") + hash := []byte("test-hash") + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + hasher.On("Hash", dek).Return(hash) + encrypter.On("Decrypt", encryptedKey).Return(dek, nil) + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + RemainingReads: sql.NullInt16{Int16: 0, Valid: true}, + }, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.Error(t, ErrEntryNoRemainingReads, err) + assert.Nil(t, foundDEK) + assert.Nil(t, entryKey) +} + +func Test_EntryKeyManager_GetDEKTx_Expired(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + entryUUID := "test-entry-uuid" + encryptedKey := []byte("test-encrypted-key") + dek := []byte("test-dek") + hash := []byte("test-hash") + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + hasher.On("Hash", dek).Return(hash) + encrypter.On("Decrypt", encryptedKey).Return(dek, nil) + model.On("Get", ctx, mock.Anything, entryUUID).Return([]models.EntryKey{ + { + UUID: "test-uuid", + EntryUUID: entryUUID, + EncryptedKey: encryptedKey, + KeyHash: hash, + Created: time.Now(), + Expire: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, + RemainingReads: sql.NullInt16{Int16: 1, Valid: true}, + }, + }, nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + foundDEK, entryKey, err := manager.GetDEK(ctx, entryUUID, dek) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.Error(t, ErrEntryExpired, err) + assert.Nil(t, foundDEK) + assert.Nil(t, entryKey) +} + +func Test_EntryKeyManager_Delete(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + defer db.Close() + + ctx := context.Background() + model := &MockEntryKeyModel{} + hasher := &MockHasher{} + encrypter := &EncrypterMock{} + uuid := "test-uuid" + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + model.On("Delete", ctx, mock.Anything, uuid).Return(nil) + + crypto := func(key []byte) Encrypter { + return encrypter + } + + manager := NewEntryKeyManager(db, model, hasher, crypto) + err = manager.Delete(ctx, uuid) + + model.AssertExpectations(t) + hasher.AssertExpectations(t) + encrypter.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + assert.NoError(t, err) +} diff --git a/internal/services/entrymanager.go b/internal/services/entrymanager.go index 86ecb8e..669221f 100644 --- a/internal/services/entrymanager.go +++ b/internal/services/entrymanager.go @@ -13,16 +13,7 @@ import ( var ErrEntryExpired = errors.New("entry expired") var ErrEntryNotFound = errors.New("entry not found") - -// EntryModel is the interface for the entry model -// It is used to create, read and access entries -type EntryModel interface { - CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, data []byte, remainingReads int, expire time.Duration) (*models.EntryMeta, error) - ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*models.Entry, error) - UpdateAccessed(ctx context.Context, tx *sql.Tx, UUID string) error - DeleteEntry(ctx context.Context, tx *sql.Tx, UUID string, deleteKey string) error - DeleteExpired(ctx context.Context, tx *sql.Tx) error -} +var ErrEntryNoRemainingReads = errors.New("entry has no remaining reads") // EntryMeta provides the entry meta type EntryMeta struct { @@ -44,28 +35,37 @@ type Entry struct { Expire time.Time } -type EncrypterFactory = func(key []byte) Encrypter - -func (e *Entry) IsExpired() bool { - return e.Expire.Before(time.Now()) +type EntryKeyData struct { + EntryUUID string + KEK []byte + RemainingReads int + Expire time.Time } // EntryManager provides the entry service type EntryManager struct { - db *sql.DB - model EntryModel - crypto EncrypterFactory + db *sql.DB + model EntryModel + crypto EncrypterFactory + keyManager EntryKeyer } // NewEntryManager creates a new EntryService -func NewEntryManager(db *sql.DB, model EntryModel, crypto EncrypterFactory) *EntryManager { +func NewEntryManager(db *sql.DB, model EntryModel, crypto EncrypterFactory, keyManager EntryKeyer) *EntryManager { return &EntryManager{ - db: db, - model: model, - crypto: crypto, + db: db, + model: model, + crypto: crypto, + keyManager: keyManager, } } +// CreateEntry creates a new entry +// It generates a new UUID for the entry +// It encrypts the data with a new generated key +// It stores the encrypted data in the database +// It stores the key in the key manager +// It returns the meta data of the entry and the key func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, []byte, error) { uid := uuid.NewUUIDString() @@ -73,27 +73,34 @@ func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingRe if err != nil { return nil, nil, err } - k, err := key.NewGeneratedKey() + dek, err := key.NewGeneratedKey() if err != nil { tx.Rollback() return nil, nil, err } - crypto := e.crypto(k.Get()) + crypto := e.crypto(dek.Get()) encryptedData, err := crypto.Encrypt(data) if err != nil { tx.Rollback() return nil, nil, err } - // meta, err := e.model.CreateEntry(ctx, tx, uid, data, remainingReads, expire) meta, err := e.model.CreateEntry(ctx, tx, uid, encryptedData, remainingReads, expire) if err != nil { tx.Rollback() return nil, nil, err } + expireAt := time.Now().Add(expire) + _, kek, err := e.keyManager.CreateWithTx(ctx, tx, uid, dek.Get(), &expireAt, &remainingReads) + + if err != nil { + tx.Rollback() + return nil, nil, err + } + tx.Commit() return &EntryMeta{ @@ -103,11 +110,27 @@ func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingRe Created: meta.Created, Accessed: meta.Accessed.Time, Expire: meta.Expire, - }, k.Get(), nil - + }, kek.Get(), nil +} +func (e *EntryManager) readEntryLegacy(ctx context.Context, key []byte, entry *models.Entry) ([]byte, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + crypto := e.crypto(key) + return crypto.Decrypt(entry.Data) } +// ReadEntry reads an entry +// It reads the entry from the database +// It reads the key from the key manager +// It decrypts the data with the key +// It returns the decrypted data +// It returns an error if the entry is not found or expired +// It returns an error if the key is not found func (e *EntryManager) ReadEntry(ctx context.Context, UUID string, key []byte) (*Entry, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } tx, err := e.db.Begin() if err != nil { return nil, err @@ -122,24 +145,41 @@ func (e *EntryManager) ReadEntry(ctx context.Context, UUID string, key []byte) ( return nil, err } - if entry.RemainingReads <= 0 { - tx.Rollback() - return nil, ErrEntryExpired - } - - if entry.Expire.Before(time.Now()) { + if err := validateEntry(entry); err != nil { tx.Rollback() - return nil, ErrEntryExpired + return nil, err } - crypto := e.crypto(key) - decryptedData, err := crypto.Decrypt(entry.Data) + dek, entryKey, err := e.keyManager.GetDEKTx(ctx, tx, UUID, key) + var decryptedData []byte if err != nil { - tx.Rollback() - return nil, err + if errors.Is(err, ErrEntryKeyNotFound) { + legacyData, legacyErr := e.readEntryLegacy(ctx, key, entry) + if legacyErr == nil { + decryptedData = legacyData + } else { + tx.Rollback() + return nil, err + } + } else { + tx.Rollback() + return nil, err + } + } else { + crypto := e.crypto(dek) + decryptedData, err = crypto.Decrypt(entry.Data) + if err != nil { + tx.Rollback() + return nil, err + } + + if err := e.keyManager.UseTx(ctx, tx, entryKey.UUID); err != nil { + tx.Rollback() + return nil, err + } } - if err := e.model.UpdateAccessed(ctx, tx, UUID); err != nil { + if err := e.model.Use(ctx, tx, UUID); err != nil { tx.Rollback() return nil, err } @@ -190,3 +230,17 @@ func (e *EntryManager) DeleteExpired(ctx context.Context) error { tx.Commit() return nil } + +func (e *EntryManager) GenerateEntryKey(ctx context.Context, entryUUID string, key []byte) (*EntryKeyData, error) { + meta, kek, err := e.keyManager.GenerateEncryptionKey(ctx, entryUUID, key, nil, nil) + if err != nil { + return nil, err + } + + return &EntryKeyData{ + EntryUUID: entryUUID, + RemainingReads: meta.RemainingReads, + Expire: meta.Expire, + KEK: kek.Get(), + }, nil +} diff --git a/internal/services/entrymanager_test.go b/internal/services/entrymanager_test.go index 74fe67c..a262f86 100644 --- a/internal/services/entrymanager_test.go +++ b/internal/services/entrymanager_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" + "github.com/Ajnasz/sekret.link/internal/key" "github.com/Ajnasz/sekret.link/internal/models" - "github.com/Ajnasz/sekret.link/internal/test" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -31,7 +31,7 @@ func Test_EntryService_Create(t *testing.T) { data := []byte("data") encryptedData := []byte("encrypted") - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData, 1, mock.Anything). Return(&models.EntryMeta{ @@ -42,20 +42,26 @@ func Test_EntryService_Create(t *testing.T) { Expire: timenow.Add(time.Minute), }, nil) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) entryCrypto.On("Encrypt", data).Return(encryptedData, nil) crypto := func(key []byte) Encrypter { return entryCrypto } - service := NewEntryManager(db, entryModel, crypto) + keyManager := new(MockEntryKeyer) + kek := key.NewKey() + kek.Set([]byte("kek")) + keyManager.On("CreateWithTx", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&EntryKey{}, kek, nil) + + service := NewEntryManager(db, entryModel, crypto, keyManager) meta, key, err := service.CreateEntry(ctx, data, 1, time.Minute) assert.NoError(t, err) assert.NotNil(t, meta) - assert.NotNil(t, key) + assert.Equal(t, key, kek.Get()) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) if meta.UUID == "" { t.Error("expected UUID to be set") } @@ -93,18 +99,20 @@ func TestCreateError(t *testing.T) { data := []byte("data") encryptedData := []byte("encrypted") - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData, 1, mock.Anything). Return(&models.EntryMeta{}, fmt.Errorf("error")) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) entryCrypto.On("Encrypt", data).Return(encryptedData, nil) crypto := func(key []byte) Encrypter { return entryCrypto } - service := NewEntryManager(db, entryModel, crypto) + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) meta, key, err := service.CreateEntry(ctx, data, 1, time.Minute) assert.Error(t, err) @@ -112,62 +120,161 @@ func TestCreateError(t *testing.T) { assert.Nil(t, key) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) } func TestReadEntry(t *testing.T) { - db, sqlMock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer db.Close() - - sqlMock.ExpectBegin() - sqlMock.ExpectCommit() - - ctx := context.Background() - - entry := models.Entry{ - UUID: "uuid", - Data: []byte("encrypted"), - RemainingReads: 1, - DeleteKey: "delete_key", - Created: timenow, - Accessed: sql.NullTime{Time: timenow, Valid: true}, - Expire: timenow.Add(time.Minute), - } - - entryModel := new(test.MockEntryModel) - entryModel. - On("ReadEntry", ctx, mock.Anything, "uuid"). - Return(&entry, nil) - entryModel. - On("UpdateAccessed", ctx, mock.Anything, "uuid"). - Return(nil) - - key := []byte("key") - entryCrypto := new(test.MockEntryCrypto) - entryCrypto.On("Decrypt", []byte("encrypted")).Return([]byte("data"), nil) - - crypto := func(key []byte) Encrypter { - return entryCrypto - } - - service := NewEntryManager(db, entryModel, crypto) - data, err := service.ReadEntry(ctx, "uuid", key) - - assert.NoError(t, err) - assert.NotNil(t, data) - assert.Equal(t, Entry{ - UUID: entry.UUID, - Data: []byte("data"), - RemainingReads: 0, - DeleteKey: entry.DeleteKey, - Created: entry.Created, - Accessed: entry.Accessed.Time, - Expire: entry.Expire, - }, *data) - - entryModel.AssertExpectations(t) + t.Run("read entry with valid data", func(t *testing.T) { + + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + + entry := models.Entry{ + EntryMeta: models.EntryMeta{ + UUID: "uuid", + RemainingReads: 1, + DeleteKey: "delete_key", + Created: timenow, + Accessed: sql.NullTime{Time: timenow, Valid: true}, + Expire: timenow.Add(time.Minute), + }, + Data: []byte("encrypted"), + } + + entryModel := new(models.MockEntryModel) + entryModel. + On("ReadEntry", ctx, mock.Anything, "uuid"). + Return(&entry, nil) + entryModel. + On("Use", ctx, mock.Anything, "uuid"). + Return(nil) + + key := []byte("key") + entryCrypto := new(MockEntryCrypto) + entryCrypto.On("Decrypt", []byte("encrypted")).Return([]byte("data"), nil) + + crypto := func(key []byte) Encrypter { + return entryCrypto + } + + keyManager := new(MockEntryKeyer) + + keyManager.On("GetDEKTx", ctx, mock.Anything, "uuid", key).Return([]byte("dek"), &EntryKey{ + UUID: "entrykey uuid", + }, nil) + keyManager.On("UseTx", ctx, mock.Anything, "entrykey uuid").Return(nil) + + service := NewEntryManager(db, entryModel, crypto, keyManager) + data, err := service.ReadEntry(ctx, "uuid", key) + + assert.NoError(t, err) + assert.NotNil(t, data) + assert.Equal(t, Entry{ + UUID: entry.UUID, + Data: []byte("data"), + RemainingReads: 0, + DeleteKey: entry.DeleteKey, + Created: entry.Created, + Accessed: entry.Accessed.Time, + Expire: entry.Expire, + }, *data) + + entryModel.AssertExpectations(t) + }) + + t.Run("should return notfound error when entry not found", func(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + + ctx := context.Background() + + var emptyEntry *models.Entry + entryModel := new(models.MockEntryModel) + entryModel. + On("ReadEntry", ctx, mock.Anything, "uuid"). + Return(emptyEntry, models.ErrEntryNotFound) + + entryCrypto := new(MockEntryCrypto) + crypto := func(key []byte) Encrypter { + return entryCrypto + } + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) + data, err := service.ReadEntry(ctx, "uuid", []byte("key")) + + assert.Error(t, err) + assert.Nil(t, data) + + entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) + }) + + t.Run("it should try to decrypt with legacy method when key not found", func(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + + entry := models.Entry{ + EntryMeta: models.EntryMeta{ + UUID: "uuid", + RemainingReads: 1, + DeleteKey: "delete_key", + Created: timenow, + Accessed: sql.NullTime{Time: timenow, Valid: true}, + Expire: timenow.Add(time.Minute), + }, + Data: []byte("encrypted"), + } + + entryModel := new(models.MockEntryModel) + entryModel. + On("ReadEntry", ctx, mock.Anything, "uuid"). + Return(&entry, nil) + entryModel.On("Use", ctx, mock.Anything, "uuid").Return(nil) + + entryCrypto := new(MockEntryCrypto) + entryCrypto.On("Decrypt", []byte("encrypted")).Return([]byte("decrypted"), nil) + + crypto := func(key []byte) Encrypter { + return entryCrypto + } + + var emptyEntryKey *EntryKey + var emptyDEK []byte + keyManager := new(MockEntryKeyer) + keyManager.On("GetDEKTx", ctx, mock.Anything, "uuid", []byte("key")).Return(emptyDEK, emptyEntryKey, ErrEntryKeyNotFound) + + service := NewEntryManager(db, entryModel, crypto, keyManager) + data, err := service.ReadEntry(ctx, "uuid", []byte("key")) + + assert.Nil(t, err) + assert.Equal(t, "decrypted", string(data.Data)) + + entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) + }) } func TestReadEntryError(t *testing.T) { @@ -182,23 +289,25 @@ func TestReadEntryError(t *testing.T) { ctx := context.Background() - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("ReadEntry", ctx, mock.Anything, "uuid"). - Return(&models.Entry{}, fmt.Errorf("error")) + Return(&models.Entry{}, models.ErrEntryNotFound) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) crypto := func(key []byte) Encrypter { return entryCrypto } + keyManager := new(MockEntryKeyer) - service := NewEntryManager(db, entryModel, crypto) + service := NewEntryManager(db, entryModel, crypto, keyManager) data, err := service.ReadEntry(ctx, "uuid", []byte("key")) - assert.Error(t, err) + assert.Error(t, ErrEntryNotFound) assert.Nil(t, data) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) if sqlMock.ExpectationsWereMet() != nil { t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) } @@ -216,22 +325,24 @@ func TestDeleteEntry(t *testing.T) { ctx := context.Background() - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("DeleteEntry", ctx, mock.Anything, "uuid", "delete_key"). Return(nil) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) crypto := func(key []byte) Encrypter { return entryCrypto } + keyManager := new(MockEntryKeyer) - service := NewEntryManager(db, entryModel, crypto) + service := NewEntryManager(db, entryModel, crypto, keyManager) err = service.DeleteEntry(ctx, "uuid", "delete_key") assert.NoError(t, err) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) if sqlMock.ExpectationsWereMet() != nil { t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) } @@ -249,22 +360,25 @@ func TestDeleteEntryError(t *testing.T) { ctx := context.Background() - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("DeleteEntry", ctx, mock.Anything, "uuid", "delete_key"). Return(fmt.Errorf("error")) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) crypto := func(key []byte) Encrypter { return entryCrypto } - service := NewEntryManager(db, entryModel, crypto) + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) err = service.DeleteEntry(ctx, "uuid", "delete_key") assert.Error(t, err) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) if sqlMock.ExpectationsWereMet() != nil { t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) } @@ -282,25 +396,144 @@ func TestDeleteEntryInvalidDeleteKey(t *testing.T) { ctx := context.Background() - entryModel := new(test.MockEntryModel) + entryModel := new(models.MockEntryModel) entryModel. On("DeleteEntry", ctx, mock.Anything, "uuid", "delete_key"). Return(models.ErrEntryNotFound) - entryCrypto := new(test.MockEntryCrypto) + entryCrypto := new(MockEntryCrypto) crypto := func(key []byte) Encrypter { return entryCrypto } - service := NewEntryManager(db, entryModel, crypto) + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) err = service.DeleteEntry(ctx, "uuid", "delete_key") assert.Error(t, err) assert.Equal(t, models.ErrEntryNotFound, err) entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) if sqlMock.ExpectationsWereMet() != nil { t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) } } + +func Test_EntryManager_DeleteExpired(t *testing.T) { + t.Run("call delete method on the model", func(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + sqlMock.ExpectBegin() + sqlMock.ExpectCommit() + + ctx := context.Background() + + entryModel := new(models.MockEntryModel) + entryModel. + On("DeleteExpired", ctx, mock.Anything). + Return(nil) + + entryCrypto := new(MockEntryCrypto) + crypto := func(key []byte) Encrypter { + return entryCrypto + } + + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) + err = service.DeleteExpired(ctx) + + assert.NoError(t, err) + + entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + }) + + t.Run("call delete method on the model with error", func(t *testing.T) { + db, sqlMock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + sqlMock.ExpectBegin() + sqlMock.ExpectRollback() + + ctx := context.Background() + + entryModel := new(models.MockEntryModel) + entryModel. + On("DeleteExpired", ctx, mock.Anything). + Return(fmt.Errorf("error")) + + entryCrypto := new(MockEntryCrypto) + crypto := func(key []byte) Encrypter { + return entryCrypto + } + + keyManager := new(MockEntryKeyer) + + service := NewEntryManager(db, entryModel, crypto, keyManager) + err = service.DeleteExpired(ctx) + + assert.Error(t, err) + + entryModel.AssertExpectations(t) + keyManager.AssertExpectations(t) + if sqlMock.ExpectationsWereMet() != nil { + t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) + } + }) +} + +func Test_EntryManager_GenerateEntryKey(t *testing.T) { + t.Run("call generate entry key method on the key manager", func(t *testing.T) { + entryUUID := "entry-uuid" + dek := []byte("dek") + kek := key.NewKey() + kek.Generate() + + keyManager := new(MockEntryKeyer) + keyManager.On("GenerateEncryptionKey", mock.Anything, entryUUID, dek, mock.Anything, mock.Anything). + Return(&EntryKey{ + EntryUUID: entryUUID, + RemainingReads: 1, + Expire: time.Now().Add(time.Minute), + }, kek, nil) + + service := NewEntryManager(nil, nil, nil, keyManager) + + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, dek) + + assert.NoError(t, err) + assert.Equal(t, entryUUID, entryKey.EntryUUID) + assert.Equal(t, kek.Get(), entryKey.KEK) + }) + + t.Run("call generate entry key method on the key manager with error", func(t *testing.T) { + entryUUID := "entry-uuid" + dek := []byte("dek") + + var emptyEntryKey *EntryKey + var emptyKey *key.Key + + keyManager := new(MockEntryKeyer) + keyManager.On("GenerateEncryptionKey", mock.Anything, entryUUID, dek, mock.Anything, mock.Anything). + Return(emptyEntryKey, emptyKey, fmt.Errorf("error")) + + service := NewEntryManager(nil, nil, nil, keyManager) + + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, dek) + + assert.Error(t, err) + assert.Nil(t, entryKey) + }) +} diff --git a/internal/services/entryvalidation.go b/internal/services/entryvalidation.go new file mode 100644 index 0000000..e8e9297 --- /dev/null +++ b/internal/services/entryvalidation.go @@ -0,0 +1,39 @@ +package services + +import ( + "time" + + "github.com/Ajnasz/sekret.link/internal/models" +) + +func validateEntry(entry *models.Entry) error { + if entry == nil { + return ErrEntryNotFound + } + + if entry.Expire.Before(time.Now()) { + return ErrEntryExpired + } + + if entry.RemainingReads <= 0 { + return ErrEntryExpired + } + + return nil +} + +func validateEntryKey(entryKey *models.EntryKey) error { + if entryKey == nil { + return ErrEntryKeyNotFound + } + + if entryKey.Expire.Valid && entryKey.Expire.Time.Before(time.Now()) { + return ErrEntryExpired + } + + if entryKey.RemainingReads.Valid && entryKey.RemainingReads.Int16 <= 0 { + return ErrEntryNoRemainingReads + } + + return nil +} diff --git a/internal/services/interfaces.go b/internal/services/interfaces.go new file mode 100644 index 0000000..77a64c5 --- /dev/null +++ b/internal/services/interfaces.go @@ -0,0 +1,32 @@ +package services + +import ( + "context" + "database/sql" + "time" + + "github.com/Ajnasz/sekret.link/internal/key" + "github.com/Ajnasz/sekret.link/internal/models" +) + +// EntryModel is the interface for the entry model +// It is used to create, read and access entries +type EntryModel interface { + CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, data []byte, remainingReads int, expire time.Duration) (*models.EntryMeta, error) + ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*models.Entry, error) + Use(ctx context.Context, tx *sql.Tx, UUID string) error + DeleteEntry(ctx context.Context, tx *sql.Tx, UUID string, deleteKey string) error + DeleteExpired(ctx context.Context, tx *sql.Tx) error +} + +// EntryKeyer is the interface for the entry key manager +// It is used to create, read and access entry keys +type EntryKeyer interface { + CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek []byte, expire *time.Time, maxRead *int) (entryKey *EntryKey, kek *key.Key, err error) + GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID string, kek []byte) (dek []byte, entryKey *EntryKey, err error) + GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) + UseTx(ctx context.Context, tx *sql.Tx, entryUUID string) error +} + +// EncrypterFactory is function to create a new Encrypter for a given key +type EncrypterFactory = func(key []byte) Encrypter diff --git a/internal/services/mocks.go b/internal/services/mocks.go new file mode 100644 index 0000000..4ba0c41 --- /dev/null +++ b/internal/services/mocks.go @@ -0,0 +1,58 @@ +package services + +import ( + "context" + "database/sql" + "time" + + "github.com/Ajnasz/sekret.link/internal/key" + "github.com/stretchr/testify/mock" +) + +type MockEntryKeyer struct { + mock.Mock +} + +func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + args := m.Called(ctx, entryUUID, dek, expire, maxRead) + return args.Get(0).(*EntryKey), args.Get(1).(*key.Key), args.Error(2) +} + +func (m *MockEntryKeyer) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + args := m.Called(ctx, tx, entryUUID, dek, expire, maxRead) + return args.Get(0).(*EntryKey), args.Get(1).(*key.Key), args.Error(2) +} + +func (m *MockEntryKeyer) GetDEK(ctx context.Context, entryUUID string, kek []byte) ([]byte, *EntryKey, error) { + args := m.Called(ctx, entryUUID, kek) + return args.Get(0).([]byte), args.Get(1).(*EntryKey), args.Error(2) +} + +func (m *MockEntryKeyer) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID string, kek []byte) ([]byte, *EntryKey, error) { + args := m.Called(ctx, tx, entryUUID, kek) + return args.Get(0).([]byte), args.Get(1).(*EntryKey), args.Error(2) +} + +func (m *MockEntryKeyer) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey []byte, expire *time.Time, maxRead *int) (*EntryKey, *key.Key, error) { + args := m.Called(ctx, entryUUID, existingKey, expire, maxRead) + return args.Get(0).(*EntryKey), args.Get(1).(*key.Key), args.Error(2) +} + +func (m *MockEntryKeyer) UseTx(ctx context.Context, tx *sql.Tx, entryUUID string) error { + args := m.Called(ctx, tx, entryUUID) + return args.Error(0) +} + +type MockEntryCrypto struct { + mock.Mock +} + +func (m *MockEntryCrypto) Encrypt(data []byte) ([]byte, error) { + args := m.Called(data) + return args.Get(0).([]byte), args.Error(1) +} + +func (m *MockEntryCrypto) Decrypt(data []byte) ([]byte, error) { + args := m.Called(data) + return args.Get(0).([]byte), args.Error(1) +} diff --git a/internal/test/durable/durable.go b/internal/test/durable/durable.go index 1ea8021..9ccd89d 100644 --- a/internal/test/durable/durable.go +++ b/internal/test/durable/durable.go @@ -3,7 +3,6 @@ package durable import ( "context" "database/sql" - "fmt" "github.com/Ajnasz/sekret.link/internal/config" "github.com/Ajnasz/sekret.link/internal/durable" @@ -18,7 +17,6 @@ func TestConnection(ctx context.Context) (*sql.DB, error) { Database: "sekret_link_test", SslMode: "disable", } - fmt.Println(config.GetConnectionString(conf.String())) db, err := durable.OpenDatabaseClient(ctx, config.GetConnectionString(conf.String())) if err != nil { diff --git a/internal/views/entrycreate.go b/internal/views/entrycreate.go index ae60ffc..bdce9e1 100644 --- a/internal/views/entrycreate.go +++ b/internal/views/entrycreate.go @@ -24,7 +24,7 @@ type EntryCreatedResponse struct { DeleteKey string } -func buildCreatedResponse(meta *services.EntryMeta, keyString string) EntryCreatedResponse { +func BuildCreatedResponse(meta *services.EntryMeta, keyString string) EntryCreatedResponse { return EntryCreatedResponse{ UUID: meta.UUID, Created: meta.Created, @@ -43,20 +43,18 @@ func NewEntryCreateView(webExternalURL *url.URL) EntryCreateView { return EntryCreateView{webExternalURL: webExternalURL} } -func (e EntryCreateView) RenderEntryCreated(w http.ResponseWriter, r *http.Request, entry *services.EntryMeta, keyString string) { +func (e EntryCreateView) Render(w http.ResponseWriter, r *http.Request, entry EntryCreatedResponse) { w.Header().Add("x-entry-uuid", entry.UUID) - w.Header().Add("x-entry-key", keyString) + w.Header().Add("x-entry-key", entry.Key) w.Header().Add("x-entry-expire", entry.Expire.Format(time.RFC3339)) w.Header().Add("x-entry-delete-key", entry.DeleteKey) if r.Header.Get("Accept") == "application/json" { w.Header().Set("Content-Type", "application/json") - response := buildCreatedResponse(entry, keyString) - - json.NewEncoder(w).Encode(response) + json.NewEncoder(w).Encode(entry) } else { - newURL, err := uuid.GetUUIDUrlWithSecret(e.webExternalURL, entry.UUID, keyString) + newURL, err := uuid.GetUUIDUrlWithSecret(e.webExternalURL, entry.UUID, entry.Key) if err != nil { log.Println("Get UUID URL with secret failed", err) http.Error(w, "Internal error", http.StatusInternalServerError) @@ -67,7 +65,7 @@ func (e EntryCreateView) RenderEntryCreated(w http.ResponseWriter, r *http.Reque } } -func (e EntryCreateView) RenderCreateEntryErrorResponse(w http.ResponseWriter, r *http.Request, err error) { +func (e EntryCreateView) RenderError(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, parsers.ErrInvalidExpirationDate) { http.Error(w, "Invalid expiration", http.StatusBadRequest) return diff --git a/internal/views/entrydelete.go b/internal/views/entrydelete.go index bd34fa0..14a6db4 100644 --- a/internal/views/entrydelete.go +++ b/internal/views/entrydelete.go @@ -9,16 +9,18 @@ import ( "github.com/google/uuid" ) +type DeleteEntryResponse struct{} + type EntryDeleteView struct{} func NewEntryDeleteView() EntryDeleteView { return EntryDeleteView{} } -func (e EntryDeleteView) RenderDeleteEntry(w http.ResponseWriter, r *http.Request) { +func (e EntryDeleteView) Render(w http.ResponseWriter, r *http.Request, data DeleteEntryResponse) { w.WriteHeader(http.StatusAccepted) } -func (e EntryDeleteView) RenderDeleteEntryError(w http.ResponseWriter, r *http.Request, err error) { +func (e EntryDeleteView) RenderError(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, models.ErrEntryNotFound) || errors.Is(err, models.ErrEntryNotFound) { http.Error(w, "Not Found", http.StatusNotFound) diff --git a/internal/views/entrykeycreate.go b/internal/views/entrykeycreate.go new file mode 100644 index 0000000..82c9bca --- /dev/null +++ b/internal/views/entrykeycreate.go @@ -0,0 +1,73 @@ +package views + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/Ajnasz/sekret.link/internal/parsers" + "github.com/Ajnasz/sekret.link/internal/uuid" +) + +type GenerateEntryKeyResponseData struct { + // The UUID of the entry. + UUID string + // The key decryption key string of the entry. + Key string + + // The time when the entry was created. + Expire time.Time +} + +// GenerateEntryKeyView is the view for the GenerateEntryKey endpoint. +type GenerateEntryKeyView struct { + webExternalURL *url.URL +} + +func NewGenerateEntryKeyView(webExternalURL *url.URL) GenerateEntryKeyView { + return GenerateEntryKeyView{webExternalURL: webExternalURL} +} + +// RenderGenerateEntryKey renders the response for the GenerateEntryKey endpoint. +func (g GenerateEntryKeyView) Render(w http.ResponseWriter, r *http.Request, response GenerateEntryKeyResponseData) { + w.Header().Add("x-entry-uuid", response.UUID) + w.Header().Add("x-entry-key", response.Key) + w.Header().Add("x-entry-expire", response.Expire.Format(time.RFC3339)) + + if r.Header.Get("Accept") == "application/json" { + w.Header().Set("Content-Type", "application/json") + + json.NewEncoder(w).Encode(response) + } else { + newURL, err := uuid.GetUUIDUrlWithSecret(g.webExternalURL, response.UUID, response.Key) + + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + fmt.Fprint(w, newURL.String()) + } +} + +// RenderGenerateEntryKeyError renders the error response for the GenerateEntryKey endpoint. +func (v GenerateEntryKeyView) RenderError(w http.ResponseWriter, r *http.Request, err error) { + if errors.Is(err, parsers.ErrInvalidUUID) { + http.Error(w, "Invalid UUID", http.StatusBadRequest) + return + } else if errors.Is(err, parsers.ErrInvalidKey) { + http.Error(w, "Invalid key", http.StatusBadRequest) + return + } else if errors.Is(err, parsers.ErrInvalidExpirationDate) { + http.Error(w, "Invalid expiration", http.StatusBadRequest) + return + } else if errors.Is(err, parsers.ErrInvalidMaxRead) { + http.Error(w, "Invalid max read", http.StatusBadRequest) + return + } else { + http.Error(w, "Internal error", http.StatusInternalServerError) + } +} diff --git a/internal/views/entryread.go b/internal/views/entryread.go index 9a5c455..066500d 100644 --- a/internal/views/entryread.go +++ b/internal/views/entryread.go @@ -12,7 +12,7 @@ import ( "github.com/Ajnasz/sekret.link/internal/services" ) -type SecretResponse struct { +type EntryReadResponse struct { UUID string Key string Data string @@ -22,9 +22,10 @@ type SecretResponse struct { DeleteKey string } -func buildSecretResponse(meta services.Entry) SecretResponse { - return SecretResponse{ +func BuildEntryReadResponse(meta services.Entry, key string) EntryReadResponse { + return EntryReadResponse{ UUID: meta.UUID, + Key: key, Created: meta.Created, Expire: meta.Expire, Accessed: meta.Accessed, @@ -39,19 +40,16 @@ func NewEntryReadView() EntryReadView { return EntryReadView{} } -func (e EntryReadView) RenderReadEntry(w http.ResponseWriter, r *http.Request, entry *services.Entry, keyString string) { +func (e EntryReadView) Render(w http.ResponseWriter, r *http.Request, response EntryReadResponse) { if r.Header.Get("Accept") == "application/json" { - response := buildSecretResponse(*entry) - - response.Key = keyString json.NewEncoder(w).Encode(response) } else { w.WriteHeader(http.StatusOK) - w.Write(entry.Data) + w.Write([]byte(response.Data)) } } -func (e EntryReadView) RenderReadEntryError(w http.ResponseWriter, r *http.Request, err error) { +func (e EntryReadView) RenderError(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, services.ErrEntryExpired) { http.Error(w, "Gone", http.StatusGone) return @@ -62,6 +60,11 @@ func (e EntryReadView) RenderReadEntryError(w http.ResponseWriter, r *http.Reque return } + if errors.Is(err, services.ErrEntryNoRemainingReads) { + http.Error(w, "Gone", http.StatusGone) + return + } + if errors.Is(err, parsers.ErrInvalidUUID) { http.Error(w, "Bad request", http.StatusBadRequest) return diff --git a/internal/views/views.go b/internal/views/views.go new file mode 100644 index 0000000..b19de71 --- /dev/null +++ b/internal/views/views.go @@ -0,0 +1,8 @@ +package views + +import "net/http" + +type View[T any] interface { + Render(w http.ResponseWriter, r *http.Request, data T) + RenderError(w http.ResponseWriter, r *http.Request, err error) +}