Skip to content

Commit

Permalink
Support attribute keys (PK and FK) (#22)
Browse files Browse the repository at this point in the history
* add attribute keys for all supported databases

* added tests for attribute key

* add flag and docs

* update mocks
  • Loading branch information
KarnerTh authored Sep 11, 2022
1 parent 23bb526 commit fd38b79
Show file tree
Hide file tree
Showing 15 changed files with 218 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .mermerd.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
showAllConstraints: true
encloseWithMermaidBackticks: false
outputFileName: "my-db.mmd"
debug: false
omitConstraintLabels: false
omitAttributeKeys: false

# These connection strings are available as suggestions in the cli (use tab to access)
connectionStringSuggestions:
Expand Down
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func init() {
rootCmd.Flags().Bool(config.UseAllTablesKey, false, "use all available tables")
rootCmd.Flags().Bool(config.DebugKey, false, "show debug logs")
rootCmd.Flags().Bool(config.OmitConstraintLabelsKey, false, "omit the constraint labels")
rootCmd.Flags().Bool(config.OmitAttributeKeysKey, false, "omit the attribute keys (PK, FK)")
rootCmd.Flags().BoolP(config.EncloseWithMermaidBackticksKey, "e", false, "enclose output with mermaid backticks (needed for e.g. in markdown viewer)")
rootCmd.Flags().StringP(config.ConnectionStringKey, "c", "", "connection string that should be used")
rootCmd.Flags().StringP(config.SchemaKey, "s", "", "schema that should be used")
Expand All @@ -77,6 +78,7 @@ func init() {
bindFlagToViper(config.UseAllTablesKey)
bindFlagToViper(config.DebugKey)
bindFlagToViper(config.OmitConstraintLabelsKey)
bindFlagToViper(config.OmitAttributeKeysKey)
bindFlagToViper(config.EncloseWithMermaidBackticksKey)
bindFlagToViper(config.ConnectionStringKey)
bindFlagToViper(config.SchemaKey)
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const (
EncloseWithMermaidBackticksKey = "encloseWithMermaidBackticks"
DebugKey = "debug"
OmitConstraintLabelsKey = "omitConstraintLabels"
OmitAttributeKeysKey = "omitAttributeKeys"
)

type config struct{}
Expand All @@ -28,6 +29,7 @@ type MermerdConfig interface {
EncloseWithMermaidBackticks() bool
Debug() bool
OmitConstraintLabels() bool
OmitAttributeKeys() bool
}

func NewConfig() MermerdConfig {
Expand Down Expand Up @@ -73,3 +75,7 @@ func (c config) Debug() bool {
func (c config) OmitConstraintLabels() bool {
return viper.GetBool(OmitConstraintLabelsKey)
}

func (c config) OmitAttributeKeys() bool {
return viper.GetBool(OmitAttributeKeysKey)
}
54 changes: 43 additions & 11 deletions database/database_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import (
"github.com/stretchr/testify/assert"
)

type columnTestResult struct {
Name string
isPrimary bool
isForeign bool
}

func TestDatabaseIntegrations(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
Expand Down Expand Up @@ -103,33 +109,59 @@ func TestDatabaseIntegrations(t *testing.T) {
connector := getConnectionAndConnect(t)
testCases := []struct {
tableName string
expectedColumns []string
expectedColumns []columnTestResult
}{
{tableName: "article", expectedColumns: []string{"id", "title"}},
{tableName: "article_detail", expectedColumns: []string{"id", "created_at"}},
{tableName: "article_comment", expectedColumns: []string{"id", "article_id", "comment"}},
{tableName: "label", expectedColumns: []string{"id", "label"}},
{tableName: "article_label", expectedColumns: []string{"article_id", "label_id"}},
{tableName: "test_1_a", expectedColumns: []string{"id", "xid"}},
{tableName: "test_1_b", expectedColumns: []string{"aid", "bid"}},
{tableName: "article", expectedColumns: []columnTestResult{
{Name: "id", isPrimary: true, isForeign: false},
{Name: "title", isPrimary: false, isForeign: false},
}},
{tableName: "article_detail", expectedColumns: []columnTestResult{
{Name: "id", isPrimary: true, isForeign: true},
{Name: "created_at", isPrimary: false, isForeign: false},
}},
{tableName: "article_comment", expectedColumns: []columnTestResult{
{Name: "id", isPrimary: true, isForeign: false},
{Name: "article_id", isPrimary: false, isForeign: true},
{Name: "comment", isPrimary: false, isForeign: false},
}},
{tableName: "label", expectedColumns: []columnTestResult{
{Name: "id", isPrimary: true, isForeign: false},
{Name: "label", isPrimary: false, isForeign: false},
}},
{tableName: "article_label", expectedColumns: []columnTestResult{
{Name: "article_id", isPrimary: true, isForeign: true},
{Name: "label_id", isPrimary: true, isForeign: true},
}},
{tableName: "test_1_a", expectedColumns: []columnTestResult{
{Name: "id", isPrimary: true, isForeign: false},
{Name: "xid", isPrimary: true, isForeign: false},
}},
{tableName: "test_1_b", expectedColumns: []columnTestResult{
{Name: "aid", isPrimary: true, isForeign: true},
{Name: "bid", isPrimary: true, isForeign: true},
}},
}

for index, testCase := range testCases {
t.Run(fmt.Sprintf("run #%d", index), func(t *testing.T) {
// Arrange
tableName := testCase.tableName
var columnNames []string
var columnResult []columnTestResult

// Act
columns, err := connector.GetColumns(tableName)

// Assert
for _, column := range columns {
columnNames = append(columnNames, column.Name)
columnResult = append(columnResult, columnTestResult{
Name: column.Name,
isPrimary: column.IsPrimary,
isForeign: column.IsForeign,
})
}

assert.Nil(t, err)
assert.ElementsMatch(t, testCase.expectedColumns, columnNames)
assert.ElementsMatch(t, testCase.expectedColumns, columnResult)
})
}
})
Expand Down
23 changes: 18 additions & 5 deletions database/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,23 @@ func (c *mssqlConnector) GetTables(schemaName string) ([]string, error) {

func (c *mssqlConnector) GetColumns(tableName string) ([]ColumnResult, error) {
rows, err := c.db.Query(`
select column_name, data_type
from information_schema.columns
where table_name = @p1
order by ordinal_position
select c.column_name,
c.data_type,
(select count(*)
from information_schema.key_column_usage cu
left join information_schema.table_constraints tc on tc.constraint_name = cu.constraint_name
where cu.column_name = c.column_name
and cu.table_name = c.table_name
and tc.constraint_type = 'PRIMARY KEY') as is_primary,
(select count(*)
from information_schema.key_column_usage cu
left join information_schema.table_constraints tc on tc.constraint_name = cu.constraint_name
where cu.column_name = c.column_name
and cu.table_name = c.table_name
and tc.constraint_type = 'FOREIGN KEY') as is_foreign
from information_schema.columns c
where c.table_name = @p1
order by c.ordinal_position;
`, tableName)
if err != nil {
return nil, err
Expand All @@ -91,7 +104,7 @@ func (c *mssqlConnector) GetColumns(tableName string) ([]ColumnResult, error) {
var columns []ColumnResult
for rows.Next() {
var column ColumnResult
if err = rows.Scan(&column.Name, &column.DataType); err != nil {
if err = rows.Scan(&column.Name, &column.DataType, &column.IsPrimary, &column.IsForeign); err != nil {
return nil, err
}

Expand Down
22 changes: 17 additions & 5 deletions database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,22 @@ func (c *mySqlConnector) GetTables(schemaName string) ([]string, error) {

func (c *mySqlConnector) GetColumns(tableName string) ([]ColumnResult, error) {
rows, err := c.db.Query(`
select column_name, data_type
from information_schema.columns
where table_name = ?
order by ordinal_position
select c.column_name,
c.data_type,
(select count(*)
from information_schema.KEY_COLUMN_USAGE
where table_name = c.table_name
and column_name = c.column_name
and constraint_name = 'PRIMARY') as is_primary,
(select count(*)
from information_schema.key_column_usage cu
left join information_schema.table_constraints tc on tc.constraint_name = cu.constraint_name
where cu.column_name = c.column_name
and cu.table_name = c.table_name
and tc.constraint_type = 'FOREIGN KEY') as is_foreign
from information_schema.columns c
where c.table_name = ?
order by c.ordinal_position;
`, tableName)
if err != nil {
return nil, err
Expand All @@ -91,7 +103,7 @@ func (c *mySqlConnector) GetColumns(tableName string) ([]ColumnResult, error) {
var columns []ColumnResult
for rows.Next() {
var column ColumnResult
if err = rows.Scan(&column.Name, &column.DataType); err != nil {
if err = rows.Scan(&column.Name, &column.DataType, &column.IsPrimary, &column.IsForeign); err != nil {
return nil, err
}

Expand Down
23 changes: 18 additions & 5 deletions database/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,23 @@ func (c *postgresConnector) GetTables(schemaName string) ([]string, error) {

func (c *postgresConnector) GetColumns(tableName string) ([]ColumnResult, error) {
rows, err := c.db.Query(`
select column_name, data_type
from information_schema.columns
where table_name = $1
order by ordinal_position
select c.column_name,
c.data_type,
(select count(*)
from information_schema.key_column_usage cu
left join information_schema.table_constraints tc on tc.constraint_name = cu.constraint_name
where cu.column_name = c.column_name
and cu.table_name = c.table_name
and tc.constraint_type = 'PRIMARY KEY') as is_primary,
(select count(*)
from information_schema.key_column_usage cu
left join information_schema.table_constraints tc on tc.constraint_name = cu.constraint_name
where cu.column_name = c.column_name
and cu.table_name = c.table_name
and tc.constraint_type = 'FOREIGN KEY') as is_foreign
from information_schema.columns c
where c.table_name = $1
order by c.ordinal_position;
`, tableName)
if err != nil {
return nil, err
Expand All @@ -91,7 +104,7 @@ func (c *postgresConnector) GetColumns(tableName string) ([]ColumnResult, error)
var columns []ColumnResult
for rows.Next() {
var column ColumnResult
if err = rows.Scan(&column.Name, &column.DataType); err != nil {
if err = rows.Scan(&column.Name, &column.DataType, &column.IsPrimary, &column.IsForeign); err != nil {
return nil, err
}

Expand Down
6 changes: 4 additions & 2 deletions database/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ type TableResult struct {
}

type ColumnResult struct {
Name string
DataType string
Name string
DataType string
IsPrimary bool
IsForeign bool
}

type ConstraintResultList []ConstraintResult
Expand Down
22 changes: 20 additions & 2 deletions diagram/diagram.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ func (d diagram) Create(result *database.Result) error {

columnData := make([]ErdColumnData, len(table.Columns))
for columnIndex, column := range table.Columns {
attributeKey := getAttributeKey(column)
if d.config.OmitAttributeKeys() {
attributeKey = none
}

columnData[columnIndex] = ErdColumnData{
Name: column.Name,
DataType: column.DataType,
Name: column.Name,
DataType: column.DataType,
AttributeKey: attributeKey,
}
}

Expand Down Expand Up @@ -110,3 +116,15 @@ func tableNameInSlice(slice []ErdTableData, tableName string) bool {

return false
}

func getAttributeKey(column database.ColumnResult) ErdAttributeKey {
if column.IsPrimary {
return primaryKey
}

if column.IsForeign {
return foreignKey
}

return none
}
13 changes: 11 additions & 2 deletions diagram/diagram_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ const (
relationManyToOne ErdRelationType = "}o--||"
)

type ErdAttributeKey string

const (
primaryKey ErdAttributeKey = "PK"
foreignKey ErdAttributeKey = "FK"
none ErdAttributeKey = ""
)

type ErdDiagramData struct {
EncloseWithMermaidBackticks bool
Tables []ErdTableData
Expand All @@ -19,8 +27,9 @@ type ErdTableData struct {
}

type ErdColumnData struct {
Name string
DataType string
Name string
DataType string
AttributeKey ErdAttributeKey
}

type ErdConstraintData struct {
Expand Down
57 changes: 57 additions & 0 deletions diagram/diagram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,60 @@ func TestGetRelation(t *testing.T) {
})
}
}

func TestGetAttributeKey(t *testing.T) {
testCases := []struct {
column database.ColumnResult
expectedAttributeResult ErdAttributeKey
}{
{
column: database.ColumnResult{
Name: "",
DataType: "",
IsPrimary: true,
IsForeign: false,
},
expectedAttributeResult: primaryKey,
},
{
column: database.ColumnResult{
Name: "",
DataType: "",
IsPrimary: false,
IsForeign: true,
},
expectedAttributeResult: foreignKey,
},
{
column: database.ColumnResult{
Name: "",
DataType: "",
IsPrimary: true,
IsForeign: true,
},
expectedAttributeResult: primaryKey,
},
{
column: database.ColumnResult{
Name: "",
DataType: "",
IsPrimary: false,
IsForeign: false,
},
expectedAttributeResult: none,
},
}

for index, testCase := range testCases {
t.Run(fmt.Sprintf("run #%d", index), func(t *testing.T) {
// Arrange
column := testCase.column

// Act
result := getAttributeKey(column)

// Assert
assert.Equal(t, testCase.expectedAttributeResult, result)
})
}
}
2 changes: 1 addition & 1 deletion diagram/erd_template.gommd
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ erDiagram
{{- range .Tables}}
{{.Name}} {
{{- range .Columns}}
{{.DataType}} {{.Name}}
{{.DataType}} {{.Name}} {{.AttributeKey}}
{{- end}}
}
{{end -}}
Expand Down
1 change: 1 addition & 0 deletions exampleRunConfig.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ outputFileName: "my-db.mmd"
encloseWithMermaidBackticks: false
debug: false
omitConstraintLabels: false
omitAttributeKeys: false
Loading

0 comments on commit fd38b79

Please sign in to comment.