diff --git a/tests/integration/test_iam.py b/tests/integration/test_iam.py index f8d7c3b1..3c4d6ebc 100644 --- a/tests/integration/test_iam.py +++ b/tests/integration/test_iam.py @@ -22,29 +22,3 @@ def test_scim_get_user_as_dict(w): user = w.users.get(first_user.id) # should not throw user.as_dict() - - -@pytest.mark.parametrize( - "path,call", - [("/api/2.0/preview/scim/v2/Users", lambda w: w.users.list(count=10)), - ("/api/2.0/preview/scim/v2/Groups", lambda w: w.groups.list(count=4)), - ("/api/2.0/preview/scim/v2/ServicePrincipals", lambda w: w.service_principals.list(count=1)), ]) -def test_workspace_users_list_pagination(w, path, call): - raw = w.api_client.do('GET', path) - total = raw['totalResults'] - all = call(w) - found = len(list(all)) - assert found == total - - -@pytest.mark.parametrize( - "path,call", - [("/api/2.0/accounts/%s/scim/v2/Users", lambda a: a.users.list(count=3000)), - ("/api/2.0/accounts/%s/scim/v2/Groups", lambda a: a.groups.list(count=5)), - ("/api/2.0/accounts/%s/scim/v2/ServicePrincipals", lambda a: a.service_principals.list(count=1000)), ]) -def test_account_users_list_pagination(a, path, call): - raw = a.api_client.do('GET', path.replace("%s", a.config.account_id)) - total = raw['totalResults'] - all = call(a) - found = len(list(all)) - assert found == total diff --git a/tests/test_iam.py b/tests/test_iam.py new file mode 100644 index 00000000..6b86c013 --- /dev/null +++ b/tests/test_iam.py @@ -0,0 +1,39 @@ +import pytest + +from databricks.sdk import AccountClient, WorkspaceClient + + +@pytest.mark.parametrize( + "path,call", + [("http://localhost/api/2.0/preview/scim/v2/Users", lambda w: w.users.list()), + ("http://localhost/api/2.0/preview/scim/v2/Groups", lambda w: w.groups.list()), + ("http://localhost/api/2.0/preview/scim/v2/ServicePrincipals", lambda w: w.service_principals.list()), ], +) +def test_workspace_iam_list(config, requests_mock, path, call): + requests_mock.get(path, request_headers={"Accept": "application/json", }, text="null", ) + w = WorkspaceClient(config=config) + for _ in call(w): + pass + assert requests_mock.call_count == 1 + assert requests_mock.called + + +@pytest.mark.parametrize("path,call", [ + ("http://localhost/api/2.0/accounts/%s/scim/v2/Users", lambda a: a.users.list()), + ("http://localhost/api/2.0/accounts/%s/scim/v2/Groups", lambda a: a.groups.list()), + ("http://localhost/api/2.0/accounts/%s/scim/v2/ServicePrincipals", lambda a: a.service_principals.list()), +], + ) +def test_account_iam_list(config, requests_mock, path, call): + config.account_id = "test_account_id" + requests_mock.get(path.replace("%s", config.account_id), + request_headers={ + "Accept": "application/json", + }, + text="null", + ) + a = AccountClient(config=config) + for _ in call(a): + pass + assert requests_mock.call_count == 1 + assert requests_mock.called