From 8c52c4cc8dfc12bccdd11be892c56a72f73be62b Mon Sep 17 00:00:00 2001 From: KarnerTh Date: Wed, 28 Dec 2022 20:28:57 +0100 Subject: [PATCH] Multi schema support (#27) * WIP: added support for multiple schemas; added support for mysql * add test setup for postgres and mssql; fix tests * check for schema name in mysql * fix tests * simplify tests with schema * add support for postgres and mssql; add tests * cleanup * add useAllSchemas flag * update changelog --- analyzer/analyzer.go | 64 +++++++++---- analyzer/analyzer_test.go | 79 ++++++++++------ analyzer/questioner.go | 12 +-- changelog.md | 3 +- cmd/root.go | 2 + config/config.go | 12 ++- config/config_test.go | 2 +- database/connector.go | 6 +- database/database_integration_test.go | 94 ++++++++++++++----- database/mssql.go | 37 +++++--- database/mysql.go | 35 ++++--- database/mysql_test.go | 2 +- database/postgres.go | 32 ++++--- database/postgres_test.go | 2 +- database/result.go | 7 +- database/table_name_util.go | 24 +++++ database/value_sanitizer.go | 2 +- database/value_sanitizer_test.go | 1 + diagram/diagram.go | 2 +- exampleRunConfig.yaml | 7 +- mocks/Analyzer.go | 36 +++---- mocks/Connector.go | 30 +++--- mocks/MermerdConfig.go | 26 ++++- mocks/Questioner.go | 10 +- readme.md | 9 +- test/docker-compose.yaml | 3 + test/mssql/entrypoint.sh | 1 + test/mssql/mssql-multiple-databases.sql | 26 +++++ test/mysql/mysql-multiple-databases.sql | 23 +++++ test/postgres/postgres-multiple-databases.sql | 22 +++++ util/map_util.go | 10 ++ 31 files changed, 449 insertions(+), 172 deletions(-) create mode 100644 database/table_name_util.go create mode 100644 test/mssql/mssql-multiple-databases.sql create mode 100644 test/mysql/mysql-multiple-databases.sql create mode 100644 test/postgres/postgres-multiple-databases.sql create mode 100644 util/map_util.go diff --git a/analyzer/analyzer.go b/analyzer/analyzer.go index ec4ecb8..55115b9 100644 --- a/analyzer/analyzer.go +++ b/analyzer/analyzer.go @@ -2,6 +2,7 @@ package analyzer import ( "errors" + "fmt" "github.com/sirupsen/logrus" @@ -19,11 +20,10 @@ type analyzer struct { type Analyzer interface { Analyze() (*database.Result, error) - GetConnectionString() (string, error) - GetSchema(db database.Connector) (string, error) - GetTables(db database.Connector, selectedSchema string) ([]string, error) - GetColumnsAndConstraints(db database.Connector, selectedTables []string) ([]database.TableResult, error) + GetSchemas(db database.Connector) ([]string, error) + GetTables(db database.Connector, selectedSchemas []string) ([]database.TableDetail, error) + GetColumnsAndConstraints(db database.Connector, selectedTables []database.TableDetail) ([]database.TableResult, error) } func NewAnalyzer(config config.MermerdConfig, connectorFactory database.ConnectorFactory, questioner Questioner) Analyzer { @@ -49,12 +49,12 @@ func (a analyzer) Analyze() (*database.Result, error) { defer db.Close() a.loadingSpinner.Stop() - selectedSchema, err := a.GetSchema(db) + selectedSchemas, err := a.GetSchemas(db) if err != nil { return nil, err } - selectedTables, err := a.GetTables(db, selectedSchema) + selectedTables, err := a.GetTables(db, selectedSchemas) if err != nil { return nil, err } @@ -75,8 +75,8 @@ func (a analyzer) GetConnectionString() (string, error) { return a.questioner.AskConnectionQuestion(a.config.ConnectionStringSuggestions()) } -func (a analyzer) GetSchema(db database.Connector) (string, error) { - if selectedSchema := a.config.Schema(); selectedSchema != "" { +func (a analyzer) GetSchemas(db database.Connector) ([]string, error) { + if selectedSchema := a.config.Schemas(); len(selectedSchema) > 0 { return selectedSchema, nil } @@ -85,44 +85,72 @@ func (a analyzer) GetSchema(db database.Connector) (string, error) { a.loadingSpinner.Stop() if err != nil { logrus.Error("Getting schemas failed", " | ", err) - return "", err + return []string{}, err } logrus.WithField("count", len(schemas)).Info("Got schemas") + if a.config.UseAllSchemas() { + return schemas, nil + } switch len(schemas) { case 0: - return "", errors.New("no schemas available") + return []string{}, errors.New("no schemas available") case 1: - return schemas[0], nil + return schemas, nil default: return a.questioner.AskSchemaQuestion(schemas) } } -func (a analyzer) GetTables(db database.Connector, selectedSchema string) ([]string, error) { +func (a analyzer) GetTables(db database.Connector, selectedSchemas []string) ([]database.TableDetail, error) { if selectedTables := a.config.SelectedTables(); len(selectedTables) > 0 { - return selectedTables, nil + return util.Map2(selectedTables, func(value string) database.TableDetail { + res, err := database.ParseTableName(value, selectedSchemas) + if err != nil { + logrus.Error("Could not parse table name", value) + } + + return res + }), nil } a.loadingSpinner.Start("Getting tables") - tables, err := db.GetTables(selectedSchema) + tables, err := db.GetTables(selectedSchemas) a.loadingSpinner.Stop() if err != nil { logrus.Error("Getting tables failed", " | ", err) return nil, err } + if len(tables) == 0 { + logrus.Error("No tables found") + } + logrus.WithField("count", len(tables)).Info("Got tables") if a.config.UseAllTables() { return tables, nil - } else { - return a.questioner.AskTableQuestion(tables) } + + tableNames := util.Map2(tables, func(table database.TableDetail) string { + return fmt.Sprintf("%s.%s", table.Schema, table.Name) + }) + surveyResult, err := a.questioner.AskTableQuestion(tableNames) + if err != nil { + return []database.TableDetail{}, err + } + return util.Map2(surveyResult, func(value string) database.TableDetail { + res, err := database.ParseTableName(value, selectedSchemas) + if err != nil { + logrus.Error("Could not parse table name", value) + } + + return res + }), nil } -func (a analyzer) GetColumnsAndConstraints(db database.Connector, selectedTables []string) ([]database.TableResult, error) { +func (a analyzer) GetColumnsAndConstraints(db database.Connector, selectedTables []database.TableDetail) ([]database.TableResult, error) { var tableResults []database.TableResult a.loadingSpinner.Start("Getting columns and constraints") for _, table := range selectedTables { @@ -138,7 +166,7 @@ func (a analyzer) GetColumnsAndConstraints(db database.Connector, selectedTables return nil, err } - tableResults = append(tableResults, database.TableResult{TableName: table, Columns: columns, Constraints: constraints}) + tableResults = append(tableResults, database.TableResult{Table: table, Columns: columns, Constraints: constraints}) } a.loadingSpinner.Stop() columnCount, constraintCount := getTableResultStats(tableResults) diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go index 1dc4b4e..6d17b6f 100644 --- a/analyzer/analyzer_test.go +++ b/analyzer/analyzer_test.go @@ -53,27 +53,46 @@ func TestAnalyzer_GetSchema(t *testing.T) { // Arrange analyzer, configMock, _, _ := getAnalyzerWithMocks() connectorMock := mocks.Connector{} - configMock.On("Schema").Return("configuredSchema").Once() + configMock.On("Schemas").Return([]string{"configuredSchema"}).Once() // Act - result, err := analyzer.GetSchema(&connectorMock) + result, err := analyzer.GetSchemas(&connectorMock) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) assert.Nil(t, err) - assert.Equal(t, "configuredSchema", result) + assert.ElementsMatch(t, []string{"configuredSchema"}, result) + }) + + t.Run("Use all available schema", func(t *testing.T) { + // Arrange + analyzer, configMock, _, _ := getAnalyzerWithMocks() + connectorMock := mocks.Connector{} + configMock.On("UseAllSchemas").Return(true).Once() + configMock.On("Schemas").Return([]string{}).Once() + connectorMock.On("GetSchemas").Return([]string{"schema1", "schema2"}, nil).Once() + + // Act + result, err := analyzer.GetSchemas(&connectorMock) + + // Assert + configMock.AssertExpectations(t) + connectorMock.AssertExpectations(t) + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"schema1", "schema2"}, result) }) t.Run("No schema available return error", func(t *testing.T) { // Arrange analyzer, configMock, _, _ := getAnalyzerWithMocks() connectorMock := mocks.Connector{} - configMock.On("Schema").Return("").Once() + configMock.On("Schemas").Return([]string{}).Once() + configMock.On("UseAllSchemas").Return(false).Once() connectorMock.On("GetSchemas").Return([]string{}, nil).Once() // Act - result, err := analyzer.GetSchema(&connectorMock) + result, err := analyzer.GetSchemas(&connectorMock) // Assert configMock.AssertExpectations(t) @@ -86,36 +105,38 @@ func TestAnalyzer_GetSchema(t *testing.T) { // Arrange analyzer, configMock, _, _ := getAnalyzerWithMocks() connectorMock := mocks.Connector{} - configMock.On("Schema").Return("").Once() + configMock.On("Schemas").Return([]string{}).Once() + configMock.On("UseAllSchemas").Return(false).Once() connectorMock.On("GetSchemas").Return([]string{"onlyItem"}, nil).Once() // Act - result, err := analyzer.GetSchema(&connectorMock) + result, err := analyzer.GetSchemas(&connectorMock) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) assert.Nil(t, err) - assert.Equal(t, "onlyItem", result) + assert.ElementsMatch(t, []string{"onlyItem"}, result) }) t.Run("Use value from questioner", func(t *testing.T) { // Arrange analyzer, configMock, _, questionerMock := getAnalyzerWithMocks() connectorMock := mocks.Connector{} - configMock.On("Schema").Return("").Once() + configMock.On("Schemas").Return([]string{}).Once() + configMock.On("UseAllSchemas").Return(false).Once() connectorMock.On("GetSchemas").Return([]string{"first", "second"}, nil).Once() - questionerMock.On("AskSchemaQuestion", []string{"first", "second"}).Return("first", nil).Once() + questionerMock.On("AskSchemaQuestion", []string{"first", "second"}).Return([]string{"first"}, nil).Once() // Act - result, err := analyzer.GetSchema(&connectorMock) + result, err := analyzer.GetSchemas(&connectorMock) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) questionerMock.AssertExpectations(t) assert.Nil(t, err) - assert.Equal(t, "first", result) + assert.ElementsMatch(t, []string{"first"}, result) }) } @@ -127,13 +148,14 @@ func TestAnalyzer_GetTables(t *testing.T) { configMock.On("SelectedTables").Return([]string{"configuredTable"}).Once() // Act - result, err := analyzer.GetTables(&connectorMock, "validSchema") + result, err := analyzer.GetTables(&connectorMock, []string{"validSchema"}) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) assert.Nil(t, err) - assert.ElementsMatch(t, []string{"configuredTable"}, result) + assert.Len(t, result, 1) + assert.Equal(t, "configuredTable", result[0].Name) }) t.Run("Use all available tables", func(t *testing.T) { @@ -141,17 +163,19 @@ func TestAnalyzer_GetTables(t *testing.T) { analyzer, configMock, _, _ := getAnalyzerWithMocks() connectorMock := mocks.Connector{} configMock.On("SelectedTables").Return([]string{}).Once() - connectorMock.On("GetTables", "validSchema").Return([]string{"tableA", "tableB"}, nil).Once() + connectorMock.On("GetTables", []string{"validSchema"}).Return([]database.TableDetail{{Schema: "validSchema", Name: "tableA"}, {Schema: "validSchema", Name: "tableB"}}, nil).Once() configMock.On("UseAllTables").Return(true).Once() // Act - result, err := analyzer.GetTables(&connectorMock, "validSchema") + result, err := analyzer.GetTables(&connectorMock, []string{"validSchema"}) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) assert.Nil(t, err) - assert.ElementsMatch(t, []string{"tableA", "tableB"}, result) + assert.Len(t, result, 2) + assert.Equal(t, "tableA", result[0].Name) + assert.Equal(t, "tableB", result[1].Name) }) t.Run("Use value from questioner", func(t *testing.T) { @@ -159,19 +183,20 @@ func TestAnalyzer_GetTables(t *testing.T) { analyzer, configMock, _, questionerMock := getAnalyzerWithMocks() connectorMock := mocks.Connector{} configMock.On("SelectedTables").Return([]string{}).Once() - connectorMock.On("GetTables", "validSchema").Return([]string{"tableA", "tableB"}, nil).Once() + connectorMock.On("GetTables", []string{"validSchema"}).Return([]database.TableDetail{{Schema: "validSchema", Name: "tableA"}, {Schema: "validSchema", Name: "tableB"}}, nil).Once() configMock.On("UseAllTables").Return(false).Once() - questionerMock.On("AskTableQuestion", []string{"tableA", "tableB"}).Return([]string{"tableA"}, nil).Once() + questionerMock.On("AskTableQuestion", []string{"validSchema.tableA", "validSchema.tableB"}).Return([]string{"validSchema.tableA"}, nil).Once() // Act - result, err := analyzer.GetTables(&connectorMock, "validSchema") + result, err := analyzer.GetTables(&connectorMock, []string{"validSchema"}) // Assert configMock.AssertExpectations(t) connectorMock.AssertExpectations(t) questionerMock.AssertExpectations(t) assert.Nil(t, err) - assert.ElementsMatch(t, []string{"tableA"}, result) + assert.Len(t, result, 1) + assert.Equal(t, "tableA", result[0].Name) }) } @@ -184,9 +209,9 @@ func TestAnalyzer_Analyze(t *testing.T) { connectionFactoryMock.On("NewConnector", "validConnectionString").Return(&connectorMock, nil).Once() connectorMock.On("Connect").Return(nil).Once() connectorMock.On("Close").Return().Once() - configMock.On("Schema").Return("validSchema").Once() - configMock.On("SelectedTables").Return([]string{"tableA", "tableB"}).Once() - connectorMock.On("GetColumns", "tableA").Return([]database.ColumnResult{ + configMock.On("Schemas").Return([]string{"validSchema"}).Once() + configMock.On("SelectedTables").Return([]string{"validSchema.tableA", "validSchema.tableB"}).Once() + connectorMock.On("GetColumns", database.TableDetail{Schema: "validSchema", Name: "tableA"}).Return([]database.ColumnResult{ { Name: "fieldA", DataType: "int", @@ -196,7 +221,7 @@ func TestAnalyzer_Analyze(t *testing.T) { DataType: "string", }, }, nil).Once() - connectorMock.On("GetColumns", "tableB").Return([]database.ColumnResult{ + connectorMock.On("GetColumns", database.TableDetail{Schema: "validSchema", Name: "tableB"}).Return([]database.ColumnResult{ { Name: "fieldC", DataType: "int", @@ -206,14 +231,14 @@ func TestAnalyzer_Analyze(t *testing.T) { DataType: "string", }, }, nil).Once() - connectorMock.On("GetConstraints", "tableA").Return([]database.ConstraintResult{{ + connectorMock.On("GetConstraints", database.TableDetail{Schema: "validSchema", Name: "tableA"}).Return([]database.ConstraintResult{{ FkTable: "tableA", PkTable: "tableB", ConstraintName: "testConstraint", IsPrimary: false, HasMultiplePK: false, }}, nil).Once() - connectorMock.On("GetConstraints", "tableB").Return([]database.ConstraintResult{{ + connectorMock.On("GetConstraints", database.TableDetail{Schema: "validSchema", Name: "tableB"}).Return([]database.ConstraintResult{{ FkTable: "tableA", PkTable: "tableB", ConstraintName: "testConstraint", diff --git a/analyzer/questioner.go b/analyzer/questioner.go index 5cb6853..cf6c857 100644 --- a/analyzer/questioner.go +++ b/analyzer/questioner.go @@ -10,7 +10,7 @@ type questioner struct{} type Questioner interface { AskConnectionQuestion(suggestions []string) (string, error) - AskSchemaQuestion(schemas []string) (string, error) + AskSchemaQuestion(schemas []string) ([]string, error) AskTableQuestion(tables []string) ([]string, error) } @@ -35,14 +35,14 @@ func (q questioner) AskConnectionQuestion(suggestions []string) (string, error) return os.ExpandEnv(result), nil } -func (q questioner) AskSchemaQuestion(schemas []string) (string, error) { - var result string - question := &survey.Select{ - Message: "Choose a schema:", +func (q questioner) AskSchemaQuestion(schemas []string) ([]string, error) { + var result []string + question := &survey.MultiSelect{ + Message: "Choose schemas:", Options: schemas, } - err := survey.AskOne(question, &result) + err := survey.AskOne(question, &result, survey.WithValidator(survey.MinItems(1))) return result, err } diff --git a/changelog.md b/changelog.md index b28e3c5..9c1c923 100644 --- a/changelog.md +++ b/changelog.md @@ -4,9 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) (after version 0.0.5). -## [0.5.0] - 2022-12-xx +## [0.5.0] - 2022-12-28 ### Added - Support enum description ([Issue #15](https://github.com/KarnerTh/mermerd/issues/15)) +- Support multiple schemas ([Issue #23](https://github.com/KarnerTh/mermerd/issues/23)) ## [0.4.1] - 2022-09-28 ### Fixed diff --git a/cmd/root.go b/cmd/root.go index d60400a..980af23 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -65,6 +65,7 @@ func init() { rootCmd.Flags().StringVar(&runConfig, "runConfig", "", "run configuration (replaces global configuration)") rootCmd.Flags().Bool(config.ShowAllConstraintsKey, false, "show all constraints, even though the table of the resulting constraint was not selected") rootCmd.Flags().Bool(config.UseAllTablesKey, false, "use all available tables") + rootCmd.Flags().Bool(config.UseAllSchemasKey, false, "use all available schemas") rootCmd.Flags().Bool(config.DebugKey, false, "show debug logs") rootCmd.Flags().Bool(config.OmitConstraintLabelsKey, false, "omit the constraint labels") rootCmd.Flags().Bool(config.OmitAttributeKeysKey, false, "omit the attribute keys (PK, FK)") @@ -77,6 +78,7 @@ func init() { bindFlagToViper(config.ShowAllConstraintsKey) bindFlagToViper(config.UseAllTablesKey) + bindFlagToViper(config.UseAllSchemasKey) bindFlagToViper(config.DebugKey) bindFlagToViper(config.OmitConstraintLabelsKey) bindFlagToViper(config.OmitAttributeKeysKey) diff --git a/config/config.go b/config/config.go index b7a9bac..62d4ad0 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,7 @@ const ( OmitConstraintLabelsKey = "omitConstraintLabels" OmitAttributeKeysKey = "omitAttributeKeys" ShowEnumValuesKey = "showEnumValues" + UseAllSchemasKey = "useAllSchemas" ) type config struct{} @@ -22,7 +23,7 @@ type config struct{} type MermerdConfig interface { ShowAllConstraints() bool UseAllTables() bool - Schema() string + Schemas() []string ConnectionString() string OutputFileName() string ConnectionStringSuggestions() []string @@ -32,6 +33,7 @@ type MermerdConfig interface { OmitConstraintLabels() bool OmitAttributeKeys() bool ShowEnumValues() bool + UseAllSchemas() bool } func NewConfig() MermerdConfig { @@ -46,8 +48,8 @@ func (c config) UseAllTables() bool { return viper.GetBool(UseAllTablesKey) } -func (c config) Schema() string { - return viper.GetString(SchemaKey) +func (c config) Schemas() []string { + return viper.GetStringSlice(SchemaKey) } func (c config) ConnectionString() string { @@ -85,3 +87,7 @@ func (c config) OmitAttributeKeys() bool { func (c config) ShowEnumValues() bool { return viper.GetBool(ShowEnumValuesKey) } + +func (c config) UseAllSchemas() bool { + return viper.GetBool(UseAllSchemasKey) +} diff --git a/config/config_test.go b/config/config_test.go index 0327a85..8716d14 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -39,7 +39,7 @@ connectionStringSuggestions: // Assert assert.Nil(t, err) assert.Equal(t, "connectionStringExample", config.ConnectionString()) - assert.Equal(t, "public", config.Schema()) + assert.ElementsMatch(t, []string{"public"}, config.Schemas()) assert.Equal(t, false, config.UseAllTables()) assert.ElementsMatch(t, []string{"city", "customer"}, config.SelectedTables()) assert.Equal(t, true, config.ShowAllConstraints()) diff --git a/database/connector.go b/database/connector.go index 43d60cc..3e1be35 100644 --- a/database/connector.go +++ b/database/connector.go @@ -13,7 +13,7 @@ type Connector interface { Close() GetDbType() DbType GetSchemas() ([]string, error) - GetTables(schemaName string) ([]string, error) - GetColumns(tableName string) ([]ColumnResult, error) - GetConstraints(tableName string) ([]ConstraintResult, error) + GetTables(schemaNames []string) ([]TableDetail, error) + GetColumns(tableName TableDetail) ([]ColumnResult, error) + GetConstraints(tableName TableDetail) ([]ConstraintResult, error) } diff --git a/database/database_integration_test.go b/database/database_integration_test.go index 46ebda9..e0cb6bf 100644 --- a/database/database_integration_test.go +++ b/database/database_integration_test.go @@ -2,9 +2,9 @@ package database import ( "fmt" - "github.com/sirupsen/logrus" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) @@ -21,7 +21,7 @@ type connectionParameter struct { var ( testConnectionPostgres connectionParameter = connectionParameter{connectionString: "postgresql://user:password@localhost:5432/mermerd_test", schema: "public"} - testConnectionMySql connectionParameter = connectionParameter{connectionString: "mysql://user:password@tcp(127.0.0.1:3306)/mermerd_test", schema: "mermerd_test"} + testConnectionMySql connectionParameter = connectionParameter{connectionString: "mysql://root:password@tcp(127.0.0.1:3306)/mermerd_test", schema: "mermerd_test"} testConnectionMsSql connectionParameter = connectionParameter{connectionString: "sqlserver://sa:securePassword1!@localhost:1433?database=mermerd_test", schema: "dbo"} ) @@ -95,18 +95,19 @@ func TestDatabaseIntegrations(t *testing.T) { schema := testCase.schema // Act - tables, err := connector.GetTables(schema) + tables, err := connector.GetTables([]string{schema}) // Assert - expectedResult := []string{ - "article", - "article_detail", - "article_comment", - "label", - "article_label", - "test_1_a", - "test_1_b", - "test_2_enum", + expectedResult := []TableDetail{ + {Schema: schema, Name: "article"}, + {Schema: schema, Name: "article_detail"}, + {Schema: schema, Name: "article_comment"}, + {Schema: schema, Name: "label"}, + {Schema: schema, Name: "article_label"}, + {Schema: schema, Name: "test_1_a"}, + {Schema: schema, Name: "test_1_b"}, + {Schema: schema, Name: "test_2_enum"}, + {Schema: schema, Name: "test_3_a"}, } assert.Nil(t, err) assert.ElementsMatch(t, expectedResult, tables) @@ -114,7 +115,7 @@ func TestDatabaseIntegrations(t *testing.T) { t.Run("GetColumns", func(t *testing.T) { connector := getConnectionAndConnect(t) - testCases := []struct { + subTestCases := []struct { tableName string expectedColumns []columnTestResult }{ @@ -149,10 +150,10 @@ func TestDatabaseIntegrations(t *testing.T) { }}, } - for index, testCase := range testCases { + for index, subTestCase := range subTestCases { t.Run(fmt.Sprintf("run #%d", index), func(t *testing.T) { // Arrange - tableName := testCase.tableName + tableName := TableDetail{Schema: testCase.schema, Name: subTestCase.tableName} var columnResult []columnTestResult // Act @@ -168,7 +169,7 @@ func TestDatabaseIntegrations(t *testing.T) { } assert.Nil(t, err) - assert.ElementsMatch(t, testCase.expectedColumns, columnResult) + assert.ElementsMatch(t, subTestCase.expectedColumns, columnResult) }) } }) @@ -178,7 +179,7 @@ func TestDatabaseIntegrations(t *testing.T) { t.Run("One-to-one relation", func(t *testing.T) { // Arrange - tableName := "article_detail" + tableName := TableDetail{Schema: testCase.schema, Name: "article_detail"} // Act constraintResults, err := connector.GetConstraints(tableName) @@ -193,7 +194,7 @@ func TestDatabaseIntegrations(t *testing.T) { t.Run("Many-to-one relation #1", func(t *testing.T) { // Arrange - tableName := "article_comment" + tableName := TableDetail{Schema: testCase.schema, Name: "article_comment"} // Act constraintResults, err := connector.GetConstraints(tableName) @@ -208,8 +209,8 @@ func TestDatabaseIntegrations(t *testing.T) { t.Run("Many-to-one relation #2", func(t *testing.T) { // Arrange - pkTableName := "article" - fkTableName := "article_label" + pkTableName := TableDetail{Schema: testCase.schema, Name: "article"} + fkTableName := TableDetail{Schema: testCase.schema, Name: "article_label"} // Act constraintResults, err := connector.GetConstraints(pkTableName) @@ -218,7 +219,7 @@ func TestDatabaseIntegrations(t *testing.T) { assert.Nil(t, err) var constraint *ConstraintResult for _, item := range constraintResults { - if item.FkTable == fkTableName { + if item.FkTable == fkTableName.Name { constraint = &item break } @@ -229,9 +230,9 @@ func TestDatabaseIntegrations(t *testing.T) { }) // Multiple primary keys (https://github.com/KarnerTh/mermerd/issues/8) - t.Run("Test 1 (Issue #8)", func(t *testing.T) { + t.Run("Multiple primary keys (Issue #8)", func(t *testing.T) { // Arrange - pkTableName := "test_1_b" + pkTableName := TableDetail{Schema: testCase.schema, Name: "test_1_b"} // Act constraintResults, err := connector.GetConstraints(pkTableName) @@ -248,6 +249,53 @@ func TestDatabaseIntegrations(t *testing.T) { assert.Equal(t, constraintResults[1].ColumnName, "bid") }) }) + + t.Run("Multiple schemas (Issue #23)", func(t *testing.T) { + connector := getConnectionAndConnect(t) + + t.Run("GetTables", func(t *testing.T) { + // Arrange + secondSchema := "other_db" + schemas := []string{testCase.schema, secondSchema} + + // Act + tables, err := connector.GetTables(schemas) + + // Assert + expectedResult := []TableDetail{ + {Schema: testCase.schema, Name: "article"}, + {Schema: testCase.schema, Name: "article_detail"}, + {Schema: testCase.schema, Name: "article_comment"}, + {Schema: testCase.schema, Name: "label"}, + {Schema: testCase.schema, Name: "article_label"}, + {Schema: testCase.schema, Name: "test_1_a"}, + {Schema: testCase.schema, Name: "test_1_b"}, + {Schema: testCase.schema, Name: "test_2_enum"}, + {Schema: testCase.schema, Name: "test_3_a"}, + {Schema: secondSchema, Name: "test_3_b"}, + {Schema: secondSchema, Name: "test_3_c"}, + } + assert.Nil(t, err) + assert.ElementsMatch(t, expectedResult, tables) + }) + + t.Run("GetCrossSchemaConstraints", func(t *testing.T) { + // Arrange + tableName := TableDetail{Schema: "other_db", Name: "test_3_b"} + + // Act + constraintResults, err := connector.GetConstraints(tableName) + + // Assert + assert.Nil(t, err) + assert.Len(t, constraintResults, 1) + assert.False(t, constraintResults[0].IsPrimary) + assert.False(t, constraintResults[0].HasMultiplePK) + assert.Equal(t, constraintResults[0].ColumnName, "aid") + assert.Equal(t, constraintResults[0].FkTable, "test_3_b") + assert.Equal(t, constraintResults[0].PkTable, "test_3_a") + }) + }) }) } } diff --git a/database/mssql.go b/database/mssql.go index 30b57ef..6747c12 100644 --- a/database/mssql.go +++ b/database/mssql.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "strings" _ "github.com/denisenkom/go-mssqldb" ) @@ -53,31 +54,39 @@ func (c *mssqlConnector) GetSchemas() ([]string, error) { return schemas, nil } -func (c *mssqlConnector) GetTables(schemaName string) ([]string, error) { +func (c *mssqlConnector) GetTables(schemaNames []string) ([]TableDetail, error) { + args := make([]any, len(schemaNames)) + searchPlaceholder := make([]string, len(schemaNames)) + for i, schemaName := range schemaNames { + args[i] = schemaName + searchPlaceholder[i] = fmt.Sprintf("@p%d", i+1) + } rows, err := c.db.Query(` - select table_name + select table_schema, table_name from information_schema.tables where table_type = 'BASE TABLE' - and table_schema = @p1 - `, schemaName) + and table_schema in(`+strings.Join(searchPlaceholder, ",")+`) + `, args...) if err != nil { return nil, err } - var tables []string + var tables []TableDetail for rows.Next() { - var table string - if err = rows.Scan(&table); err != nil { + var table TableDetail + if err = rows.Scan(&table.Schema, &table.Name); err != nil { return nil, err } - tables = append(tables, SanitizeValue(table)) + table.Name = SanitizeValue(table.Name) + + tables = append(tables, table) } return tables, nil } -func (c *mssqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { +func (c *mssqlConnector) GetColumns(tableName TableDetail) ([]ColumnResult, error) { rows, err := c.db.Query(` select c.column_name, c.data_type, @@ -94,9 +103,9 @@ func (c *mssqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { and cu.table_name = c.table_name and tc.constraint_type = 'FOREIGN KEY') as is_foreign from information_schema.columns c - where c.table_name = @p1 + where c.table_name = @p1 and c.TABLE_SCHEMA = @p2 order by c.ordinal_position; - `, tableName) + `, tableName.Name, tableName.Schema) if err != nil { return nil, err } @@ -117,7 +126,7 @@ func (c *mssqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { return columns, nil } -func (c *mssqlConnector) GetConstraints(tableName string) ([]ConstraintResult, error) { +func (c *mssqlConnector) GetConstraints(tableName TableDetail) ([]ConstraintResult, error) { rows, err := c.db.Query(` select fk.table_name, pk.table_name, @@ -144,8 +153,8 @@ from information_schema.referential_constraints c inner join information_schema.table_constraints fk on c.constraint_name = fk.constraint_name inner join information_schema.table_constraints pk on c.unique_constraint_name = pk.constraint_name inner join information_schema.key_column_usage kcu on c.constraint_name = kcu.constraint_name -where fk.table_name = @p1 or pk.table_name = @p1; - `, tableName) +where c.CONSTRAINT_SCHEMA = @p1 and (fk.table_name = @p2 or pk.table_name = @p2); + `, tableName.Schema, tableName.Name) if err != nil { return nil, err } diff --git a/database/mysql.go b/database/mysql.go index f1cdb91..a22322e 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "strings" _ "github.com/go-sql-driver/mysql" ) @@ -53,31 +54,37 @@ func (c *mySqlConnector) GetSchemas() ([]string, error) { return schemas, nil } -func (c *mySqlConnector) GetTables(schemaName string) ([]string, error) { +func (c *mySqlConnector) GetTables(schemaNames []string) ([]TableDetail, error) { + args := make([]any, len(schemaNames)) + for i, schemaName := range schemaNames { + args[i] = schemaName + } rows, err := c.db.Query(` - select table_name + select table_schema, table_name from information_schema.tables where table_type = 'BASE TABLE' - and table_schema = ? - `, schemaName) + and table_schema in (?`+strings.Repeat(",?", len(schemaNames)-1)+`) + `, args...) if err != nil { return nil, err } - var tables []string + var tables []TableDetail for rows.Next() { - var table string - if err = rows.Scan(&table); err != nil { + var table TableDetail + if err = rows.Scan(&table.Schema, &table.Name); err != nil { return nil, err } - tables = append(tables, SanitizeValue(table)) + table.Name = SanitizeValue(table.Name) + + tables = append(tables, table) } return tables, nil } -func (c *mySqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { +func (c *mySqlConnector) GetColumns(tableName TableDetail) ([]ColumnResult, error) { rows, err := c.db.Query(` select c.column_name, c.data_type, @@ -94,9 +101,9 @@ func (c *mySqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { and tc.constraint_type = 'FOREIGN KEY') as is_foreign, case when c.data_type = 'enum' then REPLACE(REPLACE(REPLACE(REPLACE(c.column_type, 'enum', ''), '\'', ''), '(', ''), ')', '') else '' end as enum_values from information_schema.columns c - where c.table_name = ? + where c.table_name = ? and c.TABLE_SCHEMA = ? order by c.ordinal_position; - `, tableName) + `, tableName.Name, tableName.Schema) if err != nil { return nil, err } @@ -117,7 +124,7 @@ func (c *mySqlConnector) GetColumns(tableName string) ([]ColumnResult, error) { return columns, nil } -func (c *mySqlConnector) GetConstraints(tableName string) ([]ConstraintResult, error) { +func (c *mySqlConnector) GetConstraints(tableName TableDetail) ([]ConstraintResult, error) { rows, err := c.db.Query(` select c.TABLE_NAME, c.REFERENCED_TABLE_NAME, @@ -139,8 +146,8 @@ func (c *mySqlConnector) GetConstraints(tableName string) ([]ConstraintResult, e ) "hasMultiplePk" from information_schema.REFERENTIAL_CONSTRAINTS c inner join information_schema.KEY_COLUMN_USAGE kcu on c.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME - where c.TABLE_NAME = ? or c.REFERENCED_TABLE_NAME = ? - `, tableName, tableName) + where c.CONSTRAINT_SCHEMA = ? and (c.TABLE_NAME = ? or c.REFERENCED_TABLE_NAME = ?) + `, tableName.Schema, tableName.Name, tableName.Name) if err != nil { return nil, err } diff --git a/database/mysql_test.go b/database/mysql_test.go index 8ffce22..d9ad82d 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -21,7 +21,7 @@ func TestMysqlEnums(t *testing.T) { logrus.Error(err) t.FailNow() } - columns, err := connector.GetColumns("test_2_enum") + columns, err := connector.GetColumns(TableDetail{Schema: "mermerd_test", Name: "test_2_enum"}) // Assert for _, column := range columns { diff --git a/database/postgres.go b/database/postgres.go index 6bc5f14..3e71856 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "strings" _ "github.com/jackc/pgx/v4/stdlib" ) @@ -53,31 +54,34 @@ func (c *postgresConnector) GetSchemas() ([]string, error) { return schemas, nil } -func (c *postgresConnector) GetTables(schemaName string) ([]string, error) { +func (c *postgresConnector) GetTables(schemaNames []string) ([]TableDetail, error) { + schemaSearch := "{" + strings.Join(schemaNames, ",") + "}" rows, err := c.db.Query(` - select table_name + select table_schema, table_name from information_schema.tables where table_type = 'BASE TABLE' - and table_schema = $1 - `, schemaName) + and table_schema = ANY($1::varchar[]) + `, schemaSearch) if err != nil { return nil, err } - var tables []string + var tables []TableDetail for rows.Next() { - var table string - if err = rows.Scan(&table); err != nil { + var table TableDetail + if err = rows.Scan(&table.Schema, &table.Name); err != nil { return nil, err } - tables = append(tables, SanitizeValue(table)) + table.Name = SanitizeValue(table.Name) + + tables = append(tables, table) } return tables, nil } -func (c *postgresConnector) GetColumns(tableName string) ([]ColumnResult, error) { +func (c *postgresConnector) GetColumns(tableName TableDetail) ([]ColumnResult, error) { rows, err := c.db.Query(` select c.column_name, (case @@ -101,14 +105,14 @@ func (c *postgresConnector) GetColumns(tableName string) ([]ColumnResult, error) from information_schema.columns c left join pg_type typ on c.udt_name = typ.typname left join pg_enum enu on typ.oid = enu.enumtypid - where c.table_name = $1 + where c.table_name = $1 and c.table_schema = $2 group by c.column_name, c.table_name, c.data_type, c.udt_name, c.ordinal_position order by c.ordinal_position; - `, tableName) + `, tableName.Name, tableName.Schema) if err != nil { return nil, err } @@ -129,7 +133,7 @@ func (c *postgresConnector) GetColumns(tableName string) ([]ColumnResult, error) return columns, nil } -func (c *postgresConnector) GetConstraints(tableName string) ([]ConstraintResult, error) { +func (c *postgresConnector) GetConstraints(tableName TableDetail) ([]ConstraintResult, error) { rows, err := c.db.Query(` select fk.table_name, pk.table_name, @@ -157,8 +161,8 @@ func (c *postgresConnector) GetConstraints(tableName string) ([]ConstraintResult inner join information_schema.table_constraints fk on c.constraint_name = fk.constraint_name inner join information_schema.table_constraints pk on c.unique_constraint_name = pk.constraint_name inner join information_schema.key_column_usage kcu on c.constraint_name = kcu.constraint_name - where fk.table_name = $1 or pk.table_name = $1; - `, tableName) + where c.constraint_schema = $1 and (fk.table_name = $2 or pk.table_name = $2); + `, tableName.Schema, tableName.Name) if err != nil { return nil, err } diff --git a/database/postgres_test.go b/database/postgres_test.go index 30f1d82..9780fa0 100644 --- a/database/postgres_test.go +++ b/database/postgres_test.go @@ -21,7 +21,7 @@ func TestPostgresEnums(t *testing.T) { logrus.Error(err) t.FailNow() } - columns, err := connector.GetColumns("test_2_enum") + columns, err := connector.GetColumns(TableDetail{Schema: "public", Name: "test_2_enum"}) // Assert for _, column := range columns { diff --git a/database/result.go b/database/result.go index 6db1b7d..47bb344 100644 --- a/database/result.go +++ b/database/result.go @@ -5,11 +5,16 @@ type Result struct { } type TableResult struct { - TableName string + Table TableDetail Columns []ColumnResult Constraints ConstraintResultList } +type TableDetail struct { + Schema string + Name string +} + type ColumnResult struct { Name string DataType string diff --git a/database/table_name_util.go b/database/table_name_util.go new file mode 100644 index 0000000..d72b6bd --- /dev/null +++ b/database/table_name_util.go @@ -0,0 +1,24 @@ +package database + +import ( + "errors" + "strings" +) + +func ParseTableName(value string, selectedSchemas []string) (TableDetail, error) { + parts := strings.Split(value, ".") + + if len(parts) == 1 { + if len(selectedSchemas) != 1 { + return TableDetail{}, errors.New("If table names do not specify the schema, exactly one selected schema should be present") + } + + return TableDetail{Schema: selectedSchemas[0], Name: parts[0]}, nil + } + + if len(parts) == 2 { + return TableDetail{Schema: parts[0], Name: parts[1]}, nil + } + + return TableDetail{}, errors.New("Could not parse table name") +} diff --git a/database/value_sanitizer.go b/database/value_sanitizer.go index 4ea6f7a..509ec38 100644 --- a/database/value_sanitizer.go +++ b/database/value_sanitizer.go @@ -8,7 +8,7 @@ import ( func SanitizeValue(value string) string { result := strings.ReplaceAll(value, " ", "_") - reg := regexp.MustCompile("[^a-zA-Z0-9_-]+") + reg := regexp.MustCompile("[^a-zA-Z0-9_.-]+") result = reg.ReplaceAllString(result, "") return result diff --git a/database/value_sanitizer_test.go b/database/value_sanitizer_test.go index 9a7ab35..e47a352 100644 --- a/database/value_sanitizer_test.go +++ b/database/value_sanitizer_test.go @@ -16,6 +16,7 @@ func TestSanitizeValue(t *testing.T) { {inputValue: "numbers are allowed 7", expectedResult: "numbers_are_allowed_7"}, {inputValue: "valid_stays_valid", expectedResult: "valid_stays_valid"}, {inputValue: "valid-stays-valid", expectedResult: "valid-stays-valid"}, + {inputValue: "dots.are.allowed", expectedResult: "dots.are.allowed"}, {inputValue: "symbols_$_are_&_not_ยง_allowed", expectedResult: "symbols__are__not__allowed"}, {inputValue: "also: not allowed", expectedResult: "also_not_allowed"}, } diff --git a/diagram/diagram.go b/diagram/diagram.go index a7ad97d..6c39adb 100644 --- a/diagram/diagram.go +++ b/diagram/diagram.go @@ -68,7 +68,7 @@ func (d diagram) Create(result *database.Result) error { } tableData[tableIndex] = ErdTableData{ - Name: table.TableName, + Name: table.Table.Name, Columns: columnData, } } diff --git a/exampleRunConfig.yaml b/exampleRunConfig.yaml index aff978e..e278d9b 100644 --- a/exampleRunConfig.yaml +++ b/exampleRunConfig.yaml @@ -1,6 +1,11 @@ # Connection properties connectionString: "postgresql://user:password@localhost:5432/mermerd_test" -schema: "public" + +# Define what tables should be used +#useAllSchemas: true +schema: + - "public" + - "other_db" # Define what tables should be used #useAllTables: true diff --git a/mocks/Analyzer.go b/mocks/Analyzer.go index e507b94..d306e08 100644 --- a/mocks/Analyzer.go +++ b/mocks/Analyzer.go @@ -36,11 +36,11 @@ func (_m *Analyzer) Analyze() (*database.Result, error) { } // GetColumnsAndConstraints provides a mock function with given fields: db, selectedTables -func (_m *Analyzer) GetColumnsAndConstraints(db database.Connector, selectedTables []string) ([]database.TableResult, error) { +func (_m *Analyzer) GetColumnsAndConstraints(db database.Connector, selectedTables []database.TableDetail) ([]database.TableResult, error) { ret := _m.Called(db, selectedTables) var r0 []database.TableResult - if rf, ok := ret.Get(0).(func(database.Connector, []string) []database.TableResult); ok { + if rf, ok := ret.Get(0).(func(database.Connector, []database.TableDetail) []database.TableResult); ok { r0 = rf(db, selectedTables) } else { if ret.Get(0) != nil { @@ -49,7 +49,7 @@ func (_m *Analyzer) GetColumnsAndConstraints(db database.Connector, selectedTabl } var r1 error - if rf, ok := ret.Get(1).(func(database.Connector, []string) error); ok { + if rf, ok := ret.Get(1).(func(database.Connector, []database.TableDetail) error); ok { r1 = rf(db, selectedTables) } else { r1 = ret.Error(1) @@ -79,15 +79,17 @@ func (_m *Analyzer) GetConnectionString() (string, error) { return r0, r1 } -// GetSchema provides a mock function with given fields: db -func (_m *Analyzer) GetSchema(db database.Connector) (string, error) { +// GetSchemas provides a mock function with given fields: db +func (_m *Analyzer) GetSchemas(db database.Connector) ([]string, error) { ret := _m.Called(db) - var r0 string - if rf, ok := ret.Get(0).(func(database.Connector) string); ok { + var r0 []string + if rf, ok := ret.Get(0).(func(database.Connector) []string); ok { r0 = rf(db) } else { - r0 = ret.Get(0).(string) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } } var r1 error @@ -100,22 +102,22 @@ func (_m *Analyzer) GetSchema(db database.Connector) (string, error) { return r0, r1 } -// GetTables provides a mock function with given fields: db, selectedSchema -func (_m *Analyzer) GetTables(db database.Connector, selectedSchema string) ([]string, error) { - ret := _m.Called(db, selectedSchema) +// GetTables provides a mock function with given fields: db, selectedSchemas +func (_m *Analyzer) GetTables(db database.Connector, selectedSchemas []string) ([]database.TableDetail, error) { + ret := _m.Called(db, selectedSchemas) - var r0 []string - if rf, ok := ret.Get(0).(func(database.Connector, string) []string); ok { - r0 = rf(db, selectedSchema) + var r0 []database.TableDetail + if rf, ok := ret.Get(0).(func(database.Connector, []string) []database.TableDetail); ok { + r0 = rf(db, selectedSchemas) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) + r0 = ret.Get(0).([]database.TableDetail) } } var r1 error - if rf, ok := ret.Get(1).(func(database.Connector, string) error); ok { - r1 = rf(db, selectedSchema) + if rf, ok := ret.Get(1).(func(database.Connector, []string) error); ok { + r1 = rf(db, selectedSchemas) } else { r1 = ret.Error(1) } diff --git a/mocks/Connector.go b/mocks/Connector.go index a49da9e..58a3ed9 100644 --- a/mocks/Connector.go +++ b/mocks/Connector.go @@ -32,11 +32,11 @@ func (_m *Connector) Connect() error { } // GetColumns provides a mock function with given fields: tableName -func (_m *Connector) GetColumns(tableName string) ([]database.ColumnResult, error) { +func (_m *Connector) GetColumns(tableName database.TableDetail) ([]database.ColumnResult, error) { ret := _m.Called(tableName) var r0 []database.ColumnResult - if rf, ok := ret.Get(0).(func(string) []database.ColumnResult); ok { + if rf, ok := ret.Get(0).(func(database.TableDetail) []database.ColumnResult); ok { r0 = rf(tableName) } else { if ret.Get(0) != nil { @@ -45,7 +45,7 @@ func (_m *Connector) GetColumns(tableName string) ([]database.ColumnResult, erro } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { + if rf, ok := ret.Get(1).(func(database.TableDetail) error); ok { r1 = rf(tableName) } else { r1 = ret.Error(1) @@ -55,11 +55,11 @@ func (_m *Connector) GetColumns(tableName string) ([]database.ColumnResult, erro } // GetConstraints provides a mock function with given fields: tableName -func (_m *Connector) GetConstraints(tableName string) ([]database.ConstraintResult, error) { +func (_m *Connector) GetConstraints(tableName database.TableDetail) ([]database.ConstraintResult, error) { ret := _m.Called(tableName) var r0 []database.ConstraintResult - if rf, ok := ret.Get(0).(func(string) []database.ConstraintResult); ok { + if rf, ok := ret.Get(0).(func(database.TableDetail) []database.ConstraintResult); ok { r0 = rf(tableName) } else { if ret.Get(0) != nil { @@ -68,7 +68,7 @@ func (_m *Connector) GetConstraints(tableName string) ([]database.ConstraintResu } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { + if rf, ok := ret.Get(1).(func(database.TableDetail) error); ok { r1 = rf(tableName) } else { r1 = ret.Error(1) @@ -114,22 +114,22 @@ func (_m *Connector) GetSchemas() ([]string, error) { return r0, r1 } -// GetTables provides a mock function with given fields: schemaName -func (_m *Connector) GetTables(schemaName string) ([]string, error) { - ret := _m.Called(schemaName) +// GetTables provides a mock function with given fields: schemaNames +func (_m *Connector) GetTables(schemaNames []string) ([]database.TableDetail, error) { + ret := _m.Called(schemaNames) - var r0 []string - if rf, ok := ret.Get(0).(func(string) []string); ok { - r0 = rf(schemaName) + var r0 []database.TableDetail + if rf, ok := ret.Get(0).(func([]string) []database.TableDetail); ok { + r0 = rf(schemaNames) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) + r0 = ret.Get(0).([]database.TableDetail) } } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(schemaName) + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(schemaNames) } else { r1 = ret.Error(1) } diff --git a/mocks/MermerdConfig.go b/mocks/MermerdConfig.go index 541560b..8198b69 100644 --- a/mocks/MermerdConfig.go +++ b/mocks/MermerdConfig.go @@ -109,15 +109,17 @@ func (_m *MermerdConfig) OutputFileName() string { return r0 } -// Schema provides a mock function with given fields: -func (_m *MermerdConfig) Schema() string { +// Schemas provides a mock function with given fields: +func (_m *MermerdConfig) Schemas() []string { ret := _m.Called() - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { r0 = rf() } else { - r0 = ret.Get(0).(string) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } } return r0 @@ -167,6 +169,20 @@ func (_m *MermerdConfig) ShowEnumValues() bool { return r0 } +// UseAllSchemas provides a mock function with given fields: +func (_m *MermerdConfig) UseAllSchemas() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // UseAllTables provides a mock function with given fields: func (_m *MermerdConfig) UseAllTables() bool { ret := _m.Called() diff --git a/mocks/Questioner.go b/mocks/Questioner.go index 31a35a1..112ee2e 100644 --- a/mocks/Questioner.go +++ b/mocks/Questioner.go @@ -31,14 +31,16 @@ func (_m *Questioner) AskConnectionQuestion(suggestions []string) (string, error } // AskSchemaQuestion provides a mock function with given fields: schemas -func (_m *Questioner) AskSchemaQuestion(schemas []string) (string, error) { +func (_m *Questioner) AskSchemaQuestion(schemas []string) ([]string, error) { ret := _m.Called(schemas) - var r0 string - if rf, ok := ret.Get(0).(func([]string) string); ok { + var r0 []string + if rf, ok := ret.Get(0).(func([]string) []string); ok { r0 = rf(schemas) } else { - r0 = ret.Get(0).(string) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } } var r1 error diff --git a/readme.md b/readme.md index 5f20f3b..baa2909 100644 --- a/readme.md +++ b/readme.md @@ -80,6 +80,7 @@ via `mermerd -h` --selectedTables strings tables to include --showAllConstraints show all constraints, even though the table of the resulting constraint was not selected --showEnumValues show enum values in description column + --useAllSchemas use all available schemas --useAllTables use all available tables ``` @@ -117,7 +118,13 @@ shown below) and start mermerd via `mermerd --runConfig yourRunConfig.yaml` ```yaml # Connection properties connectionString: "postgresql://user:password@localhost:5432/yourDb" -schema: "public" + +# Define what schemas should be used +useAllSchemas: true +# or +schema: + - "public" + - "other_db" # Define what tables should be used useAllTables: true diff --git a/test/docker-compose.yaml b/test/docker-compose.yaml index 9540efe..006db24 100644 --- a/test/docker-compose.yaml +++ b/test/docker-compose.yaml @@ -11,6 +11,7 @@ services: volumes: - ./db-table-setup.sql:/docker-entrypoint-initdb.d/1.sql - ./postgres/postgres-enum-setup.sql:/docker-entrypoint-initdb.d/2.sql + - ./postgres/postgres-multiple-databases.sql:/docker-entrypoint-initdb.d/3.sql mermerd-mysql-test-db: image: mysql:8.0 command: --default-authentication-plugin=mysql_native_password @@ -24,6 +25,7 @@ services: volumes: - ./db-table-setup.sql:/docker-entrypoint-initdb.d/1.sql - ./mysql/mysql-enum-setup.sql:/docker-entrypoint-initdb.d/2.sql + - ./mysql/mysql-multiple-databases.sql:/docker-entrypoint-initdb.d/3.sql mermerd-mssql-test-db: image: mcr.microsoft.com/mssql/server:2019-latest environment: @@ -35,6 +37,7 @@ services: - ./db-table-setup.sql:/usr/src/app/db-table-setup.sql - ./mssql/mssql-setup.sql:/usr/src/app/mssql-setup.sql - ./mssql/mssql-enum-setup.sql:/usr/src/app/mssql-enum-setup.sql + - ./mssql/mssql-multiple-databases.sql:/usr/src/app/mssql-multiple-databases.sql - ./mssql/entrypoint.sh:/usr/src/app/entrypoint.sh working_dir: /usr/src/app command: sh -c './entrypoint.sh & /opt/mssql/bin/sqlservr;' diff --git a/test/mssql/entrypoint.sh b/test/mssql/entrypoint.sh index ea3d11b..8376279 100755 --- a/test/mssql/entrypoint.sh +++ b/test/mssql/entrypoint.sh @@ -12,5 +12,6 @@ echo importing data... /opt/mssql-tools/bin/sqlcmd -S 0.0.0.0 -U sa -P $password -i ./mssql-setup.sql /opt/mssql-tools/bin/sqlcmd -S 0.0.0.0 -U sa -P $password -d mermerd_test -i ./db-table-setup.sql /opt/mssql-tools/bin/sqlcmd -S 0.0.0.0 -U sa -P $password -d mermerd_test -i ./mssql-enum-setup.sql +/opt/mssql-tools/bin/sqlcmd -S 0.0.0.0 -U sa -P $password -d mermerd_test -i ./mssql-multiple-databases.sql echo importing done diff --git a/test/mssql/mssql-multiple-databases.sql b/test/mssql/mssql-multiple-databases.sql new file mode 100644 index 0000000..16733a6 --- /dev/null +++ b/test/mssql/mssql-multiple-databases.sql @@ -0,0 +1,26 @@ +-- Test case for https://github.com/KarnerTh/mermerd/issues/23 +create table dbo.test_3_a +( + id int not null primary key, + title varchar(255) not null +); + +GO + +create schema other_db; + +GO + +create table other_db.test_3_b +( + id int not null primary key, + aid int, + foreign key (aid) references dbo.test_3_a (id) +); + +create table other_db.test_3_c +( + id int not null primary key, + title varchar(255) not null +); + diff --git a/test/mysql/mysql-multiple-databases.sql b/test/mysql/mysql-multiple-databases.sql new file mode 100644 index 0000000..10789e4 --- /dev/null +++ b/test/mysql/mysql-multiple-databases.sql @@ -0,0 +1,23 @@ +-- Test case for https://github.com/KarnerTh/mermerd/issues/23 +create table test_3_a +( + id int not null primary key, + title varchar(255) not null +); + +create database other_db; +use other_db; + +create table test_3_b +( + id int not null primary key, + aid int, + foreign key (aid) references mermerd_test.test_3_a (id) +); + +create table test_3_c +( + id int not null primary key, + title varchar(255) not null +); + diff --git a/test/postgres/postgres-multiple-databases.sql b/test/postgres/postgres-multiple-databases.sql new file mode 100644 index 0000000..c32879b --- /dev/null +++ b/test/postgres/postgres-multiple-databases.sql @@ -0,0 +1,22 @@ +-- Test case for https://github.com/KarnerTh/mermerd/issues/23 +create table test_3_a +( + id int not null primary key, + title varchar(255) not null +); + +create schema other_db; + +create table other_db.test_3_b +( + id int not null primary key, + aid int, + foreign key (aid) references public.test_3_a (id) +); + +create table other_db.test_3_c +( + id int not null primary key, + title varchar(255) not null +); + diff --git a/util/map_util.go b/util/map_util.go new file mode 100644 index 0000000..4e151cb --- /dev/null +++ b/util/map_util.go @@ -0,0 +1,10 @@ +package util + +func Map2[T, U any](data []T, f func(T) U) []U { + res := make([]U, 0, len(data)) + for _, e := range data { + res = append(res, f(e)) + } + + return res +}